Как запустить CartPole со скоростью 1,25 миллиарда шагов в секунду

JAX — это относительно новая и захватывающая среда машинного обучения с открытым исходным кодом. Вот некоторые из замечательных особенностей:

  • Скомпилирован с использованием XLA, поэтому он может поддерживать процессоры, графические процессоры и TPU.
  • С помощью функции jit он может точно в срок компилировать несколько операций и оптимизировать график вычислений.
  • Автоматическая векторизация через vmap.
  • Автоматическое распараллеливание с помощью pmap.
  • Дифференциация высшего порядка
  • Обеспечивает функциональное программирование.

Все эти функции полезны, но первые 3 особенно полезны при написании сред RL. В настоящее время самым мощным ускорителем, доступным бесплатно через ноутбуки Colab или Kaggle, является TPU, но его часто сложно использовать, а JAX упрощает его. Ускорение отдельных операций помогает, но любая нетривиальная среда RL будет включать множество операций для каждого временного шага. Оптимизация всего графа с помощью компилятора XLA имеет большое значение. И я не думаю, что мог бы указать на что-то, что доставляло мне больше головной боли в глубоком RL, чем ручное управление пакетным измерением! Намного проще написать код для одного экземпляра и позволить JAX позаботиться о векторизации с помощью vmap. В качестве примера того, насколько эффективным может быть JAX при реализации среды RL, не смотрите дальше Бракс. Это симулятор твердого тела, написанный на JAX, способный ускорить классические среды Ant или Humanoid в 100–1000 раз по сравнению с эквивалентом Mujoco.

Для начала давайте рассмотрим, как выглядит простая среда JAX. В качестве поучительного примера можно попробовать конвертировать классическую среду Cartpole-v1 из openai gym. Наверное знакомый API выглядит так

env = gym.make("CartPole-v1")
obsv = env.reset()
obsv, reward, done, info = env.step(action)

После создания среды мы сбрасываем ее, чтобы получить начальное состояние, а затем применяем действие с функцией шага. Функция env.step имеет состояние, поскольку выходные данные зависят от внутреннего состояния среды. Функция env.reset также имеет состояние, поскольку она зависит от внутреннего состояния генератора случайных чисел (каждый раз при сбросе мы получаем другое состояние). Но мы уже говорили, что JAX позволяет использовать только функциональный код без каких-либо состояний или побочных эффектов 🤔

К счастью, в превосходной документации JAX есть решение, описанное в Вычисления с отслеживанием состояния в JAX. Примените это к каркасной среде с тренажерным залом, мы можем прийти к примерной реализации.

Среда предназначена для автоматического сброса, поскольку это упрощает развертывание в пакетных средах и не снижает удобство использования. Это эквивалентно использованию оберток gym.vector.

Функции

  • __init__ — это конструктор, используемый для установки значений в экземпляре среды, например self.random_limit. Это нормально, потому что мы никогда не меняли и не использовали as как константы, а не как состояния, поэтому вызовы среды по-прежнему работают.
  • _get_obsv — это частная функция, которая сопоставляет состояние с наблюдением, поскольку мы можем захотеть использовать наблюдения, отличные от необработанного состояния. Например, мы можем захотеть растрировать изображение окружающей среды и скрыть сжатое внутреннее состояние.
  • _reset — это закрытая функция, которая фактически генерирует новое случайное состояние. Это используется как функцией общедоступного сброса, так и функцией шага для условного сброса среды.
  • _maybe_reset используется для условного сброса среды. Это делается с помощью функции jax.lax.cond, которая позволяет компилировать ветвление с помощью jit (подробнее об этом позже).
jax.lax.cond(done, self._reset, lambda key: env_state, key)

эквивалентно

  if done:
    return self._reset(key)
  else:
    return env_state

Проблемы использования потока управления python в JAX хорошо документированы, но версия tl; dr заключается в том, что JAX использует трассировку и не знает, какой путь выбрать, когда используется поток управления python. Эта реализация кажется более сложной, но производительность того стоит.

