{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# unit 1.7 - Professional training script\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/culurciello/deep-learning-course-source/blob/main/source/lectures/17-full-train_script_pytorch.ipynb)\n", "\n", "Here we will look at a professional training script in pytorch\n", "\n", "You can use this as a REFERENCE for all your projects!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "DATASET:\n", "\n", "this uses a small dataset used for learning neural networks" ] }, { "cell_type": "code", "execution_count": 6, "id": "b4984da9", "metadata": {}, "outputs": [], "source": [ "# PyTorch train script\n", "# https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html\n", "\n", "# We will use the Fashion MNIST dataset: https://www.kaggle.com/datasets/zalando-research/fashionmnist\n", "\n", "import torch\n", "from torch import nn\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import ToTensor\n", "\n", "training_data = datasets.FashionMNIST(\n", " root=\"data\",\n", " train=True,\n", " download=True,\n", " transform=ToTensor()\n", ")\n", "\n", "test_data = datasets.FashionMNIST(\n", " root=\"data\",\n", " train=False,\n", " download=True,\n", " transform=ToTensor()\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## DATA:\n", "\n", "If we want to categorize images into classes, we will need a list of possible classes. The output neuros will be the same number as the number of classes.\n", "\n", "Before dojg anything else, it is always a good idea to take a look at the data in the dataset!" ] }, { "cell_type": "code", "execution_count": 7, "id": "f1a81a12", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Inputs sample - image size: torch.Size([1, 28, 28])\n", "Label: 1 \n", "\n", "Inputs sample - min,max,mean,std: 0.0 1.0 0.28586435317993164 0.392581582069397\n", "Inputs sample normalized - min,max,mean,std: -0.7281654477119446 1.8190758228302002 1.2468318821845514e-08 1.0\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAeoElEQVR4nO3df2yV9d3/8dc5h/ZAoZxSSn9JYQV/sAl09xh0RGU4GqB+bwNKFn99EzAGoitmyJymi4q6Jd0wcUbD8J8N5vcr/rojEM3ComhL3IAJSpB7rqHc3Sg3tMy6/qDQUno+3z/4enYfKbLPxel5t+X5SK6EnnO9e7399Kqvc/VcfTfknHMCACDNwtYNAACuTAQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATIywbuDL4vG4jh8/ruzsbIVCIet2AACenHPq7OxUcXGxwuGLX+cMugA6fvy4SkpKrNsAAFympqYmTZw48aLPD7oAys7OliTdqFs0QhnG3eCrhLPHeNfEO09514T+7eveNe7AX7xrzhcymSqw2df713z4n6nvA+bOqVcf6HeJ/59fzIAF0IYNG/TMM8+oublZZWVleuGFFzRnzpxL1n3xY7cRytCIEAE0mIVDmd418QBf01Ak6l3jAp87BFBgI0b61/A9Pjz9/2+jS72NMiA3Ibz22mtau3at1q1bp48++khlZWVatGiRTp48ORCHAwAMQQMSQM8++6xWrlype++9V9/4xjf04osvKisrS7/5zW8G4nAAgCEo5QF09uxZ7d+/XxUVFf88SDisiooK7d69+4L9e3p61NHRkbQBAIa/lAfQZ599pr6+PhUUFCQ9XlBQoObm5gv2r6mpUSwWS2zcAQcAVwbzX0Strq5We3t7YmtqarJuCQCQBim/Cy4vL0+RSEQtLS1Jj7e0tKiwsPCC/aPRqKJR/7ucAABDW8qvgDIzMzVr1izt3Lkz8Vg8HtfOnTs1d+7cVB8OADBEDcjvAa1du1bLly/Xt7/9bc2ZM0fPPfecurq6dO+99w7E4QAAQ9CABNAdd9yhv//973riiSfU3Nysb37zm9qxY8cFNyYAAK5cIecG1+yRjo4OxWIxzdcSJiFAkjRxj//In++N+zTQsRp7JnjX/FvWX71ruuL+73t2x/2/H/adKvWukaRzLuJdc8PYw941/+d/3+Jd4z78xLsG6XXO9apW29Xe3q6xY8dedD/zu+AAAFcmAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgZkGjaQSpNGfe5dMzrcE+hYEzP9j9XWN9q7Jkh/mZE+75qK2H9610hSW1+Wd831mce9a7om+h8n60PvEgxSXAEBAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwwDRtpFRk3zrtmYuan3jUnz431rpGCTakOMjm6s2+kd017gOMUZfzDu0aSPu8bE6jO16mrIt41/quAwYorIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYYRoq0cpOLvGvGj/jQu+a/e/2HnkpSRuicd83peKZ3TSxyxrsmHIp714wM93rXSJL6gpX56rrKpedAGJS4AgIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCYaRIq66vjfGuyQmf9q75r3i+d40k9Tr/b4n2c1neNcUZbd41n/Vme9fEXbDXmBkh/2mkbfFR3jWhKV3eNRg+uAICAJgggAAAJlIeQE8++aRCoVDSNm3atFQfBgAwxA3Ie0DXX3+93n333X8eZARvNQEAkg1IMowYMUKFhYUD8akBAMPEgLwHdPjwYRUXF2vKlCm65557dPTo0Yvu29PTo46OjqQNADD8pTyAysvLtXnzZu3YsUMbN25UY2OjbrrpJnV2dva7f01NjWKxWGIrKSlJdUsAgEEo5QFUWVmp73//+5o5c6YWLVqk3/3ud2pra9Prr7/e7/7V1dVqb29PbE1NTaluCQAwCA343QE5OTm69tpr1dDQ0O/z0WhU0Wh0oNsAAAwyA/57QKdOndKRI0dUVFQ00IcCAAwhKQ+ghx9+WHV1dfrrX/+qP/7xj7rtttsUiUR01113pfpQAIAhLOU/gjt27Jjuuusutba2asKECbrxxhu1Z88eTZgwIdWHAgAMYSkPoFdffTXVnxLDSHdOxLumy2V61/Qp5F2TTleN+Id3TUN3gXdNXxqnbbXF/YeyFo3j1y6uZMyCAwCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYGLA/yAd8D915/oPCe1z6Xud1Ov8h6VmhPq8a8KhuHdNT9z/2zXIf48kReTf39kAx8rP6vSuafeuwGDFFRAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwATTsJFW3XnOuyYSYHJ0UG19Wd412ZFu75pM+U/QDqI7npGW4wQ9Vn70lHcN07CHD66AAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmGAYKdLqXMx/sGif83+d1BsPdmpHwv7DUidmtnrXfHhmindNNHzOu6bXBVuHjJD/sU7Ho9414zP9h5EeVvoGrGJgcQUEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABMNIkV6j/YdcDnbXZzZ71zz1yb971yyZ8ol3Ta+LeNdIUjjkPzQ2HmBobH5Gh3eNND5ADQYjroAAACYIIACACe8A2rVrl2699VYVFxcrFApp27ZtSc875/TEE0+oqKhIo0aNUkVFhQ4fPpyqfgEAw4R3AHV1damsrEwbNmzo9/n169fr+eef14svvqi9e/dq9OjRWrRokbq7uy+7WQDA8OF9E0JlZaUqKyv7fc45p+eee06PPfaYlixZIkl66aWXVFBQoG3btunOO++8vG4BAMNGSt8DamxsVHNzsyoqKhKPxWIxlZeXa/fu3f3W9PT0qKOjI2kDAAx/KQ2g5ubzt6MWFBQkPV5QUJB47stqamoUi8USW0lJSSpbAgAMUuZ3wVVXV6u9vT2xNTU1WbcEAEiDlAZQYWGhJKmlpSXp8ZaWlsRzXxaNRjV27NikDQAw/KU0gEpLS1VYWKidO3cmHuvo6NDevXs1d+7cVB4KADDEed8Fd+rUKTU0NCQ+bmxs1IEDB5Sbm6tJkyZpzZo1+tnPfqZrrrlGpaWlevzxx1VcXKylS5emsm8AwBDnHUD79u3TzTffnPh47dq1kqTly5dr8+bNeuSRR9TV1aVVq1apra1NN954o3bs2KGRI0emrmsAwJDnHUDz58+Xc+6iz4dCIT399NN6+umnL6sxDE+ZWWe9a+IBflIcCTBMU5I6+/xfKE0eEfKuCe+J+ddMvfj33cVEw73eNVKwwaJBBp9OzviHdw3DSIcP87vgAABXJgIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACe9p2MDlyM7q8a45G2DKcnc8w7smqHCA13H5+/3Xoecu/2/X7PAZ7xpJ+rxvjHdNn/ynggftD8MDV0AAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMMIwUaTU+q8u6ha90Op6ZluNEP2rwruk4N9K7Jn9Ep3eNJLWciwWq85UTOZ2W42Bw4goIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACYaRIq1yomfScpzsSHdajhNUX1u7d81/ny7yrskYf867RpJ64/7/a8gI+x9rQrjHuyY0wr83dy7YOmBgcQUEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABMNIkVb50U7vmrjzf50UDsW9ayQpI9TnXXPa9QY6lq8TndneNYWRU4GOFVfIvybA1ynD/zAKZWZ61zCMdHDiCggAYIIAAgCY8A6gXbt26dZbb1VxcbFCoZC2bduW9PyKFSsUCoWStsWLF6eqXwDAMOEdQF1dXSorK9OGDRsuus/ixYt14sSJxPbKK69cVpMAgOHH+yaEyspKVVZWfuU+0WhUhYWFgZsCAAx/A/IeUG1trfLz83XdddfpgQceUGtr60X37enpUUdHR9IGABj+Uh5Aixcv1ksvvaSdO3fqF7/4herq6lRZWam+vv5vb62pqVEsFktsJSUlqW4JADAIpfz3gO68887Ev2fMmKGZM2dq6tSpqq2t1YIFCy7Yv7q6WmvXrk183NHRQQgBwBVgwG/DnjJlivLy8tTQ0NDv89FoVGPHjk3aAADD34AH0LFjx9Ta2qqioqKBPhQAYAjx/hHcqVOnkq5mGhsbdeDAAeXm5io3N1dPPfWUli1bpsLCQh05ckSPPPKIrr76ai1atCiljQMAhjbvANq3b59uvvnmxMdfvH+zfPlybdy4UQcPHtRvf/tbtbW1qbi4WAsXLtRPf/pTRaPR1HUNABjyvANo/vz5cs5d9Pnf//73l9UQhrf8TP9hpGddxLsmyFBRScqOdHvXdMYv/v2QSu2dWd41sXCwdQiyftGw/1DWzFCAaaQR//MBgxOz4AAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJlL+J7mBr1KU0eZdE0/j66RYpMu75nDvuAHo5ELxf/j/SZNA06YlhUNx75q48/86BfnKhkaO9C/q9J/CjoHHFRAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATDCNFWhUGGEba3JvjXdPrIt41kjQ6fNa75s89VwU6lq8RHf6vF/d0Twh0rIicd02QAaYjQ/5fp1BWgGGkGJS4AgIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCYaRIq/HhLu+a1tCYAeikfyMDDCP9pHNigCOd8q4Y+feQf02o17smncIBXgP3TYj5H+hvTf41GHBcAQEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDBMFKkVXaAYZ/xAK+Tgg7hDFL3SWuRd804HfauGfV3510TZO0kKSPU513TJ/9hqUH05I3yrskcgD5w+bgCAgCYIIAAACa8AqimpkazZ89Wdna28vPztXTpUtXX1yft093draqqKo0fP15jxozRsmXL1NLSktKmAQBDn1cA1dXVqaqqSnv27NE777yj3t5eLVy4UF1d//wjYw899JDeeustvfHGG6qrq9Px48d1++23p7xxAMDQ5nUTwo4dO5I+3rx5s/Lz87V//37NmzdP7e3t+vWvf60tW7boe9/7niRp06ZN+vrXv649e/boO9/5Tuo6BwAMaZf1HlB7e7skKTc3V5K0f/9+9fb2qqKiIrHPtGnTNGnSJO3evbvfz9HT06OOjo6kDQAw/AUOoHg8rjVr1uiGG27Q9OnTJUnNzc3KzMxUTk5O0r4FBQVqbm7u9/PU1NQoFosltpKSkqAtAQCGkMABVFVVpUOHDunVV1+9rAaqq6vV3t6e2Jqami7r8wEAhoZAv4i6evVqvf3229q1a5cmTpyYeLywsFBnz55VW1tb0lVQS0uLCgsL+/1c0WhU0Wg0SBsAgCHM6wrIOafVq1dr69ateu+991RaWpr0/KxZs5SRkaGdO3cmHquvr9fRo0c1d+7c1HQMABgWvK6AqqqqtGXLFm3fvl3Z2dmJ93VisZhGjRqlWCym++67T2vXrlVubq7Gjh2rBx98UHPnzuUOOABAEq8A2rhxoyRp/vz5SY9v2rRJK1askCT98pe/VDgc1rJly9TT06NFixbpV7/6VUqaBQAMH14B5NylhyGOHDlSGzZs0IYNGwI3heErFg4w5NL53yvT64LN2Q0yhLP18zHeNeO8K6RRree8a4IOCA2H4v5FAb5OQZwdG/GuYRjp4MQsOACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACAiWAjgwFJkXH+M51Hh9LzmifQNGdJI0O9/kWt6fmLvtHPe7xruuLBeosowPoFGLzdK//p4z0x/wP5zytHOnAFBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwATDSBFYaFzMuyYa8j/lIkEGizr/kqCirel5HTfiZId3Ta8L9i0+Muw/lLU7nuFdE3f+X6ggw0gxOHEFBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwATDSBFYfGyWd00k5D9IMiz/YaSRADWSlBmgLrM90KG8uc8+964JMiBUkjJC5wLU9AU6lq9zY9JyGKQBV0AAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMMIwUgfVl+Q+6/I9Thd413S7Tu6Yv4GurswHqslqCDT71FT/T7V3T0hsLdKyJma3eNb0u4l3zx55c75p4hvOuweDEFRAAwAQBBAAw4RVANTU1mj17trKzs5Wfn6+lS5eqvr4+aZ/58+crFAolbffff39KmwYADH1eAVRXV6eqqirt2bNH77zzjnp7e7Vw4UJ1dXUl7bdy5UqdOHEisa1fvz6lTQMAhj6vmxB27NiR9PHmzZuVn5+v/fv3a968eYnHs7KyVFjo/2YzAODKcVnvAbW3n/9bxLm5yXeyvPzyy8rLy9P06dNVXV2t06dPX/Rz9PT0qKOjI2kDAAx/gW/DjsfjWrNmjW644QZNnz498fjdd9+tyZMnq7i4WAcPHtSjjz6q+vp6vfnmm/1+npqaGj311FNB2wAADFGBA6iqqkqHDh3SBx98kPT4qlWrEv+eMWOGioqKtGDBAh05ckRTp0694PNUV1dr7dq1iY87OjpUUlIStC0AwBARKIBWr16tt99+W7t27dLEiRO/ct/y8nJJUkNDQ78BFI1GFY1Gg7QBABjCvALIOacHH3xQW7duVW1trUpLSy9Zc+DAAUlSUVFRoAYBAMOTVwBVVVVpy5Yt2r59u7Kzs9Xc3CxJisViGjVqlI4cOaItW7bolltu0fjx43Xw4EE99NBDmjdvnmbOnDkg/wEAgKHJK4A2btwo6fwvm/5PmzZt0ooVK5SZmal3331Xzz33nLq6ulRSUqJly5bpscceS1nDAIDhwftHcF+lpKREdXV1l9UQAODKwDRsBNZ2bZZ3zeyRR71rPunxf/8wyGRmScoJn/WuCZ9Lz3Rm1+vf24/GHwp0rI97/H9FsLnPf/L2/8ryn/D9ZJn/pG4MTgwjBQCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIJhpAgs78PPvWtu+UOV/3FyTnnXxKL+Qy4l6f9G+rxrsuvbvWvi3hXB/Ptfbg9UFwn5d/hfJ8d71zx0JsO7ZvJ/8Lp5uOArCQAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATg24WnHNOknROvZIzbgZfyfX1eNfET4e8a/oy/I9z7px/jSQp4j8DLR5kHVyvd00Q8a5g6+ACzIKLn/afvxfv9p+9d67X/3VzJE3rjfPO6fx6f/H/84sJuUvtkWbHjh1TSUmJdRsAgMvU1NSkiRMnXvT5QRdA8Xhcx48fV3Z2tkKh5FfLHR0dKikpUVNTk8aOHWvUoT3W4TzW4TzW4TzW4bzBsA7OOXV2dqq4uFjh8MWvWAfdj+DC4fBXJqYkjR079oo+wb7AOpzHOpzHOpzHOpxnvQ6xWOyS+3ATAgDABAEEADAxpAIoGo1q3bp1ikaj1q2YYh3OYx3OYx3OYx3OG0rrMOhuQgAAXBmG1BUQAGD4IIAAACYIIACACQIIAGBiyATQhg0b9LWvfU0jR45UeXm5/vSnP1m3lHZPPvmkQqFQ0jZt2jTrtgbcrl27dOutt6q4uFihUEjbtm1Let45pyeeeEJFRUUaNWqUKioqdPjwYZtmB9Cl1mHFihUXnB+LFy+2aXaA1NTUaPbs2crOzlZ+fr6WLl2q+vr6pH26u7tVVVWl8ePHa8yYMVq2bJlaWlqMOh4Y/8o6zJ8//4Lz4f777zfquH9DIoBee+01rV27VuvWrdNHH32ksrIyLVq0SCdPnrRuLe2uv/56nThxIrF98MEH1i0NuK6uLpWVlWnDhg39Pr9+/Xo9//zzevHFF7V3716NHj1aixYtUne3/3DMwexS6yBJixcvTjo/XnnllTR2OPDq6upUVVWlPXv26J133lFvb68WLlyorq6uxD4PPfSQ3nrrLb3xxhuqq6vT8ePHdfvttxt2nXr/yjpI0sqVK5POh/Xr1xt1fBFuCJgzZ46rqqpKfNzX1+eKi4tdTU2NYVfpt27dOldWVmbdhilJbuvWrYmP4/G4KywsdM8880zisba2NheNRt0rr7xi0GF6fHkdnHNu+fLlbsmSJSb9WDl58qST5Orq6pxz57/2GRkZ7o033kjs8+mnnzpJbvfu3VZtDrgvr4Nzzn33u991P/zhD+2a+hcM+iugs2fPav/+/aqoqEg8Fg6HVVFRod27dxt2ZuPw4cMqLi7WlClTdM899+jo0aPWLZlqbGxUc3Nz0vkRi8VUXl5+RZ4ftbW1ys/P13XXXacHHnhAra2t1i0NqPb2dklSbm6uJGn//v3q7e1NOh+mTZumSZMmDevz4cvr8IWXX35ZeXl5mj59uqqrq3X69GmL9i5q0A0j/bLPPvtMfX19KigoSHq8oKBAf/nLX4y6slFeXq7Nmzfruuuu04kTJ/TUU0/ppptu0qFDh5SdnW3dnonm5mZJ6vf8+OK5K8XixYt1++23q7S0VEeOHNFPfvITVVZWavfu3YpEItbtpVw8HteaNWt0ww03aPr06ZLOnw+ZmZnKyclJ2nc4nw/9rYMk3X333Zo8ebKKi4t18OBBPfroo6qvr9ebb75p2G2yQR9A+KfKysrEv2fOnKny8nJNnjxZr7/+uu677z7DzjAY3HnnnYl/z5gxQzNnztTUqVNVW1urBQsWGHY2MKqqqnTo0KEr4n3Qr3KxdVi1alXi3zNmzFBRUZEWLFigI0eOaOrUqelus1+D/kdweXl5ikQiF9zF0tLSosLCQqOuBoecnBxde+21amhosG7FzBfnAOfHhaZMmaK8vLxheX6sXr1ab7/9tt5///2kP99SWFios2fPqq2tLWn/4Xo+XGwd+lNeXi5Jg+p8GPQBlJmZqVmzZmnnzp2Jx+LxuHbu3Km5c+cadmbv1KlTOnLkiIqKiqxbMVNaWqrCwsKk86Ojo0N79+694s+PY8eOqbW1dVidH845rV69Wlu3btV7772n0tLSpOdnzZqljIyMpPOhvr5eR48eHVbnw6XWoT8HDhyQpMF1PljfBfGvePXVV100GnWbN292f/7zn92qVatcTk6Oa25utm4trX70ox+52tpa19jY6P7whz+4iooKl5eX506ePGnd2oDq7Ox0H3/8sfv444+dJPfss8+6jz/+2P3tb39zzjn385//3OXk5Ljt27e7gwcPuiVLlrjS0lJ35swZ485T66vWobOz0z388MNu9+7drrGx0b377rvuW9/6lrvmmmtcd3e3desp88ADD7hYLOZqa2vdiRMnEtvp06cT+9x///1u0qRJ7r333nP79u1zc+fOdXPnzjXsOvUutQ4NDQ3u6aefdvv27XONjY1u+/btbsqUKW7evHnGnScbEgHknHMvvPCCmzRpksvMzHRz5sxxe/bssW4p7e644w5XVFTkMjMz3VVXXeXuuOMO19DQYN3WgHv//fedpAu25cuXO+fO34r9+OOPu4KCAheNRt2CBQtcfX29bdMD4KvW4fTp027hwoVuwoQJLiMjw02ePNmtXLly2L1I6++/X5LbtGlTYp8zZ864H/zgB27cuHEuKyvL3Xbbbe7EiRN2TQ+AS63D0aNH3bx581xubq6LRqPu6quvdj/+8Y9de3u7beNfwp9jAACYGPTvAQEAhicCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAm/h+y9WWFx1BbhgAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# let us print some data:\n", "\n", "categories = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', \n", " 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']\n", "\n", "# select a random sample from the training set\n", "sample_num = 143\n", "# print(training_data[sample_num])\n", "print('Inputs sample - image size:', training_data[sample_num][0].shape)\n", "print('Label:', training_data[sample_num][1], '\\n')\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "ima = training_data[sample_num][0]\n", "print('Inputs sample - min,max,mean,std:', ima.min().item(), ima.max().item(), ima.mean().item(), ima.std().item())\n", "ima = (ima - ima.mean())/ ima.std()\n", "print('Inputs sample normalized - min,max,mean,std:', ima.min().item(), ima.max().item(), ima.mean().item(), ima.std().item())\n", "iman = ima.permute(1, 2, 0) # needed to be able to plot\n", "plt.imshow(iman)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Looks like a pair of pants!" ] }, { "cell_type": "code", "execution_count": 8, "id": "06c62e13", "metadata": {}, "outputs": [], "source": [ "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.flatten = nn.Flatten()\n", " self.l1 = nn.Linear(28*28, 512)\n", " self.l2 = nn.Linear(512, 512)\n", " self.l3 = nn.Linear(512, 10)\n", "\n", " def forward(self, x):\n", " x = self.flatten(x)\n", " x = F.relu(self.l1(x))\n", " x = F.relu(self.l2(x))\n", " output = self.l3(x)\n", " return output\n", " \n", "# Can also be written as: \n", "\n", "# class Net(nn.Module):\n", "# def __init__(self):\n", "# super(NeuralNetwork, self).__init__()\n", "# self.flatten = nn.Flatten()\n", "# self.linear_relu_stack = nn.Sequential(\n", "# nn.Linear(28*28, 512),\n", "# nn.ReLU(),\n", "# nn.Linear(512, 512),\n", "# nn.ReLU(),\n", "# nn.Linear(512, 10),\n", "# )\n", "\n", "# def forward(self, x):\n", "# x = self.flatten(x)\n", "# output = self.linear_relu_stack(x)\n", "# return output\n", "\n", "\n", "def train_loop(dataloader, model, loss_fn, optimizer):\n", " size = len(dataloader.dataset)\n", " for batch, (X, y) in enumerate(dataloader):\n", " # Compute prediction and loss\n", " pred = model(X)\n", " loss = loss_fn(pred, y)\n", "\n", " # Backpropagation\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if batch % 100 == 0:\n", " loss, current = loss.item(), (batch + 1) * len(X)\n", " print(f\"loss: {loss:>7f} [{current:>5d}/{size:>5d}]\")\n", "\n", "\n", "def test_loop(dataloader, model, loss_fn):\n", " size = len(dataloader.dataset)\n", " num_batches = len(dataloader)\n", " test_loss, correct = 0, 0\n", "\n", " with torch.no_grad():\n", " for X, y in dataloader:\n", " pred = model(X)\n", " test_loss += loss_fn(pred, y).item()\n", " correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n", "\n", " test_loss /= num_batches\n", " correct /= size\n", " print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice our output layer has 10 neurons, just like the number of classes in the dataset.\n", "\n", "Let us now train the network" ] }, { "cell_type": "code", "execution_count": 9, "id": "d8684362", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1\n", "-------------------------------\n", "loss: 2.299345 [ 64/60000]\n", "loss: 0.571365 [ 6464/60000]\n", "loss: 0.411667 [12864/60000]\n", "loss: 0.506400 [19264/60000]\n", "loss: 0.440894 [25664/60000]\n", "loss: 0.442545 [32064/60000]\n", "loss: 0.380788 [38464/60000]\n", "loss: 0.523417 [44864/60000]\n", "loss: 0.476528 [51264/60000]\n", "loss: 0.513048 [57664/60000]\n", "Test Error: \n", " Accuracy: 83.4%, Avg loss: 0.439552 \n", "\n", "Epoch 2\n", "-------------------------------\n", "loss: 0.296889 [ 64/60000]\n", "loss: 0.359523 [ 6464/60000]\n", "loss: 0.287938 [12864/60000]\n", "loss: 0.388232 [19264/60000]\n", "loss: 0.377390 [25664/60000]\n", "loss: 0.385652 [32064/60000]\n", "loss: 0.314416 [38464/60000]\n", "loss: 0.529878 [44864/60000]\n", "loss: 0.416916 [51264/60000]\n", "loss: 0.453392 [57664/60000]\n", "Test Error: \n", " Accuracy: 85.7%, Avg loss: 0.390238 \n", "\n", "Epoch 3\n", "-------------------------------\n", "loss: 0.223048 [ 64/60000]\n", "loss: 0.333038 [ 6464/60000]\n", "loss: 0.229556 [12864/60000]\n", "loss: 0.330350 [19264/60000]\n", "loss: 0.448680 [25664/60000]\n", "loss: 0.359768 [32064/60000]\n", "loss: 0.277115 [38464/60000]\n", "loss: 0.442255 [44864/60000]\n", "loss: 0.326743 [51264/60000]\n", "loss: 0.414322 [57664/60000]\n", "Test Error: \n", " Accuracy: 86.2%, Avg loss: 0.378768 \n", "\n", "Epoch 4\n", "-------------------------------\n", "loss: 0.189927 [ 64/60000]\n", "loss: 0.306457 [ 6464/60000]\n", "loss: 0.215652 [12864/60000]\n", "loss: 0.290086 [19264/60000]\n", "loss: 0.388842 [25664/60000]\n", "loss: 0.369259 [32064/60000]\n", "loss: 0.241898 [38464/60000]\n", "loss: 0.397346 [44864/60000]\n", "loss: 0.275215 [51264/60000]\n", "loss: 0.327559 [57664/60000]\n", "Test Error: \n", " Accuracy: 86.8%, Avg loss: 0.355445 \n", "\n", "Epoch 5\n", "-------------------------------\n", "loss: 0.182914 [ 64/60000]\n", "loss: 0.246371 [ 6464/60000]\n", "loss: 0.205581 [12864/60000]\n", "loss: 0.242901 [19264/60000]\n", "loss: 0.413763 [25664/60000]\n", "loss: 0.321556 [32064/60000]\n", "loss: 0.241755 [38464/60000]\n", "loss: 0.402537 [44864/60000]\n", "loss: 0.241768 [51264/60000]\n", "loss: 0.302564 [57664/60000]\n", "Test Error: \n", " Accuracy: 87.3%, Avg loss: 0.347425 \n", "\n", "Epoch 6\n", "-------------------------------\n", "loss: 0.195791 [ 64/60000]\n", "loss: 0.234772 [ 6464/60000]\n", "loss: 0.166920 [12864/60000]\n", "loss: 0.222736 [19264/60000]\n", "loss: 0.472416 [25664/60000]\n", "loss: 0.304280 [32064/60000]\n", "loss: 0.209779 [38464/60000]\n", "loss: 0.325730 [44864/60000]\n", "loss: 0.265666 [51264/60000]\n", "loss: 0.296183 [57664/60000]\n", "Test Error: \n", " Accuracy: 88.1%, Avg loss: 0.333879 \n", "\n", "Epoch 7\n", "-------------------------------\n", "loss: 0.193307 [ 64/60000]\n", "loss: 0.188094 [ 6464/60000]\n", "loss: 0.164318 [12864/60000]\n", "loss: 0.242136 [19264/60000]\n", "loss: 0.304813 [25664/60000]\n", "loss: 0.280470 [32064/60000]\n", "loss: 0.204909 [38464/60000]\n", "loss: 0.313631 [44864/60000]\n", "loss: 0.283201 [51264/60000]\n", "loss: 0.260384 [57664/60000]\n", "Test Error: \n", " Accuracy: 87.7%, Avg loss: 0.349270 \n", "\n", "Epoch 8\n", "-------------------------------\n", "loss: 0.185134 [ 64/60000]\n", "loss: 0.178030 [ 6464/60000]\n", "loss: 0.186955 [12864/60000]\n", "loss: 0.224317 [19264/60000]\n", "loss: 0.263922 [25664/60000]\n", "loss: 0.295104 [32064/60000]\n", "loss: 0.174908 [38464/60000]\n", "loss: 0.330540 [44864/60000]\n", "loss: 0.241252 [51264/60000]\n", "loss: 0.283601 [57664/60000]\n", "Test Error: \n", " Accuracy: 87.7%, Avg loss: 0.353108 \n", "\n", "Epoch 9\n", "-------------------------------\n", "loss: 0.149028 [ 64/60000]\n", "loss: 0.191820 [ 6464/60000]\n", "loss: 0.167478 [12864/60000]\n", "loss: 0.198771 [19264/60000]\n", "loss: 0.286330 [25664/60000]\n", "loss: 0.239577 [32064/60000]\n", "loss: 0.167874 [38464/60000]\n", "loss: 0.296653 [44864/60000]\n", "loss: 0.224247 [51264/60000]\n", "loss: 0.287964 [57664/60000]\n", "Test Error: \n", " Accuracy: 87.7%, Avg loss: 0.358775 \n", "\n", "Epoch 10\n", "-------------------------------\n", "loss: 0.148260 [ 64/60000]\n", "loss: 0.138894 [ 6464/60000]\n", "loss: 0.152476 [12864/60000]\n", "loss: 0.188317 [19264/60000]\n", "loss: 0.292156 [25664/60000]\n", "loss: 0.267396 [32064/60000]\n", "loss: 0.193608 [38464/60000]\n", "loss: 0.318536 [44864/60000]\n", "loss: 0.225123 [51264/60000]\n", "loss: 0.240415 [57664/60000]\n", "Test Error: \n", " Accuracy: 87.2%, Avg loss: 0.388099 \n", "\n", "Done!\n" ] } ], "source": [ "# training!\n", "\n", "model = Net()\n", "\n", "batch_size = 64\n", "train_dataloader = DataLoader(training_data, batch_size=batch_size)\n", "test_dataloader = DataLoader(test_data, batch_size=batch_size)\n", "\n", "loss_fn = nn.CrossEntropyLoss() # used for categorization\n", "learning_rate = 1e-3\n", "# note: optimizer is Adam: one of the best optimizers to date\n", "# it can infer learning rate and all hyper-parameters automatically\n", "optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n", "\n", "epochs = 10\n", "for t in range(epochs):\n", " print(f\"Epoch {t+1}\\n-------------------------------\")\n", " train_loop(train_dataloader, model, loss_fn, optimizer)\n", " test_loop(test_dataloader, model, loss_fn)\n", "print(\"Done!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now test if the network was trained correctly:" ] }, { "cell_type": "code", "execution_count": 11, "id": "4c11e9cf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "neural network output pseudo-probabilities: tensor([[-11.9542, 31.1532, -30.3186, -11.9923, -10.9189, -27.4956, -11.2437,\n", " -48.6321, -12.0049, -27.4024]])\n", "neural network output class number: 1\n", "neural network output, predicted class: Trouser\n" ] } ], "source": [ "sample_num = 143 # select a random sample\n", "\n", "with torch.no_grad():\n", " r = model(training_data[sample_num][0])\n", "\n", "print('neural network output pseudo-probabilities:', r)\n", "print('neural network output class number:', torch.argmax(r).item())\n", "print('neural network output, predicted class:', categories[torch.argmax(r).item()])\n" ] }, { "cell_type": "markdown", "id": "0b6b3ba2", "metadata": {}, "source": [ "## Important Details\n", "\n", "In this script we covered some important details that you may need to study further. Please explore and read the following pages: \n", "\n", "- [Torch datasets](http://pytorch.org/vision/stable/datasets.html#)\n", "- [Torch optimization algorithms](https://pytorch.org/docs/stable/optim.html)\n", "- [Torch loss functions, CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)\n", "- [Torch data transformations](https://pytorch.org/vision/stable/transforms.html#)\n", "\n", "The scope of all these topics is quite vast, but at least please read and try to remember the functions and routines you use, and what is available.\n", "\n", "### Note about the loss function used\n", "\n", "In this example we use the [cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) for \"categorization tasks\", as in predicting one class out of N\n", "\n", "### Note about the optimizer used\n", "\n", "In this examples we use [Adam](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#adam) as an optimizer. Adam is an advanced version of SGD that adjusts hyper-parameters automatically." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## HOMEWORK:\n", "\n", "Train your neural network with the same architecture for your own data. Get a few images of 3-4 categories of objects from the internet. Resize the images to square size 28x28 just like the example above. You can also modify your network to accept images of different sizes. An example of training data is [here](data/my_data.zip). You may want to have a `train/` folder and a `text/` folder just like in this example here. Train and test data are sets of the same categories of images, but with different images; train usually has a lot more than test. Use the [data loader](https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html) `torchvision.datasets.ImageFolder` instead of the `torchvision.datasets.FashionMNIST` one used above. " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }