WedX - журнал о программировании и компьютерных науках

Cython: Как ускорить рекурсивные функции?

Я реализую дерево сегментов в cython и сравниваю его с реализацией python.

Версия на cython кажется только в 1,5 раза быстрее, и я хочу сделать ее еще быстрее.

Обе реализации можно считать правильными.

Вот код cython:

# distutils: language = c++
from libcpp.vector cimport vector

cdef struct Result:
    int range_sum  
    int range_min 
    int range_max



cdef class SegmentTree:
    cdef vector[int] nums
    cdef vector[Result] tree 

    def __init__(self, vector[int] nums):
        self.nums = nums
        self.tree.resize(4 * len(nums)) #just a safe upper bound 
        self._build(1, 0, len(nums)-1)

    cdef Result _build(self, int index, int left, int right):
        cdef Result result

        if left == right:
            value = self.nums[left]
            result.range_max, result.range_min, result.range_sum = value, value, value 
            self.tree[index] = result
            return self.tree[index]
        else:
            mid = (left+right)//2
            left_range_result = self._build(index*2, left, mid)
            right_range_result = self._build(index*2+1, mid+1, right)
            self.tree[index] = self.combine_range_results(left_range_result, right_range_result)
            return self.tree[index]

    cdef Result range_query(self, int query_i, int query_j):
        return self._range_query(query_i, query_j, 0, len(self.nums)-1, 1)

    cdef Result _range_query(self, int query_i, int query_j, int current_i, int current_j, int index):
        if current_i == query_i and current_j == query_j:
            return self.tree[index]
        else:
            mid = (current_i + current_j)//2 
            if query_j <= mid:
                return self._range_query(query_i, query_j, current_i, mid, index*2)
            elif mid < query_i:
                return self._range_query(query_i, query_j, mid+1, current_j, index*2+1 )  
            else:
                left_range_result = self._range_query(query_i, mid, current_i, mid, index*2)
                right_range_result = self._range_query(mid+1, query_j, mid+1, current_j, index*2+1)
                return self.combine_range_results(left_range_result, right_range_result)


    cpdef int range_sum(self, int query_i, int query_j):
        return self.range_query(query_i, query_j).range_sum 
    cpdef int range_min(self, int query_i, int query_j):
        return self.range_query(query_i, query_j).range_min
    cpdef int range_max(self, int query_i, int query_j):
        return self.range_query(query_i, query_j).range_max

    cpdef void  update(self, int i, int new_value):
        self._update(i, new_value, 1, 0, len(self.nums)-1)

    cdef Result _update(self, int i, int new_value, int index, int left, int right):
        if left == right == i:
            self.tree[index] = [new_value, new_value, new_value]
            return self.tree[index]
        if left == right:
            return self.tree[index]
        mid = (left+right)//2 
        left_range_result = self._update(i, new_value, index*2, left, mid)
        right_range_result = self._update(i, new_value, index*2+1, mid+1, right)
        self.tree[index] = self.combine_range_results(left_range_result, right_range_result)
        return self.tree[index]

    cdef Result combine_range_results(self, Result r1, Result r2):
        cdef Result result;
        result.range_min = min(r1.range_min, r2.range_min)
        result.range_max = max(r1.range_max, r2.range_max)
        result.range_sum = r1.range_sum + r2.range_sum
        return result 
        

Вот версия для питона:




class PurePythonSegmentTree:
    def __init__(self, nums):
        self.nums = nums
        self.tree = [0] * (len(nums) * 4)
        self._build(1, 0, len(nums) - 1)

    def _build(self, index, left, right):
        if left == right:
            value = self.nums[left]
            self.tree[index] = (value, value, value)
            return self.tree[index]
        else:
            mid = (left + right) // 2
            left_range_result = self._build(index * 2, left, mid)
            right_range_result = self._build(index * 2 + 1, mid + 1, right)
            self.tree[index] = self._combine_range_results(
                left_range_result, right_range_result)
            return self.tree[index]

    def range_query(self, query_i, query_j):
        return self._range_query(query_i, query_j, 0, len(self.nums) - 1, 1)

    def _range_query(self, query_i, query_j, current_i, current_j, index):
        if current_i == query_i and current_j == query_j:
            return self.tree[index]
        else:
            mid = (current_i + current_j) // 2
            if query_j <= mid:
                return self._range_query(query_i, query_j, current_i, mid,
                                         index * 2)
            elif mid < query_i:
                return self._range_query(query_i, query_j, mid + 1, current_j,
                                         index * 2 + 1)
            else:
                left_range_result = self._range_query(query_i, mid, current_i,
                                                      mid, index * 2)
                right_range_result = self._range_query(mid + 1, query_j,
                                                       mid + 1, current_j,
                                                       index * 2 + 1)
                return self._combine_range_results(left_range_result,
                                                   right_range_result)

    def range_sum(self, query_i, query_j):
        return self.range_query(query_i, query_j)[0]

    def range_min(self, query_i, query_j):
        return self.range_query(query_i, query_j)[1]

    def range_max(self, query_i, query_j):
        return self.range_query(query_i, query_j)[2]

    def _combine_range_results(self, r1, r2):
        return (r1[0] + r2[0], min(r1[1], r2[1]), max(r1[2], r2[2]))


