Введение
Вероятностные модели диффузионного шумоподавления (DDPM) — это глубокие генеративные модели, которым в последнее время уделяется много внимания благодаря их впечатляющим характеристикам. Совершенно новые модели, такие как генераторы OpenAI DALL-E 2 и Google Imagen, основаны на DDPM. Они обуславливают генератор текстом таким образом, что становится возможным генерировать фотореалистичные изображения на основе произвольной строки текста.
Например, введите «Фотография собаки сиба-ину с рюкзаком на велосипеде. На нем солнцезащитные очки и пляжная шляпа» для новой модели Imagen и «голова корги в виде взрыва туманности» для Модель DALL-E 2 выдает следующие изображения:
Эти модели просто умопомрачительны, но понимание того, как они работают, требует понимания оригинальной работы Хо и др. др. «Вероятностные модели шумоподавления».
В этом коротком посте я сосредоточусь на создании с нуля (в PyTorch) простой версии DDPM. В частности, я буду повторно реализовывать «оригинальную статью Хо. И другие". Мы будем работать с классическими и нетребовательными к ресурсам наборами данных MNIST и Fashion-MNIST и попытаемся создать изображения из воздуха. Начнем с небольшой теории.
Вероятностные модели диффузии шумоподавления
Вероятностные модели диффузии шумоподавления (DDPM) впервые появились в этой статье.
Идея довольно проста: к набору данных изображений мы добавляем немного шума шаг за шагом. С каждым шагом изображение становится все менее и менее четким, пока не останется только шум. Это называется «прямой процесс». Затем мы изучаем модель машинного обучения, которая может отменить каждый из таких шагов, и называем это «обратным процессом». Если мы сможем успешно изучить обратный процесс, у нас будет модель, которая может генерировать изображения из чистого случайного шума.
Шаг в прямом процессе состоит в том, чтобы сделать входное изображение более зашумленным (x на шаге t) путем выборки из многомерного гауссовского распределения, среднее значение которого является уменьшенной версией предыдущего изображения (x на шаге t-1), и какая ковариационная матрица является диагональным и фиксированным. Другими словами, мы возмущаем каждый пиксель изображения независимо, добавляя некоторое нормально распределенное значение.
Для каждого шага существует свой бета-коэффициент, который говорит, насколько сильно мы искажаем изображение на этом шаге. Чем выше бета, тем больше шума добавляется к изображению. Мы можем выбирать бета-коэффициенты, но мы должны стараться, чтобы не было шагов, на которых сразу добавляется слишком много шума, а общий процесс продвижения вперед должен быть «плавным». В оригинальной работе Ho et. др., бета расположены в линейном пространстве от 0,0001 до 0,02.
Замечательным свойством гауссовского распределения является то, что мы можем делать выборку из него, добавляя к среднему вектору нормально распределенный вектор шума, масштабированный по стандартному отклонению. Это приводит к:
Теперь мы знаем, как получить следующую выборку в прямом процессе, просто масштабируя то, что у нас уже есть, и добавляя немного масштабированного шума. Если мы теперь учтем, что формула рекурсивна, мы можем написать:
Если мы продолжим делать это и сделаем некоторые упрощения, мы можем вернуться назад и получить формулу для получения зашумленной выборки на шаге t, начиная с исходного незашумленного изображения x0:
Большой. Теперь независимо от того, сколько шагов будет иметь наш прямой процесс, у нас всегда будет способ напрямую получить зашумленное изображение на шаге t непосредственно из исходного изображения.
Мы знаем, что для обратного процесса наша модель также должна работать как распределение Гаусса, поэтому нам просто нужна модель для прогнозирования среднего значения распределения и стандартного отклонения с учетом зашумленного изображения и временного шага. На практике в этой первой статье о DDPM ковариационная матрица остается фиксированной, поэтому мы действительно хотим предсказать только среднее значение гауссова (учитывая зашумленное изображение и временной шаг, на котором мы находимся в настоящее время):
Теперь оказывается, что оптимальное среднее значение, которое нужно предсказать, — это просто функция условий, с которыми мы уже знакомы:
Итак, мы можем еще больше упростить нашу модель и просто предсказать эпсилон шума с помощью функции зашумленного изображения и временного шага.
И наша функция потерь будет просто масштабированной версией среднеквадратичной ошибки (MSE) между реальным шумом, который был добавлен, и шумом, предсказанным нашей моделью.
После обучения модели (алгоритм 1) мы можем использовать модель шумоподавления для выборки новых изображений (алгоритм 2).
Приступаем к кодированию
Теперь, когда у нас есть общее представление о том, как работают модели распространения, пришло время реализовать что-то свое. Вы можете самостоятельно запустить следующий код в этом Блокноте Google Colab или в этом репозитории GitHub.
Как обычно, импорт — это наш первый шаг.
# Import of libraries import random import imageio import numpy as np from argparse import ArgumentParser from tqdm.auto import tqdm import matplotlib.pyplot as plt import einops import torch import torch.nn as nn from torch.optim import Adam from torch.utils.data import DataLoader from torchvision.transforms import Compose, ToTensor, Lambda from torchvision.datasets.mnist import MNIST, FashionMNIST # Setting reproducibility SEED = 0 random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) # Definitions STORE_PATH_MNIST = f"ddpm_model_mnist.pt" STORE_PATH_FASHION = f"ddpm_model_fashion.pt"
Далее мы определяем несколько параметров для нашего эксперимента. В частности, мы решаем, хотим ли мы запустить цикл обучения, хотим ли мы использовать набор данных Fashion-MNIST и некоторые гиперпараметры обучения.
no_train = False fashion = True batch_size = 128 n_epochs = 20 lr = 0.001 store_path = "ddpm_fashion.pt" if fashion else "ddpm_mnist.pt"
Далее нам очень хотелось бы отображать изображения. Нас интересуют как обучающие изображения, так и сгенерированные моделью. Мы напишем служебную функцию, которая для некоторых изображений будет отображать квадратную (или как можно более близкую) сетку субфигурок:
def show_images(images, title=""): """Shows the provided images as sub-pictures in a square""" # Converting images to CPU numpy arrays if type(images) is torch.Tensor: images = images.detach().cpu().numpy() # Defining number of rows and columns fig = plt.figure(figsize=(8, 8)) rows = int(len(images) ** (1 / 2)) cols = round(len(images) / rows) # Populating figure with sub-plots idx = 0 for r in range(rows): for c in range(cols): fig.add_subplot(rows, cols, idx + 1) if idx < len(images): plt.imshow(images[idx][0], cmap="gray") idx += 1 fig.suptitle(title, fontsize=30) # Showing the figure plt.show()
Чтобы протестировать эту служебную функцию, мы загружаем наш набор данных и показываем первую партию. Важно! Изображения должны быть нормализованы в диапазоне [-1, 1], так как наша сеть должна будет предсказывать значения шума, которые распределяются нормально:
# Shows the first batch of images def show_first_batch(loader): for batch in loader: show_images(batch[0], "Images in the first batch") break # Loading the data (converting each image into a tensor and normalizing between [-1, 1]) transform = Compose([ ToTensor(), Lambda(lambda x: (x - 0.5) * 2)] ) ds_fn = FashionMNIST if fashion else MNIST dataset = ds_fn("./datasets", download=True, train=True, transform=transform) loader = DataLoader(dataset, batch_size, shuffle=True)
Большой! Теперь, когда у нас есть эта замечательная служебная функция, мы будем использовать ее позже и для изображений, сгенерированных нашей моделью. Прежде чем мы начнем фактически работать с моделью DDPM, мы просто получим устройство с графическим процессором от colab (обычно это Tesla T4 для пользователей, не являющихся пользователями colab-pro):
Модель ДДПМ
Теперь, когда мы убрали тривиальные вещи, пришло время поработать над DDPM. Мы создадим модуль PyTorch MyDDPM, который будет отвечать за хранение бета- и альфа-значений и применение процесса пересылки. Вместо этого для обратного процесса модуль MyDDPM будет просто полагаться на сеть, используемую для создания DDPM:
# DDPM class class MyDDPM(nn.Module): def __init__(self, network, n_steps=200, min_beta=10 ** -4, max_beta=0.02, device=None, image_chw=(1, 28, 28)): super(MyDDPM, self).__init__() self.n_steps = n_steps self.device = device self.image_chw = image_chw self.network = network.to(device) self.betas = torch.linspace(min_beta, max_beta, n_steps).to( device) # Number of steps is typically in the order of thousands self.alphas = 1 - self.betas self.alpha_bars = torch.tensor([torch.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))]).to(device) def forward(self, x0, t, eta=None): # Make input image more noisy (we can directly skip to the desired step) n, c, h, w = x0.shape a_bar = self.alpha_bars[t] if eta is None: eta = torch.randn(n, c, h, w).to(self.device) noisy = a_bar.sqrt().reshape(n, 1, 1, 1) * x0 + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta return noisy def backward(self, x, t): # Run each image through the network for each timestep t in the vector t. # The network returns its estimation of the noise that was added. return self.network(x, t)
Обратите внимание, что прямой процесс не зависит от сети, используемой для шумоподавления, поэтому технически мы уже можем визуализировать его эффект. В то же время мы также можем создать вспомогательную функцию, которая применяет алгоритм 2 (процедура выборки) для создания новых изображений. Мы делаем это с помощью двух специальных служебных функций DDPM:
def show_forward(ddpm, loader, device): # Showing the forward process for batch in loader: imgs = batch[0] show_images(imgs, "Original images") for percent in [0.25, 0.5, 0.75, 1]: show_images( ddpm(imgs.to(device), [int(percent * ddpm.n_steps) - 1 for _ in range(len(imgs))]), f"DDPM Noisy images {int(percent * 100)}%" ) break
Чтобы сгенерировать изображения, мы начинаем со случайного шума и возвращаем t от T обратно к 0. На каждом шаге мы оцениваем шум как eta_theta и применяем функцию шумоподавления. Наконец, добавляется дополнительный шум, как в динамике Ланжевена.
def generate_new_images(ddpm, n_samples=16, device=None, frames_per_gif=100, gif_name="sampling.gif", c=1, h=28, w=28): """Given a DDPM model, a number of samples to be generated and a device, returns some newly generated samples""" frame_idxs = np.linspace(0, ddpm.n_steps, frames_per_gif).astype(np.uint) frames = [] with torch.no_grad(): if device is None: device = ddpm.device # Starting from random noise x = torch.randn(n_samples, c, h, w).to(device) for idx, t in enumerate(list(range(ddpm.n_steps))[::-1]): # Estimating noise to be removed time_tensor = (torch.ones(n_samples, 1) * t).to(device).long() eta_theta = ddpm.backward(x, time_tensor) alpha_t = ddpm.alphas[t] alpha_t_bar = ddpm.alpha_bars[t] # Partially denoising the image x = (1 / alpha_t.sqrt()) * (x - (1 - alpha_t) / (1 - alpha_t_bar).sqrt() * eta_theta) if t > 0: z = torch.randn(n_samples, c, h, w).to(device) # Option 1: sigma_t squared = beta_t beta_t = ddpm.betas[t] sigma_t = beta_t.sqrt() # Option 2: sigma_t squared = beta_tilda_t # prev_alpha_t_bar = ddpm.alpha_bars[t-1] if t > 0 else ddpm.alphas[0] # beta_tilda_t = ((1 - prev_alpha_t_bar)/(1 - alpha_t_bar)) * beta_t # sigma_t = beta_tilda_t.sqrt() # Adding some more noise like in Langevin Dynamics fashion x = x + sigma_t * z # Adding frames to the GIF if idx in frame_idxs or t == 0: # Putting digits in range [0, 255] normalized = x.clone() for i in range(len(normalized)): normalized[i] -= torch.min(normalized[i]) normalized[i] *= 255 / torch.max(normalized[i]) # Reshaping batch (n, c, h, w) to be a (as much as it gets) square frame frame = einops.rearrange(normalized, "(b1 b2) c h w -> (b1 h) (b2 w) c", b1=int(n_samples ** 0.5)) frame = frame.cpu().numpy().astype(np.uint8) # Rendering frame frames.append(frame) # Storing the gif with imageio.get_writer(gif_name, mode="I") as writer: for idx, frame in enumerate(frames): writer.append_data(frame) if idx == len(frames) - 1: for _ in range(frames_per_gif // 3): writer.append_data(frames[-1]) return x
Все, что касается DDPM, сейчас на столе. Нам просто нужно определить модель, которая фактически будет выполнять работу по прогнозированию шума на изображении с учетом изображения и текущего временного шага. Для этого мы создадим пользовательскую модель U-Net. Само собой разумеется, что вы можете использовать любую другую модель по вашему выбору.
U-Net
Мы начинаем создание нашей U-Net с создания блока, который будет сохранять пространственную размерность неизменной. Этот блок будет использоваться на каждом уровне нашей U-Net.
class MyBlock(nn.Module): def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, padding=1, activation=None, normalize=True): super(MyBlock, self).__init__() self.ln = nn.LayerNorm(shape) self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding) self.conv2 = nn.Conv2d(out_c, out_c, kernel_size, stride, padding) self.activation = nn.SiLU() if activation is None else activation self.normalize = normalize def forward(self, x): out = self.ln(x) if self.normalize else x out = self.conv1(out) out = self.activation(out) out = self.conv2(out) out = self.activation(out) return out
Сложность DDPM заключается в том, что наша модель преобразования изображения должна быть обусловлена текущим временным шагом. Чтобы сделать это на практике, мы используем синусоидальное вложение и однослойные MLP. Результирующие тензоры будут добавлены по каналам ко входу сети через каждый уровень U-Net.
def sinusoidal_embedding(n, d): # Returns the standard positional embedding embedding = torch.zeros(n, d) wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)]) wk = wk.reshape((1, d)) t = torch.arange(n).reshape((n, 1)) embedding[:,::2] = torch.sin(t * wk[:,::2]) embedding[:,1::2] = torch.cos(t * wk[:,::2]) return embedding
Мы создаем небольшую служебную функцию, которая создает однослойный MLP, который будет использоваться для отображения позиционных вложений.
def _make_te(self, dim_in, dim_out): return nn.Sequential( nn.Linear(dim_in, dim_out), nn.SiLU(), nn.Linear(dim_out, dim_out) )
Теперь, когда мы знаем, как работать с информацией о времени, мы можем создать пользовательскую сеть U-Net. У нас будет 3 части с пониженной выборкой, узкое место в середине сети и 3 шага с повышенной выборкой с обычными остаточными соединениями U-Net (конкатенациями).
class MyUNet(nn.Module): def __init__(self, n_steps=1000, time_emb_dim=100): super(MyUNet, self).__init__() # Sinusoidal embedding self.time_embed = nn.Embedding(n_steps, time_emb_dim) self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim) self.time_embed.requires_grad_(False) # First half self.te1 = self._make_te(time_emb_dim, 1) self.b1 = nn.Sequential( MyBlock((1, 28, 28), 1, 10), MyBlock((10, 28, 28), 10, 10), MyBlock((10, 28, 28), 10, 10) ) self.down1 = nn.Conv2d(10, 10, 4, 2, 1) self.te2 = self._make_te(time_emb_dim, 10) self.b2 = nn.Sequential( MyBlock((10, 14, 14), 10, 20), MyBlock((20, 14, 14), 20, 20), MyBlock((20, 14, 14), 20, 20) ) self.down2 = nn.Conv2d(20, 20, 4, 2, 1) self.te3 = self._make_te(time_emb_dim, 20) self.b3 = nn.Sequential( MyBlock((20, 7, 7), 20, 40), MyBlock((40, 7, 7), 40, 40), MyBlock((40, 7, 7), 40, 40) ) self.down3 = nn.Sequential( nn.Conv2d(40, 40, 2, 1), nn.SiLU(), nn.Conv2d(40, 40, 4, 2, 1) ) # Bottleneck self.te_mid = self._make_te(time_emb_dim, 40) self.b_mid = nn.Sequential( MyBlock((40, 3, 3), 40, 20), MyBlock((20, 3, 3), 20, 20), MyBlock((20, 3, 3), 20, 40) ) # Second half self.up1 = nn.Sequential( nn.ConvTranspose2d(40, 40, 4, 2, 1), nn.SiLU(), nn.ConvTranspose2d(40, 40, 2, 1) ) self.te4 = self._make_te(time_emb_dim, 80) self.b4 = nn.Sequential( MyBlock((80, 7, 7), 80, 40), MyBlock((40, 7, 7), 40, 20), MyBlock((20, 7, 7), 20, 20) ) self.up2 = nn.ConvTranspose2d(20, 20, 4, 2, 1) self.te5 = self._make_te(time_emb_dim, 40) self.b5 = nn.Sequential( MyBlock((40, 14, 14), 40, 20), MyBlock((20, 14, 14), 20, 10), MyBlock((10, 14, 14), 10, 10) ) self.up3 = nn.ConvTranspose2d(10, 10, 4, 2, 1) self.te_out = self._make_te(time_emb_dim, 20) self.b_out = nn.Sequential( MyBlock((20, 28, 28), 20, 10), MyBlock((10, 28, 28), 10, 10), MyBlock((10, 28, 28), 10, 10, normalize=False) ) self.conv_out = nn.Conv2d(10, 1, 3, 1, 1) def forward(self, x, t): # x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension) t = self.time_embed(t) n = len(x) out1 = self.b1(x + self.te1(t).reshape(n, -1, 1, 1)) # (N, 10, 28, 28) out2 = self.b2(self.down1(out1) + self.te2(t).reshape(n, -1, 1, 1)) # (N, 20, 14, 14) out3 = self.b3(self.down2(out2) + self.te3(t).reshape(n, -1, 1, 1)) # (N, 40, 7, 7) out_mid = self.b_mid(self.down3(out3) + self.te_mid(t).reshape(n, -1, 1, 1)) # (N, 40, 3, 3) out4 = torch.cat((out3, self.up1(out_mid)), dim=1) # (N, 80, 7, 7) out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1, 1)) # (N, 20, 7, 7) out5 = torch.cat((out2, self.up2(out4)), dim=1) # (N, 40, 14, 14) out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1, 1)) # (N, 10, 14, 14) out = torch.cat((out1, self.up3(out5)), dim=1) # (N, 20, 28, 28) out = self.b_out(out + self.te_out(t).reshape(n, -1, 1, 1)) # (N, 1, 28, 28) out = self.conv_out(out) return out def _make_te(self, dim_in, dim_out): return nn.Sequential( nn.Linear(dim_in, dim_out), nn.SiLU(), nn.Linear(dim_out, dim_out) )
Теперь, когда мы определили нашу сеть шумоподавления, мы можем приступить к созданию экземпляра модели DDPM и поиграть с некоторыми визуализациями.
Некоторые визуализации
Мы создаем модель DDPM, используя нашу пользовательскую сеть U-Net, следующим образом.
# Defining model n_steps, min_beta, max_beta = 1000, 10 ** -4, 0.02 # Originally used by the authors ddpm = MyDDPM(MyUNet(n_steps), n_steps=n_steps, min_beta=min_beta, max_beta=max_beta, device=device)
Давайте проверим, как выглядит процесс пересылки:
# Optionally, show the diffusion (forward) process show_forward(ddpm, loader, device)
Мы еще не обучили модель, но уже можем использовать функцию, которая позволяет генерировать новые изображения и смотреть, что получится:
Неудивительно, что при этом ничего не происходит. Однако мы будем повторно использовать этот же метод позже, когда модель завершит обучение.
Тренировочный цикл
Теперь мы реализуем Алгоритм 1, чтобы изучить модель, которая будет знать, как очищать изображения от шума. Это соответствует нашему тренировочному циклу.
def training_loop(ddpm, loader, n_epochs, optim, device, display=False, store_path="ddpm_model.pt"): mse = nn.MSELoss() best_loss = float("inf") n_steps = ddpm.n_steps for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"): epoch_loss = 0.0 for step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500")): # Loading data x0 = batch[0].to(device) n = len(x0) # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars eta = torch.randn_like(x0).to(device) t = torch.randint(0, n_steps, (n,)).to(device) # Computing the noisy image based on x0 and the time-step (forward process) noisy_imgs = ddpm(x0, t, eta) # Getting model estimation of noise based on the images and the time-step eta_theta = ddpm.backward(noisy_imgs, t.reshape(n, -1)) # Optimizing the MSE between the noise plugged and the predicted noise loss = mse(eta_theta, eta) optim.zero_grad() loss.backward() optim.step() epoch_loss += loss.item() * len(x0) / len(loader.dataset) # Display images generated at this epoch if display: show_images(generate_new_images(ddpm, device=device), f"Images generated at epoch {epoch + 1}") log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}" # Storing the model if best_loss > epoch_loss: best_loss = epoch_loss torch.save(ddpm.state_dict(), store_path) log_string += " --> Best model ever (stored)" print(log_string)
Как видите, в нашем обучающем цикле мы просто отбираем несколько изображений и несколько случайных временных шагов для каждого из них. Затем мы делаем их зашумленными с помощью прямого процесса и запускаем обратный процесс на этих зашумленных изображениях. СКО между добавленным фактическим шумом и шумом, предсказанным моделью, оптимизируется.
По умолчанию я установил тренировочные эпохи на 20, так как каждая эпоха занимает 24 секунды (всего на обучение уходит примерно 8 минут). Обратите внимание, что можно получить еще лучшие характеристики с большим количеством эпох, лучшим U-Net и другими приемами. В этом посте я опускаю их для простоты.
Тестирование модели
Теперь, когда работа сделана, мы можем просто наслаждаться результатами. Мы загружаем лучшую модель, полученную во время обучения, по функции потерь MSE, переводим ее в режим оценки и используем для генерации новых выборок.
# Loading the trained model best_model = MyDDPM(MyUNet(), n_steps=n_steps, device=device) best_model.load_state_dict(torch.load(store_path, map_location=device)) best_model.eval() print("Model loaded") print("Generating new images") generated = generate_new_images( best_model, n_samples=100, device=device, gif_name="fashion.gif" if fashion else "mnist.gif" ) show_images(generated, "Final result")
Вишенкой на торте является тот факт, что наша функция генерации автоматически создает красивую картинку процесса распространения. Мы визуализируем этот gif в Colab с помощью следующей команды:
И мы закончили! Наконец-то наша модель DDPM заработала!
Дальнейшие улучшения
Были внесены дальнейшие улучшения, позволяющие генерировать изображения с более высоким разрешением, ускорять выборку или получать лучшее качество выборки и вероятность. Модели Imagen и DALL-E 2 основаны на улучшенных версиях оригинальных модулей DDPM.
Больше ссылок
Чтобы узнать больше о DDPM, я настоятельно рекомендую прочитать выдающийся пост Лилиан Венг и Нильса Рогге, а также удивительный Блог Hugging Face Кашифа Расула. Другие авторы также упоминаются в конце блокнота Colab.
Заключение
Диффузионные модели — это генеративные модели, которые учатся итеративно очищать изображения от шума. Затем, начиная с некоторого шума, можно попросить модель удалить шум сэмпла до тех пор, пока не будет получено какое-то реалистичное изображение.
Мы создали DDPM с нуля в PyTorch и научили его очищать изображения MNIST/Fashion-MNIST. Модель после обучения наконец смогла генерировать новые изображения из случайного шума. Довольно волшебно, правда?
Блокнот Colab с показанной реализацией находится в свободном доступе по этой ссылке, а репозиторий GitHub содержит файлы .py. Если вы нашли эту историю полезной, похлопайте ей 👏. Если вы чувствуете, что что-то неясно, не стесняйтесь обращаться ко мне напрямую! Я был бы рад обсудить это с вами.