После использования ключа jax.random.split функционально генерирует новый ключ, который мы возвращаем для последующих случайных операций.

  • reset — это общедоступная функция, которая принимает ключ и создает новое случайное состояние с помощью функции _reset. состояние и ключ объединяются в кортеж с именем env_state. env_state — это полное состояние среды. Когда я впервые начал этот проект, я думал, что состояния достаточно, но по мере продвижения я понял, что сам ключ также является частью состояния, поскольку, когда среда достигает конечного состояния, она сбрасывается и создается новый ключ. генерируется. Это также означает, что для данного начального env_state мы производим одно и то же развертывание, даже если среда перезагружается несколько раз. Возвращаются env_state и initial_observation, определенные с помощью функции _get_obsv.
  • step — это общедоступная функция, которая принимает env_state и действие. Он применяет преобразование для улучшения внутреннего состояния. Затем он вычисляет значения reward и done и использует функцию _maybe_reset для условного сброса env_state на основе done значение. Наконец, обновленные значения env_state, obsv, reward и done возвращаются с помощью obsv снова получается из функции _get_obsv.

Применение

Использование можно увидеть в Gist выше под определением класса. Давайте пройдемся по нему. Сначала мы создаем ключ с помощью функции jax.random.PRNGKey. Причина, по которой нам нужно сделать это явно, заключается в том, что для того, чтобы процесс генерации случайных чисел работал (в данном случае random.uniform), его вывод должен зависеть от ключа, т. е. один и тот же ключ = один и тот же случайный вывод.

Затем мы создаем среду, которая устанавливает константы.

Функция reset используется с ранее созданным ключом для определения начального env_state (содержащего как состояние, так и обновленный ключ). strong>) и получите initial_obsv.

Функция step теперь может быть вызвана с использованием env_state и произвольно выбранного действия (1) для улучшения среды.

Основное изменение, которое мы внесли в API тренажерного зала, заключается в том, что все функции в среде теперь требуют, чтобы состояние управлялось извне и передавалось или возвращалось, чтобы обеспечить сохранение функциональной парадигмы. Это может показаться обременительным, но у него есть то преимущество, что поведение функции детерминировано в зависимости от ее входных данных.

ДЖАКС Картпол

Теперь у нас есть каркасное окружение, но оно не очень интересное. Преобразовав среду Cartpole в JAX, используя среду скелета в качестве шаблона, мы можем прийти к следующему:

Разберем изменения.

  1. Мы заменяем поток управления python на эквивалентный набор операций, когда можем, поскольку поток управления python имеет только ограниченную поддержку.
force = self.force_mag if action == 1 else -self.force_mag
---
force = self.force_mag * (2 * action - 1)

2. Убираем побочные эффекты из функций reset и step, чтобы они были строго функциональны.

# Removed
assert self.action_space.contains(action), err_msg
log.warn

операторы печати являются побочными эффектами, поскольку они не возвращаются как часть функции.

3. Мы заменяем несовместимые математические операции эквивалентами jax.numpy.

costheta = math.cos(theta)
sintheta = math.sin(theta)
---
costheta = jnp.cos(theta)
sintheta = jnp.sin(theta)

Кажется, это необходимо, потому что JAX отслеживает ShapedArray, а математическиефункции применяют функцию float, которая требует конкретного значения.

4. Замените элемент управления python для вычисления done эквивалентным набором операций.

done = bool( x < -self.x_threshold 
             or x > self.x_threshold
             or theta < -self.theta_threshold_radians
             or theta > self.theta_threshold_radians)
---
done = ((x < -self.x_threshold) 
         | (x > self.x_threshold)
         | (theta > self.theta_threshold_radians)
         | (theta < -self.theta_threshold_radians))

5. Установите вознаграждение всегда равным 1, так как среда автоматически сбрасывается.

6. jit сброс и пошаговая функция с помощью @partial(jit, static_argnums=(0,)). Нам нужно указать static_argnums=(0,), потому что первый аргумент self считается статическим (постоянным). Код можно запустить без jit, но он очень медленный. В частности, jax.lax.cond работает медленно, поэтому без jit поток управления python работает быстрее. Это компилирует код в статически типизированный язык выражений под названием jaxpr, который затем компилируется в граф выполнения с помощью XLA.

