Введение

Вероятностные модели диффузионного шумоподавления (DDPM) — это глубокие генеративные модели, которым в последнее время уделяется много внимания благодаря их впечатляющим характеристикам. Совершенно новые модели, такие как генераторы OpenAI DALL-E 2 и Google Imagen, основаны на DDPM. Они обуславливают генератор текстом таким образом, что становится возможным генерировать фотореалистичные изображения на основе произвольной строки текста.

Например, введите «Фотография собаки сиба-ину с рюкзаком на велосипеде. На нем солнцезащитные очки и пляжная шляпа» для новой модели Imagen и «голова корги в виде взрыва туманности» для Модель DALL-E 2 выдает следующие изображения:

Эти модели просто умопомрачительны, но понимание того, как они работают, требует понимания оригинальной работы Хо и др. др. «Вероятностные модели шумоподавления».

В этом коротком посте я сосредоточусь на создании с нуля (в PyTorch) простой версии DDPM. В частности, я буду повторно реализовывать «оригинальную статью Хо. И другие". Мы будем работать с классическими и нетребовательными к ресурсам наборами данных MNIST и Fashion-MNIST и попытаемся создать изображения из воздуха. Начнем с небольшой теории.

Вероятностные модели диффузии шумоподавления

Вероятностные модели диффузии шумоподавления (DDPM) впервые появились в этой статье.

Идея довольно проста: к набору данных изображений мы добавляем немного шума шаг за шагом. С каждым шагом изображение становится все менее и менее четким, пока не останется только шум. Это называется «прямой процесс». Затем мы изучаем модель машинного обучения, которая может отменить каждый из таких шагов, и называем это «обратным процессом». Если мы сможем успешно изучить обратный процесс, у нас будет модель, которая может генерировать изображения из чистого случайного шума.

Шаг в прямом процессе состоит в том, чтобы сделать входное изображение более зашумленным (x на шаге t) путем выборки из многомерного гауссовского распределения, среднее значение которого является уменьшенной версией предыдущего изображения (x на шаге t-1), и какая ковариационная матрица является диагональным и фиксированным. Другими словами, мы возмущаем каждый пиксель изображения независимо, добавляя некоторое нормально распределенное значение.

Для каждого шага существует свой бета-коэффициент, который говорит, насколько сильно мы искажаем изображение на этом шаге. Чем выше бета, тем больше шума добавляется к изображению. Мы можем выбирать бета-коэффициенты, но мы должны стараться, чтобы не было шагов, на которых сразу добавляется слишком много шума, а общий процесс продвижения вперед должен быть «плавным». В оригинальной работе Ho et. др., бета расположены в линейном пространстве от 0,0001 до 0,02.

Замечательным свойством гауссовского распределения является то, что мы можем делать выборку из него, добавляя к среднему вектору нормально распределенный вектор шума, масштабированный по стандартному отклонению. Это приводит к:

Теперь мы знаем, как получить следующую выборку в прямом процессе, просто масштабируя то, что у нас уже есть, и добавляя немного масштабированного шума. Если мы теперь учтем, что формула рекурсивна, мы можем написать:

Если мы продолжим делать это и сделаем некоторые упрощения, мы можем вернуться назад и получить формулу для получения зашумленной выборки на шаге t, начиная с исходного незашумленного изображения x0:

Большой. Теперь независимо от того, сколько шагов будет иметь наш прямой процесс, у нас всегда будет способ напрямую получить зашумленное изображение на шаге t непосредственно из исходного изображения.

Мы знаем, что для обратного процесса наша модель также должна работать как распределение Гаусса, поэтому нам просто нужна модель для прогнозирования среднего значения распределения и стандартного отклонения с учетом зашумленного изображения и временного шага. На практике в этой первой статье о DDPM ковариационная матрица остается фиксированной, поэтому мы действительно хотим предсказать только среднее значение гауссова (учитывая зашумленное изображение и временной шаг, на котором мы находимся в настоящее время):

Теперь оказывается, что оптимальное среднее значение, которое нужно предсказать, — это просто функция условий, с которыми мы уже знакомы:

Итак, мы можем еще больше упростить нашу модель и просто предсказать эпсилон шума с помощью функции зашумленного изображения и временного шага.

