{ "cells": [ { "cell_type": "markdown", "id": "2c33caea-527f-4bc7-a337-5642943f4249", "metadata": {}, "source": [ "# Image classification for the MNIST dataset" ] }, { "cell_type": "code", "execution_count": 1, "id": "279aabec-4a30-4708-ae14-5232495a6f1d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# To use full widht of screen in the notebook\n", "from IPython.display import display, HTML\n", "display(HTML(\"\"))" ] }, { "cell_type": "markdown", "id": "0f3dd6dd-ff9a-4e84-a3cd-9e747fc17cbc", "metadata": {}, "source": [ "## Import Libraries and test GPUs\n", "\n", "We import the required libraries and check if the GPU cards are available (cuda)." ] }, { "cell_type": "code", "execution_count": 2, "id": "ed0cffe6-5633-4d94-abde-b6adcc18c5fe", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n" ] } ], "source": [ "# Import libraries\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader\n", "from torch.optim import Adam\n", "from torchvision import datasets, transforms\n", "import matplotlib.pyplot as plt\n", "\n", "# Check if GPU is available\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(f\"Using device: {device}\")" ] }, { "cell_type": "markdown", "id": "49427fca-cec1-4b67-826e-ca6e345a93a4", "metadata": {}, "source": [ "## Dataset\n", "\n", "We download the MNIST datasets for the first time and transform them tensors in order to be used by PyTorch. Afterward, when running the notebook again, the datasets are already stored locally." ] }, { "cell_type": "code", "execution_count": 3, "id": "6661e5ae-68e6-45a4-a3d3-18a92808468f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset MNIST\n", " Number of datapoints: 60000\n", " Root location: ./data\n", " Split: Train\n", " StandardTransform\n", "Transform: Compose(\n", " ToTensor()\n", " )\n", "Dataset MNIST\n", " Number of datapoints: 10000\n", " Root location: ./data\n", " Split: Test\n", " StandardTransform\n", "Transform: Compose(\n", " ToTensor()\n", " )\n", "Image batch shape: torch.Size([64, 1, 28, 28])\n", "Label batch shape: torch.Size([64])\n" ] } ], "source": [ "# Define transformations\n", "my_transform = transforms.Compose([transforms.ToTensor()])\n", "\n", "# Load datasets\n", "train_dataset = datasets.MNIST(root='./data', train=True, transform=my_transform, download=True)\n", "test_dataset = datasets.MNIST(root='./data', train=False, transform=my_transform, download=True)\n", "print(train_dataset)\n", "print(test_dataset)\n", "\n", "# Data loaders\n", "train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n", "test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)\n", "\n", "# Inspect data shape\n", "images, labels = next(iter(train_loader))\n", "print(f\"Image batch shape: {images.shape}\")\n", "print(f\"Label batch shape: {labels.shape}\")" ] }, { "cell_type": "markdown", "id": "88d1036e-4026-416d-9013-aa5c79a4ee51", "metadata": {}, "source": [ "## Data visualization\n", "\n", "Let's print an element of the training dataset as a tensor." ] }, { "cell_type": "code", "execution_count": 4, "id": "b47f2c6e-fdff-407b-909f-9663a57cb156", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.1, 0.1, 0.5, 0.5, 0.7, 0.1, 0.7, 1.0, 1.0, 0.5, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.1, 0.4, 0.6, 0.7, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9, 0.7, 1.0, 0.9, 0.8, 0.3, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.4, 0.3, 0.3, 0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.7, 1.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.6, 0.4, 1.0, 1.0, 0.8, 0.0, 0.0, 0.2, 0.6, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.0, 0.6, 1.0, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 1.0, 0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 1.0, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.9, 0.9, 0.6, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.9, 1.0, 1.0, 0.5, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.7, 1.0, 1.0, 0.6, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.4, 1.0, 1.0, 0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.5, 0.7, 1.0, 1.0, 0.8, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.6, 0.9, 1.0, 1.0, 1.0, 1.0, 0.7, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.4, 0.9, 1.0, 1.0, 1.0, 1.0, 0.8, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.3, 0.8, 1.0, 1.0, 1.0, 1.0, 0.8, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.7, 0.9, 1.0, 1.0, 1.0, 1.0, 0.8, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.2, 0.7, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.5, 1.0, 1.0, 1.0, 0.8, 0.5, 0.5, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# We set the decimal precision to 1 (or 2) only for printing purposes, i.e. to be able to visualize the matrix (tensor) similarly as the actual image\n", "torch.set_printoptions(precision=1, linewidth=300)\n", "\n", "# Let's check the first element as a tensor\n", "train_dataset[0][0][0]" ] }, { "cell_type": "markdown", "id": "25d2b649-b7c4-4a85-81b6-dddc75939814", "metadata": {}, "source": [ "We can barely distinguish the element with the naked eye. Let's check a couple more of elements (this is the first one), but now as images." ] }, { "cell_type": "code", "execution_count": 5, "id": "c78b2a40-e4fa-4866-969c-091daa4fd7db", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "n_imgs=10\n", "fig, axes = plt.subplots(1, n_imgs, figsize=(15, 4))\n", "for i in range(n_imgs):\n", " ax = axes[i]\n", " x, y = train_dataset[i]\n", " image = transforms.functional.to_pil_image(x)\n", " ax.imshow(image, cmap='gray')\n", " ax.set_title(f\"Label: {y}\")\n", " ax.axis('off')\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "1eaa46f7-d7ca-4da3-a5e6-086084061339", "metadata": {}, "source": [ "## Neural network and training\n", "\n", "Now we define our neural network and start the training process." ] }, { "cell_type": "code", "execution_count": 6, "id": "1139183a-1ea0-4ff8-9c97-3131201a8831", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MNISTModel(\n", " (network): Sequential(\n", " (0): Flatten(start_dim=1, end_dim=-1)\n", " (1): Linear(in_features=784, out_features=128, bias=True)\n", " (2): ReLU()\n", " (3): Linear(in_features=128, out_features=64, bias=True)\n", " (4): ReLU()\n", " (5): Linear(in_features=64, out_features=10, bias=True)\n", " )\n", ")\n" ] } ], "source": [ "# Define the neural network\n", "class MNISTModel(nn.Module):\n", " def __init__(self):\n", " super(MNISTModel, self).__init__()\n", " self.network = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(28*28, 128),\n", " nn.ReLU(),\n", " nn.Linear(128, 64),\n", " nn.ReLU(),\n", " nn.Linear(64, 10)\n", " )\n", " \n", " def forward(self, x):\n", " return self.network(x)\n", "\n", "# Initialize model\n", "model = MNISTModel().to(device)\n", "print(model)" ] }, { "cell_type": "code", "execution_count": 7, "id": "1cc351e5-c643-4f35-a041-b777c2bb3a27", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch [1/5], Loss: 0.3365\n", "Epoch [2/5], Loss: 0.1377\n", "Epoch [3/5], Loss: 0.0945\n", "Epoch [4/5], Loss: 0.0699\n", "Epoch [5/5], Loss: 0.0563\n" ] } ], "source": [ "# Define loss function and optimizer\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = Adam(model.parameters(), lr=0.001)\n", "\n", "# Training loop\n", "epochs = 5\n", "for epoch in range(epochs):\n", " model.train()\n", " running_loss = 0.0\n", " for images, labels in train_loader:\n", " images, labels = images.to(device), labels.to(device)\n", " \n", " # Forward pass\n", " outputs = model(images)\n", " loss = criterion(outputs, labels)\n", " \n", " # Backward pass\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " \n", " running_loss += loss.item()\n", " \n", " print(f\"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}\")" ] }, { "cell_type": "markdown", "id": "645153a8-a79b-44a3-9ebf-df45a1109957", "metadata": {}, "source": [ "## Predictions and inference\n", "\n", "Now that the model is trained we can test it." ] }, { "cell_type": "code", "execution_count": 8, "id": "40374e27-c391-4a74-b4e7-0261965b716d", "metadata": {}, "outputs": [], "source": [ "def test_model(model, test_loader, device):\n", " model.eval() # Set model to evaluation mode\n", " \n", " # Get a batch of test data\n", " images, labels = next(iter(test_loader))\n", " images, labels = images.to(device), labels.to(device)\n", " \n", " # Predict\n", " outputs = model(images)\n", " _, predictions = torch.max(outputs, 1)\n", " \n", " # Visualize results\n", " n_imgs=10\n", " fig, axes = plt.subplots(1, n_imgs, figsize=(15, 4))\n", " for i in range(n_imgs):\n", " ax = axes[i]\n", " ax.imshow(images[i].cpu().squeeze(), cmap='gray')\n", " ax.set_title(f\"Label: {labels[i].item()}\\nPredicted: {predictions[i].item()}\")\n", " ax.axis('off')\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 9, "id": "d0ce9ca0-7101-474a-a321-1d313d6d2f23", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "test_model(model, test_loader, device)" ] }, { "cell_type": "code", "execution_count": null, "id": "5d4b4111-1485-4636-91ef-c07f1e6172af", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 5 }