Код бенчмаркинга:

import pytest
from segment_tree import SegmentTree

def _test_all_ranges(nums, correct_fn, test_fn, threshold=float("inf")):
    count = 0
    for i in range(len(nums)):
        for j in range(i + 1, len(nums)):
            if count > threshold:
                break
            expected = correct_fn(nums[i:j + 1])
            actual = test_fn(i, j)
            assert actual == expected
            count += 1


def test_cython_tree_speed(benchmark):
    nums = [i for i in range(1000)]

    @benchmark
    def foo():
        s = SegmentTree(nums)
        _test_all_ranges(nums, max, s.range_max, 20)


def test_python_tree_speed(benchmark):
    nums = [i for i in range(1000)]

    @benchmark
    def foo():
        s = PurePythonSegmentTree(nums)
        _test_all_ranges(nums, max, s.range_max, 20)

Статистика:

-------------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------------
Name (time in us)                 Min                   Max                  Mean              StdDev                Median                IQR            Outliers         OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_cython_tree_speed       708.0450 (1.0)      1,534.6150 (1.0)        739.7052 (1.0)       59.9436 (1.0)        717.7565 (1.0)      21.0070 (1.0)       116;200  1,351.8900 (1.0)        1290           1
test_python_tree_speed     1,625.1940 (2.30)     2,676.9020 (1.74)     1,696.8420 (2.29)     135.9121 (2.27)     1,644.7810 (2.29)     79.6613 (3.79)        36;37    589.3300 (0.44)        391           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Как сделать cythonized версию быстрее?


  • (это может принадлежать codereview.stackexchange.com). Даже не глядя на ваш код, я знаю, что это часто проблема в python. Вы пытались заменить рекурсивные вызовы стеком? Это будет обрабатывать логику рекурсии без фактического вызова функций. Вот лучший существующий ответ, который я нашел в кратком поиске: stackoverflow .com/questions/13591970/ 20.06.2020
  • Я предполагаю, что вы пытались сгенерировать аннотированный HTML-код Cython, чтобы увидеть, что, по его мнению, может быть медленным. Интересно, являются ли рекурсивные функции чем-то вроде отвлекающего маневра и на самом деле не являются вашей проблемой. 20.06.2020
  • @KennyOstrom Я не уверен, что узким местом являются рекурсивные вызовы. Версия фибоначчи для cython на порядок быстрее, чем версия для python noreferrer">blog.nelsonliu.me/2016/04/29/gsoc-week-0-cython-vs-python Само собой разумеется, что это должно быть более или менее верно и для этого, учитывая, что это тяжелая рекурсия 20.06.2020
  • Я бы беспокоился, что ваши вызовы len для вектора заставляют его преобразовываться в список, а затем, например, len для этого. 20.06.2020
  • @DavidW, хорошее место. Я изменил его на nums.size() (постоянное время), но статистика осталась почти такой же. 20.06.2020
  • Чтобы действительно сиять, cython нуждается в вводе в виде массива memoryview/(numpy). Я не проверял, но предполагаю, что узким местом версии cython является преобразование list в std::vector. 20.06.2020
  • @ead Спасибо за предложение - есть ли простые способы проверить наличие узких мест? 20.06.2020
  • Кроме того, я только что изменил тестовый тест, чтобы он проверял только вызовы запросов, а не вызов конструктора. Результат показывает лишь очень небольшое улучшение 20.06.2020
  • Предложение Дэвида по сборке с аннотациями — хороший первый шаг. 20.06.2020
  • Вы пробовали использовать директивы компилятора? В вашем случае, когда вы выполняете разделение, я рекомендую добавить @cython.cdivision(False) перед определением вашего класса. См. более подробную информацию здесь in-cython 21.06.2020

