🔥 Custom Datasets and Transforms
Custom Dataset
In order to use our custom dataset, we need to
inherit
torch.utils.data.Dataset, an abstract class representing a dataset.override
__len__so thatlen(dataset)returns the size of the dataset.__getitem__to support the indexing such thatdataset[i]can be used to get i-th sample.
The skeleton is as follows:
from torch.utils.data.dataset import Dataset
class MyCustomDataset(Dataset):
def __init__(self, ...):
# initial logic, e.g.
# read csv
# assign data transformation
# ...
def __getitem__(self, index):
"""Get the {index}-th sample"""
# Note: the return value can be customized depending on application
return (img, label)
def __len__(self):
return count # of how many examples(images?) you have
Example
Let’s take MNIST dataset as example. Assuming we have the csv file located in CSV_PATH. The structure of our csv file is
One instance/sample per line
- The first column is the digit label (0 - 9)
- The rest 784 columns represents the values of each pixel in the image of size 28x28 ($28 \times 28 = 784$)
- I.e. each sample consists of an image of digit and the label of the digit
There’re 5000 lines in total. I.e. 5000 samples
- We want to use the first 4000 samples for training and validation,
- and the rest 1000 samples for testing.
Let’s implement our custom MNIST dataset:
from torch.utils.data import Dataset
class MyMNIST(Dataset):
TRAIN, VALID, TEST = 0, 1, 2
def __init__(self, csv_file, usage=TRAIN, transform=None, label_transform=None):
"""
Args:
csv_file (string): Path to the csv file
usage (int): usage of the dataset (train/validation/test)
transform (callable, optional): Optional transform to be applied on the image.
label_transform (callable, optional): Optional transform to be applied on the label.
"""
self.transform = transform # image preprocessing
self.label_transform = label_transform # label preprocessing
# load from csv file
all_data = np.genfromtxt(csv_file, delimiter=',', dtype='uint8')
# 5000 lines in csv file --> 5000 instances
# training set: first 3000 lines
# validation set: 3000 - 4000
# test set: last 1000 lines
train, test = all_data[:4000], all_data[4000:]
train, val = train[:3000], train[3000:]
# choose lines based on specified usage
if usage == self.TRAIN:
self.images = train[:, 1:]
self.labels = train[:, 0] # first column is label of the digit
elif usage == self.VALID:
self.images = val[:, 1:]
self.labels = val[:, 0]
else:
self.images = test[:, 1:]
self.labels = test[:, 0]
def __getitem__(self, index):
image, label = self.images[index], self.labels[index]
if self.transform is not None:
image = self.transform(image)
if self.label_transform is not None:
label = self.label_transform(label)
# convert label to Tensor of dtype long
label = torch.as_tensor(label, dtype=torch.long)
return image, label
def __len__(self):
return len(self.labels)
Use our custom MNIST dataset:
from torchvision import transforms
# apply normalizaton and convertion to Tensor before using the dataset
preprocess_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1,), (0.4))])
# let's say we use the dataset for testing
my_mnist = MyMNIST(csv_file=CSV_PATH,
usage=MyMNIST.TEST,
transform=preprocess_transform)
Custom transform and augmentation
The example code above takes use of the transforms provided by torchvision.transforms. We can also implement custom transforms by ourselves.
To do this, we need to write them as callable classes:
- inherit
objectclass - implement
__init___if needed - define desired transformations in
__call__(self, image)method
Example
For example, let’s implement two custom transforms:
class MyNormalizer(object):
"""Normalize image"""
def __call__(self, image):
"""
Only works for our custom MNIST dataset: Devide the pixel values by 255
Generally, normalization should work as follows:
data_normalized = (data - data.mean) / data.std
"""
image = image * 1.0 / 255
return image
class MyToTensor(object):
"""Convert image to PyTorch Tensor"""
def __call__(self, image):
image = torch.from_numpy(image).float()
return image
Use custom transform in our custom MNIST dataset
preprocess_transform = transforms.Compose([MyToTensor(),
MyNormalizer()])
my_mnist = MyMNIST(csv_file=CSV_PATH,
usage=MyMNIST.TEST,
transform=preprocess_transform)