{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "23812cec-0eec-435e-8a6d-3be594896d6c",
   "metadata": {},
   "source": [
    "# Modèle de langue neuronal\n",
    "\n",
    "15 décembre 2025\n",
    "\n",
    "Adapté du tutoriel d'A. Karphathy \"Makemore\", deuxième partie: https://www.youtube.com/watch?v=PaCmpygFfXo"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01fde97b-7c23-46ff-adf1-ee0f7b1296e4",
   "metadata": {},
   "source": [
    "## Jeu de données: les mots du code civil"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "591d8bd9-2d13-493e-9c47-5a99e1afa965",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CTOI = {\"'\": 1, '-': 2, 'a': 3, 'b': 4, 'c': 5, 'd': 6, 'e': 7, 'f': 8, 'g': 9, 'h': 10, 'i': 11, 'j': 12, 'l': 13, 'm': 14, 'n': 15, 'o': 16, 'p': 17, 'q': 18, 'r': 19, 's': 20, 't': 21, 'u': 22, 'v': 23, 'w': 24, 'x': 25, 'y': 26, 'z': 27, 'à': 28, 'â': 29, 'ç': 30, 'è': 31, 'é': 32, 'ê': 33, 'ë': 34, 'î': 35, 'ï': 36, 'ô': 37, 'ù': 38, 'û': 39, 'œ': 40, '.': 0}\n",
      "ITOC = {1: \"'\", 2: '-', 3: 'a', 4: 'b', 5: 'c', 6: 'd', 7: 'e', 8: 'f', 9: 'g', 10: 'h', 11: 'i', 12: 'j', 13: 'l', 14: 'm', 15: 'n', 16: 'o', 17: 'p', 18: 'q', 19: 'r', 20: 's', 21: 't', 22: 'u', 23: 'v', 24: 'w', 25: 'x', 26: 'y', 27: 'z', 28: 'à', 29: 'â', 30: 'ç', 31: 'è', 32: 'é', 33: 'ê', 34: 'ë', 35: 'î', 36: 'ï', 37: 'ô', 38: 'ù', 39: 'û', 40: 'œ', 0: '.'}\n"
     ]
    }
   ],
   "source": [
    "words = open('civil_mots.txt', 'r').read().splitlines()\n",
    "chars = sorted(list(set(''.join(words))))\n",
    "nb_chars = len(chars) + 1  # On ajoute 1 pour EOS\n",
    "ctoi = {c:i+1 for i,c in enumerate(chars)}\n",
    "ctoi['.'] = 0\n",
    "print(\"CTOI =\", ctoi)\n",
    "# Dictionnaire permettant permettant de passer d'un entier à son caractère\n",
    "itoc = {i:s for s,i in ctoi.items()}\n",
    "print(\"ITOC =\", itoc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a20eda5-2300-458e-befd-ca4b75772bb3",
   "metadata": {},
   "source": [
    "## Approche par réseau de neurones reproduisant l'approche par comptage"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2322e3b-d394-4190-b4fe-fc229966569d",
   "metadata": {},
   "source": [
    "### Représentation des mots avec des vecteurs \"one-hot\": exemple avec un seul mot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "678b1d9e-7cda-4985-9b92-7e8737864cc3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ". a -> 0 3\n",
      "a c -> 3 5\n",
      "c c -> 5 5\n",
      "c e -> 5 7\n",
      "e p -> 7 17\n",
      "p t -> 17 21\n",
      "t é -> 21 32\n",
      "é e -> 32 7\n",
      "e . -> 7 0\n",
      "acceptée\n",
      "tensor_dims = 9\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "# Création d'un jeu d'entrainement de bigrams (x,y)\n",
    "xs, ys = [], []\n",
    "\n",
    "for w in [words[40]]:\n",
    "  chs = ['.'] + list(w) + ['.']\n",
    "  for ch1, ch2 in zip(chs, chs[1:]):\n",
    "    ix1 = ctoi[ch1]\n",
    "    ix2 = ctoi[ch2]\n",
    "    print(ch1, ch2, '->', ix1, ix2)\n",
    "    xs.append(ix1)\n",
    "    ys.append(ix2)\n",
    "    \n",
    "xs = torch.tensor(xs)\n",
    "ys = torch.tensor(ys)\n",
    "print(words[40])\n",
    "tensor_dims = len(words[40]) + 1\n",
    "print(\"tensor_dims =\", tensor_dims)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "be05261e-f1c6-4865-b7a3-add32bbd22d6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0,  3,  5,  5,  7, 17, 21, 32,  7])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f87f46c3-1504-41a4-8026-2fff6a23fd7b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 3,  5,  5,  7, 17, 21, 32,  7,  0])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "87166434-340e-4276-a549-754ca1783014",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,\n",
       "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "         0., 0., 0., 0., 0.]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Représentation de chaque caractère par un vecteur one-hot\n",
    "# seul une composante est à 1.0, correspondant à l'indice du numéro du caractère\n",
    "import torch.nn.functional as F\n",
    "xenc = F.one_hot(xs, num_classes=nb_chars).float()\n",
    "xenc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d23338ef-0acc-49dd-9832-ac36b9972ede",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([9, 41])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# La première dimension est la dimension du tenseur exemple\n",
    "xenc.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0d784a4c-ad47-4d8b-b207-d5ac42ddcfa1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x10f450830>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhkAAACYCAYAAABJafvfAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAADZVJREFUeJzt3X1MleUfx/HvAQQzFTKMh0SUzJxmODUZK8sF01prZv1h5R9UDqfhZpmtbEusf3C1sR7mtK0V/zQxW+ZqvywfaRZkYs4enTgSGCLpxoOaqHD99r36QRzEn6LnEs513q/tFg7nBu+L7324P+e6rvu+A8YYIwAAACEWFeofCAAAQMgAAADO0JMBAACcIGQAAAAnCBkAAMAJQgYAAHCCkAEAAJyIkeuoo6ND6uvrZdiwYRIIBK7nfw0AAK6SXlKrtbVVUlNTJSoqamCGDA0YaWlp1/O/BAAAIVJbWyujRo0amCFDezDU0f1jZPjQK0tC88ZPdrxVAADg/7kg52WP/KfrOD4gQ0bnEIkGjOHDrixkxAQGOd4qAADwf/3vBiR9nepwVRM/165dK2PGjJHBgwdLVlaW7N2792p+DAAA8FifQ8bGjRtl+fLlUlhYKPv375fMzEyZM2eONDY2utlCAAAQGSGjuLhY8vPz5ZlnnpGJEyfK+vXrZciQIfLhhx9etG5bW5u0tLQELQAAIDL0KWScO3dOKisrJTc3998fEBVlH5eXl1+0flFRkcTHx3ctnFkCAEDk6FPIOHHihLS3t0tSUlLQ1/VxQ0PDReuvXLlSmpubuxY99QUAAEQGp2eXxMXF2QUAAESePvVkJCYmSnR0tBw/fjzo6/o4OTk51NsGAAAiJWTExsbKtGnTZMeOHUGXCtfH2dnZLrYPAABEynCJnr6al5cn06dPlxkzZsjbb78tp0+ftmebAAAAXHXImD9/vvz111+yatUqO9lzypQpsnXr1osmg17uUuED5UqeX9cf6NP6c1KnONsWAAAk0id+Ll261C4AAAAhvaw4AADA5RAyAACAE4QMAADgBCEDAAA4QcgAAABOEDIAAIAThAwAAOAEIQMAADhByAAAAE4QMgAAwMC5rLhPXN+LpK/3RlHcHwUA4AN6MgAAgBOEDAAA4AQhAwAAOEHIAAAAThAyAACAE4QMAADgBCEDAAA4QcgAAABOEDIAAIAThAwAAOAEIQMAADhByAAAAE4QMgAAgBOEDAAA4AQhAwAAOEHIAAAAThAyAACAE4QMAADgBCEDAAA4QcgAAABOEDIAAIAThAwAAOAEIQMAADgR4+bHotOc1CnOfxlf1x8YcNsEAAA9GQAAwAlCBgAAcIKQAQAACBkAACB80JMBAACcIGQAAAAnCBkAAMAJQgYAAHCCkAEAAJwgZAAAACcIGQAAwAlCBgAAcIIbpHmAG575j5vgAQhH9GQAAID+DxmrV6+WQCAQtEyYMMHNlgEAgMgaLpk0aZJs37793x8Qw4gLAAC4WJ8TgoaK5OTkvn4bAACIMH2ek3H48GFJTU2VjIwMWbBggdTU1Fxy3ba2NmlpaQlaAABAZOhTyMjKypKSkhLZunWrrFu3Tqqrq2XmzJnS2tra6/pFRUUSHx/ftaSlpYVquwEAwAAXMMaYq/3mpqYmSU9Pl+LiYlm4cGGvPRm6dNKeDA0as2SuxAQGXf1WAxGGU1gB9KcL5rzsli3S3Nwsw4cPv+Lvu6ZZmwkJCTJ+/Hipqqrq9fm4uDi7AACAyHNN18k4deqUHDlyRFJSUkK3RQAAIPJCxooVK6SsrEz+/PNP+f7772XevHkSHR0tTz75pLstBAAAYalPwyV1dXU2UJw8eVJGjhwp9957r1RUVNjPAQAArjpklJaW9mV1ACES7venYeIqEJm4dwkAAHCCkAEAAJwgZAAAACcIGQAAwAlCBgAAcIKQAQAAnCBkAAAAJwgZAADACUIGAABwgpABAACcIGQAAID+v3cJAETivVcQetzPJjLQkwEAAJwgZAAAACcIGQAAwAlCBgAAcIKQAQAAnCBkAAAAJwgZAADACUIGAABwgpABAACcIGQAAAAnCBkAAMAJ7l3iAe4BACDccD+byEBPBgAAcIKQAQAAnCBkAAAAJwgZAADACUIGAABwgpABAACcIGQAAAAnCBkAAMAJQgYAAHCCkAEAAML/suLGGPvxgpwX+edThEBLa0ef1r9gzvN7BwBcMXvc7nYcv1IB09fvuAZ1dXWSlpZ2vf47AAAQQrW1tTJq1KiBGTI6Ojqkvr5ehg0bJoFAIOi5lpYWG0C0AcOHDxff0V7/UWP/UWP/UeN/aFRobW2V1NRUiYqKGpjDJbphl0tAGjAiIWR0or3+o8b+o8b+o8Yi8fHxff69MfETAAA4QcgAAAB+h4y4uDgpLCy0HyMB7fUfNfYfNfYfNb4213XiJwAAiBwDpicDAAD4hZABAACcIGQAAAAnCBkAAMAJQgYAAPA3ZKxdu1bGjBkjgwcPlqysLNm7d6/4avXq1faS6t2XCRMmiC++/fZbeeSRR+ylZ7Vtn3/+edDzejLTqlWrJCUlRW644QbJzc2Vw4cPi89tfvrppy+q+YMPPijhqqioSO6++257e4BbbrlFHn30UTl06FDQOmfPnpWCggK5+eabZejQofL444/L8ePHxdf2zpo166IaL168WMLVunXr5K677uq6ymV2drZ89dVXXtb3StrrW317WrNmjW3T888/H/Ia93vI2LhxoyxfvtxeI2P//v2SmZkpc+bMkcbGRvHVpEmT5NixY13Lnj17xBenT5+2NdTg2Js333xT3n33XVm/fr388MMPcuONN9p66w7ta5uVhoruNd+wYYOEq7KyMvvHp6KiQrZt2ybnz5+X2bNn299DpxdeeEG++OIL2bRpk11f71n02GOPia/tVfn5+UE11n09XOntH/TAU1lZKfv27ZMHHnhA5s6dK7/++qt39b2S9vpW3+5+/PFHef/9923I6i5kNTb9bMaMGaagoKDrcXt7u0lNTTVFRUXGR4WFhSYzM9NEAt29Nm/e3PW4o6PDJCcnm7feeqvra01NTSYuLs5s2LDB+NhmlZeXZ+bOnWt81djYaNtdVlbWVdNBgwaZTZs2da3z+++/23XKy8uNb+1V999/v1m2bJnx2U033WQ++OAD7+vbs70+17e1tdXcfvvtZtu2bUFtDGWN+7Un49y5czY5apd595uo6ePy8nLxlQ4PaNd6RkaGLFiwQGpqaiQSVFdXS0NDQ1C99YY7OkTmc73V7t27bVf7HXfcIUuWLJGTJ0+KL5qbm+3HESNG2I/6mtZ3+93rrEOCo0eP9qLOPdvb6eOPP5bExES58847ZeXKlXLmzBnxQXt7u5SWltqeGx1G8L2+Pdvrc30LCgrk4YcfDqqlCmWNr+tdWHs6ceKELWhSUlLQ1/XxH3/8IT7SA2pJSYk92GiX2+uvvy4zZ86UX375xY75+kwDhuqt3p3P+UiHSrSbcezYsXLkyBF59dVX5aGHHrIv1ujoaAlnHR0ddhz3nnvusX98ldYyNjZWEhISvKtzb+1VTz31lKSnp9s3DwcPHpSXX37Zztv47LPPJFz9/PPP9iCrQ5k6Jr9582aZOHGiHDhwwMv6Xqq9vta3tLTUTlHQ4ZKeQvka7teQEYn04NJJx8A0dOjO+8knn8jChQv7ddvgxhNPPNH1+eTJk23db7vtNtu7kZOTE/bvhDQg+zSv6Grau2jRoqAa68Rmra2GSq11ONI3QhootOfm008/lby8PDs276tLtVeDhm/1ra2tlWXLltk5RnrChUv9OlyiXU/6Tq7njFV9nJycLJFAk+L48eOlqqpKfNdZ00iut9JhMt33w73mS5culS+//FJ27dplJ8510lrqUGhTU5NXdb5Ue3ujbx5UONdY38mOGzdOpk2bZs+w0cnN77zzjrf1vVR7faxvZWWlPbli6tSpEhMTYxcNVDopXz/XHotQ1Tiqv4uqBd2xY0dQd6Q+7j4W5rNTp07ZNKzJ2Hc6XKA7aPd6t7S02LNMIqXeqq6uzs7JCNea6/xWPeBqd/LOnTttXbvT1/SgQYOC6qxdyzr3KBzrfLn29kbfEatwrXFv9G9zW1ubd/W9XHt9rG9OTo4dHtJ2dC7Tp0+3cwQ7Pw9ZjU0/Ky0ttWcXlJSUmN9++80sWrTIJCQkmIaGBuOjF1980ezevdtUV1eb7777zuTm5prExEQ7Y92X2co//fSTXXT3Ki4utp8fPXrUPr9mzRpb3y1btpiDBw/asy7Gjh1r/v77b+Njm/W5FStW2BnZWvPt27ebqVOn2hndZ8+eNeFoyZIlJj4+3u7Hx44d61rOnDnTtc7ixYvN6NGjzc6dO82+fftMdna2XXxsb1VVlXnjjTdsO7XGum9nZGSY++67z4SrV155xZ49o+3R16k+DgQC5ptvvvGuvpdrr4/17U3PM2hCVeN+Dxnqvffes42JjY21p7RWVFQYX82fP9+kpKTYtt566632se7Evti1a5c90PZc9DTOztNYX3vtNZOUlGTDZU5Ojjl06JDxtc16IJo9e7YZOXKkPSUsPT3d5Ofnh3WI7q2tunz00Udd62hofO655+xpgEOGDDHz5s2zB2Yf21tTU2MPOCNGjLD79Lhx48xLL71kmpubTbh69tln7b6qf6d039XXaWfA8K2+l2uvj/W9kpARqhoH9J/Qd8YAAIBI1+9X/AQAAH4iZAAAACcIGQAAwAlCBgAAcIKQAQAAnCBkAAAAJwgZAADACUIGAABwgpABAACcIGQAAAAnCBkAAEBc+C8k+bsNu0IuyQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "plt.imshow(xenc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a2730fe6-7758-48c3-809e-7f8dcd89443a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 9.1839e-01, -1.1502e+00, -8.3629e-01,  1.5226e+00,  8.4343e-01,\n",
       "          3.7099e-01,  1.8180e+00, -6.9873e-01,  1.1443e+00],\n",
       "        [-8.5412e-01,  9.1880e-01,  1.6299e+00, -9.2986e-02,  4.3231e-01,\n",
       "          3.4169e-01,  6.3554e-01, -2.3884e+00, -1.6909e-01],\n",
       "        [ 1.1750e+00,  9.7545e-01,  6.1806e-01,  2.2869e-01, -1.2293e+00,\n",
       "          8.1735e-01, -9.5131e-01,  1.2416e+00,  5.6462e-01],\n",
       "        [ 2.6404e-01, -1.0920e-01,  5.7622e-01,  1.4399e-01,  7.6392e-01,\n",
       "          3.0535e-01, -7.3563e-01, -9.6206e-01,  3.5338e-01],\n",
       "        [ 1.5070e+00,  1.4913e+00,  2.9527e-01, -4.1975e-01, -1.4888e+00,\n",
       "          6.9776e-01,  8.9753e-01,  1.0222e+00, -1.4788e+00],\n",
       "        [ 1.2613e-01,  8.5963e-01, -7.1895e-01, -2.6217e+00,  4.3825e-01,\n",
       "          2.3034e+00, -1.7862e+00,  3.5024e-01,  9.5713e-01],\n",
       "        [-3.4914e-01,  1.1381e+00, -8.4526e-01,  3.8407e-01, -2.9689e-02,\n",
       "         -8.9423e-01,  9.8505e-01,  9.0442e-01, -5.1849e-01],\n",
       "        [ 6.5824e-01, -7.6243e-01, -8.5083e-02, -1.5097e+00,  3.3365e-01,\n",
       "         -4.7317e-02, -5.2482e-01,  1.1375e+00, -3.8334e-01],\n",
       "        [-4.3876e-02,  9.8445e-01,  1.6056e-01, -4.3068e-02, -4.3082e-01,\n",
       "         -2.1318e-01, -4.0735e-01, -8.3677e-01,  3.9790e-01],\n",
       "        [-1.7768e+00, -9.3988e-01,  1.0934e+00,  8.8543e-01, -1.2180e+00,\n",
       "         -7.6091e-01,  5.6308e-01, -6.6491e-01, -1.1863e+00],\n",
       "        [-1.4602e+00,  4.4132e-01, -3.5278e-01, -1.4347e+00, -1.5626e+00,\n",
       "          6.8525e-01,  4.2568e-01,  8.3430e-01, -1.0295e+00],\n",
       "        [-1.6114e+00, -5.9811e-01, -6.2155e-01, -2.9091e-03,  5.5972e-01,\n",
       "         -7.8317e-01, -8.3771e-02,  7.5206e-01, -1.4175e+00],\n",
       "        [ 8.1634e-01,  4.9134e-01, -2.7166e-01, -6.0395e-01,  7.2573e-01,\n",
       "          7.1407e-01,  1.7733e-01,  1.7713e-01,  1.4325e-01],\n",
       "        [ 7.0480e-01, -1.6713e-01,  9.5935e-01, -7.8019e-01,  1.6251e+00,\n",
       "          1.2835e+00,  4.8042e-02,  5.8192e-01, -3.0040e+00],\n",
       "        [ 6.3264e-01,  2.4104e-01, -7.3922e-01,  4.7470e-01, -1.7250e+00,\n",
       "         -7.5146e-01, -9.4407e-01, -8.6993e-01, -2.6355e+00],\n",
       "        [-1.8073e-01, -1.2490e+00, -8.2327e-01,  1.4730e+00,  5.8836e-01,\n",
       "          4.1034e-03,  9.9122e-01,  2.5957e-01, -1.8716e-01],\n",
       "        [ 9.1802e-01, -3.6061e-01, -8.1788e-01,  1.5692e+00, -1.1461e+00,\n",
       "         -5.7984e-01,  6.7999e-01, -4.1150e-02,  1.0829e+00],\n",
       "        [ 1.1454e+00, -4.3676e-01,  2.2854e+00,  1.1893e+00,  4.0966e-01,\n",
       "         -1.1711e-01, -2.1936e-02, -8.4547e-01,  1.2236e-01],\n",
       "        [ 1.4835e+00, -1.6104e-01,  4.8580e-02, -2.6155e+00, -1.4138e-01,\n",
       "          1.0443e+00, -2.3830e-01, -1.4911e+00,  3.5307e-01],\n",
       "        [ 2.0946e+00,  1.9337e+00, -2.6824e-01,  2.3089e-01,  1.1514e-03,\n",
       "          2.1862e+00,  6.3311e-01, -6.1647e-01, -1.3749e+00],\n",
       "        [ 1.5384e+00, -7.6717e-01, -7.3752e-01,  1.2570e+00, -5.8681e-01,\n",
       "          1.3887e+00, -9.6056e-01,  4.8157e-01, -4.1506e-01],\n",
       "        [ 1.1072e-01,  1.1431e+00,  2.0399e+00, -5.5736e-01,  6.7684e-01,\n",
       "         -6.8097e-01, -8.7569e-01, -1.2483e+00,  7.7616e-01],\n",
       "        [-2.6649e-01,  7.9188e-01,  7.2701e-01,  1.7280e+00, -1.1796e+00,\n",
       "          5.2148e-01, -6.1184e-01,  3.1035e-01,  9.7009e-01],\n",
       "        [-8.3095e-01, -1.6064e+00,  2.3667e+00, -1.2204e+00,  4.1136e-01,\n",
       "         -1.4684e+00,  6.3564e-02, -1.5051e+00, -2.2001e-01],\n",
       "        [-5.2230e-01,  8.1375e-01,  6.1553e-01, -4.2599e-02,  1.7301e-01,\n",
       "         -1.2271e-01, -2.0114e+00, -7.8907e-01, -1.2734e+00],\n",
       "        [ 6.8031e-01, -3.0871e-01, -3.0772e-01,  6.3263e-01,  1.5590e+00,\n",
       "          2.7520e-01, -1.0685e+00,  2.7201e-01,  9.9093e-01],\n",
       "        [ 9.2734e-02, -1.3574e+00,  4.0598e-01,  4.9756e-01,  5.6375e-01,\n",
       "         -1.1143e+00,  5.2196e-01,  1.4329e-01, -3.9328e-01],\n",
       "        [ 2.9002e-01,  2.9804e-01,  2.2331e+00,  2.1968e+00, -5.0351e-01,\n",
       "          1.5336e-01,  4.6150e-02,  1.7698e+00, -2.8478e-01],\n",
       "        [ 1.2679e+00, -5.9689e-01, -1.3967e+00,  1.0058e+00,  2.8197e-01,\n",
       "          1.0780e+00, -5.4874e-01, -5.0334e-01, -2.8814e-01],\n",
       "        [ 2.1569e+00,  6.2626e-01,  1.9641e-01, -1.5439e+00,  5.7199e-01,\n",
       "         -9.1998e-01, -1.1759e+00, -7.0328e-01,  3.1153e-01],\n",
       "        [-2.3782e-01,  1.5635e+00, -9.9518e-01,  3.7845e-01, -1.5209e+00,\n",
       "         -7.1817e-01,  3.5974e-01,  1.7197e-01,  3.6362e-01],\n",
       "        [ 2.2835e-01, -1.5459e+00, -9.5902e-01, -1.1907e+00, -3.7331e-01,\n",
       "          5.8753e-01, -1.4671e+00, -3.0594e-01,  8.4065e-01],\n",
       "        [ 1.1026e+00, -8.7721e-01,  1.5426e+00,  4.4972e-01,  3.1010e-01,\n",
       "          9.7017e-01,  7.8906e-01,  1.0092e+00,  2.0941e+00],\n",
       "        [-1.7965e+00,  1.8701e-01,  4.3996e-01,  4.6749e-01,  4.0462e-01,\n",
       "          9.2346e-02, -1.9932e+00,  7.7445e-01,  7.0330e-01],\n",
       "        [ 1.0255e+00,  1.0959e+00,  2.6846e-01, -1.3541e+00, -6.1543e-02,\n",
       "         -3.2624e-01, -7.2100e-01, -4.0800e-01,  9.1167e-01],\n",
       "        [ 1.0763e+00, -1.4209e+00, -5.9040e-01, -1.2943e+00,  2.4534e-01,\n",
       "         -9.7199e-01, -1.7028e+00, -1.6392e+00, -1.0653e+00],\n",
       "        [ 2.0118e-01, -2.0262e+00,  1.3364e+00, -1.7834e+00,  6.5608e-01,\n",
       "         -1.3171e+00,  8.7381e-01, -2.6745e-01,  1.3398e-01],\n",
       "        [ 8.5316e-01,  8.4795e-01, -1.1855e-01, -3.7457e-02,  2.4010e-01,\n",
       "          2.7003e-01,  1.1791e+00,  1.0496e+00,  1.6055e+00],\n",
       "        [-1.3566e+00, -8.5464e-01,  8.9450e-01,  1.3328e+00,  2.9470e-01,\n",
       "         -6.0269e-01, -9.7720e-01,  4.0491e-01,  1.5337e+00],\n",
       "        [ 1.5095e+00,  2.3863e-01,  1.1418e+00, -7.8052e-01, -2.4659e-01,\n",
       "         -8.4775e-01,  3.0722e-01,  6.8697e-01,  8.0159e-02],\n",
       "        [-1.3932e+00, -6.0093e-01, -8.3311e-01,  2.6036e-01, -6.4276e-01,\n",
       "          3.4897e-01, -1.7955e+00, -7.9129e-01,  1.5356e-01]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Pour notre réseau, on va utiliser une matrice W des valeurs normales aléatoires comme\n",
    "# point de départ\n",
    "W = torch.randn((nb_chars, tensor_dims))  # Quand on aura tous les mots, on utilisera nb_chars x nb_chars\n",
    "W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "955624a3-b510-4416-9af0-c59984ae6888",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.9184, -1.1502, -0.8363,  1.5226,  0.8434,  0.3710,  1.8180, -0.6987,\n",
       "          1.1443],\n",
       "        [ 0.2640, -0.1092,  0.5762,  0.1440,  0.7639,  0.3053, -0.7356, -0.9621,\n",
       "          0.3534],\n",
       "        [ 0.1261,  0.8596, -0.7190, -2.6217,  0.4382,  2.3034, -1.7862,  0.3502,\n",
       "          0.9571],\n",
       "        [ 0.1261,  0.8596, -0.7190, -2.6217,  0.4382,  2.3034, -1.7862,  0.3502,\n",
       "          0.9571],\n",
       "        [ 0.6582, -0.7624, -0.0851, -1.5097,  0.3336, -0.0473, -0.5248,  1.1375,\n",
       "         -0.3833],\n",
       "        [ 1.1454, -0.4368,  2.2854,  1.1893,  0.4097, -0.1171, -0.0219, -0.8455,\n",
       "          0.1224],\n",
       "        [ 0.1107,  1.1431,  2.0399, -0.5574,  0.6768, -0.6810, -0.8757, -1.2483,\n",
       "          0.7762],\n",
       "        [ 1.1026, -0.8772,  1.5426,  0.4497,  0.3101,  0.9702,  0.7891,  1.0092,\n",
       "          2.0941],\n",
       "        [ 0.6582, -0.7624, -0.0851, -1.5097,  0.3336, -0.0473, -0.5248,  1.1375,\n",
       "         -0.3833]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# En multipliant ces \"poids\" par nos vecteurs one-hot organisés en matrice...\n",
    "# On obtient des valeurs que l'on va \"interpréter\" comme des logs (log-counts).\n",
    "# En utilisant l'exponentielle de ces valeurs, on va retrouver quelque chose\n",
    "# d'équivalent à la matrice N que nous avions définie précédemment dans la méthode\n",
    "# par comptage.\n",
    "xenc @ W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ffed90e6-fb4e-4343-9634-83f0ac5a5d1a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.1170, 0.0148, 0.0202, 0.2141, 0.1086, 0.0677, 0.2877, 0.0232, 0.1467],\n",
       "        [0.1192, 0.0821, 0.1629, 0.1057, 0.1965, 0.1243, 0.0439, 0.0350, 0.1304],\n",
       "        [0.0573, 0.1193, 0.0246, 0.0037, 0.0783, 0.5053, 0.0085, 0.0717, 0.1315],\n",
       "        [0.0573, 0.1193, 0.0246, 0.0037, 0.0783, 0.5053, 0.0085, 0.0717, 0.1315],\n",
       "        [0.1879, 0.0454, 0.0893, 0.0215, 0.1358, 0.0928, 0.0576, 0.3034, 0.0663],\n",
       "        [0.1439, 0.0296, 0.4501, 0.1504, 0.0690, 0.0407, 0.0448, 0.0197, 0.0518],\n",
       "        [0.0625, 0.1756, 0.4304, 0.0321, 0.1101, 0.0283, 0.0233, 0.0161, 0.1216],\n",
       "        [0.1126, 0.0156, 0.1749, 0.0586, 0.0510, 0.0987, 0.0823, 0.1026, 0.3036],\n",
       "        [0.1879, 0.0454, 0.0893, 0.0215, 0.1358, 0.0928, 0.0576, 0.3034, 0.0663]])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits = xenc @ W # log-counts \n",
    "counts = logits.exp() # statut équivalent à N\n",
    "probs = counts / counts.sum(1, keepdims=True) # distribution de probabilités (equ. à p)\n",
    "probs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c13f459-6c52-4522-80ab-972364d20d94",
   "metadata": {},
   "source": [
    "#### Réseau de neurones sur cet exemple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "6e059892-1a6e-46f1-b2ce-9f6806eac11e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialisation de \"nb_chars\" poids de neurones\n",
    "g = torch.Generator().manual_seed(2147483647)\n",
    "W = torch.randn((nb_chars, nb_chars), generator=g, requires_grad=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a7e0583f-148e-431e-bde6-466db2e69586",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Réseau à une couche (probs)\n",
    "xenc = F.one_hot(xs, num_classes=nb_chars).float() # input to the network: one-hot encoding\n",
    "logits = xenc @ W  # predict log-counts\n",
    "counts = logits.exp() # counts, equivalent to N\n",
    "probs = counts / counts.sum(1, keepdims=True)  # probabilities for next character\n",
    "# btw: the last 2 lines here are together called a 'softmax'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "adc929b6-9648-4a40-a344-19666747e480",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------\n",
      "bigram example 1: .a (indexes 0,3)\n",
      "input to the neural net: 0\n",
      "output probabilities from the neural net: tensor([0.0495, 0.0081, 0.0100, 0.0034, 0.0137, 0.0100, 0.0022, 0.0189, 0.0112,\n",
      "        0.0255, 0.0064, 0.0227, 0.0074, 0.0067, 0.0407, 0.1939, 0.0492, 0.0020,\n",
      "        0.0203, 0.0045, 0.0276, 0.0089, 0.0023, 0.0162, 0.0096, 0.1253, 0.1189,\n",
      "        0.0053, 0.0030, 0.0140, 0.0035, 0.0214, 0.0109, 0.0382, 0.0046, 0.0044,\n",
      "        0.0017, 0.0361, 0.0030, 0.0348, 0.0039], grad_fn=<SelectBackward0>)\n",
      "label (actual next character): 3\n",
      "probability assigned by the net to the the correct character: 0.003431369084864855\n",
      "log likelihood: -5.674796104431152\n",
      "negative log likelihood: 5.674796104431152\n",
      "--------\n",
      "bigram example 2: ac (indexes 3,5)\n",
      "input to the neural net: 3\n",
      "output probabilities from the neural net: tensor([0.0017, 0.0064, 0.0258, 0.0032, 0.0085, 0.0247, 0.0371, 0.0103, 0.0104,\n",
      "        0.0024, 0.0027, 0.0207, 0.0226, 0.0620, 0.0193, 0.0406, 0.1549, 0.0225,\n",
      "        0.0073, 0.0261, 0.0076, 0.0234, 0.0546, 0.0178, 0.0089, 0.0141, 0.0084,\n",
      "        0.0245, 0.0226, 0.0035, 0.0713, 0.0167, 0.0378, 0.0234, 0.0390, 0.0021,\n",
      "        0.0092, 0.0017, 0.0368, 0.0490, 0.0181], grad_fn=<SelectBackward0>)\n",
      "label (actual next character): 5\n",
      "probability assigned by the net to the the correct character: 0.024747123941779137\n",
      "log likelihood: -3.6990458965301514\n",
      "negative log likelihood: 3.6990458965301514\n",
      "--------\n",
      "bigram example 3: cc (indexes 5,5)\n",
      "input to the neural net: 5\n",
      "output probabilities from the neural net: tensor([0.0355, 0.0383, 0.0081, 0.0299, 0.0039, 0.0132, 0.0410, 0.0156, 0.0065,\n",
      "        0.0104, 0.0081, 0.0344, 0.0140, 0.0583, 0.0237, 0.0640, 0.0057, 0.0073,\n",
      "        0.0205, 0.0083, 0.0115, 0.0048, 0.0567, 0.0031, 0.0105, 0.0059, 0.0147,\n",
      "        0.0128, 0.0543, 0.0087, 0.0325, 0.0304, 0.1164, 0.0116, 0.0171, 0.0283,\n",
      "        0.0362, 0.0034, 0.0144, 0.0673, 0.0128], grad_fn=<SelectBackward0>)\n",
      "label (actual next character): 5\n",
      "probability assigned by the net to the the correct character: 0.013233082368969917\n",
      "log likelihood: -4.325035572052002\n",
      "negative log likelihood: 4.325035572052002\n",
      "--------\n",
      "bigram example 4: ce (indexes 5,7)\n",
      "input to the neural net: 5\n",
      "output probabilities from the neural net: tensor([0.0355, 0.0383, 0.0081, 0.0299, 0.0039, 0.0132, 0.0410, 0.0156, 0.0065,\n",
      "        0.0104, 0.0081, 0.0344, 0.0140, 0.0583, 0.0237, 0.0640, 0.0057, 0.0073,\n",
      "        0.0205, 0.0083, 0.0115, 0.0048, 0.0567, 0.0031, 0.0105, 0.0059, 0.0147,\n",
      "        0.0128, 0.0543, 0.0087, 0.0325, 0.0304, 0.1164, 0.0116, 0.0171, 0.0283,\n",
      "        0.0362, 0.0034, 0.0144, 0.0673, 0.0128], grad_fn=<SelectBackward0>)\n",
      "label (actual next character): 7\n",
      "probability assigned by the net to the the correct character: 0.01563512347638607\n",
      "log likelihood: -4.158235549926758\n",
      "negative log likelihood: 4.158235549926758\n",
      "--------\n",
      "bigram example 5: ep (indexes 7,17)\n",
      "input to the neural net: 7\n",
      "output probabilities from the neural net: tensor([0.0579, 0.0032, 0.0506, 0.0075, 0.0371, 0.0182, 0.0344, 0.0079, 0.0179,\n",
      "        0.0157, 0.0043, 0.0297, 0.0035, 0.0061, 0.0524, 0.0082, 0.0414, 0.0450,\n",
      "        0.0362, 0.0097, 0.0044, 0.0414, 0.0829, 0.0135, 0.0096, 0.0334, 0.0228,\n",
      "        0.0134, 0.0243, 0.0257, 0.0424, 0.0110, 0.0411, 0.0228, 0.0270, 0.0017,\n",
      "        0.0121, 0.0113, 0.0182, 0.0468, 0.0071], grad_fn=<SelectBackward0>)\n",
      "label (actual next character): 17\n",
      "probability assigned by the net to the the correct character: 0.04502563923597336\n",
      "log likelihood: -3.1005232334136963\n",
      "negative log likelihood: 3.1005232334136963\n",
      "=========\n",
      "average negative log likelihood, i.e. loss = 4.191527366638184\n"
     ]
    }
   ],
   "source": [
    "nlls = torch.zeros(5)\n",
    "for i in range(5):\n",
    "    # i-th bigram:\n",
    "    x = xs[i].item() # input character index\n",
    "    y = ys[i].item() # label character index\n",
    "    print('--------')\n",
    "    print(f'bigram example {i+1}: {itoc[x]}{itoc[y]} (indexes {x},{y})')\n",
    "    print('input to the neural net:', x)\n",
    "    print('output probabilities from the neural net:', probs[i])\n",
    "    print('label (actual next character):', y)\n",
    "    p = probs[i, y]\n",
    "    print('probability assigned by the net to the the correct character:', p.item())\n",
    "    logp = torch.log(p)\n",
    "    print('log likelihood:', logp.item())\n",
    "    nll = -logp\n",
    "    print('negative log likelihood:', nll.item())\n",
    "    nlls[i] = nll\n",
    "\n",
    "print('=========')\n",
    "print('average negative log likelihood, i.e. loss =', nlls.mean().item())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efb1e71a-28f3-4302-b6ba-08e060ea8aad",
   "metadata": {},
   "source": [
    "#### Optimization sur un mot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "5eedede8-d473-4389-b7da-ab41cf321a2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# forward pass\n",
    "xenc = F.one_hot(xs, num_classes=nb_chars).float() # input to the network: one-hot encoding\n",
    "logits = xenc @ W # predict log-counts\n",
    "counts = logits.exp() # counts, equivalent to N\n",
    "probs = counts / counts.sum(1, keepdims=True) # probabilities for next character\n",
    "loss = -probs[torch.arange(tensor_dims), ys].log().mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "d066be61-b7d3-48cf-8efb-223ffb6c291c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3.725316047668457\n"
     ]
    }
   ],
   "source": [
    "print(loss.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "b29db5f6-b20c-49c7-a7a9-02b8b7408547",
   "metadata": {},
   "outputs": [],
   "source": [
    "# backward pass\n",
    "W.grad = None # set to zero the gradient\n",
    "loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "b5dd596b-f190-4367-bb07-7d9dfbfb24fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "W.data += -0.1 * W.grad\n",
    "# ^^ loop above from forward pass and see loss decreasing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de6830ad-e2a1-4253-9e28-f787915fa665",
   "metadata": {},
   "source": [
    "### Synthèse: apprentissage complet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "e33ec731-508e-4a18-929f-116ff1727592",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NB exemples: 67652\n"
     ]
    }
   ],
   "source": [
    "#\n",
    "# Générateur des mots selon notre modèle de langue génératif bigrams par réseau de neurones\n",
    "#\n",
    "import torch\n",
    "\n",
    "# Lecture des données\n",
    "EOS='.'\n",
    "words = open('civil_mots.txt', 'r').read().splitlines()\n",
    "chars = sorted(list(set(''.join(words))))\n",
    "nb_chars = len(chars) + 1  # On ajoute 1 pour EOS\n",
    "\n",
    "# Dictionnaires caractère <-> entier\n",
    "ctoi = {c:i+1 for i,c in enumerate(chars)}\n",
    "ctoi['.'] = 0\n",
    "itoc = {i:c for c,i in ctoi.items()}# Création du dataset avec tous les mots\n",
    "\n",
    "# Génération du jeu d'entraînement\n",
    "xs, ys = [], []\n",
    "for w in words:\n",
    "    chs = ['.'] + list(w) + ['.']\n",
    "    for ch1, ch2 in zip(chs, chs[1:]):\n",
    "        ix1 = ctoi[ch1]\n",
    "        ix2 = ctoi[ch2]\n",
    "        xs.append(ix1)\n",
    "        ys.append(ix2)\n",
    "xs = torch.tensor(xs)\n",
    "ys = torch.tensor(ys)\n",
    "num = xs.nelement()\n",
    "print('NB exemples:', num)\n",
    "\n",
    "# Initialisation du réseau (une seule couche de neurones sans biais)\n",
    "g = torch.Generator().manual_seed(2147483647)\n",
    "W = torch.randn((nb_chars, nb_chars), generator=g, requires_grad=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "65623cfb-bd14-40e9-a02f-dd6539b68966",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.268752574920654\n",
      "3.8354008197784424\n",
      "3.5137295722961426\n",
      "3.2882258892059326\n",
      "3.136672258377075\n",
      "3.0282742977142334\n",
      "2.9461023807525635\n",
      "2.881960153579712\n",
      "2.830700159072876\n",
      "2.788752555847168\n",
      "2.75366473197937\n",
      "2.723792552947998\n",
      "2.69801926612854\n",
      "2.6755597591400146\n",
      "2.6558377742767334\n",
      "2.6384148597717285\n",
      "2.622948169708252\n",
      "2.609158754348755\n",
      "2.5968127250671387\n",
      "2.585711717605591\n",
      "2.575685501098633\n",
      "2.566587209701538\n",
      "2.5582938194274902\n",
      "2.5506985187530518\n",
      "2.543713331222534\n",
      "2.537261962890625\n",
      "2.5312814712524414\n",
      "2.5257174968719482\n",
      "2.5205249786376953\n",
      "2.5156655311584473\n",
      "2.5111048221588135\n",
      "2.5068154335021973\n",
      "2.5027718544006348\n",
      "2.498953342437744\n",
      "2.495340585708618\n",
      "2.491917610168457\n",
      "2.4886693954467773\n",
      "2.4855830669403076\n",
      "2.482647657394409\n",
      "2.4798519611358643\n",
      "2.477186679840088\n",
      "2.4746437072753906\n",
      "2.472215175628662\n",
      "2.4698939323425293\n",
      "2.4676730632781982\n",
      "2.4655468463897705\n",
      "2.463510274887085\n",
      "2.461557626724243\n",
      "2.459684371948242\n",
      "2.457886219024658\n",
      "2.4561588764190674\n",
      "2.454498529434204\n",
      "2.452902317047119\n",
      "2.4513659477233887\n",
      "2.4498867988586426\n",
      "2.4484617710113525\n",
      "2.4470887184143066\n",
      "2.4457643032073975\n",
      "2.4444868564605713\n",
      "2.443253517150879\n",
      "2.4420626163482666\n",
      "2.4409120082855225\n",
      "2.4397995471954346\n",
      "2.4387240409851074\n",
      "2.43768310546875\n",
      "2.436675786972046\n",
      "2.4357004165649414\n",
      "2.434755325317383\n",
      "2.4338393211364746\n",
      "2.4329514503479004\n",
      "2.4320902824401855\n",
      "2.4312546253204346\n",
      "2.430443286895752\n",
      "2.4296555519104004\n",
      "2.4288904666900635\n",
      "2.4281468391418457\n",
      "2.4274234771728516\n",
      "2.4267208576202393\n",
      "2.4260365962982178\n",
      "2.425370693206787\n",
      "2.424722671508789\n",
      "2.424091339111328\n",
      "2.423476219177246\n",
      "2.422877311706543\n",
      "2.422293186187744\n",
      "2.4217233657836914\n",
      "2.4211676120758057\n",
      "2.4206252098083496\n",
      "2.4200961589813232\n",
      "2.419579029083252\n",
      "2.419074296951294\n",
      "2.418581247329712\n",
      "2.4180994033813477\n",
      "2.417628526687622\n",
      "2.4171676635742188\n",
      "2.416717529296875\n",
      "2.4162769317626953\n",
      "2.4158456325531006\n",
      "2.415423631668091\n",
      "2.415010690689087\n",
      "2.4146060943603516\n",
      "2.4142098426818848\n",
      "2.4138216972351074\n",
      "2.4134411811828613\n",
      "2.4130687713623047\n",
      "2.412703275680542\n",
      "2.4123449325561523\n",
      "2.4119937419891357\n",
      "2.411648750305176\n",
      "2.4113104343414307\n",
      "2.4109785556793213\n",
      "2.4106526374816895\n",
      "2.4103331565856934\n",
      "2.4100189208984375\n",
      "2.4097108840942383\n",
      "2.4094078540802\n",
      "2.4091103076934814\n",
      "2.408817768096924\n",
      "2.4085307121276855\n",
      "2.4082484245300293\n",
      "2.407970905303955\n",
      "2.407698392868042\n",
      "2.4074296951293945\n",
      "2.407166004180908\n",
      "2.4069066047668457\n",
      "2.406651258468628\n",
      "2.406399965286255\n",
      "2.4061529636383057\n",
      "2.405909776687622\n",
      "2.405670404434204\n",
      "2.4054346084594727\n",
      "2.405202627182007\n",
      "2.4049742221832275\n",
      "2.404749631881714\n",
      "2.4045279026031494\n",
      "2.4043097496032715\n",
      "2.404094696044922\n",
      "2.4038827419281006\n",
      "2.403674364089966\n",
      "2.4034688472747803\n",
      "2.403266191482544\n",
      "2.403066396713257\n",
      "2.402869462966919\n",
      "2.4026753902435303\n",
      "2.402484178543091\n",
      "2.4022953510284424\n",
      "2.4021096229553223\n",
      "2.401926040649414\n",
      "2.401745080947876\n",
      "2.401566743850708\n",
      "2.401390790939331\n",
      "2.401216983795166\n",
      "2.40104603767395\n",
      "2.400876998901367\n",
      "2.400709867477417\n",
      "2.400545597076416\n",
      "2.4003829956054688\n",
      "2.4002227783203125\n",
      "2.40006422996521\n",
      "2.3999080657958984\n",
      "2.399754047393799\n",
      "2.399601936340332\n",
      "2.39945125579834\n",
      "2.3993031978607178\n",
      "2.3991565704345703\n",
      "2.3990116119384766\n",
      "2.3988687992095947\n",
      "2.3987278938293457\n",
      "2.3985884189605713\n",
      "2.3984506130218506\n",
      "2.3983144760131836\n",
      "2.3981800079345703\n",
      "2.3980472087860107\n",
      "2.397916078567505\n",
      "2.3977863788604736\n",
      "2.397658348083496\n",
      "2.3975319862365723\n",
      "2.3974063396453857\n",
      "2.3972833156585693\n",
      "2.3971610069274902\n",
      "2.397040367126465\n",
      "2.396921157836914\n",
      "2.396803140640259\n",
      "2.396686315536499\n",
      "2.396571159362793\n",
      "2.3964574337005615\n",
      "2.3963444232940674\n",
      "2.396233320236206\n",
      "2.396122932434082\n",
      "2.3960142135620117\n",
      "2.3959062099456787\n",
      "2.3957996368408203\n",
      "2.3956944942474365\n",
      "2.395590305328369\n",
      "2.3954873085021973\n",
      "2.3953850269317627\n",
      "2.395284414291382\n",
      "2.3951847553253174\n",
      "2.3950858116149902\n",
      "2.394988536834717\n",
      "2.3948915004730225\n",
      "2.3947958946228027\n",
      "2.3947017192840576\n",
      "2.3946080207824707\n",
      "2.3945152759552\n",
      "2.394423723220825\n",
      "2.3943333625793457\n",
      "2.3942437171936035\n",
      "2.3941547870635986\n",
      "2.39406681060791\n",
      "2.393979787826538\n",
      "2.3938939571380615\n",
      "2.393808364868164\n",
      "2.393724203109741\n",
      "2.3936407566070557\n",
      "2.3935577869415283\n",
      "2.3934762477874756\n",
      "2.393394708633423\n",
      "2.3933145999908447\n",
      "2.393235206604004\n",
      "2.3931565284729004\n",
      "2.393078327178955\n",
      "2.3930013179779053\n",
      "2.3929250240325928\n",
      "2.3928489685058594\n",
      "2.3927741050720215\n",
      "2.392699718475342\n",
      "2.3926267623901367\n",
      "2.3925535678863525\n",
      "2.3924813270568848\n",
      "2.3924098014831543\n",
      "2.3923392295837402\n",
      "2.392268657684326\n",
      "2.3921995162963867\n",
      "2.3921306133270264\n",
      "2.3920624256134033\n",
      "2.3919947147369385\n",
      "2.391927719116211\n",
      "2.3918612003326416\n",
      "2.3917956352233887\n",
      "2.3917300701141357\n",
      "2.391665458679199\n",
      "2.391601324081421\n",
      "2.391538381576538\n",
      "2.391475200653076\n",
      "2.3914127349853516\n",
      "2.3913512229919434\n",
      "2.3912899494171143\n",
      "2.3912291526794434\n",
      "2.3911690711975098\n",
      "2.391108989715576\n",
      "2.391049861907959\n",
      "2.390991449356079\n",
      "2.3909332752227783\n",
      "2.3908755779266357\n",
      "2.3908183574676514\n",
      "2.390761613845825\n",
      "2.3907055854797363\n",
      "2.3906497955322266\n",
      "2.390594959259033\n",
      "2.3905396461486816\n",
      "2.3904852867126465\n",
      "2.3904316425323486\n",
      "2.390378475189209\n",
      "2.3903253078460693\n",
      "2.390272378921509\n",
      "2.3902206420898438\n",
      "2.3901686668395996\n",
      "2.3901174068450928\n",
      "2.390066623687744\n",
      "2.3900163173675537\n",
      "2.3899660110473633\n",
      "2.38991641998291\n",
      "2.389867067337036\n",
      "2.3898181915283203\n",
      "2.3897697925567627\n",
      "2.389721632003784\n",
      "2.389674186706543\n",
      "2.3896265029907227\n",
      "2.3895797729492188\n",
      "2.389533042907715\n",
      "2.38948655128479\n",
      "2.3894405364990234\n",
      "2.3893954753875732\n",
      "2.389349937438965\n",
      "2.389305353164673\n",
      "2.38926100730896\n",
      "2.389216899871826\n",
      "2.3891727924346924\n",
      "2.389129161834717\n",
      "2.3890860080718994\n",
      "2.3890435695648193\n",
      "2.389000654220581\n",
      "2.38895845413208\n",
      "2.3889169692993164\n",
      "2.3888752460479736\n",
      "2.388834238052368\n",
      "2.3887929916381836\n",
      "2.3887522220611572\n",
      "2.388712167739868\n",
      "2.388671875\n",
      "2.388632297515869\n",
      "2.3885927200317383\n",
      "2.3885533809661865\n",
      "2.388514757156372\n",
      "2.3884763717651367\n",
      "2.3884377479553223\n",
      "2.388399600982666\n",
      "2.388362169265747\n",
      "2.388324499130249\n",
      "2.3882875442504883\n",
      "2.3882505893707275\n",
      "2.3882133960723877\n",
      "2.3881771564483643\n",
      "2.388140916824341\n",
      "2.3881051540374756\n",
      "2.3880693912506104\n",
      "2.3880341053009033\n",
      "2.3879990577697754\n",
      "2.3879637718200684\n",
      "2.3879292011260986\n",
      "2.387895107269287\n",
      "2.3878607749938965\n",
      "2.387826681137085\n",
      "2.3877930641174316\n",
      "2.3877594470977783\n",
      "2.387726306915283\n",
      "2.387693405151367\n",
      "2.387660503387451\n",
      "2.387627601623535\n",
      "2.3875954151153564\n",
      "2.3875629901885986\n",
      "2.387531280517578\n",
      "2.3874995708465576\n",
      "2.3874683380126953\n",
      "2.387436628341675\n",
      "2.3874058723449707\n",
      "2.3873748779296875\n",
      "2.3873443603515625\n",
      "2.3873136043548584\n",
      "2.3872838020324707\n",
      "2.3872532844543457\n",
      "2.387223720550537\n",
      "2.3871941566467285\n",
      "2.387164354324341\n",
      "2.3871355056762695\n",
      "2.38710618019104\n",
      "2.3870770931243896\n",
      "2.3870487213134766\n",
      "2.3870201110839844\n",
      "2.386991500854492\n",
      "2.3869638442993164\n",
      "2.3869357109069824\n",
      "2.3869078159332275\n",
      "2.3868801593780518\n",
      "2.386852741241455\n",
      "2.3868257999420166\n",
      "2.386798858642578\n",
      "2.3867716789245605\n",
      "2.386744976043701\n",
      "2.386718511581421\n",
      "2.3866920471191406\n",
      "2.3866658210754395\n",
      "2.3866398334503174\n",
      "2.386613368988037\n",
      "2.3865880966186523\n",
      "2.3865625858306885\n",
      "2.3865370750427246\n",
      "2.38651180267334\n",
      "2.386486530303955\n",
      "2.3864617347717285\n",
      "2.386437177658081\n",
      "2.3864123821258545\n",
      "2.386388063430786\n",
      "2.386363983154297\n",
      "2.3863391876220703\n",
      "2.3863158226013184\n",
      "2.386291265487671\n",
      "2.38626766204834\n",
      "2.386244058609009\n",
      "2.386220932006836\n",
      "2.386197566986084\n",
      "2.386174440383911\n",
      "2.3861513137817383\n",
      "2.3861281871795654\n",
      "2.386105537414551\n",
      "2.3860831260681152\n",
      "2.3860607147216797\n",
      "2.386038303375244\n",
      "2.3860161304473877\n",
      "2.3859941959381104\n",
      "2.385972023010254\n",
      "2.3859503269195557\n",
      "2.3859286308288574\n",
      "2.385906934738159\n",
      "2.385885715484619\n",
      "2.3858642578125\n",
      "2.38584303855896\n",
      "2.385822057723999\n",
      "2.385801076889038\n",
      "2.3857805728912354\n",
      "2.3857595920562744\n",
      "2.385739326477051\n",
      "2.385718584060669\n",
      "2.3856985569000244\n",
      "2.385678291320801\n",
      "2.385658025741577\n",
      "2.3856382369995117\n",
      "2.385618209838867\n",
      "2.3855984210968018\n",
      "2.3855788707733154\n",
      "2.385559320449829\n",
      "2.385540246963501\n",
      "2.3855206966400146\n",
      "2.3855013847351074\n",
      "2.3854825496673584\n",
      "2.3854634761810303\n",
      "2.385444402694702\n",
      "2.3854258060455322\n",
      "2.385406970977783\n",
      "2.3853886127471924\n",
      "2.3853702545166016\n",
      "2.3853516578674316\n",
      "2.385333776473999\n",
      "2.385315418243408\n",
      "2.3852975368499756\n",
      "2.385279655456543\n",
      "2.3852617740631104\n",
      "2.3852438926696777\n",
      "2.385226249694824\n",
      "2.38520884513855\n",
      "2.385190963745117\n",
      "2.385173797607422\n",
      "2.3851566314697266\n",
      "2.3851394653320312\n",
      "2.385122537612915\n",
      "2.3851053714752197\n",
      "2.3850882053375244\n",
      "2.3850717544555664\n",
      "2.3850550651550293\n",
      "2.385038137435913\n",
      "2.385021686553955\n",
      "2.3850057125091553\n",
      "2.3849892616271973\n",
      "2.38497257232666\n",
      "2.3849565982818604\n",
      "2.3849403858184814\n",
      "2.3849244117736816\n",
      "2.384908437728882\n",
      "2.384892463684082\n",
      "2.3848769664764404\n",
      "2.3848612308502197\n",
      "2.384845733642578\n",
      "2.3848302364349365\n",
      "2.384814739227295\n",
      "2.3847994804382324\n",
      "2.3847837448120117\n",
      "2.3847687244415283\n",
      "2.384753704071045\n",
      "2.3847386837005615\n",
      "2.3847239017486572\n",
      "2.3847086429595947\n",
      "2.3846943378448486\n",
      "2.3846793174743652\n",
      "2.38466477394104\n",
      "2.384650468826294\n",
      "2.3846359252929688\n",
      "2.3846211433410645\n",
      "2.3846068382263184\n",
      "2.3845925331115723\n",
      "2.3845784664154053\n",
      "2.3845643997192383\n",
      "2.384549856185913\n",
      "2.3845362663269043\n",
      "2.3845224380493164\n",
      "2.3845083713531494\n",
      "2.3844945430755615\n",
      "2.3844809532165527\n",
      "2.384467124938965\n",
      "2.384453773498535\n",
      "2.3844399452209473\n",
      "2.3844265937805176\n",
      "2.384413480758667\n",
      "2.384399890899658\n",
      "2.3843865394592285\n",
      "2.384373426437378\n",
      "2.3843603134155273\n",
      "2.384347438812256\n",
      "2.3843343257904053\n",
      "2.3843212127685547\n",
      "2.384308338165283\n",
      "2.3842954635620117\n",
      "2.3842833042144775\n",
      "2.384270429611206\n",
      "2.3842577934265137\n",
      "2.3842451572418213\n",
      "2.384232759475708\n",
      "2.3842198848724365\n",
      "2.3842077255249023\n",
      "2.384195327758789\n",
      "2.384183168411255\n",
      "2.3841710090637207\n",
      "2.3841586112976074\n",
      "2.3841466903686523\n",
      "2.384134531021118\n",
      "2.384122848510742\n",
      "2.384110450744629\n",
      "2.384098768234253\n",
      "2.384087085723877\n",
      "2.384075164794922\n",
      "2.384063482284546\n",
      "2.38405179977417\n",
      "2.384040355682373\n",
      "2.384028673171997\n",
      "2.3840174674987793\n",
      "2.3840057849884033\n",
      "2.3839943408966064\n",
      "2.3839831352233887\n",
      "2.383971929550171\n",
      "2.383960485458374\n",
      "2.383949041366577\n",
      "2.3839383125305176\n",
      "2.383927345275879\n",
      "2.3839163780212402\n",
      "2.3839049339294434\n",
      "2.3838939666748047\n",
      "2.383883476257324\n",
      "2.3838727474212646\n",
      "2.383861780166626\n",
      "2.3838508129119873\n",
      "2.383840322494507\n",
      "2.3838295936584473\n",
      "2.383819341659546\n",
      "2.3838088512420654\n",
      "2.383798122406006\n",
      "2.3837876319885254\n",
      "2.383777141571045\n",
      "2.3837666511535645\n",
      "2.383756399154663\n",
      "2.383746385574341\n",
      "2.3837358951568604\n",
      "2.383725643157959\n",
      "2.3837156295776367\n",
      "2.3837058544158936\n",
      "2.383695602416992\n",
      "2.383685350418091\n",
      "2.3836755752563477\n",
      "2.3836655616760254\n",
      "2.383655548095703\n",
      "2.38364577293396\n",
      "2.383636236190796\n",
      "2.383626699447632\n",
      "2.3836166858673096\n",
      "2.3836071491241455\n",
      "2.3835973739624023\n",
      "2.3835878372192383\n",
      "2.383578300476074\n",
      "2.38356876373291\n",
      "2.383559465408325\n",
      "2.383549928665161\n",
      "2.383540630340576\n",
      "2.383531332015991\n",
      "2.3835220336914062\n",
      "2.3835127353668213\n",
      "2.3835034370422363\n",
      "2.3834943771362305\n",
      "2.3834853172302246\n",
      "2.3834760189056396\n",
      "2.383467197418213\n",
      "2.383457899093628\n",
      "2.383448600769043\n",
      "2.3834400177001953\n",
      "2.3834309577941895\n",
      "2.383422374725342\n",
      "2.383413553237915\n",
      "2.383404493331909\n",
      "2.3833956718444824\n",
      "2.3833870887756348\n",
      "2.383378505706787\n",
      "2.3833696842193604\n",
      "2.3833611011505127\n",
      "2.383352279663086\n",
      "2.3833439350128174\n",
      "2.3833351135253906\n",
      "2.383326768875122\n",
      "2.3833186626434326\n",
      "2.3833096027374268\n",
      "2.3833014965057373\n",
      "2.3832931518554688\n",
      "2.3832848072052\n",
      "2.3832764625549316\n",
      "2.383268117904663\n",
      "2.3832600116729736\n",
      "2.383251667022705\n",
      "2.3832435607910156\n",
      "2.383235454559326\n",
      "2.3832273483276367\n",
      "2.3832192420959473\n",
      "2.383211135864258\n",
      "2.3832032680511475\n"
     ]
    }
   ],
   "source": [
    "# Apprentissage: descente du gradient\n",
    "for k in range(600):\n",
    "  \n",
    "  # Forward pass\n",
    "  xenc = F.one_hot(xs, num_classes=nb_chars).float() # input to the network: one-hot encoding\n",
    "  logits = xenc @ W # predict log-counts (logits)\n",
    "  counts = logits.exp() # counts, equivalent to N\n",
    "  probs = counts / counts.sum(1, keepdims=True) # probabilities for next character\n",
    "  loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean()  # + 0.01... for smoothing the model\n",
    "  print(loss.item())\n",
    "  \n",
    "  # backward pass\n",
    "  W.grad = None # set to zero the gradient\n",
    "  loss.backward()\n",
    "  \n",
    "  # update\n",
    "  W.data += -50 * W.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "3a5852ba-40b3-4a77-864c-f3170f0f0729",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "éssanée.\n",
      "mexcororér.\n",
      "monts.\n",
      "ex.\n",
      "moût.\n"
     ]
    }
   ],
   "source": [
    "# finally, sample from the 'neural net' model\n",
    "g = torch.Generator().manual_seed(2147483647)\n",
    "\n",
    "for i in range(5):\n",
    "  \n",
    "  out = []\n",
    "  ix = 0\n",
    "  while True:\n",
    "    \n",
    "    # ----------\n",
    "    # Avec l'approche par comptage on utilisait:\n",
    "    #p = P[ix]\n",
    "    # ----------\n",
    "    # NOW:\n",
    "    xenc = F.one_hot(torch.tensor([ix]), num_classes=nb_chars).float()\n",
    "    logits = xenc @ W # predict log-counts\n",
    "    counts = logits.exp() # counts, equivalent to N\n",
    "    p = counts / counts.sum(1, keepdims=True) # probabilities for next character\n",
    "    # ----------\n",
    "    \n",
    "    ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()\n",
    "    out.append(itoc[ix])\n",
    "    if ix == 0:\n",
    "      break\n",
    "  print(''.join(out))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f02816c8-50a5-40f6-8a03-f0bff2f3b9dc",
   "metadata": {},
   "source": [
    "On voit qu'on a les mêmes mots que ceux générés par comptage, nous avons donc bâti une méthode neuronale équivalente à ce qu'on obtient par la méthode par comptage."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