И наша функция потерь будет просто масштабированной версией среднеквадратичной ошибки (MSE) между реальным шумом, который был добавлен, и шумом, предсказанным нашей моделью.

После обучения модели (алгоритм 1) мы можем использовать модель шумоподавления для выборки новых изображений (алгоритм 2).

Приступаем к кодированию

Теперь, когда у нас есть общее представление о том, как работают модели распространения, пришло время реализовать что-то свое. Вы можете самостоятельно запустить следующий код в этом Блокноте Google Colab или в этом репозитории GitHub.

Как обычно, импорт — это наш первый шаг.

# Import of libraries
import random
import imageio
import numpy as np
from argparse import ArgumentParser

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import einops
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

from torchvision.transforms import Compose, ToTensor, Lambda
from torchvision.datasets.mnist import MNIST, FashionMNIST

# Setting reproducibility
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Definitions
STORE_PATH_MNIST = f"ddpm_model_mnist.pt"
STORE_PATH_FASHION = f"ddpm_model_fashion.pt"

Далее мы определяем несколько параметров для нашего эксперимента. В частности, мы решаем, хотим ли мы запустить цикл обучения, хотим ли мы использовать набор данных Fashion-MNIST и некоторые гиперпараметры обучения.

no_train = False
fashion = True
batch_size = 128
n_epochs = 20
lr = 0.001
store_path = "ddpm_fashion.pt" if fashion else "ddpm_mnist.pt"

Далее нам очень хотелось бы отображать изображения. Нас интересуют как обучающие изображения, так и сгенерированные моделью. Мы напишем служебную функцию, которая для некоторых изображений будет отображать квадратную (или как можно более близкую) сетку субфигурок:

def show_images(images, title=""):
    """Shows the provided images as sub-pictures in a square"""

    # Converting images to CPU numpy arrays
    if type(images) is torch.Tensor:
        images = images.detach().cpu().numpy()

    # Defining number of rows and columns
    fig = plt.figure(figsize=(8, 8))
    rows = int(len(images) ** (1 / 2))
    cols = round(len(images) / rows)

    # Populating figure with sub-plots
    idx = 0
    for r in range(rows):
        for c in range(cols):
            fig.add_subplot(rows, cols, idx + 1)

            if idx < len(images):
                plt.imshow(images[idx][0], cmap="gray")
                idx += 1
    fig.suptitle(title, fontsize=30)

    # Showing the figure
    plt.show()

Чтобы протестировать эту служебную функцию, мы загружаем наш набор данных и показываем первую партию. Важно! Изображения должны быть нормализованы в диапазоне [-1, 1], так как наша сеть должна будет предсказывать значения шума, которые распределяются нормально:

# Shows the first batch of images
def show_first_batch(loader):
    for batch in loader:
        show_images(batch[0], "Images in the first batch")
        break
# Loading the data (converting each image into a tensor and normalizing between [-1, 1])
transform = Compose([
    ToTensor(),
    Lambda(lambda x: (x - 0.5) * 2)]
)
ds_fn = FashionMNIST if fashion else MNIST
dataset = ds_fn("./datasets", download=True, train=True, transform=transform)
loader = DataLoader(dataset, batch_size, shuffle=True)

Большой! Теперь, когда у нас есть эта замечательная служебная функция, мы будем использовать ее позже и для изображений, сгенерированных нашей моделью. Прежде чем мы начнем фактически работать с моделью DDPM, мы просто получим устройство с графическим процессором от colab (обычно это Tesla T4 для пользователей, не являющихся пользователями colab-pro):

Модель ДДПМ

Теперь, когда мы убрали тривиальные вещи, пришло время поработать над DDPM. Мы создадим модуль PyTorch MyDDPM, который будет отвечать за хранение бета- и альфа-значений и применение процесса пересылки. Вместо этого для обратного процесса модуль MyDDPM будет просто полагаться на сеть, используемую для создания DDPM:

