NumPy на ускорителях, компиляция JIT, XLA, оптимизированные ядра и автоматическая дифференциация.

Здравствуйте, любители НЛП! Сегодня я наконец попробовал JAX, библиотеку, которая позволяет очень быстро выполнять операции линейной алгебры благодаря поддержке ускорителей, таких как GPU и TPU, которые являются обычными операциями для проектов машинного обучения. Наслаждаться! 😄

Что такое ДЖАКС

JAX — это библиотека Python, созданная Google для оптимизации научных вычислений:

  • Его можно рассматривать как альтернативу NumPy, предоставляя очень похожий интерфейс, который также работает на GPU и TPU. JAX предоставляет jax.numpy, который очень похож на NumPy API, что упрощает вход в библиотеку. Почти все, что можно сделать с помощью numpy, можно сделать и с помощью jax.numpy.
  • Он работает на ускорителях (например, GPU и TPU) благодаря JIT (Just-In-Time) компиляции кода Python и JAX с XLA (Accelerated Linear Algebra, компилятор) в оптимизированные ядра. Ядро — это процедура, скомпилированная для высокопроизводительных ускорителей (например, GPU и TPU), отдельная от основной программы, но используемая ею. Компиляции JIT можно запускать с помощью jax.jit() .
  • Он отлично поддерживает автоматическое дифференцирование, полезное для исследований в области машинного обучения. Автоматическое дифференцирование можно запустить с помощью jax.grad() .
  • JAX поощряет функциональное программирование, поскольку его функции являются чистыми. В отличие от массивов NumPy, массивы JAX всегда неизменяемы.
  • JAX поставляется с несколькими программными преобразованиями, полезными при написании числового кода, такими как jit.jax() для JIT-компиляции и ускорения кода, jit.grad() для получения производных и jit.vmap() для автоматической векторизации или пакетной обработки.
  • JAX имеет асинхронную отправку. Это означает, что вам нужно вызвать .block_until_ready(), чтобы убедиться, что вычисление действительно произошло. Асинхронная диспетчеризация полезна, поскольку позволяет коду Python бежать впереди устройства-ускорителя, не допуская попадания кода Python на критический путь.

Есть два способа, которыми JAX использует JIT-компиляцию:

  • Автоматически: JIT-компиляция по умолчанию происходит «под капотом» при выполнении библиотечных вызовов функций JAX.
  • Вручную: вы можете вручную запросить JIT-компиляцию ваших собственных функций Python, используя jax.jit().

Давайте посмотрим несколько примеров кода!

Кодирование с помощью JAX

Мы можем установить библиотеку с помощью pip.

Затем давайте также импортируем его с помощью NumPy, чтобы мы могли выполнить некоторые тесты.

Подобно тому, как мы обычно делаем import numpy as np, мы можем сделать import jax.numpy as jnp и заменить все np в нашем коде на jnp. Если код NumPy был написан в стиле функционального программирования, новый код JAX будет работать сразу после установки. Однако, если ускорители доступны, они будут использоваться для запуска кода JAX: приятная разница!

Случайные числа в JAX генерируются иначе, чем в NumPy. С JAX нам нужно создать файл jax.random.PRNGKey . Позже мы увидим, как его использовать.

Давайте проведем простой бенчмарк в Google Colab, чтобы у нас был легкий доступ к GPU и TPU. Начнем с инициализации случайной квадратной матрицы с 25 млн элементов и умножения ее на ее транспонирование. С NumPy, оптимизированным для ЦП, умножение матриц заняло в среднем 1,61 секунды.

Выполнение той же операции с JAX на ЦП заняло в среднем около 3,49 секунды.

JAX часто медленнее, чем NumPy, при работе на ЦП, поскольку NumPy очень оптимизирован для этого. Однако это меняется при использовании ускорителей, поэтому давайте попробуем матричное умножение с помощью графического процессора.

Как видно из примера, для добросовестного сравнения нам нужно измерять разные шаги с помощью JAX:

  • Время передачи устройства: время, затраченное на передачу матрицы на графический процессор. Это заняло 0,155 миллисекунды.
  • Время компиляции: время, прошедшее для JIT-компиляции. Это заняло 2,16 секунды.
  • Время выполнения: эффективное время выполнения кода. Это заняло 68,9 миллисекунды.

Общее затраченное время для умножения одной матрицы с помощью JAX на графическом процессоре составляет около 2,23 секунды, что больше, чем общее время NumPy, составляющее 1,61 секунды. Однако для каждого дополнительного матричного умножения JAX потребовалось бы всего 68,9 миллисекунды по сравнению с 1,61 секундой NumPy, что более чем в 22 раза быстрее! По этой причине имеет смысл использовать JAX, если вы можете многократно использовать ускорители для выполнения операций линейной алгебры.

Давайте также попробуем матричное умножение с TPU.

Игнорируя время передачи устройства и время компиляции, каждое матричное умножение занимает в среднем 16,5 миллисекунд: ускорение x4 по сравнению с графическими процессорами и ускорение x88 по сравнению с процессорами с NumPy. Обратите внимание, что вы не всегда будете получать одинаковое ускорение при перемножении матриц разных размеров: чем больше перемножаемые матрицы, тем больше ускорители могут оптимизировать операции и тем больше ускорение.

Чтобы воспроизвести приведенные выше тесты в Google Colab, вам нужно запустить следующий код, чтобы сообщить JAX о наличии доступного TPU.

