{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "# unit 4.2 - Training a CNN on CIFAR\n", "\n", "[](https://githubtocolab.com/culurciello/deep-learning-course-source/blob/main/source/lectures/42-conv-cifar10-tutorial.ipynb)\n", "\n", "This is it. You have seen how to define neural networks, compute loss and make\n", "updates to the weights of the network.\n", "\n", "Now you might be thinking,\n", "\n", "## What about data?\n", "\n", "Generally, when you have to deal with image, text, audio or video data,\n", "you can use standard python packages that load data into a numpy array.\n", "Then you can convert this array into a ``torch.*Tensor``.\n", "\n", "- For images, packages such as Pillow, OpenCV are useful\n", "- For audio, packages such as scipy and librosa\n", "- For text, either raw Python or Cython based loading, or NLTK and\n", " SpaCy are useful\n", "\n", "Specifically for vision, we have created a package called\n", "``torchvision``, that has data loaders for common datasets such as\n", "ImageNet, CIFAR10, MNIST, etc. and data transformers for images, viz.,\n", "``torchvision.datasets`` and ``torch.utils.data.DataLoader``.\n", "\n", "This provides a huge convenience and avoids writing boilerplate code.\n", "\n", "For this tutorial, we will use the CIFAR10 dataset.\n", "It has the classes: ‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’,\n", "‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’. The images in CIFAR-10 are of\n", "size 3x32x32, i.e. 3-channel color images of 32x32 pixels in size.\n", "\n", ".. figure:: /_static/img/cifar10.png\n", " :alt: cifar10\n", "\n", " cifar10\n", "\n", "\n", "## Training an image classification convolutional neural network\n", "\n", "We will do the following steps in order:\n", "\n", "1. Load and normalize the CIFAR10 training and test datasets using\n", " ``torchvision``\n", "2. Define a Convolutional Neural Network\n", "3. Define a loss function\n", "4. Train the network on the training data\n", "5. Test the network on the test data\n", "\n", "### 1. Load and normalize CIFAR10\n", "\n", "Using ``torchvision``, it’s extremely easy to load CIFAR10.\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torchvision\n", "import torchvision.transforms as transforms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The output of torchvision datasets are PILImage images of range [0, 1].\n", "We transform them to Tensors of normalized range [-1, 1].\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
If running on Windows and you get a BrokenPipeError, try setting\n", " the num_worker of torch.utils.data.DataLoader() to 0.