# DDPM class
class MyDDPM(nn.Module):
    def __init__(self, network, n_steps=200, min_beta=10 ** -4, max_beta=0.02, device=None, image_chw=(1, 28, 28)):
        super(MyDDPM, self).__init__()
        self.n_steps = n_steps
        self.device = device
        self.image_chw = image_chw
        self.network = network.to(device)
        self.betas = torch.linspace(min_beta, max_beta, n_steps).to(
            device)  # Number of steps is typically in the order of thousands
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.tensor([torch.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))]).to(device)

    def forward(self, x0, t, eta=None):
        # Make input image more noisy (we can directly skip to the desired step)
        n, c, h, w = x0.shape
        a_bar = self.alpha_bars[t]

        if eta is None:
            eta = torch.randn(n, c, h, w).to(self.device)

        noisy = a_bar.sqrt().reshape(n, 1, 1, 1) * x0 + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta
        return noisy

    def backward(self, x, t):
        # Run each image through the network for each timestep t in the vector t.
        # The network returns its estimation of the noise that was added.
        return self.network(x, t)

Обратите внимание, что прямой процесс не зависит от сети, используемой для шумоподавления, поэтому технически мы уже можем визуализировать его эффект. В то же время мы также можем создать вспомогательную функцию, которая применяет алгоритм 2 (процедура выборки) для создания новых изображений. Мы делаем это с помощью двух специальных служебных функций DDPM:

def show_forward(ddpm, loader, device):
    # Showing the forward process
    for batch in loader:
        imgs = batch[0]

        show_images(imgs, "Original images")

        for percent in [0.25, 0.5, 0.75, 1]:
            show_images(
                ddpm(imgs.to(device),
                     [int(percent * ddpm.n_steps) - 1 for _ in range(len(imgs))]),
                f"DDPM Noisy images {int(percent * 100)}%"
            )
        break

Чтобы сгенерировать изображения, мы начинаем со случайного шума и возвращаем t от T обратно к 0. На каждом шаге мы оцениваем шум как eta_theta и применяем функцию шумоподавления. Наконец, добавляется дополнительный шум, как в динамике Ланжевена.

def generate_new_images(ddpm, n_samples=16, device=None, frames_per_gif=100, gif_name="sampling.gif", c=1, h=28, w=28):
    """Given a DDPM model, a number of samples to be generated and a device, returns some newly generated samples"""
    frame_idxs = np.linspace(0, ddpm.n_steps, frames_per_gif).astype(np.uint)
    frames = []

    with torch.no_grad():
        if device is None:
            device = ddpm.device

        # Starting from random noise
        x = torch.randn(n_samples, c, h, w).to(device)

        for idx, t in enumerate(list(range(ddpm.n_steps))[::-1]):
            # Estimating noise to be removed
            time_tensor = (torch.ones(n_samples, 1) * t).to(device).long()
            eta_theta = ddpm.backward(x, time_tensor)

            alpha_t = ddpm.alphas[t]
            alpha_t_bar = ddpm.alpha_bars[t]

            # Partially denoising the image
            x = (1 / alpha_t.sqrt()) * (x - (1 - alpha_t) / (1 - alpha_t_bar).sqrt() * eta_theta)

            if t > 0:
                z = torch.randn(n_samples, c, h, w).to(device)

                # Option 1: sigma_t squared = beta_t
                beta_t = ddpm.betas[t]
                sigma_t = beta_t.sqrt()

                # Option 2: sigma_t squared = beta_tilda_t
                # prev_alpha_t_bar = ddpm.alpha_bars[t-1] if t > 0 else ddpm.alphas[0]
                # beta_tilda_t = ((1 - prev_alpha_t_bar)/(1 - alpha_t_bar)) * beta_t
                # sigma_t = beta_tilda_t.sqrt()

                # Adding some more noise like in Langevin Dynamics fashion
                x = x + sigma_t * z

            # Adding frames to the GIF
            if idx in frame_idxs or t == 0:
                # Putting digits in range [0, 255]
                normalized = x.clone()
                for i in range(len(normalized)):
                    normalized[i] -= torch.min(normalized[i])
                    normalized[i] *= 255 / torch.max(normalized[i])

                # Reshaping batch (n, c, h, w) to be a (as much as it gets) square frame
                frame = einops.rearrange(normalized, "(b1 b2) c h w -> (b1 h) (b2 w) c", b1=int(n_samples ** 0.5))
                frame = frame.cpu().numpy().astype(np.uint8)

                # Rendering frame
                frames.append(frame)

    # Storing the gif
    with imageio.get_writer(gif_name, mode="I") as writer:
        for idx, frame in enumerate(frames):
            writer.append_data(frame)
            if idx == len(frames) - 1:
                for _ in range(frames_per_gif // 3):
                    writer.append_data(frames[-1])
    return x

Все, что касается DDPM, сейчас на столе. Нам просто нужно определить модель, которая фактически будет выполнять работу по прогнозированию шума на изображении с учетом изображения и текущего временного шага. Для этого мы создадим пользовательскую модель U-Net. Само собой разумеется, что вы можете использовать любую другую модель по вашему выбору.

U-Net

Мы начинаем создание нашей U-Net с создания блока, который будет сохранять пространственную размерность неизменной. Этот блок будет использоваться на каждом уровне нашей U-Net.

class MyBlock(nn.Module):
    def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, padding=1, activation=None, normalize=True):
        super(MyBlock, self).__init__()
        self.ln = nn.LayerNorm(shape)
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size, stride, padding)
        self.activation = nn.SiLU() if activation is None else activation
        self.normalize = normalize

    def forward(self, x):
        out = self.ln(x) if self.normalize else x
        out = self.conv1(out)
        out = self.activation(out)
        out = self.conv2(out)
        out = self.activation(out)
        return out