Используя make_jaxpr, мы можем визуализировать, как выглядят скомпилированные операции. Ниже приведен пример функции env.step. Для контекста константы берутся из аргумента self.

y = mul p w
z = sub x y
ba = integer_pow[ y=2 ] p
bb = mul ba 0.10000000149011612
bc = div bb 1.100000023841858
bd = sub 1.3333333730697632 bc
be = mul bd 0.5
bf = div z be
bg = mul bf 0.05000000074505806

Эта реализация позволяет нам запускать среду для одной среды так же, как и скелетную среду.

Пакетная среда

Пропускная способность ограничена одним экземпляром среды, поэтому мы хотим запускать несколько экземпляров параллельно. Один из способов добиться этого — использовать многопроцессорность и запускать каждый новый экземпляр в новом процессе, однако это имеет тенденцию к плохому масштабированию, и гораздо эффективнее, если мы можем векторизовать реализацию. Векторизация кода может оказаться непростой задачей, потому что вам нужно учитывать размер пакета во всех выполняемых вами операциях, к счастью, JAX делает это намного проще с помощью vmap. С помощью этой функции мы можем автоматически векторизовать функции reset и step. Теперь мы можем пройтись по множеству экземпляров среды без изменения кода!

Вот как это выглядит. Сначала мы берем функции окружения и обертываем их vmap и jit. Функция vmapped выигрывает от джиттинга, потому что это позволяет xla оптимизировать операцию в пакетном измерении.

vstep = jit(jax.vmap(env.step, in_axes=((0, 0), 0), out_axes=((0, 0), 0, 0, 0, 0), axis_name="batch_axis"))
vreset = jit(jax.vmap(env.reset, out_axes=((0, 0), 0), axis_name="batch_axis"))

in_axes и out_axes определяют пакетное измерение каждого из массивов jnp в аргументах.

Теперь вместо того, чтобы передавать один ключ в reset, мы передаем столько ключей, сколько нам нужно для сред.

NUM_ENV = 10
seed = 0
key = jax.random.PRNGKey(seed)
keys = random.split(key, NUM_ENV)
env_state, obsv = vreset(keys)

Затем мы получаем векторизованное env_state из сброса, которое мы можем использовать вместе с векторизованным действием, которое аналогичным образом имеет векторизованные возвращаемые значения.

action = random.randint(keys[0], (NUM_ENV,), 0, 2)
new_env_state, obsv, reward, done, info = vstep(env_state, action)

Пакетная среда на нескольких устройствах

Полное использование TPU или нескольких GPU означает работу с несколькими устройствами. Нам также нужен способ векторизации по размеру устройства. JAX предоставляет эту функциональность через pmap. Как и в случае с vmap, мы можем просто обернуть наши jit-функции vmapped с помощью pmap. В этом случае нам не нужно снова применять jit, потому что pmap автоматически выполняет jit для функции.

pvreset = jax.pmap(vreset, out_axes=((0, 0), 0), axis_name="device_axis")
pvstep = jax.pmap(vstep, in_axes=((0, 0), 0), out_axes=((0, 0), 0, 0, 0, 0), axis_name="device_axis")

Здесь может показаться запутанным то, что аргументы для pmap выглядят так же, как и vmap, из-за чего может показаться, что мы используем одно и то же измерение как для устройства, так и для пакета. В этом случае это работает так: pmapped reset и step функции принимают jnp.array с размерами [device_axis, batch_axis, …] первое измерение в качестве оси устройства, которое удаляется pmap, поэтому vmap видит только 2- размерный jnp.array [batch_axis, …].

Скомпилированный выпуск

До этого момента мы компилировали шаг окружения и функции сброса. Это означает, что мы выполняем каждый шаг на процессоре или ускорителе, прежде чем вернуть управление ядру Python. Существует стоимость задержки, связанная с диспетчеризацией выполнения. Я заметил, что это очень мало для процессора, меньше миллисекунды для графического процессора, но порядка нескольких миллисекунд для TPU. Это вынуждает очень большие размеры пакетов обеспечивать хорошую пропускную способность ускорителей, особенно TPU.

