def conv(in_f, out_f, kernel_size=3, stride=1, actn=True, pad=None, bn=True):
    if pad is None: pad = kernel_size//2
    layers = [nn.Conv2d(in_f, out_f, kernel_size, stride, pad, bias=not bn)]
    if actn: layers.append(nn.ReLU(inplace=True))
    if bn: layers.append(nn.BatchNorm2d(out_f))
    return nn.Sequential(*layers)

Создайте блок Conv-ReLU-BN, который дополняет входной тензор таким образом, чтобы выходной тензор блока имел ту же форму (высоту/ширину), что и входной тензор.

class ResSequentialCenter(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.m = nn.Sequential(*layers)
    def forward(self, x):
        return x[:, :, 2:-2, 2:-2] + self.m(x)

Модуль Residual Sequential Center принимает на вход массив слоев (блоки conv-relu-bn) и создает последовательный блок, используя слои.

В прямой функции модуль возвращает выходной тензор входного тензора (без заполнения) + вывод последовательного блока с входным тензором.

def res_block(num_f):
    return ResSequentialCenter([conv(num_f, num_f, pad=0), conv(num_f, num_f, pad=0)])

Остаточный блок создает модуль ResSequentialCenter с 2 блоками conv-relu-bn.

def upsample(in_f, out_f):
    return nn.Sequential(nn.Upsample(scale_factor=2), conv(in_f, out_f))

Блок Upsample создает модуль Sequential слоя Upsample с блоком conv-relu-bn.

class StyleResnet(nn.Module):
    def __init__(self):
        super().__init__()
        layers = [nn.ReflectionPad2d(40), nn.Conv2d(3, 32, 9), 
                  nn.Conv2d(32, 64, 3, stride=2), nn.Conv2d(64, 128, 3, stride=2)]
        for i in range(5): layers.append(res_block(128))
        layers += [upsample(128, 64), upsample(64, 32), conv(32, 3, 9, actn=False)]
        self.features = nn.Sequential(*layers)
    
    def forward(self, x): return self.features(x)

Модуль Style Resnet инициализирует последовательный блок, который начинается со слоя заполнения Reflection 40, 32-мерного слоя Conv с размером ядра 9 и 2 слоев Conv с размером ядра 3, оба с шагом 2 (уменьшает форму тензора на 2 каждый раз ).

Затем добавляются 5 128-дневных остаточных блоков, заканчивающихся 2 блоками повышающей дискретизации и конвекторным блоком размера ядра 9.

Функция Forward возвращает выходной тензор блока Sequential из входного тензора.

class SaveFeatures():
    features=None
    def __init__(self, m): self.features = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def close(self): self.hook.remove()

Объект SaveFeatures сохраняет выходные данные входного слоя для использования в функции потерь.

def ct_loss(input, targ): return F.mse_loss(input, targ)

Потеря содержимого возвращает среднеквадратичную ошибку ввода и цели.

def gram(x):
    b,c,h,w = x.size()
    x = x.view(b, c, -1)
    return torch.bmm(x, x.transpose(1, 2))/(c*h*w)*1e6

Функция Gram возвращает матрицу Gram входных данных (умножение матрицы тензора на себя с последующим делением количества элементов для нахождения среднего значения)

def gram_loss(input, targ):
    return F.mse_loss(gram(input), gram(targ[:input.size(0)]))

Потеря грамма возвращает среднеквадратичную ошибку потери матрицы Грама входного тензора и матрицы Грама цели.

class CombinedLoss(nn.Module):
    def __init__(self, m, layer_ids, style_im, ct_weight, style_weights):
        self.m,self.ct_weight,self.style_weights=m,ct_weight,style_weights
        self.sfs = SaveFeatures(self.m[i] for i in layer_ids)
        m(VV(style_im))
        self.style_feats = [V(sf.features.data.clone()) for sf in self.sfs]

Параметры комбинированных потерь: m (модель для определения потери стиля/контента), layer_ids (идентификаторы слоев стилей), style_im (целевое изображение для получения стиля), ct_weight (вес потери контента), style_weights (веса потерь каждого слоя)

Сохраните функции стиля изображения целевого стиля, которые будут использоваться при потере стиля.

def forward(self, input, targ):
        self.m(VV(targ))
        targ_feat = self.sfs[2].features.data.clone()
        self.m(input)
        inp_feats = [sf.features for sf in self.sfs]
        ct_loss = [ct_loss(inp_feats[2], V(targ_feat))*self.ct_weight]
        style_loss = [gram_loss(inp, sty)*weight for inp, sty, weight in zip(inp_feats, self.style_feats, self.style_weights)]
        loss = sum(ct_loss + style_loss)
        return loss

В прямой функции получите признаки из изображения содержимого (признаки из 2-го слоя модели для потерь) и получите признаки из входного тензора (признаки от каждого объекта SaveFeatures модели для потерь)

ct_loss = [ct_loss(inp_feats[2], V(targ_feat))*self.ct_weight]

Потеря контента рассчитывается с использованием функции ct_loss со вторым входным параметром и функциями изображения контента в качестве входных данных, умноженных на вес.

[gram_loss(inp, sty)*weight for inp, sty, weight in zip(inp_feats, self.style_feats, self.style_weights)]
        loss = sum(ct_loss + style_loss)

Потеря стиля рассчитывается путем суммирования потерь в граммах каждого входного элемента и элемента стиля, умноженных на вес:

РЕЗЮМЕ

Модель вынуждена создавать изображение, которое будет давать такие же функции контента, как и обучающее изображение, и функции стиля, что и целевое изображение стиля, в результате чего получается модель, которая может брать изображение из обучающего набора и создавать изображение с аналогичным контентом. тренировочное изображение, но со стилем изображения целевого стиля