Сложность DDPM заключается в том, что наша модель преобразования изображения должна быть обусловлена ​​текущим временным шагом. Чтобы сделать это на практике, мы используем синусоидальное вложение и однослойные MLP. Результирующие тензоры будут добавлены по каналам ко входу сети через каждый уровень U-Net.

def sinusoidal_embedding(n, d):
    # Returns the standard positional embedding
    embedding = torch.zeros(n, d)
    wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
    wk = wk.reshape((1, d))
    t = torch.arange(n).reshape((n, 1))
    embedding[:,::2] = torch.sin(t * wk[:,::2])
    embedding[:,1::2] = torch.cos(t * wk[:,::2])

    return embedding

Мы создаем небольшую служебную функцию, которая создает однослойный MLP, который будет использоваться для отображения позиционных вложений.

def _make_te(self, dim_in, dim_out):
  return nn.Sequential(
    nn.Linear(dim_in, dim_out),
    nn.SiLU(),
    nn.Linear(dim_out, dim_out)
  )

Теперь, когда мы знаем, как работать с информацией о времени, мы можем создать пользовательскую сеть U-Net. У нас будет 3 части с пониженной выборкой, узкое место в середине сети и 3 шага с повышенной выборкой с обычными остаточными соединениями U-Net (конкатенациями).

