Недавно я играл с Free Monads в Scala и обнаружил, что это может быть идеальным способом выполнения градиентных вычислений в стиле функционального программирования. По-видимому, вычисление градиентов с использованием Free Monads - не лучшая идея с точки зрения производительности, но она может быть очень полезной в образовательных и экспериментальных целях. Начать создание простых нейронных сетей не составляет большого труда, если у вас есть способ вычисления градиентов для произвольного выражения в вашем распоряжении.

Бесплатная монада - идеальный способ построить любой вид абстрактного синтаксического дерева (AST), представляющего вычисление, и в то же время сохранить AST вычисления независимо от способа его интерпретации.

Моя цель - продемонстрировать, как можно построить простой движок вычисления градиента с использованием Free Monads. Во-первых, мы собираемся определить модель предметной области для представления AST. Затем может быть определена свободная монада, представляющая вычисление. Наконец, мы сможем вычислять градиенты аналитически и численно, используя разные интерпретаторы, и сравнивать результаты (должны быть одинаковыми). Кроме того, мы сможем определить простой оптимизатор градиентного спуска, способный решать простые уравнения, определенные в терминах вычислений Free Monad. Вот репо с продемонстрированным здесь кодом.

Представление вычислений AST

Нам нужен способ представить вычисление как AST. Мы можем представить его как граф, в котором ребра представляют собой тензоры, входящие и исходящие в / из вершин, представленных операциями. Есть два вида ребер: переменные и константы:

И несколько видов операций, представляющих вершины графа AST:

Определив ребра (тензоры) и вершины (операции) вычислительного графа, мы можем представить произвольное вычисление, построенное на основе набора предопределенных примитивных операций.

Я использую здесь термин тензор. Это просто математическая абстракция набора значений различной формы. Скаляр - это 0-мерный тензор. Вектор - это одномерный тензор. Матрица - это двумерный тензор. А все, что имеет более высокую размерность, называется просто n-мерным тензором. В этом примере я использую 0-мерные тензоры, также известные как скаляры.

Кроме того, удобно определить два дополнительных типа:

Бесплатная монада вычислений

Следующим шагом является определение бесплатной монады вычислений. Я использую здесь библиотеку Scala для кошек:

Я не собираюсь здесь углубляться в то, что такое бесплатные монады. Есть классные статьи здесь и здесь. По сути, просто имея Op [A], мы можем поднять его в монадический контекст, используя свободную монаду. В результате у нас есть способ объединить Op [A] в монадическом стиле. Эта монадическая композиция построена так, что она безопасна для стека и может быть интерпретирована отдельно от места, где она определена. Это, в свою очередь, означает, что мы можем применять несколько интерпретаторов к одному и тому же вычислительному выражению.

Теперь можно определить вычисление в терминах монады, свободной от вычислений, используя синтаксис Scala для понимания:

Приведенная выше функция принимает карту, где ключ - это имя переменной или константы, а значение - сама переменная / константа. Выражение принимает три переменные (x1, x2 и x3) и одну константу (c1).

Вот визуальное представление вычисления:

Числовой интерпретатор градиента

Числовой градиент здесь используется в смысле, описанном в этой статье в Википедии. Когда у нас есть вычислительные выражения, которые зависят от нескольких входных переменных, тогда могут быть вычислены частные производные для каждой входной переменной. Самый простой способ вычислить такую ​​частную производную - просто передать начальный набор значений в вычисление и получить выходной результат, а затем передать тот же начальный набор значений, но со значением, увеличенным на небольшую дельту для переменной, для которой мы вычисляем производную. . Рассчитав два выходных значения, мы можем вычесть их и разделить на дельту, использованную на предыдущем шаге. По определению, это будет частная производная.

Мы можем сразу же попробовать это для простого выражения:

Кажется, это работает нормально, но с точки зрения производительности это не очень хорошо. Если у нас есть огромные входные тензоры, мы должны выполнить интерпретацию дважды для каждого отдельного скаляра (элемента) в тензоре. Было бы намного лучше, если бы мы могли сначала вычислить производные аналитически, а затем вычислить производные за один проход, используя векторизованные тензорные операции (подробнее об этом здесь).

Аналитический интерпретатор градиента

Похоже, что мы можем применить цепное правило к вычислительному графу и вычислить частные производные для всех входных переменных за один прогон интерпретатора (это известно как обратное распространение в приложениях для нейронных сетей). Вот что делает интерпретатор:

Как видите, он имеет более сложную реализацию, но он способен вычислять все частные производные за один прогон интерпретации. Давайте попробуем и посмотрим, соответствует ли он тому, что было вычислено с помощью 6 различных выполнений числового интерпретатора:

Как видите, не только результат совпадает с результатом, полученным с помощью числового интерпретатора градиента, но и точность намного выше, и она была рассчитана за один запуск интерпретатора.

Оптимизатор спуска градиента

Имея в своем распоряжении интерпретаторы вычисления градиента, мы можем легко создать оптимизатор, который использует спуск градиента для минимизации значения функции. Это очень полезно при решении уравнений или обучении модели машинного обучения для данной функции стоимости.

Давайте немедленно попробуем это для того же выражения и посмотрим, работает ли оно:

Как видите, оптимизатор нашел такой набор входных переменных, поэтому значение выражения почти равно нулю, что является минимально возможным значением для неотрицательного по определению выражения.

Заключение

Я надеюсь, что вам также покажется забавным поиграть с частными производными и градиентным спуском, выраженным функционально с использованием свободных монад. Раньше я реализовывал тот же подход с использованием Python, но реализация была более громоздкой.