Pytorch — одна из наиболее широко используемых библиотек для задач, связанных с машинным обучением или глубоким обучением. В любом приложении ML есть одна часть, которую нельзя игнорировать, как бы вы ни старались, и эта часть загружает данные. Загрузка пользовательского набора данных иногда становится слишком сложной, если вы не привыкли к различным функциям, которые предоставляет нам Pytorch.

Нам доступны два волшебных инструмента, которые облегчают всю задачу загрузки данных.

  1. Набор данных
  2. Загрузчик данных

Начнем с набора данных

torch.utils.data.Dataset — это основной класс, который нам нужно наследовать, если мы хотим загрузить пользовательский набор данных, который соответствует нашим требованиям. Множественные предварительно загруженные наборы данных гораздо проще загружать и использовать для обучения с использованием класса Dataset и Dataloader. Соответствующие руководства можно легко найти на Официальном сайте Pytorch (набор данных и загрузчик данных).

Здесь мы собираемся рассмотреть набор данных, который состоит из 2 частей.

  1. Изображений

2. Ярлык для изображений (в формате CSV)

Нам нужно написать собственный набор данных, прочитать CSV и соответствующие изображения и предоставить входные данные для класса Dataloader во время обучения/тестирования.

Наконец код. ура

Мы пишем класс CustomDataset, который наследует torch.utils.data.Dataset.

import torch
import pandas as pd
import skimage.io as sk
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import ImageTk, Image
# Need to override __init__, __len__, __getitem__
# as per datasets requirement
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, labelsFile, rootDir, sourceTransform):
        self.data = pd.read_csv(labelsFile)
        self.rootDir = rootDir
        self.sourceTransform = sourceTransform
        return

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        print("getitem")
        if torch.is_tensor(idx):
            idx = idx.tolist()
        imagePath = self.rootDir + "/" + self.data['Image_path'][idx]
        image = sk.imread(imagePath)
        label = self.data['Condition'][idx]
        image = Image.fromarray(image)

        if self.sourceTransform:
            image = self.sourceTransform(image)

        return image, label

__init__: этот перегруженный метод инициализирует параметры класса, читая CSV-файл в self. data и устанавливая self.root с корнем папки изображения. Также можно установить несколько transforms, которые можно использовать в соответствии с требованиями.

__getitem__: Это основной метод, который выполняет всю фактическую загрузку и преобразование изображения, прежде чем передать его загрузчику данных. Например, в приведенном выше коде мы загружаем изображение в image = sk.imread(imagePath) и применяем к нему преобразования image = self.sourceTransform(image).

Теперь очередь за Dataloader

Как только набор данных готов, Dataloader не требует какой-либо специальной обработки для работы. Он работает из коробки.

dataloader = torch.utils.data.DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=0)

for i_batch, image in enumerate(dataloader):
    print(image[1])

batch_size: количество изображений в одном пакете.

shuffle: он будет выбирать случайным образом из набора данных

Существует множество других настроек, которые можно выполнить с помощью Dataloader. Он может быть найден здесь".

Здесь внутри for loop добавляется единственный оператор печати, но в реальном сценарии будет много обучающего кода, который будет проходить внутри этого цикла for, который будет обучать большое количество сложных моделей.

Теперь, когда мы отсортировали загрузку данных, мы готовы больше узнать об обучении различных моделей в Pytorch.