Чтобы решить эту проблему, мы можем попытаться выполнить jit-компиляцию всего развертывания. Таким образом, ядро ​​Python просто запускает развертывание среды и возвращает данные для всего развертывания. Этот подход должен быть жизнеспособным для алгоритмов RL, которые имеют этап сбора данных и этап обучения, например. Проксимальная оптимизация политики (PPO) или Q-Learning.

В этом примере мы случайным образом выбираем действия вместо модели, но не должно быть проблем с использованием нейронной сети на основе Flax в той же функции.

См. пример реализации ниже.

Мы определяем функцию rollout, которая сбрасывает среду, выполняет фиксированное количество итераций и возвращает данные. Обратите внимание, что вместо того, чтобы использовать цикл python for, как обычно, мы используем эквивалент потока управления JAX jax.lax.fori. На самом деле здесь можно использовать поток управления python, но компиляция занимает очень много времени, поскольку необходимо отслеживать каждый шаг в цикле. Основной недостаток, который я вижу в использовании потока управления JAX, заключается в том, что его сложнее читать.

jax.lax.fori хорошо документирован, так что это должно помочь понять реализацию цикла. Другие вещи, которые стоит отметить здесь:

  1. Мы предварительно выделяем массивы obsv, вознаграждения и выполненных работ и на каждом этапе функционально создаем новый массив. Это звучит неэффективно, но думайте об этом как о предоставлении функциональных операций, которые вы хотите получить, а затем позволяющих компилятору XLA реализовать это производительным способом.
  2. Мы можем напрямую использовать pmap и vmap в качестве декораторов без аргументов, потому что по умолчанию они не предполагают статических аргументов и первое измерение в pmapped или vmapped соответственно, которое для аргумента keys является правильным.

Ориентиры

На приведенной ниже диаграмме показано количество шагов в секунду, достигнутое для ряда сред (NUM_ENV), начиная с 1.

Базовый уровень — Openai Gym Cartpole @ 94,2 тыс. шагов в секунду.

Здесь стоит подчеркнуть, что все тесты выполнялись в Colab, включая тесты ЦП (то есть только одно физическое ядро). Как упоминалось ранее, среды тренажерного зала можно запускать параллельно с использованием многопроцессорной обработки или аналогичной, но это работает только в том случае, если у вас есть циклы ЦП или ядра для запасной, чего, к сожалению, нет в Colab.

Версия графического процессора была k80, а версия TPU — V2.

Блокнот Colab, используемый для сбора результатов теста, доступен здесь.

Обсуждение

Для эквивалентной версии JAX с одной средой на ЦП мы видим 95 тыс. шагов в секунду, что в основном эквивалентно базовому уровню. На GPU и TPU это намного медленнее. Увеличение размера пакета до 100 насыщает производительность ЦП при 4,7 млн ​​шагов в секунду. GPU и TPU могут достигать гораздо более высокой пропускной способности (> 400 млн шагов в секунду) при очень большом NUM_ENV (> 1 млн).

Результаты для скомпилированного развертывания более интересны, когда развертывание ЦП достигает 1 млн шагов в секунду. Это в 10 раз быстрее, чем в среде Gym или векторизованной среде JAX при пошаговом ручном управлении. Графический процессор имеет только «незначительное» улучшение в 2 раза при том же размере пакета. Производительность TPU значительно улучшается (в 400 раз), что, вероятно, связано с тем, что задержка диспетчеризации теперь является второстепенным фактором, в то время как она была доминирующим фактором при переходе в среду. Это масштабируется до NUM_ENV=800k на 8 устройствах, прежде чем OOM достигает невероятных 1,25G шагов в секунду.

Заключение

Мы увидели, как вы можете написать среду RL в JAX и как ее можно ускорить как последовательно, скомпилировав развертывание, так и эффективно распараллеливая с помощью vmap и pmap. JAX может ускорить глубокие исследования в области RL, буквально ускоряя среду, чтобы сократить время итерации, и упрощая использование ускорителей, таких как TPU.

Желаем удачи в ваших экспериментах!