Вернемся к теории и поговорим подробнее о компиляторе XLA.

XLA

XLA — это предметно-ориентированный компилятор для линейной алгебры, используемый JAX (и другими библиотеками или фреймворками, такими как TensorFlow), который создает настраиваемые оптимизированные ядра для максимально быстрого выполнения операций линейной алгебры в вашей программе. Лучшее, что вы можете сделать с XLA, — это позволить ему создать собственное ядро ​​для большей части вашей программы, которая использует операции линейной алгебры, чтобы она могла выполнять наибольшую оптимизацию. Позже мы увидим, как это сделать.

Наиболее важной оптимизацией XLA является слияние, то есть выполнение нескольких операций линейной алгебры в одном и том же ядре с сохранением промежуточных выходных данных в регистрах графического процессора без их материализации в памяти. Это может резко увеличить нашу «арифметическую интенсивность», то есть соотношение выполненной работы и количества загрузок и запасов, которые мы делаем. Слияние также может позволить нам полностью исключить операции, которые просто перетасовывают элементы в памяти (например, изменяют форму).

Давайте посмотрим, как вручную запустить JIT-компиляцию с помощью XLA с jax.jit.

Своевременная компиляция с помощью jax.jit

Давайте посмотрим на новые тесты, чтобы узнать, как использовать jax.jit и как он работает. Мы определяем две функции, которые реализуют SELU (масштабируемую экспоненциальную линейную единицу): одну с NumPy и одну с JAX. Давайте забудем о jax.jit на данный момент.

Затем мы запускаем его на векторе из 1 млн элементов, используя NumPy.

В среднем это занимает 7,6 миллисекунды. Давайте попробуем теперь с JAX на процессоре.

Теперь это занимает в среднем 4,8 миллисекунды, что в данном случае оказывается быстрее, чем NumPy. Следующий тест — с JAX на графическом процессоре.

Время работы функции составляет 1,21 миллисекунды, что даже быстрее. Теперь давайте проверим это с помощью jax.jit, то есть мы запустим JIT-компилятор для компиляции нашей функции SELU с использованием XLA в оптимизированное ядро ​​графического процессора, оптимизируя все операции внутри функции вместе.

При использовании нового ядра время работы функции составляет 0,13 миллисекунды!

Давайте подытожим достигнутое время работы:

  • NumPy на ЦП: 7,6 мс.
  • JAX на ЦП: 4,8 миллисекунды (ускорение x1,58 по сравнению с NumPy).
  • JAX на графическом процессоре без JIT: 1,21 миллисекунды (ускорение x6,28 по сравнению с NumPy).
  • JAX на графическом процессоре с JIT: 0,13 миллисекунды (ускорение x58,46 по сравнению с NumPy).

Использование JIT-компиляции дало нам действительно большое ускорение, основная причина которого в том, что мы избегаем перемещения данных из регистров GPU. В общем, перемещение данных между разными типами памяти происходит очень медленно по сравнению с выполнением кода, поэтому при оптимизации его следует избегать!

Опять же, обратите внимание, что вы можете получить разное ускорение при применении функции SELU к векторам разных размеров. Чем больше вектор, тем больше ускорители могут оптимизировать операции и тем больше ускорение.

Вместо выполнения selu_jax_jit = jit(selu_jax) функции могут быть JIT-компилированы также с использованием декоратора @jit следующим образом.

Если JIT-компиляция дает нам такое ускорение, почему бы нам не делать это всегда? Проблема в том, что не весь код может быть скомпилирован JIT, так как он требует, чтобы формы массива были статическими и известными во время компиляции. Более того, jax.jit сам по себе вводит некоторые накладные расходы. Поэтому обычно это экономит время только в том случае, если скомпилированная функция сложна и вы будете запускать ее много раз. К счастью, это распространено в машинном обучении, где мы склонны компилировать большую сложную модель, а затем запускать ее в течение миллионов итераций.

Давайте также посмотрим на автоматическую дифференциацию.

Автоматическое дифференцирование с помощью jax.grad

Еще одно преобразование JAX — автоматическое дифференцирование с функцией jit.grad().

Благодаря обновленной версии Autograd JAX может автоматически различать собственный код Python и NumPy. Он отличается большим подмножеством функций Python, включая циклы, операторы if, рекурсию и замыкания.

Давайте посмотрим на пример кода с jit.grad(), где мы вычисляем производную пользовательской функции Python с некоторыми функциями JAX внутри.

JAX поддерживает дифференциацию как обратного, так и прямого режима, и они могут быть составлены произвольно в любом порядке.

Выводы и дальнейшие шаги

В этой статье мы узнали, что такое JAX, и увидели некоторые из его основных концепций: интерфейс NumPy, компиляцию JIT, XLA, оптимизированные ядра, преобразования программ, автоматическое дифференцирование и функциональное программирование. Помимо JAX, сообщество разработчиков открытого исходного кода создало более высокоуровневые библиотеки для машинного обучения, такие как Flax и Haiku.

Возможные следующие шаги:

  • Следуйте Учебнику по JAX на его страницах документации.
  • Обучите модель машинного обучения с помощью Flax.
  • Обучите модель машинного обучения с помощью Haiku.

Спасибо за чтение! Если вы хотите узнать больше о НЛП, не забудьте подписаться на NLPlanet на Medium, LinkedIn, Twitter и присоединяйтесь к нашему новому Discord серверу!