unit 4.3 - Data loaders
Data loaders from PyTorch are a very useful tool to load data in batches and shuffle it. This is very useful when training neural networks.
Most of a data scientist time is used on data preparation. Data scientists can spend 90-95% of their time on data preparation, so it is important to have a good data loader that can handle the data in a way that is easy to use and efficient.
Here we create a custom dataset loader for a folder of images.
[12]:
# imports:
import torch
import torchvision
import torchvision.transforms as transforms
Transforms in PyTorch are very useful to preprocess data. Here we use transforms.Compose to chain together a series of transforms. We use transforms.Resize to resize the image to a fixed size, transforms.Normalize to set the value of mean and std for each input channel, and transforms.ToTensor to convert the image to a tensor. There are many other transforms available in PyTorch, such as transforms.CenterCrop to crop the
image to a fixed size, and more.
[15]:
transform = transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # mean and std for 3 channels - should be 0 and 1, but here we use this to visualize the images in full color
transforms.Normalize((0.0, 0.0, 0.0), (1.0, 1.0, 1.0))
])
batch_size = 3
trainset = torchvision.datasets.ImageFolder(root='./data/my_data/train',
transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
classes = ('cat1', 'cat2', 'cat3')
Here we plot a few examples of the images in the dataset. This is useful to check if the data loader is working correctly. We always need to visualize what is the input for the neural network.
[16]:
import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # un-normalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)
# print(images[0].shape, images[0].mean(), images[0].std())
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
cat1 cat3 cat1
Issues
Sometimes, especially when dealing with images, the input data size can be very large. In this case, we can use transforms.CenterCrop to crop the image to a fixed size, and transforms.RandomCrop to crop the image to a random size. This is useful to reduce the input size and speed up training.
[ ]:
# same example as above but with randomcrop
transform = transforms.Compose([
# transforms.Resize((32,32)),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
batch_size = 3
trainset = torchvision.datasets.ImageFolder(root='./data/my_data/train',
transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
classes = ('cat1', 'cat2', 'cat3')
# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
cat2 cat3 cat1
As you can see this may or may not work, because we may be getting parts of the images that have no content of interest.
Here we try to load them all in original size, but it will fail because the data loader wants all images to be of the same size!
[11]:
# load one image and display it full size
transform = transforms.Compose([
# transforms.Resize((32,32)),
# transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
batch_size = 3
trainset = torchvision.datasets.ImageFolder(root='./data/my_data/train',
transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
classes = ('cat1', 'cat2', 'cat3')
# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[11], line 19
17 # get some random training images
18 dataiter = iter(trainloader)
---> 19 images, labels = next(dataiter)
21 # show images
22 imshow(torchvision.utils.make_grid(images))
File /opt/homebrew/lib/python3.12/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
627 if self._sampler_iter is None:
628 # TODO(https://github.com/pytorch/pytorch/issues/76750)
629 self._reset() # type: ignore[call-arg]
--> 630 data = self._next_data()
631 self._num_yielded += 1
632 if self._dataset_kind == _DatasetKind.Iterable and \
633 self._IterableDataset_len_called is not None and \
634 self._num_yielded > self._IterableDataset_len_called:
File /opt/homebrew/lib/python3.12/site-packages/torch/utils/data/dataloader.py:1344, in _MultiProcessingDataLoaderIter._next_data(self)
1342 else:
1343 del self._task_info[idx]
-> 1344 return self._process_data(data)
File /opt/homebrew/lib/python3.12/site-packages/torch/utils/data/dataloader.py:1370, in _MultiProcessingDataLoaderIter._process_data(self, data)
1368 self._try_put_index()
1369 if isinstance(data, ExceptionWrapper):
-> 1370 data.reraise()
1371 return data
File /opt/homebrew/lib/python3.12/site-packages/torch/_utils.py:706, in ExceptionWrapper.reraise(self)
702 except TypeError:
703 # If the exception takes multiple arguments, don't try to
704 # instantiate since we don't know how to
705 raise RuntimeError(msg) from None
--> 706 raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/opt/homebrew/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
data = fetcher.fetch(index) # type: ignore[possibly-undefined]
^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
return self.collate_fn(data)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py", line 317, in default_collate
return collate(batch, collate_fn_map=default_collate_fn_map)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py", line 174, in collate
return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] # Backwards compatibility.
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py", line 142, in collate
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py", line 214, in collate_tensor_fn
return torch.stack(batch, 0, out=out)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: stack expects each tensor to be equal size, but got [3, 154, 220] at entry 0 and [3, 132, 126] at entry 1
For example, imagine that we want to identify animal species based on their eye, then we will need to either use large resolution input images or to crop the images to the eye region. Here we will do it manually.
[ ]:
# load properly cropped data:
transform = transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
batch_size = 3
trainset = torchvision.datasets.ImageFolder(root='./data/my_data_crops/train',
transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
classes = ('cat1', 'cat2', 'cat3')
# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
Notes
There is a lot that can go wrong just here, when preparing the data for neural network training. This is why it is important to visualize the data and check if it is being loaded correctly.
Always check your steps and make sure that the data is being loaded correctly. This will save you a lot of time from debugging.