Pytorch — одна из наиболее широко используемых библиотек для задач, связанных с машинным обучением или глубоким обучением. В любом приложении ML есть одна часть, которую нельзя игнорировать, как бы вы ни старались, и эта часть загружает данные. Загрузка пользовательского набора данных иногда становится слишком сложной, если вы не привыкли к различным функциям, которые предоставляет нам Pytorch.
Нам доступны два волшебных инструмента, которые облегчают всю задачу загрузки данных.
- Набор данных
- Загрузчик данных
Начнем с набора данных
torch.utils.data.Dataset — это основной класс, который нам нужно наследовать, если мы хотим загрузить пользовательский набор данных, который соответствует нашим требованиям. Множественные предварительно загруженные наборы данных гораздо проще загружать и использовать для обучения с использованием класса Dataset и Dataloader. Соответствующие руководства можно легко найти на Официальном сайте Pytorch (набор данных и загрузчик данных).
Здесь мы собираемся рассмотреть набор данных, который состоит из 2 частей.
- Изображений
2. Ярлык для изображений (в формате CSV)
Нам нужно написать собственный набор данных, прочитать CSV и соответствующие изображения и предоставить входные данные для класса Dataloader во время обучения/тестирования.
Наконец код. ура
Мы пишем класс CustomDataset, который наследует torch.utils.data.Dataset.
import torch import pandas as pd import skimage.io as sk from torchvision import transforms import matplotlib.pyplot as plt from PIL import ImageTk, Image
# Need to override __init__, __len__, __getitem__ # as per datasets requirement class CustomDataset(torch.utils.data.Dataset): def __init__(self, labelsFile, rootDir, sourceTransform): self.data = pd.read_csv(labelsFile) self.rootDir = rootDir self.sourceTransform = sourceTransform return def __len__(self): return len(self.data) def __getitem__(self, idx): print("getitem") if torch.is_tensor(idx): idx = idx.tolist() imagePath = self.rootDir + "/" + self.data['Image_path'][idx] image = sk.imread(imagePath) label = self.data['Condition'][idx] image = Image.fromarray(image) if self.sourceTransform: image = self.sourceTransform(image) return image, label
__init__
: этот перегруженный метод инициализирует параметры класса, читая CSV-файл в self. data
и устанавливая self.root
с корнем папки изображения. Также можно установить несколько transforms
, которые можно использовать в соответствии с требованиями.
__getitem__
: Это основной метод, который выполняет всю фактическую загрузку и преобразование изображения, прежде чем передать его загрузчику данных. Например, в приведенном выше коде мы загружаем изображение в image = sk.imread(imagePath)
и применяем к нему преобразования image = self.sourceTransform(image)
.
Теперь очередь за Dataloader
Как только набор данных готов, Dataloader не требует какой-либо специальной обработки для работы. Он работает из коробки.
dataloader = torch.utils.data.DataLoader(transformed_dataset, batch_size=4,
shuffle=True, num_workers=0)
for i_batch, image in enumerate(dataloader):
print(image[1])
batch_size
: количество изображений в одном пакете.
shuffle
: он будет выбирать случайным образом из набора данных
Существует множество других настроек, которые можно выполнить с помощью Dataloader. Он может быть найден здесь".
Здесь внутри for loop
добавляется единственный оператор печати, но в реальном сценарии будет много обучающего кода, который будет проходить внутри этого цикла for, который будет обучать большое количество сложных моделей.
Теперь, когда мы отсортировали загрузку данных, мы готовы больше узнать об обучении различных моделей в Pytorch.