Ответы:


1

При попытке оптимизировать код cython первым шагом является сборка с аннотациями (см., например, эту часть Cython-документация), т.е.

 cython -a xxx.pyx

или похожие. Он генерирует html, в котором можно увидеть, какие части кода используют функциональность Python.

В вашем случае видно, что mid = (current_i + current_j)//2 это проблема.

Он генерирует следующий C-код:

  /*else*/ {
    __pyx_t_3 = __Pyx_PyInt_From_long(__Pyx_div_long((__pyx_v_current_i + __pyx_v_current_j), 2)); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 42, __pyx_L1_error)
    __Pyx_GOTREF(__pyx_t_3);
    __pyx_v_mid = __pyx_t_3;
    __pyx_t_3 = 0;

т.е. mid является целым числом Python (из-за __Pyx_PyInt_From_long), и все операции с ним приведут к большему преобразованию в целое число Python и медленным операциям.

Сделайте mid cdef int. Изучите другие желтые строки (взаимодействие с Python) в аннотированном коде.

20.06.2020
  • Спасибо за указатель! Я исправил это и заметил, что функции range_sum, range_min и range_max окрашены в темно-желтый цвет, что указывает на интенсивное взаимодействие с Python. Есть ли способ исправить это? 20.06.2020
  • Сделайте их cpdef, которые объявляют, что они возвращают тип (int), то же самое для комбинированной функции. Cython выберет функцию cdef без преобразования результата в целое число Python, что является дорогостоящим. 20.06.2020
  • Извините, я искал не ту версию. Тогда не беспокойтесь: поскольку они частично определены, требуется взаимодействие с Python. Сначала сконцентрируйтесь на коде. Как выглядит бенчмарк? 20.06.2020
  • они почти одинаковы, cython примерно в 1,18 раза быстрее 20.06.2020
  • Получение некоторого улучшения! Вместо того, чтобы хранить входной список в виде вектора, я изменил код так, чтобы он просто сохранял свою длину. Для того же теста это примерно в 4,5 раза быстрее. 20.06.2020
  • @ nz_21 Для некоторых ваших расчетов вы записываете промежуточные переменные (например, left_range_result). Им также должны быть присвоены типы cdef int, иначе вы собираетесь создавать объекты python в своих вызовах функций. 28.06.2020
  • Новые материалы

    Объяснение документов 02: BERT
    BERT представил двухступенчатую структуру обучения: предварительное обучение и тонкая настройка. Во время предварительного обучения модель обучается на неразмеченных данных с помощью..

    Как проанализировать работу вашего классификатора?
    Не всегда просто знать, какие показатели использовать С развитием глубокого обучения все больше и больше людей учатся обучать свой первый классификатор. Но как только вы закончите..

    Работа с цепями Маркова, часть 4 (Машинное обучение)
    Нелинейные цепи Маркова с агрегатором и их приложения (arXiv) Автор : Бар Лайт Аннотация: Изучаются свойства подкласса случайных процессов, называемых дискретными нелинейными цепями Маркова..

    Crazy Laravel Livewire упростил мне создание электронной коммерции (панель администратора и API) [Часть 3]
    Как вы сегодня, ребята? В этой части мы создадим CRUD для данных о продукте. Думаю, в этой части я не буду слишком много делиться теорией, но чаще буду делиться своим кодом. Потому что..

    Использование машинного обучения и Python для классификации 1000 сезонов новичков MLB Hitter
    Чему может научиться машина, глядя на сезоны новичков 1000 игроков MLB? Это то, что исследует это приложение. В этом процессе мы будем использовать неконтролируемое обучение, чтобы..

    Учебные заметки: создание моего первого пакета Node.js
    Это мои обучающие заметки, когда я научился создавать свой самый первый пакет Node.js, распространяемый через npm. Оглавление Глоссарий I. Новый пакет 1.1 советы по инициализации..

    Забудьте о Matplotlib: улучшите визуализацию данных с помощью умопомрачительных функций Seaborn!
    Примечание. Эта запись в блоге предполагает базовое знакомство с Python и концепциями анализа данных. Привет, энтузиасты данных! Добро пожаловать в мой блог, где я расскажу о невероятных..


    Для любых предложений по сайту: [email protected]