TL;DR: модели сходятся быстрее и с меньшими потерями при использовании пакетной нормализации.

Пакетная нормализация — это метод, используемый для ускорения и повышения стабильности искусственных нейронных сетей за счет нормализации входного слоя путем повторного центрирования и масштабирования. Он был предложен Сергеем Иоффе и Кристианом Сегеди в их статье 2015 года Пакетная нормализация: ускорение обучения глубокой сети за счет уменьшения внутреннего ковариатного сдвига.

3 примера ниже:

  • 1 минимальный пример использования чистого pytorch на случайных данных
  • Более подробный пример для разных размеров моделей, учитывающий потери как в поезде, так и в валу.
  • Пример на МНИСТ

Спойлер: Пакетная нормализация с каждым разом сходится все быстрее и обеспечивает лучшие потери.

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# Set seed for reproducibility
torch.manual_seed(0)

# Create dummy data for binary classification
input_dim = 10
n_samples = 1000
X = torch.randn(n_samples, input_dim)
y = torch.randint(0, 2, (n_samples,))

# Define the number of epochs and learning rate
epochs = 200
lr = 0.01

# Define a simple model without batch normalization
class ModelWithoutBN(nn.Module):
    def __init__(self):
        super(ModelWithoutBN, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_dim, 100),
            nn.ReLU(),
            nn.Linear(100, 1),
        )

    def forward(self, x):
        return self.layer(x)

# Define a simple model with batch normalization
class ModelWithBN(nn.Module):
    def __init__(self):
        super(ModelWithBN, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_dim, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(),
            nn.Linear(100, 1),
        )

    def forward(self, x):
        return self.layer(x)

# Initialize the models
model_without_bn = ModelWithoutBN()
model_with_bn = ModelWithBN()

# Define the loss function and the optimizers
criterion = nn.BCEWithLogitsLoss()
optimizer_without_bn = optim.SGD(model_without_bn.parameters(), lr=lr)
optimizer_with_bn = optim.SGD(model_with_bn.parameters(), lr=lr)

# Placeholders for losses
losses_without_bn = []
losses_with_bn = []

# Training loop
for epoch in range(epochs):
    for model, optimizer, losses in [(model_without_bn, optimizer_without_bn, losses_without_bn), 
                                     (model_with_bn, optimizer_with_bn, losses_with_bn)]:
        optimizer.zero_grad()
        outputs = model(X).view(-1)
        loss = criterion(outputs, y.float())
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

# Plot the losses
plt.figure(figsize=(10, 6))
plt.plot(losses_without_bn, label='Without Batch Normalization')
plt.plot(losses_with_bn, label='With Batch Normalization')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.legend()
plt.grid(True)
plt.show()

Давайте проведем более тщательный тест с различной глубиной и шириной модели, используя некоторые данные, которые имеют шаблон, рассматривая как валидацию, так и потери при поездке:

import torch
import torch.optim as optim
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import torch.nn as nn

# Set seed for reproducibility
torch.manual_seed(0)

# Define the characteristics of the data
input_dim = 10
n_samples = 1000

# Define the depths and widths to benchmark
depths = [2, 10]
widths = [100, 500]

# Define the number of epochs and learning rate
epochs = 200
lr = 0.01
# Define a function to create a model with a specified depth and width
def create_model(depth, width, batch_norm):
    layers = []
    for i in range(depth):
        if i == 0:
            layers.append(nn.Linear(input_dim, width))
        else:
            layers.append(nn.Linear(width, width))
        if batch_norm:
            layers.append(nn.BatchNorm1d(width))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(width, 1))
    return nn.Sequential(*layers)

# Define the loss function and the optimizers
criterion = nn.BCEWithLogitsLoss()


# Create dummy data for binary classification with a pattern
X_train = torch.randn(n_samples, input_dim)
y_train = (X_train.sum(dim=1) > 0).float()

# Create validation data with the same pattern
X_val = torch.randn(n_samples, input_dim)
y_val = (X_val.sum(dim=1) > 0).float()

# Initialize the models, optimizers, and loss placeholders
results = []
for depth in depths:
    for width in widths:
        for bn in [False, True]:
            key = f'depth={depth}, width={width}, BN={bn}'
            model = create_model(depth, width, bn)
            optimizer = optim.SGD(model.parameters(), lr=lr)
            loss_train = []
            loss_val = []
            results.append({
                'key': key,
                'model': model,
                'optimizer': optimizer,
                'loss_train': loss_train,
                'loss_val': loss_val
            })

# Training loop
for epoch in range(epochs):
    for result in results:
        optimizer = result['optimizer']
        model = result['model']
        optimizer.zero_grad()
        outputs_train = model(X_train).view(-1)
        loss_train = criterion(outputs_train, y_train.float())
        loss_train.backward()
        optimizer.step()
        result['loss_train'].append(loss_train.item())

        # Calculate validation loss
        with torch.no_grad():
            outputs_val = model(X_val).view(-1)
            loss_val = criterion(outputs_val, y_val.float())
            result['loss_val'].append(loss_val.item())

# Create subplots for training and validation losses
fig, axs = plt.subplots(2, 1, figsize=(14, 10))

# Line styles and labels for models with and without batch normalization
styles = ['-', '--']
bn_labels = ['Without BN', 'With BN']

