Параллелизм данных в JAX достигается с помощью функции pmap
.
TL;DR
1. params
и opt_state
должны быть реплицированы на всех устройствах:
replicated_params = jax.device_put_replicated(params, jax.devices())
2. data
и labels
нужно разделить на устройства:
n_devices = jax.device_count() batch_size, *data_shapes = data.shape assert batch_size % n_devices == 0, 'The data cannot be split evenly to the devices' data = data.reshape(n_devices, batch_size // n_devices, *data_shapes)
3. Украсьте целевую функцию jax.pmap
:
@partial(jax.pmap, axis_name='num_devices')
4. В функции loss
используйте jax.lax.pmean
для вычисления среднего значения по устройствам:
grads = jax.lax.pmean(grads, axis_name='num_devices') # calculate mean across devices
Объяснение
Напомним, что когда мы занимаемся глубоким обучением, нам нужно сделать следующее:
- потеря:
(params, data, labels) -> (grads)
- обновление:
(grads, opt_state, params) -> (updates, opt_state')
- применить_обновления:
(params, updates) -> (params')
Добавление pmean
к функции loss
означает, что мы вычисляем среднее значение all-reduce для grads
по устройствам:
Если вы хотите узнать больше о TPU и JAX, не забудьте ознакомиться с поддержанным сообществом подробным руководством по TPU: https://github.com/ayaka14732/tpu-starter.