{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "56710956-7461-4d84-86b1-bfd3ae4924e2",
   "metadata": {},
   "source": [
    "# 9. Wavenet (CNN)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f238647-e5ac-440c-a180-95cc9a57898a",
   "metadata": {},
   "source": [
    "- Bengio et al. 2003 MLP LM <https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf>\n",
    "- WaveNet 2016 from DeepMind <https://arxiv.org/abs/1609.03499>\n",
    "- <https://deepmind.google/discover/blog/wavenet-a-generative-model-for-raw-audio/>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02800f69",
   "metadata": {},
   "source": [
    "L'objectif est d'améliorer le modèle de langue au niveau des caractères en s'éloignant d'un simple perceptron multicouche (MLP) qui écrase immédiatement tous les caractères d'entrée en une seule couche cachée. Nous allons mettre en oeuvre une architecture hiérarchique inspirée de l'article **WaveNet** (2016). Le but est de fusionner progressivement les informations (deux caractères à la fois) afin de traiter plus efficacement les contextes plus longs."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "202308d9-8529-4826-aeeb-469a34ffb16c",
   "metadata": {},
   "source": [
    "# Reprise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "caddd508",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "torch.set_default_device('mps')\n",
    "\n",
    "seed = 2147483647\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "class Words(object):\n",
    "    \"\"\"Représente une liste de mots, ainsi que la liste ordonnée des caractères les composants.\"\"\"\n",
    "\n",
    "    EOS = '.'\n",
    "\n",
    "    def __init__(self, filename):\n",
    "        self.filename = filename\n",
    "        self.words = open(self.filename, 'r').read().splitlines()\n",
    "        self.nb_words = len(self.words)\n",
    "        self.chars = sorted(list(set(''.join(self.words))))\n",
    "        self.nb_chars = len(self.chars) + 1  # On ajoute 1 pour EOS\n",
    "        self.ctoi = {c:i+1 for i,c in enumerate(self.chars)}\n",
    "        self.ctoi[self.EOS] = 0\n",
    "        self.itoc = {i:s for s,i in self.ctoi.items()}\n",
    "\n",
    "    def __repr__(self):\n",
    "        l = []\n",
    "        l.append(\"<Words\")\n",
    "        l.append(f'  filename=\"{self.filename}\"')\n",
    "        l.append(f'  nb_words=\"{self.nb_words}\"')\n",
    "        l.append(f'  nb_chars=\"{self.nb_chars}\"/>')\n",
    "        return '\\n'.join(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ed358a8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Datasets:\n",
    "    \"\"\"Construits les jeu de données d'entraînement, de test et de validation.\n",
    "    \n",
    "    Prend en paramètres une liste de mots et la taille du contexte pour la prédiction.\n",
    "    \"\"\"\n",
    "\n",
    "    def _build_dataset(self, lwords:list, context_size:int):\n",
    "        X, Y = [], []\n",
    "        for w in lwords:\n",
    "            context = [0] * context_size\n",
    "            for ch in w + self.words.EOS:\n",
    "                ix = self.words.ctoi[ch]\n",
    "                X.append(context)\n",
    "                Y.append(ix)\n",
    "                context = context[1:] + [ix] # crop and append\n",
    "        X = torch.tensor(X)\n",
    "        Y = torch.tensor(Y)\n",
    "        return X, Y\n",
    "    \n",
    "    def __init__(self, words:Words, context_size:int, seed:int=42):\n",
    "        # 80%, 10%, 10%\n",
    "        self.shuffled_words = words.words.copy()\n",
    "        random.shuffle(self.shuffled_words)\n",
    "        self.n1 = int(0.8*len(self.shuffled_words))\n",
    "        self.n2 = int(0.9*len(self.shuffled_words))\n",
    "        self.words = words\n",
    "        self.Xtr, self.Ytr = self._build_dataset(self.shuffled_words[:self.n1], context_size)\n",
    "        self.Xdev, self.Ydev = self._build_dataset(self.shuffled_words[self.n1:self.n2], context_size)\n",
    "        self.Xte, self.Yte = self._build_dataset(self.shuffled_words[self.n2:], context_size)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7b458ba3-1ae8-4b9b-99aa-68daa355f14b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Embedding:\n",
    "  \n",
    "    def __init__(self, num_embeddings, embedding_dim):\n",
    "        self.weight = torch.randn((num_embeddings, embedding_dim))\n",
    "    \n",
    "    def __call__(self, IX):\n",
    "        self.out = self.weight[IX]\n",
    "        return self.out\n",
    "  \n",
    "    def parameters(self):\n",
    "        return [self.weight]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "129c38bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "class FlattenConsecutive:\n",
    "  \n",
    "    def __init__(self, n): # n = 2\n",
    "        self.n = n\n",
    "    \n",
    "    def __call__(self, x):\n",
    "        B, T, C = x.shape  # 4, 8, 10\n",
    "        x = x.view(B, T//self.n, C*self.n)\n",
    "        if x.shape[1] == 1:\n",
    "            x = x.squeeze(1)\n",
    "        self.out = x\n",
    "        return self.out\n",
    "  \n",
    "    def parameters(self):\n",
    "        return []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "cfb9e6dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Linear:\n",
    "    \"\"\"Linear layer.\n",
    "    \n",
    "    Similar to <https://pytorch.org/docs/stable/generated/torch.nn.Linear.html>\n",
    "    \"\"\"\n",
    "    def __init__(self, fan_in, fan_out, bias=True):\n",
    "        self.weight = torch.randn((fan_in, fan_out)) / fan_in**0.5\n",
    "        self.bias = torch.zeros(fan_out) if bias else None\n",
    "  \n",
    "    def __call__(self, x):\n",
    "        self.out = x @ self.weight\n",
    "        if self.bias is not None:\n",
    "            self.out += self.bias\n",
    "        return self.out\n",
    "  \n",
    "    def parameters(self):\n",
    "        return [self.weight] + ([] if self.bias is None else [self.bias])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "c6ff29bc-59c6-4416-a3c7-ade2097b03b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class BatchNorm1d:\n",
    "    \"\"\"Batch normalization layer.\n",
    "    \n",
    "    Similar to <https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html#torch.nn.BatchNorm1d>\n",
    "    \"\"\"\n",
    "  \n",
    "    def __init__(self, dim, eps=1e-5, momentum=0.1):\n",
    "        self.eps = eps\n",
    "        self.momentum = momentum\n",
    "        self.training = True\n",
    "        # parameters (trained with backprop)\n",
    "        self.gamma = torch.ones(dim)\n",
    "        self.beta = torch.zeros(dim)\n",
    "        # buffers (trained with a running 'momentum update')\n",
    "        self.running_mean = torch.zeros(dim)\n",
    "        self.running_var = torch.ones(dim)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        # calculate the forward pass\n",
    "        if self.training:\n",
    "            if x.ndim == 2:\n",
    "                dim = 0\n",
    "            elif x.ndim == 3:\n",
    "                dim = (0, 1)\n",
    "            xmean = x.mean(dim, keepdim=True) # batch mean\n",
    "            xvar = x.var(dim, keepdim=True) # batch variance\n",
    "        else:\n",
    "            xmean = self.running_mean\n",
    "            xvar = self.running_var\n",
    "        xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance\n",
    "        self.out = self.gamma * xhat + self.beta\n",
    "        # update the buffers\n",
    "        if self.training:\n",
    "            with torch.no_grad():\n",
    "                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean\n",
    "                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar\n",
    "        return self.out\n",
    "\n",
    "    def parameters(self):\n",
    "        return [self.gamma, self.beta]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "924700fb-fd48-4dff-9842-119eaf83a6b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Tanh:\n",
    "    \n",
    "    def __call__(self, x):\n",
    "        self.out = torch.tanh(x)\n",
    "        return self.out\n",
    "\n",
    "    def parameters(self):\n",
    "        return []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "d9ce689e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Sequential:\n",
    "    \"\"\"<https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html>\"\"\"\n",
    "    def __init__(self, layers):\n",
    "        self.layers = layers\n",
    "  \n",
    "    def __call__(self, x):\n",
    "        for layer in self.layers:\n",
    "            x = layer(x)\n",
    "        self.out = x\n",
    "        return self.out\n",
    "  \n",
    "    def parameters(self):\n",
    "        # get parameters of all layers and stretch them out into one list\n",
    "        return [p for layer in self.layers for p in layer.parameters()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "3fa9e7df",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Wavenet:\n",
    "    \n",
    "    def __init__(self, e_dims, n_hidden, context_size, nb_chars):\n",
    "        self.nb_chars = nb_chars\n",
    "        self.e_dims = e_dims\n",
    "        self.n_hidden = n_hidden\n",
    "        self.context_size = context_size\n",
    "        self.steps = 0\n",
    "        self.create_model()\n",
    "\n",
    "    def _create_model(self):\n",
    "        self.model = Sequential([\n",
    "            Embedding(self.nb_chars, self.e_dims),\n",
    "            FlattenConsecutive(2), Linear(self.e_dims   * 2, self.n_hidden, bias=False), BatchNorm1d(self.n_hidden), Tanh(),\n",
    "            FlattenConsecutive(2), Linear(self.n_hidden * 2, self.n_hidden, bias=False), BatchNorm1d(self.n_hidden), Tanh(),\n",
    "            FlattenConsecutive(2), Linear(self.n_hidden * 2, self.n_hidden, bias=False), BatchNorm1d(self.n_hidden), Tanh(),\n",
    "            Linear(self.n_hidden, self.nb_chars),\n",
    "        ])\n",
    "        with torch.no_grad():\n",
    "            self.model.layers[-1].weight *= 0.1\n",
    "\n",
    "    def create_model(self):\n",
    "        self._create_model()\n",
    "        self.parameters = self.model.parameters()\n",
    "        for p in self.parameters:\n",
    "            p.requires_grad = True\n",
    "        self.nb_parameters = sum(p.nelement() for p in self.parameters)\n",
    "\n",
    "    def forward(self, X, Y):\n",
    "        logits = self.model(X)\n",
    "        self.loss = F.cross_entropy(logits, Y) # loss function\n",
    "\n",
    "    def backward(self):\n",
    "        for p in self.parameters:\n",
    "            p.grad = None\n",
    "        self.loss.backward()\n",
    "\n",
    "    def train(self, datasets: Datasets, max_steps, mini_batch_size):\n",
    "        #lossi = []\n",
    "        for i in range(max_steps):\n",
    "            # minibatch construct\n",
    "            ix = torch.randint(0, datasets.Xtr.shape[0], (mini_batch_size,))\n",
    "            Xb, Yb = datasets.Xtr[ix], datasets.Ytr[ix]\n",
    "            \n",
    "            # forward pass\n",
    "            self.forward(Xb, Yb)\n",
    "        \n",
    "            # backward pass\n",
    "            self.backward()\n",
    "\n",
    "            # update\n",
    "            lr = 0.2 if i < 100000 else 0.02 # step learning rate decay\n",
    "            self.update_grad(lr)\n",
    "        \n",
    "            # track stats\n",
    "            if i % 10000 == 0:\n",
    "                print(f\"{i:7d}/{max_steps:7d}\")\n",
    "            #lossi.append(self.loss.log10().item())\n",
    "        self.steps += max_steps\n",
    "        #return lossi\n",
    "\n",
    "    def update_grad(self, lr):\n",
    "        for p in self.parameters:\n",
    "            p.data += -lr * p.grad\n",
    "\n",
    "    @torch.no_grad() # this decorator disables gradient tracking\n",
    "    def compute_loss(self, X, Y):\n",
    "        logits = self.model(X)\n",
    "        loss = F.cross_entropy(logits, Y)\n",
    "        return loss\n",
    "\n",
    "    @torch.no_grad() # this decorator disables gradient tracking\n",
    "    def training_loss(self, datasets:Datasets):\n",
    "        loss = self.compute_loss(datasets.Xtr, datasets.Ytr)\n",
    "        return loss.item()\n",
    "\n",
    "    @torch.no_grad() # this decorator disables gradient tracking\n",
    "    def test_loss(self, datasets:Datasets):\n",
    "        loss = self.compute_loss(datasets.Xte, datasets.Yte)\n",
    "        return loss.item()\n",
    "\n",
    "    @torch.no_grad() # this decorator disables gradient tracking\n",
    "    def dev_loss(self, datasets:Datasets):\n",
    "        loss = self.compute_loss(datasets.Xdev, datasets.Xdev)\n",
    "        return loss.item()\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def generate_word(self, itoc, g):\n",
    "        for layer in self.model.layers:\n",
    "            layer.training = False\n",
    "        out = []\n",
    "        context = [0] * self.context_size\n",
    "        while True:\n",
    "            logits = self.model(torch.tensor([context]))\n",
    "            probs = F.softmax(logits, dim=1)\n",
    "            # Sample from the probability distribution\n",
    "            ix = torch.multinomial(probs, num_samples=1, generator=g).item()\n",
    "            # Shift the context window\n",
    "            context = context[1:] + [ix]\n",
    "            # Store the generated character\n",
    "            if ix != 0:\n",
    "                out.append(ix)\n",
    "            else:\n",
    "                # Stop when encounting '.'\n",
    "                break\n",
    "        return ''.join(itoc[i] for i in out)\n",
    "\n",
    "    def __repr__(self):\n",
    "        l = []\n",
    "        l.append(\"<Wavenet\")\n",
    "        l.append(f'  nb_chars=\"{self.nb_chars}\"')\n",
    "        l.append(f'  e_dims=\"{self.e_dims}\"')\n",
    "        l.append(f'  n_hidden=\"{self.n_hidden}\"')\n",
    "        l.append(f'  context_size=\"{self.context_size}\"')\n",
    "        l.append(f'  nb_parameters=\"{self.nb_parameters}\"/>')\n",
    "        return '\\n'.join(l)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "300c4fc0",
   "metadata": {},
   "source": [
    "### Entraînement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "d22250ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<Words\n",
      "  filename=\"civil_mots.txt\"\n",
      "  nb_words=\"7223\"\n",
      "  nb_chars=\"41\"/>\n"
     ]
    }
   ],
   "source": [
    "words = Words('civil_mots.txt')\n",
    "print(words)\n",
    "context_size = 8\n",
    "datasets = Datasets(words, context_size)\n",
    "vocab_size = words.nb_chars\n",
    "e_dims = 24 # the dimensionality of the character embedding vectors\n",
    "n_hidden = 128 # the number of neurons in the hidden layer of the FFN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "8c0980b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<Wavenet\n",
      "  nb_chars=\"41\"\n",
      "  e_dims=\"24\"\n",
      "  n_hidden=\"128\"\n",
      "  context_size=\"8\"\n",
      "  nb_parameters=\"78721\"/>\n"
     ]
    }
   ],
   "source": [
    "nn = Wavenet(e_dims, n_hidden, context_size, words.nb_chars)\n",
    "print(nn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "be203a9c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      0/ 200000\n",
      "  10000/ 200000\n",
      "  20000/ 200000\n",
      "  30000/ 200000\n",
      "  40000/ 200000\n",
      "  50000/ 200000\n",
      "  60000/ 200000\n",
      "  70000/ 200000\n",
      "  80000/ 200000\n",
      "  90000/ 200000\n",
      " 100000/ 200000\n",
      " 110000/ 200000\n",
      " 120000/ 200000\n",
      " 130000/ 200000\n",
      " 140000/ 200000\n",
      " 150000/ 200000\n",
      " 160000/ 200000\n",
      " 170000/ 200000\n",
      " 180000/ 200000\n",
      " 190000/ 200000\n",
      "train_loss=0.9990963339805603\n",
      "val_loss=1.636179804801941\n"
     ]
    }
   ],
   "source": [
    "max_steps = 200000\n",
    "mini_batch_size = 32\n",
    "nn.train(datasets, max_steps, mini_batch_size)\n",
    "train_loss = nn.training_loss(datasets)\n",
    "val_loss = nn.test_loss(datasets)\n",
    "print(f\"{train_loss=}\")\n",
    "print(f\"{val_loss=}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95e44944",
   "metadata": {},
   "source": [
    "### Génération"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "5b5844e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "g = torch.Generator(device='mps').manual_seed(seed + 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "83d34e89",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "exécutoire\n",
      "imposée\n",
      "rendus\n",
      "équant\n",
      "organisé\n",
      "employeur\n",
      "garanti\n",
      "quatre-vingt\n",
      "lorsqu'on\n",
      "avantageuses\n",
      "contresigné\n",
      "formalité\n",
      "injugal\n",
      "coulevée\n",
      "finit\n",
      "délivrés\n",
      "reporter\n",
      "résultera\n",
      "ordre\n",
      "casrect\n"
     ]
    }
   ],
   "source": [
    "for _ in range(20):\n",
    "    word = nn.generate_word(words.itoc, g)\n",
    "    print(word)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cours_nlp_mines (3.14.2)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
