Параллелизм данных в JAX достигается с помощью функции pmap.

TL;DR

1. params и opt_state должны быть реплицированы на всех устройствах:

replicated_params = jax.device_put_replicated(params, jax.devices())

2. data и labels нужно разделить на устройства:

n_devices = jax.device_count()
batch_size, *data_shapes = data.shape
assert batch_size % n_devices == 0, 'The data cannot be split evenly to the devices'
data = data.reshape(n_devices, batch_size // n_devices, *data_shapes)

3. Украсьте целевую функцию jax.pmap:

@partial(jax.pmap, axis_name='num_devices')

4. В функции loss используйте jax.lax.pmean для вычисления среднего значения по устройствам:

grads = jax.lax.pmean(grads, axis_name='num_devices')  # calculate mean across devices

Объяснение

Напомним, что когда мы занимаемся глубоким обучением, нам нужно сделать следующее:

  1. потеря: (params, data, labels) -> (grads)
  2. обновление: (grads, opt_state, params) -> (updates, opt_state')
  3. применить_обновления: (params, updates) -> (params')

Добавление pmean к функции loss означает, что мы вычисляем среднее значение all-reduce для grads по устройствам:

Если вы хотите узнать больше о TPU и JAX, не забудьте ознакомиться с поддержанным сообществом подробным руководством по TPU: https://github.com/ayaka14732/tpu-starter.