class MyUNet(nn.Module):
    def __init__(self, n_steps=1000, time_emb_dim=100):
        super(MyUNet, self).__init__()

        # Sinusoidal embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)

        # First half
        self.te1 = self._make_te(time_emb_dim, 1)
        self.b1 = nn.Sequential(
            MyBlock((1, 28, 28), 1, 10),
            MyBlock((10, 28, 28), 10, 10),
            MyBlock((10, 28, 28), 10, 10)
        )
        self.down1 = nn.Conv2d(10, 10, 4, 2, 1)

        self.te2 = self._make_te(time_emb_dim, 10)
        self.b2 = nn.Sequential(
            MyBlock((10, 14, 14), 10, 20),
            MyBlock((20, 14, 14), 20, 20),
            MyBlock((20, 14, 14), 20, 20)
        )
        self.down2 = nn.Conv2d(20, 20, 4, 2, 1)

        self.te3 = self._make_te(time_emb_dim, 20)
        self.b3 = nn.Sequential(
            MyBlock((20, 7, 7), 20, 40),
            MyBlock((40, 7, 7), 40, 40),
            MyBlock((40, 7, 7), 40, 40)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(40, 40, 2, 1),
            nn.SiLU(),
            nn.Conv2d(40, 40, 4, 2, 1)
        )

        # Bottleneck
        self.te_mid = self._make_te(time_emb_dim, 40)
        self.b_mid = nn.Sequential(
            MyBlock((40, 3, 3), 40, 20),
            MyBlock((20, 3, 3), 20, 20),
            MyBlock((20, 3, 3), 20, 40)
        )

        # Second half
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(40, 40, 4, 2, 1),
            nn.SiLU(),
            nn.ConvTranspose2d(40, 40, 2, 1)
        )

        self.te4 = self._make_te(time_emb_dim, 80)
        self.b4 = nn.Sequential(
            MyBlock((80, 7, 7), 80, 40),
            MyBlock((40, 7, 7), 40, 20),
            MyBlock((20, 7, 7), 20, 20)
        )

        self.up2 = nn.ConvTranspose2d(20, 20, 4, 2, 1)
        self.te5 = self._make_te(time_emb_dim, 40)
        self.b5 = nn.Sequential(
            MyBlock((40, 14, 14), 40, 20),
            MyBlock((20, 14, 14), 20, 10),
            MyBlock((10, 14, 14), 10, 10)
        )

        self.up3 = nn.ConvTranspose2d(10, 10, 4, 2, 1)
        self.te_out = self._make_te(time_emb_dim, 20)
        self.b_out = nn.Sequential(
            MyBlock((20, 28, 28), 20, 10),
            MyBlock((10, 28, 28), 10, 10),
            MyBlock((10, 28, 28), 10, 10, normalize=False)
        )

        self.conv_out = nn.Conv2d(10, 1, 3, 1, 1)

    def forward(self, x, t):
        # x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)
        t = self.time_embed(t)
        n = len(x)
        out1 = self.b1(x + self.te1(t).reshape(n, -1, 1, 1))  # (N, 10, 28, 28)
        out2 = self.b2(self.down1(out1) + self.te2(t).reshape(n, -1, 1, 1))  # (N, 20, 14, 14)
        out3 = self.b3(self.down2(out2) + self.te3(t).reshape(n, -1, 1, 1))  # (N, 40, 7, 7)

        out_mid = self.b_mid(self.down3(out3) + self.te_mid(t).reshape(n, -1, 1, 1))  # (N, 40, 3, 3)

        out4 = torch.cat((out3, self.up1(out_mid)), dim=1)  # (N, 80, 7, 7)
        out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1, 1))  # (N, 20, 7, 7)

        out5 = torch.cat((out2, self.up2(out4)), dim=1)  # (N, 40, 14, 14)
        out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1, 1))  # (N, 10, 14, 14)

        out = torch.cat((out1, self.up3(out5)), dim=1)  # (N, 20, 28, 28)
        out = self.b_out(out + self.te_out(t).reshape(n, -1, 1, 1))  # (N, 1, 28, 28)

        out = self.conv_out(out)

        return out

    def _make_te(self, dim_in, dim_out):
        return nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.SiLU(),
            nn.Linear(dim_out, dim_out)
        )

Теперь, когда мы определили нашу сеть шумоподавления, мы можем приступить к созданию экземпляра модели DDPM и поиграть с некоторыми визуализациями.

Некоторые визуализации

Мы создаем модель DDPM, используя нашу пользовательскую сеть U-Net, следующим образом.

# Defining model
n_steps, min_beta, max_beta = 1000, 10 ** -4, 0.02  # Originally used by the authors
ddpm = MyDDPM(MyUNet(n_steps), n_steps=n_steps, min_beta=min_beta, max_beta=max_beta, device=device)

Давайте проверим, как выглядит процесс пересылки:

# Optionally, show the diffusion (forward) process
show_forward(ddpm, loader, device)

Мы еще не обучили модель, но уже можем использовать функцию, которая позволяет генерировать новые изображения и смотреть, что получится:

Неудивительно, что при этом ничего не происходит. Однако мы будем повторно использовать этот же метод позже, когда модель завершит обучение.

Тренировочный цикл

Теперь мы реализуем Алгоритм 1, чтобы изучить модель, которая будет знать, как очищать изображения от шума. Это соответствует нашему тренировочному циклу.