# Define labels for depths and widths
depth_width_labels = [f'depth={depth}, width={width}' for depth in depths for width in widths]

# Colors for each depth and width combination
colors = ['blue', 'green', 'red', 'purple']

# Plot the training losses
for i, result in enumerate(results):
    linestyle = styles[int('BN=True' in result['key'])]
    color = colors[i // 2]
    axs[0].plot(result['loss_train'], linestyle=linestyle, color=color)

# Plot the validation losses
for i, result in enumerate(results):
    linestyle = styles[int('BN=True' in result['key'])]
    color = colors[i // 2]
    axs[1].plot(result['loss_val'], linestyle=linestyle, color=color)

custom_lines = [Line2D([0], [0], color="black", linestyle=styles[0]),
                Line2D([0], [0], color="black", linestyle=styles[1])] + \
               [Line2D([0], [0], color=colors[i], linestyle='-') for i in range(len(depth_width_labels))]

# Set labels and titles
for ax, title in zip(axs, ['Training Loss Over Time', 'Validation Loss Over Time']):
    ax.set_xlabel('Epochs')
    ax.set_ylabel('Loss')
    ax.set_title(title)
    ax.grid(True)

# Add the legend
axs[0].legend(custom_lines, bn_labels + depth_width_labels)

# Adjust the layout
plt.tight_layout()
plt.show()

Давайте сделаем это на реальном наборе данных, единственном и неповторимом MNIST (подмножество из 10 тыс. примеров для поезда, 1 тыс. для val). Этот код работает в Google Colab без GPU :)

import pytorch_lightning as pl
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.datasets import MNIST
from torchvision import transforms
import os
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import CSVLogger


# Set seed for reproducibility
pl.seed_everything(0)

# Prepare MNIST dataset
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
indices = torch.arange(0, 11000)
dataset = Subset(dataset, indices)
train, val = random_split(dataset, [10000, 1000])

train_loader = DataLoader(train, batch_size=64)
val_loader = DataLoader(val, batch_size=64)


class LitModel(pl.LightningModule):
    def __init__(self, batch_normalization=False):
        super().__init__()

        # Simple MLP
        if batch_normalization:
            self.layer = torch.nn.Sequential(
                torch.nn.Linear(28 * 28, 128),
                torch.nn.BatchNorm1d(128),
                torch.nn.ReLU(),
                torch.nn.Linear(128, 256),
                torch.nn.BatchNorm1d(256),
                torch.nn.ReLU(),
                torch.nn.Linear(256, 10),
            )
        else:
            self.layer = torch.nn.Sequential(
                torch.nn.Linear(28 * 28, 128),
                torch.nn.ReLU(),
                torch.nn.Linear(128, 256),
                torch.nn.ReLU(),
                torch.nn.Linear(256, 10),
            )

    def forward(self, x):
        # flatten image input
        x = x.view(x.size(0), -1)
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log(
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


def train_model(model, batch_normalization, name):
    # Train the model
    early_stop_callback = EarlyStopping(
        monitor="val_loss", patience=3, verbose=True, mode="min"
    )
    csv_logger = CSVLogger("logs_folder", name=name)
    trainer = pl.Trainer(
        max_epochs=200,
        accelerator="cpu",
        callbacks=[early_stop_callback],
        logger=csv_logger,
    )
    trainer.fit(model, train_loader, val_loader)


# Train model without batch normalization
model_without_bn = LitModel(batch_normalization=False)
train_model(model_without_bn, False, "without_batch_norm")

# Train model with batch normalization
model_with_bn = LitModel(batch_normalization=True)
train_model(model_with_bn, True, "with_batch_norm")

Визуализируйте потери с течением времени:

import pandas as pd
import matplotlib.pyplot as plt

# Load the metrics from the CSV file into a pandas DataFrame
df_without_bn = pd.read_csv('logs_folder/without_batch_norm/version_0/metrics.csv')
df_with_bn = pd.read_csv('logs_folder/with_batch_norm/version_0/metrics.csv')

# Forward fill NaN values in the 'val_loss' and 'train_loss_epoch' columns
df_without_bn[['val_loss', 'train_loss_epoch']] = df_without_bn[['val_loss', 'train_loss_epoch']].fillna(method='ffill')
df_with_bn[['val_loss', 'train_loss_epoch']] = df_with_bn[['val_loss', 'train_loss_epoch']].fillna(method='ffill')

# Plot the validation and training losses for each model
plt.figure(figsize=(14, 8))

# Plot training loss
plt.subplot(2, 1, 1)
plt.plot(df_without_bn['epoch'], df_without_bn['train_loss_epoch'], label='Without BN', color='blue')
plt.plot(df_with_bn['epoch'], df_with_bn['train_loss_epoch'], label='With BN', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Training Loss')
plt.title('Training Loss Over Time')
plt.legend()
plt.grid(True)

# Plot validation loss
plt.subplot(2, 1, 2)
plt.plot(df_without_bn['epoch'], df_without_bn['val_loss'], label='Without BN', color='blue')
plt.plot(df_with_bn['epoch'], df_with_bn['val_loss'], label='With BN', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Validation Loss')
plt.title('Validation Loss Over Time')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()