{ "cells": [ { "cell_type": "markdown", "id": "751f4b48", "metadata": {}, "source": [ "\n", " \"Open\n", "" ] }, { "cell_type": "markdown", "id": "4415ef93c445cad1", "metadata": {}, "source": [ "# Image models\n", "\n", "
\n", " Table of Contents \n", "\n", "1. [Image models](#im_mod)\n", "2. [Transfer learning](#tra_lr)\n", " \n" ] }, { "cell_type": "markdown", "id": "3c44ee0e", "metadata": {}, "source": [ "In the previous notebook, we discovered how to **handle image data efficiently** using PyTorch's `Dataset` and `DataLoader` tools. We learned to organize image data in folder structures, apply transformations on-the-fly, and load batches of images for training.\n", "\n", "Now that we know how to prepare and load image data, the next step is to learn about **machine learning models specifically designed for images**. In this notebook, we will explore **Convolutional Neural Networks (CNNs)** and **transfer learning** -powerful techniques that exploit the spatial structure of images to achieve state-of-the-art results in computer vision tasks." ] }, { "cell_type": "markdown", "id": "b1abff328d518772", "metadata": {}, "source": [ "\n", "
\n", " I. Image models\n", " \n", "
\n", "\n", "\n", "Until now we have used MLPs (Multi-Layer Perceptrons). MLPs are **fully connected networks** that expect their input as a **1D tabular dataset** -meaning all data must be flattened into a single vector. When we flatten 2D image data into a 1D array, we lose the spatial structure of the image. In the previous example, we had to flatten the data, which results in:" ] }, { "cell_type": "code", "execution_count": 1, "id": "0d60373f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([10, 3, 32, 32])\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/miquelmn/Desenvolupament/01 - Docencia/AppOC/.venv/lib/python3.12/site-packages/torchvision/datasets/cifar.py:83: VisibleDeprecationWarning: dtype(): align should be passed as Python or NumPy boolean but got `align=0`. Did you mean to pass a tuple to create a subarray type? (Deprecated NumPy 2.4)\n", " entry = pickle.load(f, encoding=\"latin1\")\n" ] } ], "source": [ "import torch\n", "from torch import nn\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "\n", "BATCH_SIZE = 10\n", "\n", "# Load CIFAR-10 dataset\n", "transform = transforms.ToTensor()\n", "trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) # download must be True if is the first time you execute this notebook\n", "\n", "# Load the entire dataset into memory\n", "trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=False)\n", "images, labels = next(iter(trainloader))\n", "print(images.shape)" ] }, { "cell_type": "code", "execution_count": 2, "id": "796b9a5511d618a8", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhAAAAELCAYAAACI+b9MAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAKWNJREFUeJzt3QmUpHV57/F/rV3V+949KwMzwzAgi4AouLGIgCHCBY8oOTJoNIlXTTRqIDGGBHOIkSzocTtKLuQeQQjkiqBXIQiICiJrQGCEcQZmn+meXquruvZ7nlfftnuY31/eudNSM/P9nDOnuut51/9bU/X0W+/zPrF6vV53AAAAEcSjTAwAAEACAQAA9gpnIAAAQGQkEAAAIDISCAAAEBkJBAAAiIwEAgAAREYCAQAAIiOBAAAAkZFAAACAyEggGsT111/vYrGYe+SRR17pTQEA4LcigQAAAJGRQAAAgMhIIBrUpZde6lpbW93GjRvdueeeG/y8aNEi96UvfSmIP/XUU+700093LS0t7pBDDnE33njjnPlHRkbcJz7xCXf00UcH87a3t7tzzjnH/fd///dL1vXiiy+6t7/97cGy+vv73cc+9jF35513Bl+p3HfffXOmfeihh9zZZ5/tOjo6XHNzs3vzm9/sfvKTn8zzaAAAGg0JRAOrVqvBh/6SJUvc5z73Obds2TL34Q9/OLhewj7ETzzxRPeP//iPrq2tzV1yySVuw4YNM/OuX7/e3XbbbUHy8S//8i/uk5/8ZJB02Af+1q1bZ6abmpoKEpG7777b/emf/qn71Kc+5R544AF32WWXvWR77rnnHvemN73JTUxMuCuuuMJdddVVbmxsLJj/Zz/72e9sXAAADaCOhnDdddfV7XA8/PDDwe9r1qwJfr/qqqtmphkdHa1ns9l6LBar33TTTTPPr127Npj2iiuumHluenq6Xq1W56xjw4YN9aampvqVV14589w///M/B/PedtttM88VCoX6EUccETx/7733Bs/VarX6ypUr62eddVbwcyifz9cPPfTQ+plnnrnPxwQA0Lg4A9Hg3v/+98/83NnZ6VatWhV81fDOd75z5nl7zmJ21iHU1NTk4vH4zJmMXbt2BV9l2LSPPfbYzHTf//73g69G7CuMUCaTcR/4wAfmbMcTTzzhnn/+eXfxxRcHyxoeHg7+2RmMM844w91///2uVqvN2zgAABpL8pXeAGj2Qd7X1zfnObv2YPHixcH1Cbs/Pzo6OvO7fZh//vOfd1/+8peDrzYsiQj19PTMuf5h+fLlL1neihUr5vxuyYNZs2aN3N7x8XHX1dXFIQWAgwAJRANLJBKRnq/X7VuHX7HrEz796U+7973vfe4zn/mM6+7uDs5IfPSjH92rMwXhPFdffbU77rjj9jiNneEAABwcSCAOULfeeqs77bTT3L/927/Ned4ueuzt7Z353So4nnnmmSD5mH0WYt26dXPms7MUxqo53vKWt8z79gMAGhvXQByg7CzF7DMS5pZbbnFbtmyZ89xZZ50VPHf77bfPPDc9Pe2+/vWvz5nuhBNOCJKIf/qnf3K5XO4l6xsaGtrn+wAAaFycgThAWfnmlVde6d773ve6U045JSjhvOGGG9xhhx02Z7o//uM/dl/84hfdu9/9bvdnf/ZnbsGCBcF0dv2FCc9K2Ncf1157bVBWetRRRwXLtYsvLfm49957gzMTd9xxxyuyrwCA3z0SiAPUX/3VXwUVEnaDqZtvvtkdf/zx7rvf/a67/PLLX3Ldgt3f4SMf+Uhw0aX9bveUsKTjwgsvnEkkzKmnnuoefPDB4JoKSzrsTMTg4KB77WtfGyQiAICDR8xqOV/pjUDjueaaa4I7Um7evDk40wAAwGwkEHCFQsFls9k510C8+tWvDko/n3vuOUYIAPASfIUBd8EFF7ilS5cG5Zl2L4dvfOMbbu3atcG1EAAA7AkJBIJKDLtA0hIGO+tw5JFHuptuuslddNFFjA4AYI/4CgMAAETGfSAAAEBkJBAAAIAEAgAANNBFlG9486kyNjY2ImNNcd24qTutb0GxtKdZxvq6W2Sst1M3dEonUjKWbPpNGeMeJfRQjYyOyVipovexq7NDxuLVsowVi0UZsxJMJZP9zU2hdld1v+nWubt84aW3rg51dLbLmKvrZZaKJRlLOH2cVCMx0/ZbmnlZG3QlldJjU/Bsaz3mOYkXT+7V/lfqczujzvahz3xVrw8Afof4CgMAAERGAgEAACIjgQAAAJGRQAAAgMhIIAAAwPxVYTz9zNMyNjY8LGPd+uJ2F+vRwd5qm54v2y9jUzVdEZKr6oqIeiztfPLT+qr5fEFXRZSrugplOKGvts8k9bZWKnqZCc+V/01NTTKWn57S66vpfY9N98hYXBdMuLKnkiSb1K+LnKd6YaRa0St0zjU36yqMWFxXfsQ81TsurnPw/LSupKmUdSyR1McJABoFZyAAAEBkJBAAACAyEggAABAZCQQAAIiMBAIAAERGAgEAAOavjDOb1CWHzlN1doinVHPZgG4m1d/XrbfFV44X09tZKOpGU9NlXVZo6p7lprOeRlyeZlr1ml5nR7duJlYp62WmU3pbqrq3lUuk9UEslvS4lSt6XJo9y0y26O3MeOarxHS5abyuy1uDeZ3eVk9FrWtt0cciN5WXsXJFl2rGPeubnBjXQQBoEJyBAAAAkZFAAACAyEggAABAZCQQAAAgMhIIAAAQGQkEAACYvzLOTEx3Omxr04s5fFGXjPVkdbvGVE2XDuZGdEfGak3nRIW83oe4vxmna+9slbGkp+xwbHxSz+cZ/e42XTo4OaFLGUuerpoFT3fIuqfEsbVFl82WSwUZi1f1DqY8nUGrVb2dSU+9ZbGo5zPplD7I8Zp+bRRzo3qhng6vTZ5upJWaLjkdn/KXFANAI+AMBAAAiIwEAgAAREYCAQAAIiOBAAAAkZFAAACAyEggAADA/JVxdjXpSbOekrwOT9fFvvaUjFVrunWkp6mkSyQ9tXNxnS8Va/4SwKSn5jLp6QJZLeoyx3pCb8/OnWN6mWU9ApN53R0yX9Xlr63ZdhlzRb2+hNP7Ho/pEsdEk+7SWpjSJbzNKb2dybpen5me1vtfKOsyzprTyx3L6W0dy+vXVM5TUjxdJq8H0Ph4pwIAAJGRQAAAABIIAAAw/zgDAQAAIiOBAAAAkZFAAACA+Svj7OvUZXdtKV06mcnoWDyhy+OyWV3+Wa7ossKap6tkva7L+EoVfwlgtaRL8mp1T5dLT+lkPam7Q06WdFfNalWPab6qyyorntjklN6HLSN6W1Jxvcz2nD4W5e3DMlYY16WoS3tXyFh//2LnE2sbl7Hi6C4Zy+X0/o9P6jLO4XFdwvvCJr0t1cTL/m8JAK8YzkAAAIDISCAAAEBkJBAAACAyEggAABAZCQQAAIiMBAIAAET2suvFFva1yFh7WncWbG3WpYoxT/mj83RAjHm6XxYLugQw7inx7Gnr8GyLcy0tuox1YlyXJHa06+6Rk9N6/1/copeZK+oyzrQeGreo2dNRNOUpOdylO4MW63pbUp5unB3tbTJ2ypEnytjENl3CW8/7S3E7enX312Jej00up/PsppRe5pJBvY/9/QMytmNCl4YCQKPgDAQAAIiMBAIAAERGAgEAACIjgQAAAJGRQAAAgMhIIAAAwPyVcXa36e6YyZIu82tK6VU0NzXLWLGgSxzLNV022tnZJWP1ui7zK1X9uVS5rEvrmltbZWzrUFHGfvmi7sg4NKn3Ma9D7pCsLqs8/43HydjiBXofbn10vYw9uG67jFVquhNpMq6PxeTYkIzlc3o829p0SWWgqst4Mxk9b9rTUbY5puerVPWBWrpkoYy1jUzKGAA0Cs5AAACAyEggAABAZCQQAAAgMhIIAAAQGQkEAACIjAQCAADMXxlnf3ePjBVGdIljPObpcpjXpZqFki6BS8Z0WV2+XN2rbKlQ1iWHprNLd9UsVXVJ4vrNW2VsZMLTWTKpu5gmEnpP2jN6mf1JXR6YGdHlkSvbB2VsW7felh1jO2WsmNfj/fhzz8lYvKLbjZZb9DEKdOgOmC6uX6cdHbrcuK2mj/10Sb++66UJGVvm6XwLAI2CMxAAACAyEggAABAZCQQAAIiMBAIAAERGAgEAAEggAADA/OMMBAAAmL/7QHT19ulYq271HY/rdsdjE6MyVp7K6WVW9b0Oak7fJ6DuaS3e2pqRsWB7nI4/u17ft2CqOCVjmUyTjqX1tmZb9H0JuhL6/hmPrtshY5WSXl+xQ98Hoq9Lj0vM6fsylCv63iH5UkHGpvKeluwVT59z2x7fvT50p2+XiutgPa7vSZJK6jGtFPV9N+qe+4oAQKPgDAQAAIiMBAIAAERGAgEAACIjgQAAAJGRQAAAgMhIIAAAwPyVcTpPOWYspWM+TRk9X7PTLY2TnrwnHtexsqfEsynb4XyGt+tW2PlhXY56WLcucyzqSkaX8ZRqrlq+SMbinoVWEnq8JzwltcnEuIy1pfVx6ulaLmPLVy6VsQ0bH5axtc9tkbF0UpdGmnpdlwZXKvq/QtzTWj2V1mNaq+nXW81TNxqLkdcDaHy8UwEAgMhIIAAAQGQkEAAAIDISCAAAEBkJBAAAiIwEAgAAzF8ZZ2G6LGOxsu6e6JzukDg1NSFjpbLObSpxXRqZy+tyywlPbNES/1DUK3reQ3p1Sd7yhbrMLz+t51t0+LEylq7rUs3RcX2csp09MuZ26a6SSwYXyNjYlO42etgRK2WsvUuXqbZ3rZax0SF9HEbHdbmpSXlKTuN13Rm1XPN0f9WVmq5a1q99T4NPV6/TjRNA4+MMBAAAiIwEAgAAREYCAQAAIiOBAAAAkZFAAACAyEggAADA/JVxVmO6lK1erexVSVo2k5Wx1jZd5rd1SJeNbtg8JGPJlN6W9I6tzmd6h17uyn5dqnnGqbqU8ZdbRmSsbVGfjPX2DMrYzqEdMtbZ6SljrOl9SMd1iefOId0dM5kZk7GhsW0ytmWb7pqZSunXRWd7zV+KXNDHv57UuXTMU3NZ85R4xmOejpuerrFVqjgB7Ac4AwEAACIjgQAAAJGRQAAAgMhIIAAAQGQkEAAAIDISCAAAMH9lnJ2drTJWSeoyzlxOd46sl3UJ3Pik7qz44kZdqpjL6RLAbEbnS9s26M6gZiCTlrFFiw6Rsc6Fh8pYatJTdpjRZZWLjz1Jz7Zdl1VmK7oUter0cZqa0rEFzbrctFTV+xdr0a+nxS0LZaytU5ewTu7a7nx27tglY+WYHu/pUlEvNK5rLluadNfYUsFTqprW2wIAjYIzEAAAIDISCAAAEBkJBAAAiIwEAgAAREYCAQAAIiOBAAAA81fGOTmmS+CSpUkZS8U8OYpu8uiSCR3M53SJZ1eb7jjZ2aLL6gqj/jLO/oU9MrbomDfL2M83l2TsuXU6dsqCbhkbG9PzDSw/VsbiLi9jpaIu8eys63LMiZ36dZEtlWVsQbdn/6pNMpY6pkvGCp4On+Yn//d2Gdu8Se9/wltWqTtuepp/urInd4+X9bgBQKPgDAQAAJi/MxAAEFWtVnNbt251bW1tLhbTZ2sAzI96ve4mJyfdwoULXTy+b88ZkEAAmDeWPCxZsoQRBl5hmzZtcosXL96nyySBADBv7MyDufzyy93q1atdqVRyyWTSVatVl8/nXVdXl9u5c2fwWC6XXVNTU3DWorW1NbgtfXt7uxsZGXGJRCKY1x6N/UXV19fntmzZ4pqbm4NlmR07drgFCxa4xx57zB1zzDHuF7/4hctms8G0d911l7vwwgvdLbfc4k499VT3wgsvBPP09va69evXuxNPPNE9+eST7tBDD3XDw8Ous7MzOGti6xwYGHDbt28Plm1JUX9/f7BdLS0twfJtWUcccYTbvHmz6+joCLYnnU67sbExNzg46EZHR10qlQr2y+azaex3e97W96Mf/citWrUqGCMbG4vZm/309LTLZDJuYmIimGdqairYHhvDSqUSjFc4JjZPoVAIxs/+0iwWi8H42Xbb9tsybAwtZo+2bXZ8wm225dj+n3zyye6mm25yZ555ZjAutr3HH3+8u/32293555/v7r77bnfSSScFH0i2fNtP2xbb3qGhoWB/7fhYzI6psX2yn228bJpw32xe22Y7/rt27XLd3d3BsmydixYtCrbdjoONsW3nypUr3eOPPx4s2+L2aPuybNmymTNc4evE/oWvKXu0dRkbH4vZ+NhY2F/otn0Wt22y52x8x8fHg/XavDaPbZc9b8esp6fHPfvss2758uXBa8ymsW34z//8T7dmzZrgNWbj9/TTTwfjYdtm89lrZOPGjW7p0qXBONi+2XGz17Btp01j09u+2v7Z69C2zV6/Nia2v7Z94bG3mP2z7bJtDJ8Lj4ntj63vvPPOm/m/uC+RQACYN+Gbur2x2ZukvdGFCYS9cdtz9iZtj/bGbx+W9kZov9ujfeDYB4zNE35wGntztGlsepvfpg3XY7/bG6jF7NGes59t3vDRnrM37nAe+9li9mi/h/OFCYTFwuXs/hiux9YbPmf7Zz/PnjZcR/hcuK7Z22v7FCYQlmzYui1uz9nvwZv2r8dwTwmE/RwmEOEybAztMVxuGLPxDhMgW6/Na8uz52x7wvENj134Qbf7voYfVuGxtH/h72ECEX74zj7eth02r41xeJzt0eYJfw6nDecLnwt/tv2yRGD2V2S+BMKmCT9ww2RhdgIRjk/4Gg33wcbUHu15e7T1NTc3B+sPp7GfbV3h4+ztt/Xa+sLp7dGSwd3HOly2PW+PNo9th/1s+2WPtn02Xfg68yUQtj+2rNn/F/clLqIEAOwzXOty8HjZZyASnuSl6uksWPeUucWd7uJZjekyzlFPldvEhK6dqxd1+eOCDl3+aV5z2mkytnjV62Ts/1z3v2Rs0NORMlEqyNiW9b/UyzzsSBnL9KyQsZa6LsXNj+yUsWxNl1WWCrpsdHhSxzr7dAfTnsFlMlbItTufuCdcTeuOo7G4fg2Xy/o1FavobrOxuo5VKpwYBND4OAMBANhn7CsBHBxIIAAA+wxfYRw8SCAAAPsMZyAOHiQQAAAgMhIIAMA+w1cYBw8SCADAPsNXGAePl10vFvNcWFv1dA+Mee69nfSkL/WCZ5m6OaTr7mmWscFmXTZ6/ImH64U651afoks1R3fqMtamiu4cepjntqI1z04O9vfJWGVa72Pe08WzVNHzlQv6ZVJ1uhT1l1s2y9hTP39Exk55nd7OnkHdFXViUpebmpR+abjeZbqMt+Z5DVdLnnJMT9nw+NCYjBUnPRsKNDjOQBw8OAMBAAAiI4EAAACRkUAAAIDISCAAAEBkJBAAACAyEggAADB/ZZw1T2fBQlGXHKY9HSeTyZSMJeK6BG7FoO4AmcnqnGjZIUtk7Ng36G6bZsGqY2TsiQevk7GlS/S2Dh51tIyl+5bLWLK5Q8by07qktDChO27u2LpJxkZ36HLMall31cy2ZWSst1cf+01bH5exgQWLZKyS1/tu6oWijMWmRmWsWtedUeue+uZsk97H9KCOTTR5Wt8CQIPgDAQAYJ/hRlIHDxIIAAAwf19hAMDe/jVaLBZdPp93pVLJJZNJV61WXaFQCJ4LH8vlsqvVasG/eDwePGfTWjyRSATz2qMJ55meng6mtcdwPRazae25cB77uVKpzDzadLa+cB772WL2aL/bfPZod1UM57ffd3+07bNpbHpb7+yY7fvs322/UqnUzHM2Brtvr+2TPW/T5XK5memmpqaCddmjbY/9bPth2xuOic1jywrHz5ZtjzaPLcsebT32nD3aumxemyedTgc/2zzhdOH4hsfO1hUew9n7auuzbQmPZfjPticcY9sH+9nWPfvY2Xz2e1NTU/BcJpOZec62I4zZWM5+zpYdjoVtz+Tk5MwdMMNjbv9snfa7Pdo+GtvecF9te2zZtn0WD19PNr42ZuF2h/sYPh9uby6Xmxkb+9mWGT6G22px2zb7OZw+fLR12vbZ8sNl2L7YvPYYbkM2mw1+tm2zf+Gxt5j9s+2yZdiywn0JX++2rNn/F/clEggA82bXrl3B42c/+9nf+SjffPPNL3nuZz/7WfD4yCMvvZX6XXfd9TvZrv3FTTfd9JLn7rjjjuDxm9/85iuwRfuHL33pS8Hj1Vdf7RqJJSQdHfr6ub1BAgFg3nR3dwePGzdu3OdvXgeKiYkJt2TJErdp0ybX3t7+Sm9OQ2KM9n6c7MyDJQ8LFy50+xoJBIB5Y6dbjSUPfDj62fgwRozRfLyW5it5f9kJRCqhJx2d1KV81WldkpZtzspYIq6/r+n3dNzctE13OVx+/NkytvhoHfsVXY5ZnvzVd0x70tGmD1zf4cfJ2FTyV3+57cnTjz8sY8WC3paJCT02w1s2yliiqktqMxn9ulh0qC65PObwFTJWSejOmKlEp46ldQdXk/z19+R7kn9xy16VMFc8lyHnfv3d9J409+h9HFioO44CQKOgCgMAAERGAgFg3tjV6ldccUXwCMaI19GB9f+NayAAzBt7I/vbv/1bRpgx4nV0AP5/4wwEAACIjAQCAABERgIBAADm7xqIYkGXwDU36cXEMrqULRWvyFi9qmPZVr3Mt1/0dhk75ZwzZKy9d8D57Fj/rIwlPPsxNjkuY0Mv/ELGtk7q0sH7brtNxlqzusvjdFF3qxwc0OWm7W265HDDZt3Fs+QZl+6Fy2Ts8KNPkDFX1RcHjYzprqEm7ykpHi3obY3V9et7uqA70eY8t46t5/T/p9W6UhUAGgZnIAAAQGQkEADmrSfAsmXLggZJr33ta2f6UByI7r//fvf7v//7we2CrXHSbbudJbTbCf/N3/yNW7BgQdAY6S1veYt7/vnn50wzMjLi/uAP/iC4g2BnZ6f7wz/8w6CB0mxPPvmke+Mb3xiMqd2y+HOf+5zbX/zDP/yDe81rXuPa2tpcf3+/O//8890vfjH3LKw1f/rQhz7kenp6XGtrq7vwwgvdjh075kxjt0X/vd/7Pdfc3Bws55Of/GTQOGq2++67zx1//PFBVcKKFSvc9ddf7/YHX/nKV9wxxxwzcyfJk08+2X3ve99r2PEhgQAwL42s/vzP/zyoSX/sscfcscce68466yy3c+fOA3K0reOh7WPYSGl39kH/hS98wX31q191Dz30kGtpaQnGI+wiaix5ePrpp91//dd/ue985ztBUvJHf/RHc/ocvPWtb3WHHHKIe/TRR4NmTVay97Wvfc3tD374wx8GH34//elPg320bpK2P2G3SPOxj30saNh1yy23BNNv3brVXXDBBTNx6zRpH47WdfKBBx5w//7v/x58+FlyFtqwYUMwzWmnneaeeOIJ99GPftS9//3vd3feeadrdIsXLw4az9nxtYZvp59+ujvvvPOC10Ujjk+s/jJ7fH7hf54jYyM7XtQryOjbTqd+fZ/8qNdAJNJ6vhXHv17GTr9gjYx19C51e3sNxE/v+N8yVhnXt0g+4Y2n7dU1EHc21DUQ22WsZ9ESGTvh5DfL2OFHv3HvroHYpI+RuefbX5Gx5599aXfG//9rIPQ1F/WsXubqY5fL2MevesjtD+yMg/21+cUvfjH43VoN21/MH/nIR9zll1/uDmR2BuJb3/pW8Be2sbdYOzPx8Y9/3H3iE58InhsfH3cDAwPBm/u73vUu9+yzz7ojjzzSPfzww+7EE08Mpvn+97/v3va2t7nNmzcH89tfp5/61Kfc9u3bZ1pT21ja2Y61a9e6/c3Q0FDwF7J9EL7pTW8KxqSvr8/deOON7h3veEcwje3X6tWr3YMPPuhe97rXBX+Nn3vuucEHp42fsaTssssuC5Zn42I/f/e733U///nPZ9ZlYzw2NhaM6f7YkO7qq68OxqTRxoczEAD2Kfvrx/6CstP0M2808Xjwu73RHWzsLz770J89HtbcyJKscDzs0b62CJMHY9PbuNkZi3Aa+6ANkwdjZzHsa4DR0VG3v7GEYXbHVnvN2FmJ2eN0xBFHuKVLl84Zp6OPPnrmwzEcAzs7E/6VbtPMXkY4zf722qtWq0FLdTtDY19lNOL4kEAA2KeGh4eDN7/Zb2LGfrcP0oNNuM++8bBH+2t8tmQyGXy4zp5mT8uYvY79hZ2RslPnr3/9692rXvWq4LnwzIolUr5x+m1joKaxD9FCoeAa3VNPPRVc32DXJ/zJn/xJcDbLzk414vi87DLOWl13ZHQ1fbo9VtGneCt13T0xFtPfrGSaftOmdHfHnaBLAJtS+vT+M0887nxGt/5SxopFXZI3OToiY5vWPSNjubruVJqq6vW1JnWJa3tGfxXR16W/wti2Q785Vcr6GOYn9Vcmmzbo7p/O/SpT3pNcblLGMkn/t3GVprlv0LPtqujXVDabkbHmNn2cskn9dctkfkLGKjX99R1wILBrIewU+o9//ONXelMazqpVq4JrE+wMza233urWrFkTfM3TiDgDAWCf6u3tdYlE4iVXh9vvg4ODB91oh/vsGw973P0CU7ty3iozZk+zp2XMXsf+4MMf/nBwkei9994bXDQYsn2wr7/su3jfOP22MVDTWFWDVcA0unQ6HVRGnHDCCUHlil2c+/nPf74hx4cEAsA+fwO0N78f/OAHc05Z2+/2Xe7B5tBDDw3etGePh50utmsbwvGwR/tgsO+5Q/fcc08wbnatRDiNVWbY9+Ahq2awv1i7urpco7OLSS15sFPytm82LrPZayaVSs0ZJ7u+w8oSZ4+TneKfnWzZGNiHn53mD6eZvYxwmv31tVer1VyxWGzI8aEbJ4B9zko47dSrXRR40kknuWuuuSa4GOy9733vATnadr+GdevWzblw0k5D2zUMdpGbfd//93//927lypXBB+enP/3poLIirNSwK+nPPvts94EPfCC4at6SBPuwtavjbTpz8cUXu7/7u78L7g9hV9LbVwD2l+m//uu/uv3lawurIPj2t78d3Asi/E7eLii1v3zt0fbNXjs2bvahZ1U79sFmFQbGyj7tg/A973lPUBpry/jrv/7rYNlhC2u7bsCqf/7iL/7Cve997wuSlf/4j/8IKg8a3V/+5V+6c845J3jNTE5OBuNl92ywEstGHB8SCAD73EUXXRSUjVn9ub2JHXfccUGJ2O4Xbx0orGbf6upD9iZvLImyUk17s7YEyu7rYGca3vCGNwTjYTeECt1www1B0nDGGWcE1Rd2kyC7d0TIPkDuuuuu4MPA/hq1r4psfGffK6KRWRmqOfXUU+c8f91117lLL700+NmSoXDf7a9uqw748pe/PDOtfTVmX3988IMfDD447X4aNsZXXnnlzDSWoNmHod0zwRIs+5rk2muvDZbV6Hbu3OkuueQSt23btuB4202lLHk488wzG3J8XvZ9IK75oO4jMbZd9yCIp/WFe65e3auLKJs79QVv512i/zMtWPFqGVu/YfteX0S55ec/kbHJbXPvNjfb4Ueu3quLKB/9yQMy1tOpxzue1PclGFjQs1cXUe6a0FfttvXoixaXrTxaxpYcety8XET55KP6gq1HHvnx3l1E2aSPU3wvL6JcvFK/vi/77G9OcQPAK4lrIAAAwHx+haHLMWsVXeKZTOk7UVYr+gxEyelStoEOfcHQnbd/R8a6B3R5YP8CfdfEYHvyuqtmKqX/0mxt0X9NJuO65LLFU3I62K/PFhQm9Q1lsgm9nbuGhmWsXNLHqS2j/wIv7XYf/9mef1zf+XHb2udkrFjx1Cmn9Hiaqm+8F3vOlLXo13e8SZfUZjzlmF1Oj9vqo+ZeXAYAjYgzEAAAIDISCAAAEBkJBAAAiIwEAgAAREYCAQAAIiOBAAAAkb38bpw1fROitKcDZCapyz9dXC+zntBldbWS7gA5PKxvepQb0rFsWd/YJ1in0/vY3aXLKjsX9slYpVqUsS1b9bbWnb5hUjyuD2mpossKEzFdNtqS0aW4nmarLuELem4UVi3pktm453U4kdclrKbU5Lnp1UJ9LKayc5vXzDZZ0yWe01M6P+9pP0zGej1lugDQKDgDAQAAIiOBAAAAkZFAAACAyEggAABAZCQQAAAgMhIIAAAwf2Wc8Zju5Jhp0p0F656umi1ZXR7Y0tYrY/my7oDY05aWsaRnW0rjO5xPLa6Xm0/pcsWBAd1ZsVbSJYCrjlksYw/c+wMZK9XzMpaK6RLIQk7P196mO4qmk/ollIjpcclN62O4YZsuxxwb08ewGJtyPn2H63x5Uaenq2hdH/vRYT1u6WlPaewiT0fVvO5+CgCNgjMQAAAgMhIIAAAQGQkEAACIjAQCAABERgIBAAAiI4EAAADzV8aZTupcI1/UnQwTGU9XzYQuDc2XdefEREp3cmxK63K8VEpvS7q5w/l0tOt5tw/pEtD8Il2O2b9khYxt2TksY0e95vUylhvaKmPrn3taxqZyuuNkMqGPRUeHLvGMOV3GuW2L3s6NL3q6cTbp49A+oMuCTV+3Z1s9ZaWxEb3OrlH9X2hRf7eMLe7Ur4t1z+hOrKf9DxkCgN8pzkAAAIDISCAAAEBkJBAAACAyEggAABAZCQQAAIiMBAIAAERGAgEAAObvPhADfTrXKO/aJWOFqr4XwJSn+3I9rlsaJz0tpNvbdZvkdEq3Vy5MTeiNcc5lU56hKunYIw88IGOHrdL3j9i8Wd8LIB7Xbbmbm/Q+Jjz33chm9b0OpnL6PhCFgo5VKrpdeWtWb8sprz5cxjKe1uKVhG71bapl3Xq7sEnfByI+mZGx/uY2GXv14Ufp+ToHZOzRbRtkDAAaBWcgAABAZCQQAAAgMhIIAAAQGQkEAACIjAQCAABERgIBAADmr4xz6ZK0jHXEdJnbuk26dG7HkG7LXarqMr/WVr3ZU3ndCrpay8lY4rfkUiNDulR1MqfLB6fLensSdR1ra+2SsR3bR2Rs85QuR6zVdfnnQJ8uf43VyjI2OjYqY00t+hh2dujyx3RCH4tiSZf3uqQuYTVTRb3cUk7P21LT861YMihjCwf1mG7arEt4dw3p/zMA0Cg4AwEAACIjgQAAAJGRQAAAgMhIIAAAQGQkEAAAIDISCAAAMH9lnO1dnk6WnrKzrv6EXmhLswwN7yjK2HRJd3lMpnW3Rs9srlb2lAdax9Gq3p7xgi5lbPF0nZzO65LLwvSwjJU821r1xOp1fSxyE/oYtrdnPbEOGSsU9DKHd+kxa23VnUFjcZ3zxiq6LNikk3o/mnQlskun9bgtW7FMxgp5vT333/+MjD353E69MQDQIDgDAQAAIiOBAAAAkZFAAACAyEggAABAZCQQAAAgMhIIAAAwf2WcyYyeNNOuO3V2t+ocJVnQpZGpbE3GJkY9m13V68tm+vVsKb2+IF4ck7F0s96eVFKPTSKhy1iLdb09pbKuR617Om7GPFWO9ZIuKa3qkEv5OmCmdQnr2Kgu4yyUdPfPjk5dppv0lHiauOdY5J3uqLpjeFLGRj2dWCendLfVu+9bq9dHM04A+wHOQAAAgMhIIAAAQGQkEAAAIDISCAAAEBkJBAAAiIwEAgAAzF8ZZy7nKddLtMpQa4uuAUxldV1hi6c9YkeHLnHMTRQ8sR06lv8t3Tindbwt3SNjmZQet0pRl7Emkzq3S3vSvlST7hwZi+kZm1v1SyHueZVUqrqMMZ3VM7Z36hLWkRFdNjnpKW9t79bHweQruvz1+Rd2ydjapzbJ2EC3LisdWKz30cX1fvR2tOn5AKBBcAYCAABERgIBAAAiI4EAAACRkUAAAIDISCAAAEBkJBAAAGD+yjg3v6hjxTFdctnWp8v8MllP10VdGeq6u/Vm56Z0K8OxMR0b3aU7Nf4qrmOJmi6drNV1qWq16ikdrVX3KuuLxXU3zkRSj1vB08W0rg+hS9X0MazkR2SsWtDHourp8DmW0/OV/JW4bsRT4vvCOn2Ax3ZN6XVO6ZUOdgzK2OpDFsmYZzMBoGFwBgIAAERGAgEAACIjgQAAAJGRQAAAgMhIIAAAQGQkEAAAYP7KOKupXhkrp0+UsWJNd5yMV4ZlLNOhyxE7+3TZaFdc1xx253UHxLGRrIwF8WFdqlmY0sNYrXjKQ+s6f6tV9LZOF3SH03Rary+R1PswOa3XV8h5OqrWdYfLtrjuKlmLT8hYuazHs6lFl8VmUk3OpzOtt/Uw1yljRx/bImOrjjlWxpatWCFjJ71Ol6Nu3pqTMQBoFJyBAAAAkZFAAACAyEggAABAZCQQAAAgMhIIAAAQGQkEAACILFave9pFAgAA7AFnIAAAQGQkEAAAIDISCAAAEBkJBAAAiIwEAgAAREYCAQAAIiOBAAAAkZFAAACAyEggAACAi+r/AW75lQ/qbHA6AAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from matplotlib import pyplot as plt\n", "import numpy as np\n", "\n", "\n", "plt.subplot(1, 2, 1)\n", "plt.title(\"Image\")\n", "plt.imshow(images[0].permute((1, 2, 0)))\n", "plt.axis('off')\n", "plt.subplot(1, 2, 2)\n", "plt.imshow(np.hstack(([images[0].flatten()] * 300)).reshape(300, -1), cmap=\"Greys\"); # To improve the width\n", "plt.yticks([]);" ] }, { "cell_type": "markdown", "id": "55a3886927dc5416", "metadata": {}, "source": [ "As you can see it we make a complex problem even worse. Now that we're working with image we need models that can understand **spatial patterns** in images -such as edges, textures, and shapes. This is where **Convolutional Neural Networks (CNNs)** come in.\n", "\n", "---\n", "\n", "CNNs are a special class of neural networks designed specifically for processing **grid-like data**, such as images. Unlike traditional fully connected networks, CNNs take advantage of the **2D structure** of images. They can recognize patterns that occur in small regions of an image and reuse that knowledge across the whole image. They were first introduce in the 80s by [LeCun *et al.*](https://ieeexplore.ieee.org/abstract/document/6795724) alongside the `MNIST` dataset. A CNN has two main parts:\n", "\n", "\n", "\n", "\n", "### Key Building Blocks of a CNN\n", "\n", "1. **Convolutional Layers**: These layers use learnable **filters (kernels)** that slide over the input (e.g., an image) to extract local features such as edges, textures, or shapes. Each filter generates a corresponding **feature map** that highlights the presence of specific patterns across the spatial dimensions.\n", " \n", "2. **Pooling Layers**. These layers **downsample** the feature maps by summarizing small regions (e.g., taking the max value). Pooling helps reduce the spatial size and the number of parameters, making the model faster and more robust.\n", "\n", "3. **Predictor**. We can use any machine learning model after the convolutional part. We usually use a MLP.\n", "\n", "**The most important thing to take into account**: The bigger the better!\n", "\n", "Let's load our first CNN!:\n" ] }, { "cell_type": "code", "execution_count": null, "id": "8bd5ef13", "metadata": {}, "outputs": [], "source": [ "import torchvision.models as models\n", "\n", "# Load pretrained ResNet50 model\n", "model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)\n", "model.eval() # Set to evaluation mode (disables dropout, batch norm, etc.)" ] }, { "cell_type": "markdown", "id": "a152d77c", "metadata": {}, "source": [ "**TorchVision** is an official library that is part of the **PyTorch ecosystem**. It provides:\n", "\n", "1. **Pre-trained CNN Models** (`torchvision.models`): Access to state-of-the-art architectures pre-trained on ImageNet, ready to use out-of-the-box:\n", " - `ResNet`, `VGG`, `Inception`, `DenseNet`, `EfficientNet`, `MobileNet`, `Vision Transformers`, and many more\n", " - Models come with their **pre-trained weights already downloaded and loaded**\n", " - You can easily switch between different architectures: `models.resnet50()`, `models.vgg16()`, `models.efficientnet_b0()`, etc.\n", "\n", "2. **Image Transforms** (`torchvision.transforms`): A complete pipeline for image preprocessing and data augmentation\n", " - Resizing, cropping, rotating, flipping, color adjustments, etc.\n", " - Composition of multiple transforms for efficient data pipelines\n", "\n", "3. **Datasets** (`torchvision.datasets`): Easy access to popular computer vision datasets\n", " - `MNIST`, `CIFAR10`, `CIFAR100`, `ImageNet`, `COCO`, and more\n", "\n", "In the code above, we used `torchvision.models.resnet50(weights=...)` which:\n", "- Automatically downloads the pre-trained ResNet50 model architecture\n", "- Loads the weights trained on ImageNet (1000 classes, 14M images)\n", "- Gives us a model ready for **immediate inference or fine-tuning** on our own tasks\n", "\n", "This is the power of torchvision -it abstracts away the complexity of downloading, managing, and loading pre-trained models, making transfer learning accessible to everyone.\n", "\n", "\n", "
We must adapt the images to the model!
" ] }, { "cell_type": "code", "execution_count": null, "id": "6e48ef6e", "metadata": {}, "outputs": [], "source": [ "# Prepare ImageNet normalization transform\n", "imagenet_transform = transforms.Compose([\n", " transforms.Resize(256),\n", " transforms.CenterCrop(224),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", " std=[0.229, 0.224, 0.225])\n", "])\n", "\n", "# Get a sample image and prepare it\n", "sample_image = images[0].clone() # Take first image from batch\n", "sample_image_normalized = imagenet_transform(sample_image)\n", "sample_input = sample_image_normalized.unsqueeze(0) # Add batch dimension\n" ] }, { "cell_type": "code", "execution_count": null, "id": "04270f44", "metadata": {}, "outputs": [], "source": [ "# Make inference\n", "output = model(sample_input)\n", "top_probs, top_classes = torch.topk(output, 5)\n", "\n", "print(top_classes)" ] }, { "cell_type": "markdown", "id": "ebd976f1", "metadata": {}, "source": [ "To understant the output we must check the Imagenet class [list](https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/)." ] }, { "cell_type": "markdown", "id": "68ecd95175bec246", "metadata": {}, "source": [ "### Exercise\n", "\n", "1. Load an image from your compute, apply the transformer and make a prediction with the model. It is correct?\n", "\n", "**Note:** To load an image you can use the library PIL:\n", "\n", "```python\n", "\n", "from PIL import Image\n", "\n", "img = Image.open()\n", "\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "d0d4b4b4", "metadata": {}, "outputs": [], "source": [ "#TODO " ] }, { "cell_type": "markdown", "id": "be430d95", "metadata": {}, "source": [ "
\n", "\n", "
\n", " II. Transfer learning\n", "
\n", "\n", "\n", "## Transfer Learning: Plug and Play\n", "\n", "The beauty of pre-trained models is **transfer learning** -you can use a model trained on ImageNet for your own task:" ] }, { "cell_type": "markdown", "id": "857da548", "metadata": {}, "source": [ "### The Superpower of Transfer Learning\n", "\n", "1. **Speed**: Train in minutes or hours instead of weeks\n", "2. **Accuracy**: Start from learned features, not random weights\n", "3. **Small Datasets**: Works well even with limited labeled data\n", "4. **Easy to Use**: Load, adapt, train -just a few lines of code\n", "\n", "This is why pre-trained models are a game-changer in computer vision. Whether you're classifying animals, detecting objects, or segmenting images, pre-trained models provide an incredible starting point.\n", "\n", "Let's see how we can do it. We will modify our ResNet50 to work with CIFAR10. There are a set of steps we must do:\n", "\n", "1. Load the model and change the last layer. This will make that the prediction is the one we desired:" ] }, { "cell_type": "code", "execution_count": null, "id": "fc0db35b", "metadata": {}, "outputs": [], "source": [ "# Fine-tune ResNet50 on CIFAR10\n", "\n", "# Load pretrained ResNet50\n", "resnet_finetuned = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)\n", "\n", "# Modify the final classification layer for CIFAR10 (10 classes instead of 1000)\n", "num_classes_cifar10 = 10\n", "resnet_finetuned.fc = torch.nn.Linear(resnet_finetuned.fc.in_features, num_classes_cifar10)" ] }, { "cell_type": "markdown", "id": "bb1d6433", "metadata": {}, "source": [ "2. We freeze the original layers. When we train the model we want that the first layers remain the same. The first layers detect general patterns (lines, curves, color, etc.)." ] }, { "cell_type": "code", "execution_count": null, "id": "1ccef37e", "metadata": {}, "outputs": [], "source": [ "# Strategy: Freeze early layers, fine-tune later layers\n", "# Early layers capture general features, later layers are task-specific\n", "for name, param in resnet_finetuned.named_parameters():\n", " if \"fc\" not in name:\n", " param.requires_grad = False # Freeze early layers" ] }, { "cell_type": "markdown", "id": "db515d24", "metadata": {}, "source": [ "3. We train the model. It is the first time to train a CNN it have some differences. \n", "\n", " - First we want to use our GPU, because it is more computationally intensive train a CNN than a MLP." ] }, { "cell_type": "code", "execution_count": null, "id": "5f9c2887", "metadata": {}, "outputs": [], "source": [ "# Move to device and set to training mode\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "resnet_finetuned = resnet_finetuned.to(device)\n", "resnet_finetuned.train()" ] }, { "cell_type": "markdown", "id": "3da83e6c", "metadata": {}, "source": [ "## Summary: Loss Functions in Machine Learning\n", "\n", "A **loss function** (or cost function) is a mathematical tool used to measure the \"distance\" between a model's predictions and the actual target values. The goal of training is to minimize this value to improve accuracy.\n", "\n", "### Core Functions & Use Cases\n", "\n", "| Category | Loss Function | Best For... |\n", "| :--- | :--- | :--- |\n", "| **Classification** | **Cross Entropy** | Multi-class tasks (most common). |\n", "| **Classification** | **Binary Cross Entropy** | Two-class tasks (Yes/No). |\n", "| **Regression** | **Mean Squared Error (MSE)** | General regression; penalizes large errors. |\n", "| **Regression** | **Mean Absolute Error (L1)** | Data with outliers; more robust. |\n", "\n", "\n", "### Key Takeaways\n", "* **Optimization**: Optimizers use the gradients of the loss function to adjust model weights. \n", "* **Numerical Stability**: In PyTorch, using `BCEWithLogitsLoss` is preferred over a manual Sigmoid + BCE combo to prevent mathematical errors.\n", "* **Data Characteristics**: The choice of loss should be driven by your data; for instance, use **L1 Loss** if your dataset is noisy or contains many outliers.\n", "* **Task Alignment**: Classification typically relies on probability-based losses, while regression relies on distance-based metrics.\n", "\n", "In our fine-tuning example above, we used `torch.nn.CrossEntropyLoss()` because CIFAR10 is a multi-class classification task with 10 classes." ] }, { "cell_type": "code", "execution_count": null, "id": "903ee94a", "metadata": {}, "outputs": [], "source": [ "# Define loss function and optimizer\n", "loss_fn_finetuned = torch.nn.CrossEntropyLoss()\n", "# Use lower learning rate for fine-tuning (we're starting from good weights)\n", "optimizer_finetuned = torch.optim.Adam(\n", " resnet_finetuned.parameters(), \n", " lr=0.001\n", ")" ] }, { "cell_type": "markdown", "id": "c94f8c5a", "metadata": {}, "source": [ "We must adapt our data to ImageNet so we load again the dataset and dataloader using the Imagenet transform operations, previously defined." ] }, { "cell_type": "code", "execution_count": null, "id": "d7d57a2c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/miquelmn/Desenvolupament/01 - Docencia/AppOC/.venv/lib/python3.12/site-packages/torchvision/datasets/cifar.py:83: VisibleDeprecationWarning: dtype(): align should be passed as Python or NumPy boolean but got `align=0`. Did you mean to pass a tuple to create a subarray type? (Deprecated NumPy 2.4)\n", " entry = pickle.load(f, encoding=\"latin1\")\n" ] } ], "source": [ "trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=imagenet_transform) # download must be True if is the first time you execute this notebook\n", "\n", "# Load the entire dataset into memory\n", "trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=False)" ] }, { "cell_type": "markdown", "id": "571e6b57", "metadata": {}, "source": [ "## The CNN Training Loop\n", "\n", "The **training loop** is the core of the learning process. Here's what happens in each iteration:\n", "\n", "### Step-by-Step Breakdown\n", "\n", "1. **For each epoch** (complete pass through the entire dataset):\n", " - Initialize metrics (loss, accuracy) to track performance\n", "\n", " ```python\n", " for epoch in range(epochs_finetune):\n", " running_loss = 0.0\n", " running_acc = 0.0\n", " ```\n", "\n", "2. **For each batch** from the DataLoader:\n", " - **Forward Pass**: Feed images through the CNN to get predictions\n", "\n", " ```python\n", " outputs = model(images) # CNN predicts classes\n", " ```\n", " \n", " - **Compute Loss**: Calculate how wrong the predictions are compared to true labels\n", "\n", " ```python\n", " loss = loss_fn(outputs, labels) # Quantify error\n", " ```\n", " - **Backward Pass**: Compute gradients using backpropagation\n", "\n", " ```python\n", " optimizer.zero_grad() # Clear previous gradients\n", " loss.backward() # Compute gradients via chain rule\n", " ```\n", " \n", " - **Optimizer Step**: Update model weights to reduce loss\n", "\n", " ```python\n", " optimizer.step() # Move weights in direction that reduces loss\n", " ```\n", " \n", " - **Track Metrics**: Update running loss and accuracy\n", " \n", " ```python\n", " running_loss += loss.item()\n", " _, predicted = torch.max(outputs, 1) # Get highest probability class\n", " running_acc += (predicted == labels).sum().item() / labels.size(0)\n", " ```\n", "\n", "3. **After each epoch**:\n", " - Calculate average loss and accuracy\n", " - Print progress\n", " - Repeat for next epoch\n", "\n", "### Key Concepts\n", "\n", "- **Forward Pass**: Data flows through the network to produce predictions\n", "- **Loss Computation**: Measures how far predictions are from ground truth\n", "- **Backward Pass**: Computes gradients showing how to adjust weights\n", "- **Gradient Descent**: Optimizer uses gradients to update weights in the direction that reduces loss\n", "- **Epochs**: Complete passes through the entire dataset; multiple epochs allow the model to learn iteratively\n", "\n", "\n", "This loop repeats thousands of times, gradually improving model accuracy through incremental weight adjustments.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "23c010cf", "metadata": {}, "outputs": [], "source": [ "# Fine-tuning loop\n", "epochs_finetune = 5\n", "for epoch in range(epochs_finetune):\n", " running_loss = 0.0\n", " running_acc = 0.0\n", " \n", " for batch_idx, (images, labels) in enumerate(trainloader):\n", " images = images.to(device)\n", " labels = labels.to(device)\n", " \n", " # Forward pass\n", " outputs = resnet_finetuned(images)\n", " loss = loss_fn_finetuned(outputs, labels)\n", " \n", " # Backward pass\n", " optimizer_finetuned.zero_grad()\n", " loss.backward()\n", " optimizer_finetuned.step()\n", " \n", " # Metrics\n", " running_loss += loss.item()\n", " _, predicted = torch.max(outputs.data, 1)\n", " running_acc += (predicted == labels).sum().item() / labels.size(0)\n", " \n", " if (batch_idx + 1) % 100 == 0:\n", " avg_loss = running_loss / (batch_idx + 1)\n", " avg_acc = running_acc / (batch_idx + 1)\n", " print(f\"Epoch [{epoch+1}/{epochs_finetune}], Step [{batch_idx+1}/{len(trainloader)}], \"\n", " f\"Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}\")\n", " \n", " avg_loss = running_loss / len(trainloader)\n", " avg_acc = running_acc / len(trainloader)\n", " print(f\"Epoch {epoch+1}/{epochs_finetune} completed - Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}\\n\")\n", "\n", "print(\"Fine-tuning complete!\")" ] }, { "cell_type": "markdown", "id": "a0a5a9ec", "metadata": {}, "source": [ "The training loop have become more complex, with a second loop, however, to do a prediction with an already trained model is as easy as in more simple models:" ] }, { "cell_type": "code", "execution_count": null, "id": "8eb3d8c2", "metadata": {}, "outputs": [], "source": [ "resnet_finetuned(images)" ] }, { "cell_type": "markdown", "id": "06d31beb", "metadata": {}, "source": [ "## Saving and Loading Model Weights\n", "\n", "In practice, trained model weights are usually **shared separately** from the code. This is done for several reasons:\n", "\n", "1. **Large file sizes**: Pre-trained models can be very large (hundreds of MB to GBs), making it impractical to include them in code repositories.\n", "2. **Easy distribution**: Weights can be hosted on separate servers or repositories (like HuggingFace, PyTorch Hub, Kaggle) and downloaded on-demand.\n", "3. **Version control**: Code and weights can be versioned independently.\n", "\n", "PyTorch provides simple methods to save and load model weights:\n", "\n", "- **`torch.save()`**: Saves model state (weights) to disk\n", "- **`torch.load()`**: Loads model state from disk\n", "- **`.state_dict()`**: Returns a dictionary of all model parameters\n", "\n", "This allows you to train a model once, save its weights, and then load them in any environment without retraining." ] }, { "cell_type": "code", "execution_count": null, "id": "1a80ae85", "metadata": {}, "outputs": [], "source": [ "# Save the trained model weights\n", "torch.save(resnet_finetuned.state_dict(), 'resnet_finetuned_weights.pth')\n", "\n", "# Load the model weights into a new model\n", "model_loaded = models.resnet50(pretrained=False)\n", "model_loaded.fc = nn.Linear(model_loaded.fc.in_features, 10) # Adjust for CIFAR10\n", "model_loaded.load_state_dict(torch.load('resnet_finetuned_weights.pth'))" ] }, { "cell_type": "markdown", "id": "d770f4e1", "metadata": {}, "source": [ "## Exercise \n", "\n", "Hopefully now our method is working with CIFAR10. With this exercise we will test if that is true:\n", "\n", "1. Create a dataset and a dataloader for the test set of CIFAR10.\n", "2. Obtain the accuracy of the previosly trained model. Is good enough?\n", "\n", "3. In the following link you fill find a ResNet50 weight file. Load it and verify its accuracy. (Weights download from [here](github.com/huyvnphan/PyTorch_CIFAR10/tree/master/cifar10_models)). You will have to adapt our model to this weights." ] } ], "metadata": { "kernelspec": { "display_name": "appoc (3.12.11)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" } }, "nbformat": 4, "nbformat_minor": 5 }