{
"cells": [
{
"cell_type": "markdown",
"id": "2c33caea-527f-4bc7-a337-5642943f4249",
"metadata": {},
"source": [
"# Image classification for the MNIST dataset"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "279aabec-4a30-4708-ae14-5232495a6f1d",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# To use full widht of screen in the notebook\n",
"from IPython.display import display, HTML\n",
"display(HTML(\"\"))"
]
},
{
"cell_type": "markdown",
"id": "0f3dd6dd-ff9a-4e84-a3cd-9e747fc17cbc",
"metadata": {},
"source": [
"## Import Libraries and test GPUs\n",
"\n",
"We import the required libraries and check if the GPU cards are available (cuda)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ed0cffe6-5633-4d94-abde-b6adcc18c5fe",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n"
]
}
],
"source": [
"# Import libraries\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import DataLoader\n",
"from torch.optim import Adam\n",
"from torchvision import datasets, transforms\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Check if GPU is available\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(f\"Using device: {device}\")"
]
},
{
"cell_type": "markdown",
"id": "49427fca-cec1-4b67-826e-ca6e345a93a4",
"metadata": {},
"source": [
"## Dataset\n",
"\n",
"We download the MNIST datasets for the first time and transform them tensors in order to be used by PyTorch. Afterward, when running the notebook again, the datasets are already stored locally."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6661e5ae-68e6-45a4-a3d3-18a92808468f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset MNIST\n",
" Number of datapoints: 60000\n",
" Root location: ./data\n",
" Split: Train\n",
" StandardTransform\n",
"Transform: Compose(\n",
" ToTensor()\n",
" )\n",
"Dataset MNIST\n",
" Number of datapoints: 10000\n",
" Root location: ./data\n",
" Split: Test\n",
" StandardTransform\n",
"Transform: Compose(\n",
" ToTensor()\n",
" )\n",
"Image batch shape: torch.Size([64, 1, 28, 28])\n",
"Label batch shape: torch.Size([64])\n"
]
}
],
"source": [
"# Define transformations\n",
"my_transform = transforms.Compose([transforms.ToTensor()])\n",
"\n",
"# Load datasets\n",
"train_dataset = datasets.MNIST(root='./data', train=True, transform=my_transform, download=True)\n",
"test_dataset = datasets.MNIST(root='./data', train=False, transform=my_transform, download=True)\n",
"print(train_dataset)\n",
"print(test_dataset)\n",
"\n",
"# Data loaders\n",
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
"test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)\n",
"\n",
"# Inspect data shape\n",
"images, labels = next(iter(train_loader))\n",
"print(f\"Image batch shape: {images.shape}\")\n",
"print(f\"Label batch shape: {labels.shape}\")"
]
},
{
"cell_type": "markdown",
"id": "88d1036e-4026-416d-9013-aa5c79a4ee51",
"metadata": {},
"source": [
"## Data visualization\n",
"\n",
"Let's print an element of the training dataset as a tensor."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b47f2c6e-fdff-407b-909f-9663a57cb156",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.1, 0.1, 0.5, 0.5, 0.7, 0.1, 0.7, 1.0, 1.0, 0.5, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.1, 0.4, 0.6, 0.7, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9, 0.7, 1.0, 0.9, 0.8, 0.3, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.4, 0.3, 0.3, 0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.7, 1.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.6, 0.4, 1.0, 1.0, 0.8, 0.0, 0.0, 0.2, 0.6, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.0, 0.6, 1.0, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 1.0, 0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 1.0, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.9, 0.9, 0.6, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.9, 1.0, 1.0, 0.5, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.7, 1.0, 1.0, 0.6, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.4, 1.0, 1.0, 0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.5, 0.7, 1.0, 1.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.6, 0.9, 1.0, 1.0, 1.0, 1.0, 0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.4, 0.9, 1.0, 1.0, 1.0, 1.0, 0.8, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.3, 0.8, 1.0, 1.0, 1.0, 1.0, 0.8, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.7, 0.9, 1.0, 1.0, 1.0, 1.0, 0.8, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.2, 0.7, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.5, 1.0, 1.0, 1.0, 0.8, 0.5, 0.5, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
" [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# We set the decimal precision to 1 (or 2) only for printing purposes, i.e. to be able to visualize the matrix (tensor) similarly as the actual image\n",
"torch.set_printoptions(precision=1, linewidth=300)\n",
"\n",
"# Let's check the first element as a tensor\n",
"train_dataset[0][0][0]"
]
},
{
"cell_type": "markdown",
"id": "25d2b649-b7c4-4a85-81b6-dddc75939814",
"metadata": {},
"source": [
"We can barely distinguish the element with the naked eye. Let's check a couple more of elements (this is the first one), but now as images."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c78b2a40-e4fa-4866-969c-091daa4fd7db",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"