def training_loop(ddpm, loader, n_epochs, optim, device, display=False, store_path="ddpm_model.pt"):
    mse = nn.MSELoss()
    best_loss = float("inf")
    n_steps = ddpm.n_steps

    for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"):
        epoch_loss = 0.0
        for step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500")):
            # Loading data
            x0 = batch[0].to(device)
            n = len(x0)

            # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
            eta = torch.randn_like(x0).to(device)
            t = torch.randint(0, n_steps, (n,)).to(device)

            # Computing the noisy image based on x0 and the time-step (forward process)
            noisy_imgs = ddpm(x0, t, eta)

            # Getting model estimation of noise based on the images and the time-step
            eta_theta = ddpm.backward(noisy_imgs, t.reshape(n, -1))

            # Optimizing the MSE between the noise plugged and the predicted noise
            loss = mse(eta_theta, eta)
            optim.zero_grad()
            loss.backward()
            optim.step()

            epoch_loss += loss.item() * len(x0) / len(loader.dataset)

        # Display images generated at this epoch
        if display:
            show_images(generate_new_images(ddpm, device=device), f"Images generated at epoch {epoch + 1}")

        log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}"

        # Storing the model
        if best_loss > epoch_loss:
            best_loss = epoch_loss
            torch.save(ddpm.state_dict(), store_path)
            log_string += " --> Best model ever (stored)"

        print(log_string)

Как видите, в нашем обучающем цикле мы просто отбираем несколько изображений и несколько случайных временных шагов для каждого из них. Затем мы делаем их зашумленными с помощью прямого процесса и запускаем обратный процесс на этих зашумленных изображениях. СКО между добавленным фактическим шумом и шумом, предсказанным моделью, оптимизируется.

По умолчанию я установил тренировочные эпохи на 20, так как каждая эпоха занимает 24 секунды (всего на обучение уходит примерно 8 минут). Обратите внимание, что можно получить еще лучшие характеристики с большим количеством эпох, лучшим U-Net и другими приемами. В этом посте я опускаю их для простоты.

Тестирование модели

Теперь, когда работа сделана, мы можем просто наслаждаться результатами. Мы загружаем лучшую модель, полученную во время обучения, по функции потерь MSE, переводим ее в режим оценки и используем для генерации новых выборок.

# Loading the trained model
best_model = MyDDPM(MyUNet(), n_steps=n_steps, device=device)
best_model.load_state_dict(torch.load(store_path, map_location=device))
best_model.eval()
print("Model loaded")
print("Generating new images")
generated = generate_new_images(
        best_model,
        n_samples=100,
        device=device,
        gif_name="fashion.gif" if fashion else "mnist.gif"
    )
show_images(generated, "Final result")

Вишенкой на торте является тот факт, что наша функция генерации автоматически создает красивую картинку процесса распространения. Мы визуализируем этот gif в Colab с помощью следующей команды:

И мы закончили! Наконец-то наша модель DDPM заработала!

Дальнейшие улучшения

Были внесены дальнейшие улучшения, позволяющие генерировать изображения с более высоким разрешением, ускорять выборку или получать лучшее качество выборки и вероятность. Модели Imagen и DALL-E 2 основаны на улучшенных версиях оригинальных модулей DDPM.

Больше ссылок

Чтобы узнать больше о DDPM, я настоятельно рекомендую прочитать выдающийся пост Лилиан Венг и Нильса Рогге, а также удивительный Блог Hugging Face Кашифа Расула. Другие авторы также упоминаются в конце блокнота Colab.

Заключение

Диффузионные модели — это генеративные модели, которые учатся итеративно очищать изображения от шума. Затем, начиная с некоторого шума, можно попросить модель удалить шум сэмпла до тех пор, пока не будет получено какое-то реалистичное изображение.

Мы создали DDPM с нуля в PyTorch и научили его очищать изображения MNIST/Fashion-MNIST. Модель после обучения наконец смогла генерировать новые изображения из случайного шума. Довольно волшебно, правда?

Блокнот Colab с показанной реализацией находится в свободном доступе по этой ссылке, а репозиторий GitHub содержит файлы .py. Если вы нашли эту историю полезной, похлопайте ей 👏. Если вы чувствуете, что что-то неясно, не стесняйтесь обращаться ко мне напрямую! Я был бы рад обсудить это с вами.