diff --git a/week11/week11.ipynb b/week11/week11.ipynb
new file mode 100644
index 0000000..5b175ec
--- /dev/null
+++ b/week11/week11.ipynb
@@ -0,0 +1,2798 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Continual learning: Overcoming catastophic forgetting with memory replay\n",
+ "\n",
+ "In this exercise class we'll implement an experiment to measure catastrophic forgetting in a neural network trained on MNIST. We will then fix/reduce the catastrophic forgetting by implementing a simple memory replay strategy."
+ ],
+ "metadata": {
+ "id": "lDTvzuBtNg1X"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Mzy1CzlGQvAk"
+ },
+ "source": [
+ "The following figure highlights the setup of the dataset and is taken from [van den Ven & Tolias, 2019](https://arxiv.org/pdf/1904.07734.pdf):\n",
+ "\n",
+ "![image.png]()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "id": "MN_GMVuyPD9D"
+ },
+ "outputs": [],
+ "source": [
+ "\"\"\"\n",
+ "Code adapted from the torchvision MNIST example:\n",
+ "https://github.com/pytorch/examples/blob/main/mnist/main.py\n",
+ "\n",
+ "BSD 3-Clause License\n",
+ "\n",
+ "Copyright (c) 2017, \n",
+ "All rights reserved.\n",
+ "\n",
+ "Redistribution and use in source and binary forms, with or without\n",
+ "modification, are permitted provided that the following conditions are met:\n",
+ "\n",
+ "* Redistributions of source code must retain the above copyright notice, this\n",
+ " list of conditions and the following disclaimer.\n",
+ "\n",
+ "* Redistributions in binary form must reproduce the above copyright notice,\n",
+ " this list of conditions and the following disclaimer in the documentation\n",
+ " and/or other materials provided with the distribution.\n",
+ "\n",
+ "* Neither the name of the copyright holder nor the names of its\n",
+ " contributors may be used to endorse or promote products derived from\n",
+ " this software without specific prior written permission.\n",
+ "\n",
+ "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n",
+ "AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n",
+ "IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n",
+ "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n",
+ "FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n",
+ "DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n",
+ "SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n",
+ "CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n",
+ "OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n",
+ "OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n",
+ "\"\"\"\n",
+ "\n",
+ "import argparse\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "import torch.optim as optim\n",
+ "from torchvision import datasets, transforms\n",
+ "from torch.optim.lr_scheduler import StepLR"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "id": "Bu5GtxDyPMW5"
+ },
+ "outputs": [],
+ "source": [
+ "# Network implementation -- bonus exercise: Modify the network architecture,\n",
+ "# and study the effect on the training results.\n",
+ "\n",
+ "class Net(nn.Module):\n",
+ "\n",
+ " def __init__(self):\n",
+ " super(Net, self).__init__()\n",
+ " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
+ " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
+ " self.dropout1 = nn.Dropout(0.25)\n",
+ " self.dropout2 = nn.Dropout(0.5)\n",
+ " self.fc1 = nn.Linear(9216, 128)\n",
+ " self.fc2 = nn.Linear(128, 10)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = self.conv1(x)\n",
+ " x = F.relu(x)\n",
+ " x = self.conv2(x)\n",
+ " x = F.relu(x)\n",
+ " x = F.max_pool2d(x, 2)\n",
+ " x = self.dropout1(x)\n",
+ " x = torch.flatten(x, 1)\n",
+ " x = self.fc1(x)\n",
+ " x = F.relu(x)\n",
+ " x = self.dropout2(x)\n",
+ " x = self.fc2(x)\n",
+ " output = F.log_softmax(x, dim=1)\n",
+ " return output"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "id": "FljvS0QcPPTg"
+ },
+ "outputs": [],
+ "source": [
+ "# Exercise: Adapt the following function and store part of the images\n",
+ "# presented during training in a replay buffer. Replay these images\n",
+ "# during the training step.\n",
+ "\n",
+ "def train(args, model, device, train_loader, optimizer, epoch, buffer = None):\n",
+ " model.train()\n",
+ "\n",
+ " for batch_idx, (data, target) in enumerate(train_loader):\n",
+ "\n",
+ " ### START SOLUTION ###\n",
+ " # We append images and labels from the first training batch\n",
+ " # to our buffer. You can extend this strategy based on how\n",
+ " # many images you choose to store in the buffer.\n",
+ " if buffer is not None and batch_idx == 0:\n",
+ " if buffer is not None:\n",
+ " images, targets = next(iter(train_loader))\n",
+ " buffer.add(images, targets)\n",
+ " ### END SOLUTION ###\n",
+ "\n",
+ "\n",
+ " data, target = data.to(device), target.to(device)\n",
+ "\n",
+ " optimizer.zero_grad()\n",
+ " output = model(data)\n",
+ " loss = F.nll_loss(output, target)\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ " ### START SOLUTION ###\n",
+ " # A simple strategy for overcoming forgetting is to retrieve images\n",
+ " # from the buffer (here: one image per class) and perform a gradient\n",
+ " # step on these images along with every incoming new batch.\n",
+ " if buffer is not None and len(buffer) > 0:\n",
+ " replayed_images, replayed_targets = buffer.get()\n",
+ " replayed_images = replayed_images.to(device)\n",
+ " replayed_targets = replayed_targets.to(device)\n",
+ " optimizer.zero_grad()\n",
+ " output = model(replayed_images)\n",
+ " loss = F.nll_loss(output, replayed_targets)\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " ### END SOLUTION ###\n",
+ "\n",
+ " if batch_idx % args.log_interval == 0:\n",
+ " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
+ " epoch, batch_idx * len(data), len(train_loader.dataset),\n",
+ " 100. * batch_idx / len(train_loader), loss.item()))\n",
+ " if args.dry_run:\n",
+ " break"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "id": "QXPgWNCRPRd4"
+ },
+ "outputs": [],
+ "source": [
+ "# The test routine was adapted from the original implementation and now computes\n",
+ "# the classification probabilities per class instead of an average. This allows to\n",
+ "# later assess the effect of catastrophic forgetting. No adaptations in this function\n",
+ "# are required for the exercise.\n",
+ "\n",
+ "import collections\n",
+ "\n",
+ "def test(model, device, test_loader):\n",
+ " model.eval()\n",
+ " test_loss = 0\n",
+ " correct_by_class = collections.Counter()\n",
+ " count_by_class = collections.Counter()\n",
+ " with torch.no_grad():\n",
+ " for data, target in test_loader:\n",
+ " data, target = data.to(device), target.to(device)\n",
+ " output = model(data)\n",
+ " test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n",
+ " pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n",
+ " correct = pred.eq(target.view_as(pred)).float()\n",
+ " for class_ in range(10):\n",
+ " idc = (target == class_)\n",
+ " correct_by_class[class_] += correct[idc].sum().item()\n",
+ " count_by_class[class_] += idc.sum().item()\n",
+ "\n",
+ " test_loss /= len(test_loader.dataset)\n",
+ " test_acc = correct / len(test_loader.dataset)\n",
+ "\n",
+ " print(f'\\nTest set: Average loss: {test_loss:.4f}')\n",
+ " result = {}\n",
+ " for class_ in range(10):\n",
+ " acc = correct_by_class[class_] / count_by_class[class_]\n",
+ " result[class_] = acc\n",
+ " print(\n",
+ " f\"Class {class_} accuracy: \"\n",
+ " f\"{correct_by_class[class_]}/{count_by_class[class_]}\"\n",
+ " f\"({acc*100:.0f}%)\"\n",
+ " )\n",
+ "\n",
+ " return result"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "id": "UsdJlkLkRRbp"
+ },
+ "outputs": [],
+ "source": [
+ "# Exercise: Adapt the dataset class \"MNISTContinualLearning\" for the experiment\n",
+ "# outlined in the introduction text. The class needs to support indexing of the\n",
+ "# dataset based on the provided list of classes.\n",
+ "\n",
+ "def identity(x):\n",
+ " return\n",
+ "\n",
+ "class MNISTContinualLearning(datasets.MNIST):\n",
+ "\n",
+ " def __init__(self, *args, classes=list(range(10)), **kwargs):\n",
+ "\n",
+ " # This inherits from the base dataset\n",
+ " super().__init__(*args, **kwargs)\n",
+ "\n",
+ " if len(classes) < 2:\n",
+ " raise ValueError(f\"Need at least two classes, but got {len(classes)}\")\n",
+ "\n",
+ " # Add code for filtering the dataset here. You need to adapt\n",
+ " # the \"data\" and \"targets\" attribute of the dataset. The \"data\"\n",
+ " # attribute stores the images as a numpy array, while the \"targets\"\n",
+ " # attributes stores the labels as a numpy array.\n",
+ " # You can override the existing attributes.\n",
+ " #\n",
+ " # self.data = ...\n",
+ " # self.targets = ...\n",
+ " \n",
+ " ### START SOLUTION ###\n",
+ " idc = None\n",
+ " for class_ in classes:\n",
+ " if idc is None:\n",
+ " idc = self.targets == class_\n",
+ " idc |= self.targets == class_\n",
+ "\n",
+ " self.data = self.data[idc]\n",
+ " self.targets = self.targets[idc]\n",
+ " ### END SOLUTION ###\n",
+ "\n",
+ "\n",
+ "def test_dataset():\n",
+ " transform=transforms.Compose([\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize((0.1307,), (0.3081,))\n",
+ " ])\n",
+ " dataset1 = MNISTContinualLearning('../data', train=True, download=True, transform=transform, classes = [0, 5])\n",
+ " assert len(dataset1) == 11344\n",
+ "\n",
+ "test_dataset()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "id": "EDwtthksNUpw"
+ },
+ "outputs": [],
+ "source": [
+ "# To implement memory replay, add additional code (functions or classes)\n",
+ "# to this cell. One possible solution is to implement a \"Buffer\" class\n",
+ "# that allows to add images into memory, and allows to retrieve the stored\n",
+ "# images and labels for training the model.\n",
+ "#\n",
+ "# The buffer is typically memory constrained, and there are multiple ways\n",
+ "# to efficiently compress the individual elements present in the buffer.\n",
+ "# Think about different ways of reducing the storage required by your buffer\n",
+ "# class, and explore which of them is most effective at mitigating catastophic\n",
+ "# forgetting.\n",
+ "\n",
+ "\n",
+ "### START SOLUTION ###\n",
+ "\n",
+ "class Buffer(nn.Module):\n",
+ "\n",
+ " def __init__(self):\n",
+ " self.buffer = {}\n",
+ " \n",
+ " def add(self, images, targets):\n",
+ " for class_ in targets.unique():\n",
+ " idc = targets == class_\n",
+ " self.buffer[class_] = images[idc][0]\n",
+ "\n",
+ " def get(self):\n",
+ " assert len(self) > 0\n",
+ " keys = list(self.buffer.keys())\n",
+ " targets = torch.tensor(keys)\n",
+ " images = torch.stack([self.buffer[k] for k in keys], dim = 0)\n",
+ " assert len(targets) == len(images)\n",
+ " return images, targets\n",
+ "\n",
+ " def __len__(self):\n",
+ " return sum(len(v) if v is not None else 0 for v in self.buffer.values())\n",
+ " \n",
+ "### END SOLUTION ###"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "id": "u2xvlBoZPTyx"
+ },
+ "outputs": [],
+ "source": [
+ "# Adapt the main training loop to work with the functions you defined above.\n",
+ "\n",
+ "def train_model(args, model, phase, replay_buffer = [], history = [], buffer = None):\n",
+ " use_cuda = not args.no_cuda and torch.cuda.is_available()\n",
+ " device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
+ "\n",
+ " torch.manual_seed(args.seed)\n",
+ "\n",
+ " train_kwargs = {'batch_size': args.batch_size}\n",
+ " test_kwargs = {'batch_size': args.test_batch_size}\n",
+ " if use_cuda:\n",
+ " cuda_kwargs = {'num_workers': 1,\n",
+ " 'pin_memory': True,\n",
+ " 'shuffle': True}\n",
+ " train_kwargs.update(cuda_kwargs)\n",
+ " test_kwargs.update(cuda_kwargs)\n",
+ "\n",
+ " transform=transforms.Compose([\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize((0.1307,), (0.3081,))\n",
+ " ])\n",
+ " \n",
+ " # You need to replace the original MNIST dataset here by the continual\n",
+ " # learning dataset we implemented in the previous cell. Make sure to\n",
+ " # pass the list of classes selected for training in args.train_classes.\n",
+ " \n",
+ " ### START SOLUTION ###\n",
+ " # We added the revised MNIST class here, which takes the same arguments\n",
+ " # as the original class, but additionally takes a \"classes\" argument which\n",
+ " # specifies the subselection of classes to consider during this training\n",
+ " # phase.\n",
+ " dataset1 = MNISTContinualLearning('../data', train=True, download=True,\n",
+ " transform=transform, classes=args.train_classes)\n",
+ " ### END SOLUTION ###\n",
+ " dataset2 = datasets.MNIST('../data', train=False, download=True,\n",
+ " transform=transform)\n",
+ " train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)\n",
+ " test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)\n",
+ "\n",
+ " model = model.to(device)\n",
+ " optimizer = optim.Adadelta(model.parameters(), lr=args.lr)\n",
+ " scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)\n",
+ "\n",
+ " for epoch in range(1, args.epochs + 1):\n",
+ " ### START SOLUTION ###\n",
+ " # We additionally pass the \"buffer\" class here to collect training\n",
+ " # images for later memory replay.\n",
+ " # \n",
+ " # Original content:\n",
+ " # train(args, model, device, train_loader, optimizer, epoch)\n",
+ " train(args, model, device, train_loader, optimizer, epoch, buffer)\n",
+ " ### END SOLUTION ###\n",
+ " test_results = test(model, device, test_loader)\n",
+ " test_results[\"phase\"] = phase\n",
+ " history.append(test_results)\n",
+ " scheduler.step()\n",
+ "\n",
+ " return history\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# We now define the experiment setup.\n",
+ "# Without modifications to the code, the following will simply train an MNIST\n",
+ "# network multiple times, and reach an accuracy of >99% on the test set.\n",
+ "#\n",
+ "# The result will be modified in two steps:\n",
+ "#\n",
+ "# -- first, after implementing the MNISTContinualLearning dataset, you will be\n",
+ "# able to observe catastrophic forgetting: Training on a new task (specified)\n",
+ "# by the \"phase\" and \"config.train_classes\" variables will make the network\n",
+ "# forget the previously learned tasks, and you will observe a performance drop.\n",
+ "#\n",
+ "# -- second, after implementing the catastrophic forgetting network, you will\n",
+ "# fix the catastrophic forgetting my implementing a simple memory buffer. This\n",
+ "# buffer will keep some of the images seen in each individual training phase,\n",
+ "# and add them to the training in each subsequent phase. This task is open-ended\n",
+ "# and different strategies exist. They will differ in terms of memory efficiency\n",
+ "# and performance.\n",
+ "\n",
+ "# You should adapt the train config to each experiment setup and especially test the\n",
+ "# effects of taking different values for the number of epochs (per phase of the training)\n",
+ "# and the learning rate used.\n",
+ "config = argparse.Namespace(\n",
+ " batch_size=64, \n",
+ " test_batch_size=1000, \n",
+ " epochs=3, \n",
+ " lr=0.01,\n",
+ " gamma=0.7,\n",
+ " no_cuda=False,\n",
+ " dry_run=False,\n",
+ " seed=1,\n",
+ " log_interval=10,\n",
+ " save_model=False\n",
+ ")\n",
+ "\n",
+ "def train_regular_mnist():\n",
+ " model = Net()\n",
+ " history = []\n",
+ "\n",
+ " config.train_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n",
+ " train_model(config, model, phase = \"baseline\", history = history)\n",
+ " return history\n",
+ "\n",
+ "# Uncomment to train\n",
+ "#history_regular_mnist = train_regular_mnist()\n",
+ "#history_regular_mnist"
+ ],
+ "metadata": {
+ "id": "SK-GNifLPN2S"
+ },
+ "execution_count": 18,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Exercise 1: Implement an experiment setup that measures catastrophic forgetting while\n",
+ "# training in a task-incremental way on digits (0,1), (2,3), ..., (8, 9). After finishing\n",
+ "# training, measure the performance on the first task again.\n",
+ "#\n",
+ "# Implementation requires to fill in code further up in this notebook. You can visualize\n",
+ "# the result using the code further below.\n",
+ "\n",
+ "### START SOLUTION ###\n",
+ "# In this first part, we implement an experiment setup to measure catastrophic forgetting.\n",
+ "# We will subsequently train on pairs of classes, (0,1), then (2,3), etc., and in the end\n",
+ "# revisit the first task to measure if there was any forward transfer while training on the\n",
+ "# other tasks.\n",
+ "\n",
+ "def train_catastrophic_forgetting():\n",
+ " model = Net()\n",
+ " history = []\n",
+ "\n",
+ " config.train_classes = [0, 1]\n",
+ " train_model(config, model, phase = \"0_1\", history = history)\n",
+ " \n",
+ " config.train_classes = [2, 3]\n",
+ " train_model(config, model, phase = \"2_3\", history = history)\n",
+ "\n",
+ " config.train_classes = [4, 5]\n",
+ " train_model(config, model, phase = \"4_5\", history = history)\n",
+ "\n",
+ " config.train_classes = [6, 7]\n",
+ " train_model(config, model, phase = \"6_7\", history = history)\n",
+ "\n",
+ " config.train_classes = [8, 9]\n",
+ " train_model(config, model, phase = \"8_9\", history = history)\n",
+ " \n",
+ " config.train_classes = [0, 1]\n",
+ " train_model(config, model, phase = \"0_1_again\", history = history)\n",
+ "\n",
+ " return history\n",
+ "\n",
+ "### END SOLUTION ###\n",
+ " \n",
+ "history_catastrophic_forgetting = train_catastrophic_forgetting()\n",
+ "history_catastrophic_forgetting"
+ ],
+ "metadata": {
+ "id": "OnATcaveP3Fi",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "6e9bc36f-8e10-405a-8297-e104141e58cc"
+ },
+ "execution_count": 19,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Train Epoch: 1 [0/12665 (0%)]\tLoss: 2.295613\n",
+ "Train Epoch: 1 [640/12665 (5%)]\tLoss: 1.647854\n",
+ "Train Epoch: 1 [1280/12665 (10%)]\tLoss: 1.039609\n",
+ "Train Epoch: 1 [1920/12665 (15%)]\tLoss: 0.593714\n",
+ "Train Epoch: 1 [2560/12665 (20%)]\tLoss: 0.321345\n",
+ "Train Epoch: 1 [3200/12665 (25%)]\tLoss: 0.211062\n",
+ "Train Epoch: 1 [3840/12665 (30%)]\tLoss: 0.124204\n",
+ "Train Epoch: 1 [4480/12665 (35%)]\tLoss: 0.082856\n",
+ "Train Epoch: 1 [5120/12665 (40%)]\tLoss: 0.075962\n",
+ "Train Epoch: 1 [5760/12665 (45%)]\tLoss: 0.071156\n",
+ "Train Epoch: 1 [6400/12665 (51%)]\tLoss: 0.117059\n",
+ "Train Epoch: 1 [7040/12665 (56%)]\tLoss: 0.059726\n",
+ "Train Epoch: 1 [7680/12665 (61%)]\tLoss: 0.032035\n",
+ "Train Epoch: 1 [8320/12665 (66%)]\tLoss: 0.033078\n",
+ "Train Epoch: 1 [8960/12665 (71%)]\tLoss: 0.030728\n",
+ "Train Epoch: 1 [9600/12665 (76%)]\tLoss: 0.039060\n",
+ "Train Epoch: 1 [10240/12665 (81%)]\tLoss: 0.035301\n",
+ "Train Epoch: 1 [10880/12665 (86%)]\tLoss: 0.041941\n",
+ "Train Epoch: 1 [11520/12665 (91%)]\tLoss: 0.040506\n",
+ "Train Epoch: 1 [12160/12665 (96%)]\tLoss: 0.014235\n",
+ "\n",
+ "Test set: Average loss: 7.1276\n",
+ "Class 0 accuracy: 977.0/980(100%)\n",
+ "Class 1 accuracy: 1134.0/1135(100%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 2 [0/12665 (0%)]\tLoss: 0.007759\n",
+ "Train Epoch: 2 [640/12665 (5%)]\tLoss: 0.029978\n",
+ "Train Epoch: 2 [1280/12665 (10%)]\tLoss: 0.017544\n",
+ "Train Epoch: 2 [1920/12665 (15%)]\tLoss: 0.022188\n",
+ "Train Epoch: 2 [2560/12665 (20%)]\tLoss: 0.009434\n",
+ "Train Epoch: 2 [3200/12665 (25%)]\tLoss: 0.024381\n",
+ "Train Epoch: 2 [3840/12665 (30%)]\tLoss: 0.086738\n",
+ "Train Epoch: 2 [4480/12665 (35%)]\tLoss: 0.020443\n",
+ "Train Epoch: 2 [5120/12665 (40%)]\tLoss: 0.015766\n",
+ "Train Epoch: 2 [5760/12665 (45%)]\tLoss: 0.006633\n",
+ "Train Epoch: 2 [6400/12665 (51%)]\tLoss: 0.007944\n",
+ "Train Epoch: 2 [7040/12665 (56%)]\tLoss: 0.008917\n",
+ "Train Epoch: 2 [7680/12665 (61%)]\tLoss: 0.004392\n",
+ "Train Epoch: 2 [8320/12665 (66%)]\tLoss: 0.008756\n",
+ "Train Epoch: 2 [8960/12665 (71%)]\tLoss: 0.056707\n",
+ "Train Epoch: 2 [9600/12665 (76%)]\tLoss: 0.005973\n",
+ "Train Epoch: 2 [10240/12665 (81%)]\tLoss: 0.008020\n",
+ "Train Epoch: 2 [10880/12665 (86%)]\tLoss: 0.009416\n",
+ "Train Epoch: 2 [11520/12665 (91%)]\tLoss: 0.017660\n",
+ "Train Epoch: 2 [12160/12665 (96%)]\tLoss: 0.010692\n",
+ "\n",
+ "Test set: Average loss: 8.1468\n",
+ "Class 0 accuracy: 977.0/980(100%)\n",
+ "Class 1 accuracy: 1134.0/1135(100%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 3 [0/12665 (0%)]\tLoss: 0.016127\n",
+ "Train Epoch: 3 [640/12665 (5%)]\tLoss: 0.018997\n",
+ "Train Epoch: 3 [1280/12665 (10%)]\tLoss: 0.004281\n",
+ "Train Epoch: 3 [1920/12665 (15%)]\tLoss: 0.004884\n",
+ "Train Epoch: 3 [2560/12665 (20%)]\tLoss: 0.006038\n",
+ "Train Epoch: 3 [3200/12665 (25%)]\tLoss: 0.012006\n",
+ "Train Epoch: 3 [3840/12665 (30%)]\tLoss: 0.009289\n",
+ "Train Epoch: 3 [4480/12665 (35%)]\tLoss: 0.002473\n",
+ "Train Epoch: 3 [5120/12665 (40%)]\tLoss: 0.002702\n",
+ "Train Epoch: 3 [5760/12665 (45%)]\tLoss: 0.012154\n",
+ "Train Epoch: 3 [6400/12665 (51%)]\tLoss: 0.018935\n",
+ "Train Epoch: 3 [7040/12665 (56%)]\tLoss: 0.004237\n",
+ "Train Epoch: 3 [7680/12665 (61%)]\tLoss: 0.003406\n",
+ "Train Epoch: 3 [8320/12665 (66%)]\tLoss: 0.008378\n",
+ "Train Epoch: 3 [8960/12665 (71%)]\tLoss: 0.003461\n",
+ "Train Epoch: 3 [9600/12665 (76%)]\tLoss: 0.018804\n",
+ "Train Epoch: 3 [10240/12665 (81%)]\tLoss: 0.015821\n",
+ "Train Epoch: 3 [10880/12665 (86%)]\tLoss: 0.003116\n",
+ "Train Epoch: 3 [11520/12665 (91%)]\tLoss: 0.002931\n",
+ "Train Epoch: 3 [12160/12665 (96%)]\tLoss: 0.004047\n",
+ "\n",
+ "Test set: Average loss: 8.5887\n",
+ "Class 0 accuracy: 977.0/980(100%)\n",
+ "Class 1 accuracy: 1134.0/1135(100%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 1 [0/12089 (0%)]\tLoss: 10.857133\n",
+ "Train Epoch: 1 [640/12089 (5%)]\tLoss: 6.280286\n",
+ "Train Epoch: 1 [1280/12089 (11%)]\tLoss: 2.301530\n",
+ "Train Epoch: 1 [1920/12089 (16%)]\tLoss: 0.905669\n",
+ "Train Epoch: 1 [2560/12089 (21%)]\tLoss: 0.688727\n",
+ "Train Epoch: 1 [3200/12089 (26%)]\tLoss: 0.598405\n",
+ "Train Epoch: 1 [3840/12089 (32%)]\tLoss: 0.462726\n",
+ "Train Epoch: 1 [4480/12089 (37%)]\tLoss: 0.345897\n",
+ "Train Epoch: 1 [5120/12089 (42%)]\tLoss: 0.380626\n",
+ "Train Epoch: 1 [5760/12089 (48%)]\tLoss: 0.280428\n",
+ "Train Epoch: 1 [6400/12089 (53%)]\tLoss: 0.216660\n",
+ "Train Epoch: 1 [7040/12089 (58%)]\tLoss: 0.167003\n",
+ "Train Epoch: 1 [7680/12089 (63%)]\tLoss: 0.275916\n",
+ "Train Epoch: 1 [8320/12089 (69%)]\tLoss: 0.326823\n",
+ "Train Epoch: 1 [8960/12089 (74%)]\tLoss: 0.258519\n",
+ "Train Epoch: 1 [9600/12089 (79%)]\tLoss: 0.233105\n",
+ "Train Epoch: 1 [10240/12089 (85%)]\tLoss: 0.235713\n",
+ "Train Epoch: 1 [10880/12089 (90%)]\tLoss: 0.265890\n",
+ "Train Epoch: 1 [11520/12089 (95%)]\tLoss: 0.262226\n",
+ "\n",
+ "Test set: Average loss: 6.2382\n",
+ "Class 0 accuracy: 0.0/980(0%)\n",
+ "Class 1 accuracy: 0.0/1135(0%)\n",
+ "Class 2 accuracy: 992.0/1032(96%)\n",
+ "Class 3 accuracy: 987.0/1010(98%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 2 [0/12089 (0%)]\tLoss: 0.099378\n",
+ "Train Epoch: 2 [640/12089 (5%)]\tLoss: 0.116401\n",
+ "Train Epoch: 2 [1280/12089 (11%)]\tLoss: 0.155221\n",
+ "Train Epoch: 2 [1920/12089 (16%)]\tLoss: 0.109446\n",
+ "Train Epoch: 2 [2560/12089 (21%)]\tLoss: 0.104514\n",
+ "Train Epoch: 2 [3200/12089 (26%)]\tLoss: 0.219575\n",
+ "Train Epoch: 2 [3840/12089 (32%)]\tLoss: 0.151700\n",
+ "Train Epoch: 2 [4480/12089 (37%)]\tLoss: 0.162770\n",
+ "Train Epoch: 2 [5120/12089 (42%)]\tLoss: 0.152628\n",
+ "Train Epoch: 2 [5760/12089 (48%)]\tLoss: 0.186006\n",
+ "Train Epoch: 2 [6400/12089 (53%)]\tLoss: 0.095261\n",
+ "Train Epoch: 2 [7040/12089 (58%)]\tLoss: 0.152865\n",
+ "Train Epoch: 2 [7680/12089 (63%)]\tLoss: 0.046197\n",
+ "Train Epoch: 2 [8320/12089 (69%)]\tLoss: 0.068629\n",
+ "Train Epoch: 2 [8960/12089 (74%)]\tLoss: 0.133037\n",
+ "Train Epoch: 2 [9600/12089 (79%)]\tLoss: 0.106119\n",
+ "Train Epoch: 2 [10240/12089 (85%)]\tLoss: 0.073823\n",
+ "Train Epoch: 2 [10880/12089 (90%)]\tLoss: 0.117314\n",
+ "Train Epoch: 2 [11520/12089 (95%)]\tLoss: 0.144778\n",
+ "\n",
+ "Test set: Average loss: 7.2467\n",
+ "Class 0 accuracy: 0.0/980(0%)\n",
+ "Class 1 accuracy: 0.0/1135(0%)\n",
+ "Class 2 accuracy: 998.0/1032(97%)\n",
+ "Class 3 accuracy: 994.0/1010(98%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 3 [0/12089 (0%)]\tLoss: 0.102792\n",
+ "Train Epoch: 3 [640/12089 (5%)]\tLoss: 0.209812\n",
+ "Train Epoch: 3 [1280/12089 (11%)]\tLoss: 0.079453\n",
+ "Train Epoch: 3 [1920/12089 (16%)]\tLoss: 0.146311\n",
+ "Train Epoch: 3 [2560/12089 (21%)]\tLoss: 0.057011\n",
+ "Train Epoch: 3 [3200/12089 (26%)]\tLoss: 0.062497\n",
+ "Train Epoch: 3 [3840/12089 (32%)]\tLoss: 0.114005\n",
+ "Train Epoch: 3 [4480/12089 (37%)]\tLoss: 0.089880\n",
+ "Train Epoch: 3 [5120/12089 (42%)]\tLoss: 0.057159\n",
+ "Train Epoch: 3 [5760/12089 (48%)]\tLoss: 0.117361\n",
+ "Train Epoch: 3 [6400/12089 (53%)]\tLoss: 0.098925\n",
+ "Train Epoch: 3 [7040/12089 (58%)]\tLoss: 0.069072\n",
+ "Train Epoch: 3 [7680/12089 (63%)]\tLoss: 0.111345\n",
+ "Train Epoch: 3 [8320/12089 (69%)]\tLoss: 0.131783\n",
+ "Train Epoch: 3 [8960/12089 (74%)]\tLoss: 0.150885\n",
+ "Train Epoch: 3 [9600/12089 (79%)]\tLoss: 0.079096\n",
+ "Train Epoch: 3 [10240/12089 (85%)]\tLoss: 0.082054\n",
+ "Train Epoch: 3 [10880/12089 (90%)]\tLoss: 0.129982\n",
+ "Train Epoch: 3 [11520/12089 (95%)]\tLoss: 0.066340\n",
+ "\n",
+ "Test set: Average loss: 7.6576\n",
+ "Class 0 accuracy: 0.0/980(0%)\n",
+ "Class 1 accuracy: 0.0/1135(0%)\n",
+ "Class 2 accuracy: 1007.0/1032(98%)\n",
+ "Class 3 accuracy: 985.0/1010(98%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 1 [0/11263 (0%)]\tLoss: 10.835564\n",
+ "Train Epoch: 1 [640/11263 (6%)]\tLoss: 5.729345\n",
+ "Train Epoch: 1 [1280/11263 (11%)]\tLoss: 2.635014\n",
+ "Train Epoch: 1 [1920/11263 (17%)]\tLoss: 1.219099\n",
+ "Train Epoch: 1 [2560/11263 (23%)]\tLoss: 0.913658\n",
+ "Train Epoch: 1 [3200/11263 (28%)]\tLoss: 0.655098\n",
+ "Train Epoch: 1 [3840/11263 (34%)]\tLoss: 0.466011\n",
+ "Train Epoch: 1 [4480/11263 (40%)]\tLoss: 0.539344\n",
+ "Train Epoch: 1 [5120/11263 (45%)]\tLoss: 0.249651\n",
+ "Train Epoch: 1 [5760/11263 (51%)]\tLoss: 0.209131\n",
+ "Train Epoch: 1 [6400/11263 (57%)]\tLoss: 0.223062\n",
+ "Train Epoch: 1 [7040/11263 (62%)]\tLoss: 0.228752\n",
+ "Train Epoch: 1 [7680/11263 (68%)]\tLoss: 0.198413\n",
+ "Train Epoch: 1 [8320/11263 (74%)]\tLoss: 0.215570\n",
+ "Train Epoch: 1 [8960/11263 (80%)]\tLoss: 0.272983\n",
+ "Train Epoch: 1 [9600/11263 (85%)]\tLoss: 0.155609\n",
+ "Train Epoch: 1 [10240/11263 (91%)]\tLoss: 0.087590\n",
+ "Train Epoch: 1 [10880/11263 (97%)]\tLoss: 0.132957\n",
+ "\n",
+ "Test set: Average loss: 7.1306\n",
+ "Class 0 accuracy: 0.0/980(0%)\n",
+ "Class 1 accuracy: 0.0/1135(0%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 977.0/982(99%)\n",
+ "Class 5 accuracy: 866.0/892(97%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 2 [0/11263 (0%)]\tLoss: 0.130704\n",
+ "Train Epoch: 2 [640/11263 (6%)]\tLoss: 0.123305\n",
+ "Train Epoch: 2 [1280/11263 (11%)]\tLoss: 0.092606\n",
+ "Train Epoch: 2 [1920/11263 (17%)]\tLoss: 0.108018\n",
+ "Train Epoch: 2 [2560/11263 (23%)]\tLoss: 0.059542\n",
+ "Train Epoch: 2 [3200/11263 (28%)]\tLoss: 0.051525\n",
+ "Train Epoch: 2 [3840/11263 (34%)]\tLoss: 0.105634\n",
+ "Train Epoch: 2 [4480/11263 (40%)]\tLoss: 0.075492\n",
+ "Train Epoch: 2 [5120/11263 (45%)]\tLoss: 0.056117\n",
+ "Train Epoch: 2 [5760/11263 (51%)]\tLoss: 0.119122\n",
+ "Train Epoch: 2 [6400/11263 (57%)]\tLoss: 0.098337\n",
+ "Train Epoch: 2 [7040/11263 (62%)]\tLoss: 0.049512\n",
+ "Train Epoch: 2 [7680/11263 (68%)]\tLoss: 0.070381\n",
+ "Train Epoch: 2 [8320/11263 (74%)]\tLoss: 0.070122\n",
+ "Train Epoch: 2 [8960/11263 (80%)]\tLoss: 0.074574\n",
+ "Train Epoch: 2 [9600/11263 (85%)]\tLoss: 0.085832\n",
+ "Train Epoch: 2 [10240/11263 (91%)]\tLoss: 0.194243\n",
+ "Train Epoch: 2 [10880/11263 (97%)]\tLoss: 0.052719\n",
+ "\n",
+ "Test set: Average loss: 8.4438\n",
+ "Class 0 accuracy: 0.0/980(0%)\n",
+ "Class 1 accuracy: 0.0/1135(0%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 978.0/982(100%)\n",
+ "Class 5 accuracy: 872.0/892(98%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 3 [0/11263 (0%)]\tLoss: 0.031168\n",
+ "Train Epoch: 3 [640/11263 (6%)]\tLoss: 0.133256\n",
+ "Train Epoch: 3 [1280/11263 (11%)]\tLoss: 0.064342\n",
+ "Train Epoch: 3 [1920/11263 (17%)]\tLoss: 0.081790\n",
+ "Train Epoch: 3 [2560/11263 (23%)]\tLoss: 0.055577\n",
+ "Train Epoch: 3 [3200/11263 (28%)]\tLoss: 0.031944\n",
+ "Train Epoch: 3 [3840/11263 (34%)]\tLoss: 0.079894\n",
+ "Train Epoch: 3 [4480/11263 (40%)]\tLoss: 0.054432\n",
+ "Train Epoch: 3 [5120/11263 (45%)]\tLoss: 0.032932\n",
+ "Train Epoch: 3 [5760/11263 (51%)]\tLoss: 0.054515\n",
+ "Train Epoch: 3 [6400/11263 (57%)]\tLoss: 0.015534\n",
+ "Train Epoch: 3 [7040/11263 (62%)]\tLoss: 0.073943\n",
+ "Train Epoch: 3 [7680/11263 (68%)]\tLoss: 0.051291\n",
+ "Train Epoch: 3 [8320/11263 (74%)]\tLoss: 0.071271\n",
+ "Train Epoch: 3 [8960/11263 (80%)]\tLoss: 0.078468\n",
+ "Train Epoch: 3 [9600/11263 (85%)]\tLoss: 0.074505\n",
+ "Train Epoch: 3 [10240/11263 (91%)]\tLoss: 0.041854\n",
+ "Train Epoch: 3 [10880/11263 (97%)]\tLoss: 0.051243\n",
+ "\n",
+ "Test set: Average loss: 9.0650\n",
+ "Class 0 accuracy: 0.0/980(0%)\n",
+ "Class 1 accuracy: 0.0/1135(0%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 978.0/982(100%)\n",
+ "Class 5 accuracy: 879.0/892(99%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 1 [0/12183 (0%)]\tLoss: 12.640953\n",
+ "Train Epoch: 1 [640/12183 (5%)]\tLoss: 6.948784\n",
+ "Train Epoch: 1 [1280/12183 (10%)]\tLoss: 2.239110\n",
+ "Train Epoch: 1 [1920/12183 (16%)]\tLoss: 0.789898\n",
+ "Train Epoch: 1 [2560/12183 (21%)]\tLoss: 0.433876\n",
+ "Train Epoch: 1 [3200/12183 (26%)]\tLoss: 0.327010\n",
+ "Train Epoch: 1 [3840/12183 (31%)]\tLoss: 0.240610\n",
+ "Train Epoch: 1 [4480/12183 (37%)]\tLoss: 0.189328\n",
+ "Train Epoch: 1 [5120/12183 (42%)]\tLoss: 0.189419\n",
+ "Train Epoch: 1 [5760/12183 (47%)]\tLoss: 0.063147\n",
+ "Train Epoch: 1 [6400/12183 (52%)]\tLoss: 0.061185\n",
+ "Train Epoch: 1 [7040/12183 (58%)]\tLoss: 0.115183\n",
+ "Train Epoch: 1 [7680/12183 (63%)]\tLoss: 0.030512\n",
+ "Train Epoch: 1 [8320/12183 (68%)]\tLoss: 0.037145\n",
+ "Train Epoch: 1 [8960/12183 (73%)]\tLoss: 0.089052\n",
+ "Train Epoch: 1 [9600/12183 (79%)]\tLoss: 0.044766\n",
+ "Train Epoch: 1 [10240/12183 (84%)]\tLoss: 0.027388\n",
+ "Train Epoch: 1 [10880/12183 (89%)]\tLoss: 0.041345\n",
+ "Train Epoch: 1 [11520/12183 (94%)]\tLoss: 0.022225\n",
+ "Train Epoch: 1 [4370/12183 (99%)]\tLoss: 0.004926\n",
+ "\n",
+ "Test set: Average loss: 7.6796\n",
+ "Class 0 accuracy: 0.0/980(0%)\n",
+ "Class 1 accuracy: 0.0/1135(0%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 955.0/958(100%)\n",
+ "Class 7 accuracy: 1017.0/1028(99%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 2 [0/12183 (0%)]\tLoss: 0.056296\n",
+ "Train Epoch: 2 [640/12183 (5%)]\tLoss: 0.026770\n",
+ "Train Epoch: 2 [1280/12183 (10%)]\tLoss: 0.014056\n",
+ "Train Epoch: 2 [1920/12183 (16%)]\tLoss: 0.026122\n",
+ "Train Epoch: 2 [2560/12183 (21%)]\tLoss: 0.025161\n",
+ "Train Epoch: 2 [3200/12183 (26%)]\tLoss: 0.033366\n",
+ "Train Epoch: 2 [3840/12183 (31%)]\tLoss: 0.011129\n",
+ "Train Epoch: 2 [4480/12183 (37%)]\tLoss: 0.034492\n",
+ "Train Epoch: 2 [5120/12183 (42%)]\tLoss: 0.009984\n",
+ "Train Epoch: 2 [5760/12183 (47%)]\tLoss: 0.035312\n",
+ "Train Epoch: 2 [6400/12183 (52%)]\tLoss: 0.028434\n",
+ "Train Epoch: 2 [7040/12183 (58%)]\tLoss: 0.012921\n",
+ "Train Epoch: 2 [7680/12183 (63%)]\tLoss: 0.009203\n",
+ "Train Epoch: 2 [8320/12183 (68%)]\tLoss: 0.034064\n",
+ "Train Epoch: 2 [8960/12183 (73%)]\tLoss: 0.018005\n",
+ "Train Epoch: 2 [9600/12183 (79%)]\tLoss: 0.007107\n",
+ "Train Epoch: 2 [10240/12183 (84%)]\tLoss: 0.003789\n",
+ "Train Epoch: 2 [10880/12183 (89%)]\tLoss: 0.056796\n",
+ "Train Epoch: 2 [11520/12183 (94%)]\tLoss: 0.008388\n",
+ "Train Epoch: 2 [4370/12183 (99%)]\tLoss: 0.008049\n",
+ "\n",
+ "Test set: Average loss: 8.7098\n",
+ "Class 0 accuracy: 0.0/980(0%)\n",
+ "Class 1 accuracy: 0.0/1135(0%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 955.0/958(100%)\n",
+ "Class 7 accuracy: 1021.0/1028(99%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 3 [0/12183 (0%)]\tLoss: 0.031829\n",
+ "Train Epoch: 3 [640/12183 (5%)]\tLoss: 0.026151\n",
+ "Train Epoch: 3 [1280/12183 (10%)]\tLoss: 0.013025\n",
+ "Train Epoch: 3 [1920/12183 (16%)]\tLoss: 0.015742\n",
+ "Train Epoch: 3 [2560/12183 (21%)]\tLoss: 0.036039\n",
+ "Train Epoch: 3 [3200/12183 (26%)]\tLoss: 0.019850\n",
+ "Train Epoch: 3 [3840/12183 (31%)]\tLoss: 0.021591\n",
+ "Train Epoch: 3 [4480/12183 (37%)]\tLoss: 0.011566\n",
+ "Train Epoch: 3 [5120/12183 (42%)]\tLoss: 0.012161\n",
+ "Train Epoch: 3 [5760/12183 (47%)]\tLoss: 0.005265\n",
+ "Train Epoch: 3 [6400/12183 (52%)]\tLoss: 0.005189\n",
+ "Train Epoch: 3 [7040/12183 (58%)]\tLoss: 0.023125\n",
+ "Train Epoch: 3 [7680/12183 (63%)]\tLoss: 0.058819\n",
+ "Train Epoch: 3 [8320/12183 (68%)]\tLoss: 0.027194\n",
+ "Train Epoch: 3 [8960/12183 (73%)]\tLoss: 0.010728\n",
+ "Train Epoch: 3 [9600/12183 (79%)]\tLoss: 0.002767\n",
+ "Train Epoch: 3 [10240/12183 (84%)]\tLoss: 0.006029\n",
+ "Train Epoch: 3 [10880/12183 (89%)]\tLoss: 0.090709\n",
+ "Train Epoch: 3 [11520/12183 (94%)]\tLoss: 0.003533\n",
+ "Train Epoch: 3 [4370/12183 (99%)]\tLoss: 0.005167\n",
+ "\n",
+ "Test set: Average loss: 9.1796\n",
+ "Class 0 accuracy: 0.0/980(0%)\n",
+ "Class 1 accuracy: 0.0/1135(0%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 955.0/958(100%)\n",
+ "Class 7 accuracy: 1021.0/1028(99%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 1 [0/11800 (0%)]\tLoss: 14.120369\n",
+ "Train Epoch: 1 [640/11800 (5%)]\tLoss: 7.705572\n",
+ "Train Epoch: 1 [1280/11800 (11%)]\tLoss: 2.772657\n",
+ "Train Epoch: 1 [1920/11800 (16%)]\tLoss: 1.523421\n",
+ "Train Epoch: 1 [2560/11800 (22%)]\tLoss: 0.632249\n",
+ "Train Epoch: 1 [3200/11800 (27%)]\tLoss: 0.823988\n",
+ "Train Epoch: 1 [3840/11800 (32%)]\tLoss: 0.649929\n",
+ "Train Epoch: 1 [4480/11800 (38%)]\tLoss: 0.382820\n",
+ "Train Epoch: 1 [5120/11800 (43%)]\tLoss: 0.287871\n",
+ "Train Epoch: 1 [5760/11800 (49%)]\tLoss: 0.270225\n",
+ "Train Epoch: 1 [6400/11800 (54%)]\tLoss: 0.250052\n",
+ "Train Epoch: 1 [7040/11800 (59%)]\tLoss: 0.436681\n",
+ "Train Epoch: 1 [7680/11800 (65%)]\tLoss: 0.205961\n",
+ "Train Epoch: 1 [8320/11800 (70%)]\tLoss: 0.226867\n",
+ "Train Epoch: 1 [8960/11800 (76%)]\tLoss: 0.178260\n",
+ "Train Epoch: 1 [9600/11800 (81%)]\tLoss: 0.159317\n",
+ "Train Epoch: 1 [10240/11800 (86%)]\tLoss: 0.402756\n",
+ "Train Epoch: 1 [10880/11800 (92%)]\tLoss: 0.158059\n",
+ "Train Epoch: 1 [11520/11800 (97%)]\tLoss: 0.195788\n",
+ "\n",
+ "Test set: Average loss: 5.9943\n",
+ "Class 0 accuracy: 0.0/980(0%)\n",
+ "Class 1 accuracy: 0.0/1135(0%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 943.0/974(97%)\n",
+ "Class 9 accuracy: 974.0/1009(97%)\n",
+ "Train Epoch: 2 [0/11800 (0%)]\tLoss: 0.157916\n",
+ "Train Epoch: 2 [640/11800 (5%)]\tLoss: 0.187881\n",
+ "Train Epoch: 2 [1280/11800 (11%)]\tLoss: 0.283773\n",
+ "Train Epoch: 2 [1920/11800 (16%)]\tLoss: 0.135888\n",
+ "Train Epoch: 2 [2560/11800 (22%)]\tLoss: 0.211785\n",
+ "Train Epoch: 2 [3200/11800 (27%)]\tLoss: 0.238753\n",
+ "Train Epoch: 2 [3840/11800 (32%)]\tLoss: 0.121974\n",
+ "Train Epoch: 2 [4480/11800 (38%)]\tLoss: 0.074522\n",
+ "Train Epoch: 2 [5120/11800 (43%)]\tLoss: 0.239114\n",
+ "Train Epoch: 2 [5760/11800 (49%)]\tLoss: 0.129581\n",
+ "Train Epoch: 2 [6400/11800 (54%)]\tLoss: 0.055795\n",
+ "Train Epoch: 2 [7040/11800 (59%)]\tLoss: 0.349395\n",
+ "Train Epoch: 2 [7680/11800 (65%)]\tLoss: 0.076350\n",
+ "Train Epoch: 2 [8320/11800 (70%)]\tLoss: 0.072271\n",
+ "Train Epoch: 2 [8960/11800 (76%)]\tLoss: 0.103121\n",
+ "Train Epoch: 2 [9600/11800 (81%)]\tLoss: 0.343085\n",
+ "Train Epoch: 2 [10240/11800 (86%)]\tLoss: 0.095953\n",
+ "Train Epoch: 2 [10880/11800 (92%)]\tLoss: 0.095618\n",
+ "Train Epoch: 2 [11520/11800 (97%)]\tLoss: 0.114838\n",
+ "\n",
+ "Test set: Average loss: 6.6632\n",
+ "Class 0 accuracy: 0.0/980(0%)\n",
+ "Class 1 accuracy: 0.0/1135(0%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 967.0/974(99%)\n",
+ "Class 9 accuracy: 965.0/1009(96%)\n",
+ "Train Epoch: 3 [0/11800 (0%)]\tLoss: 0.126834\n",
+ "Train Epoch: 3 [640/11800 (5%)]\tLoss: 0.187850\n",
+ "Train Epoch: 3 [1280/11800 (11%)]\tLoss: 0.102141\n",
+ "Train Epoch: 3 [1920/11800 (16%)]\tLoss: 0.036415\n",
+ "Train Epoch: 3 [2560/11800 (22%)]\tLoss: 0.165435\n",
+ "Train Epoch: 3 [3200/11800 (27%)]\tLoss: 0.109672\n",
+ "Train Epoch: 3 [3840/11800 (32%)]\tLoss: 0.070627\n",
+ "Train Epoch: 3 [4480/11800 (38%)]\tLoss: 0.047646\n",
+ "Train Epoch: 3 [5120/11800 (43%)]\tLoss: 0.051905\n",
+ "Train Epoch: 3 [5760/11800 (49%)]\tLoss: 0.264070\n",
+ "Train Epoch: 3 [6400/11800 (54%)]\tLoss: 0.211393\n",
+ "Train Epoch: 3 [7040/11800 (59%)]\tLoss: 0.062177\n",
+ "Train Epoch: 3 [7680/11800 (65%)]\tLoss: 0.033032\n",
+ "Train Epoch: 3 [8320/11800 (70%)]\tLoss: 0.068230\n",
+ "Train Epoch: 3 [8960/11800 (76%)]\tLoss: 0.071904\n",
+ "Train Epoch: 3 [9600/11800 (81%)]\tLoss: 0.102863\n",
+ "Train Epoch: 3 [10240/11800 (86%)]\tLoss: 0.046137\n",
+ "Train Epoch: 3 [10880/11800 (92%)]\tLoss: 0.120950\n",
+ "Train Epoch: 3 [11520/11800 (97%)]\tLoss: 0.101907\n",
+ "\n",
+ "Test set: Average loss: 6.9242\n",
+ "Class 0 accuracy: 0.0/980(0%)\n",
+ "Class 1 accuracy: 0.0/1135(0%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 952.0/974(98%)\n",
+ "Class 9 accuracy: 978.0/1009(97%)\n",
+ "Train Epoch: 1 [0/12665 (0%)]\tLoss: 8.722752\n",
+ "Train Epoch: 1 [640/12665 (5%)]\tLoss: 3.160372\n",
+ "Train Epoch: 1 [1280/12665 (10%)]\tLoss: 0.653003\n",
+ "Train Epoch: 1 [1920/12665 (15%)]\tLoss: 0.256268\n",
+ "Train Epoch: 1 [2560/12665 (20%)]\tLoss: 0.064630\n",
+ "Train Epoch: 1 [3200/12665 (25%)]\tLoss: 0.089074\n",
+ "Train Epoch: 1 [3840/12665 (30%)]\tLoss: 0.022682\n",
+ "Train Epoch: 1 [4480/12665 (35%)]\tLoss: 0.014132\n",
+ "Train Epoch: 1 [5120/12665 (40%)]\tLoss: 0.011369\n",
+ "Train Epoch: 1 [5760/12665 (45%)]\tLoss: 0.032196\n",
+ "Train Epoch: 1 [6400/12665 (51%)]\tLoss: 0.087249\n",
+ "Train Epoch: 1 [7040/12665 (56%)]\tLoss: 0.010759\n",
+ "Train Epoch: 1 [7680/12665 (61%)]\tLoss: 0.009488\n",
+ "Train Epoch: 1 [8320/12665 (66%)]\tLoss: 0.009431\n",
+ "Train Epoch: 1 [8960/12665 (71%)]\tLoss: 0.017381\n",
+ "Train Epoch: 1 [9600/12665 (76%)]\tLoss: 0.014684\n",
+ "Train Epoch: 1 [10240/12665 (81%)]\tLoss: 0.007916\n",
+ "Train Epoch: 1 [10880/12665 (86%)]\tLoss: 0.008834\n",
+ "Train Epoch: 1 [11520/12665 (91%)]\tLoss: 0.082196\n",
+ "Train Epoch: 1 [12160/12665 (96%)]\tLoss: 0.002452\n",
+ "\n",
+ "Test set: Average loss: 7.1540\n",
+ "Class 0 accuracy: 977.0/980(100%)\n",
+ "Class 1 accuracy: 1134.0/1135(100%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 2 [0/12665 (0%)]\tLoss: 0.000673\n",
+ "Train Epoch: 2 [640/12665 (5%)]\tLoss: 0.029478\n",
+ "Train Epoch: 2 [1280/12665 (10%)]\tLoss: 0.002406\n",
+ "Train Epoch: 2 [1920/12665 (15%)]\tLoss: 0.008619\n",
+ "Train Epoch: 2 [2560/12665 (20%)]\tLoss: 0.001777\n",
+ "Train Epoch: 2 [3200/12665 (25%)]\tLoss: 0.031666\n",
+ "Train Epoch: 2 [3840/12665 (30%)]\tLoss: 0.133443\n",
+ "Train Epoch: 2 [4480/12665 (35%)]\tLoss: 0.006141\n",
+ "Train Epoch: 2 [5120/12665 (40%)]\tLoss: 0.006000\n",
+ "Train Epoch: 2 [5760/12665 (45%)]\tLoss: 0.003495\n",
+ "Train Epoch: 2 [6400/12665 (51%)]\tLoss: 0.000851\n",
+ "Train Epoch: 2 [7040/12665 (56%)]\tLoss: 0.001280\n",
+ "Train Epoch: 2 [7680/12665 (61%)]\tLoss: 0.000972\n",
+ "Train Epoch: 2 [8320/12665 (66%)]\tLoss: 0.001776\n",
+ "Train Epoch: 2 [8960/12665 (71%)]\tLoss: 0.067486\n",
+ "Train Epoch: 2 [9600/12665 (76%)]\tLoss: 0.001030\n",
+ "Train Epoch: 2 [10240/12665 (81%)]\tLoss: 0.001892\n",
+ "Train Epoch: 2 [10880/12665 (86%)]\tLoss: 0.001493\n",
+ "Train Epoch: 2 [11520/12665 (91%)]\tLoss: 0.047196\n",
+ "Train Epoch: 2 [12160/12665 (96%)]\tLoss: 0.004952\n",
+ "\n",
+ "Test set: Average loss: 7.8296\n",
+ "Class 0 accuracy: 977.0/980(100%)\n",
+ "Class 1 accuracy: 1134.0/1135(100%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 3 [0/12665 (0%)]\tLoss: 0.007902\n",
+ "Train Epoch: 3 [640/12665 (5%)]\tLoss: 0.003385\n",
+ "Train Epoch: 3 [1280/12665 (10%)]\tLoss: 0.001230\n",
+ "Train Epoch: 3 [1920/12665 (15%)]\tLoss: 0.001810\n",
+ "Train Epoch: 3 [2560/12665 (20%)]\tLoss: 0.001579\n",
+ "Train Epoch: 3 [3200/12665 (25%)]\tLoss: 0.028920\n",
+ "Train Epoch: 3 [3840/12665 (30%)]\tLoss: 0.004439\n",
+ "Train Epoch: 3 [4480/12665 (35%)]\tLoss: 0.001913\n",
+ "Train Epoch: 3 [5120/12665 (40%)]\tLoss: 0.000407\n",
+ "Train Epoch: 3 [5760/12665 (45%)]\tLoss: 0.004402\n",
+ "Train Epoch: 3 [6400/12665 (51%)]\tLoss: 0.009100\n",
+ "Train Epoch: 3 [7040/12665 (56%)]\tLoss: 0.000980\n",
+ "Train Epoch: 3 [7680/12665 (61%)]\tLoss: 0.000387\n",
+ "Train Epoch: 3 [8320/12665 (66%)]\tLoss: 0.004606\n",
+ "Train Epoch: 3 [8960/12665 (71%)]\tLoss: 0.001144\n",
+ "Train Epoch: 3 [9600/12665 (76%)]\tLoss: 0.007978\n",
+ "Train Epoch: 3 [10240/12665 (81%)]\tLoss: 0.011006\n",
+ "Train Epoch: 3 [10880/12665 (86%)]\tLoss: 0.001300\n",
+ "Train Epoch: 3 [11520/12665 (91%)]\tLoss: 0.000758\n",
+ "Train Epoch: 3 [12160/12665 (96%)]\tLoss: 0.001379\n",
+ "\n",
+ "Test set: Average loss: 8.1397\n",
+ "Class 0 accuracy: 978.0/980(100%)\n",
+ "Class 1 accuracy: 1134.0/1135(100%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[{0: 0.996938775510204,\n",
+ " 1: 0.9991189427312775,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '0_1'},\n",
+ " {0: 0.996938775510204,\n",
+ " 1: 0.9991189427312775,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '0_1'},\n",
+ " {0: 0.996938775510204,\n",
+ " 1: 0.9991189427312775,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '0_1'},\n",
+ " {0: 0.0,\n",
+ " 1: 0.0,\n",
+ " 2: 0.9612403100775194,\n",
+ " 3: 0.9772277227722772,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '2_3'},\n",
+ " {0: 0.0,\n",
+ " 1: 0.0,\n",
+ " 2: 0.9670542635658915,\n",
+ " 3: 0.9841584158415841,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '2_3'},\n",
+ " {0: 0.0,\n",
+ " 1: 0.0,\n",
+ " 2: 0.9757751937984496,\n",
+ " 3: 0.9752475247524752,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '2_3'},\n",
+ " {0: 0.0,\n",
+ " 1: 0.0,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.994908350305499,\n",
+ " 5: 0.9708520179372198,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '4_5'},\n",
+ " {0: 0.0,\n",
+ " 1: 0.0,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.9959266802443992,\n",
+ " 5: 0.9775784753363229,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '4_5'},\n",
+ " {0: 0.0,\n",
+ " 1: 0.0,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.9959266802443992,\n",
+ " 5: 0.9854260089686099,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '4_5'},\n",
+ " {0: 0.0,\n",
+ " 1: 0.0,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.9968684759916493,\n",
+ " 7: 0.9892996108949417,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '6_7'},\n",
+ " {0: 0.0,\n",
+ " 1: 0.0,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.9968684759916493,\n",
+ " 7: 0.9931906614785992,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '6_7'},\n",
+ " {0: 0.0,\n",
+ " 1: 0.0,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.9968684759916493,\n",
+ " 7: 0.9931906614785992,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '6_7'},\n",
+ " {0: 0.0,\n",
+ " 1: 0.0,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.9681724845995893,\n",
+ " 9: 0.9653121902874133,\n",
+ " 'phase': '8_9'},\n",
+ " {0: 0.0,\n",
+ " 1: 0.0,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.9928131416837782,\n",
+ " 9: 0.956392467789891,\n",
+ " 'phase': '8_9'},\n",
+ " {0: 0.0,\n",
+ " 1: 0.0,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.9774127310061602,\n",
+ " 9: 0.9692765113974232,\n",
+ " 'phase': '8_9'},\n",
+ " {0: 0.996938775510204,\n",
+ " 1: 0.9991189427312775,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '0_1_again'},\n",
+ " {0: 0.996938775510204,\n",
+ " 1: 0.9991189427312775,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '0_1_again'},\n",
+ " {0: 0.9979591836734694,\n",
+ " 1: 0.9991189427312775,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '0_1_again'}]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 19
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Exercise 2: Implementing the previous exercise should have shown that the network\n",
+ "# suffers from catastrophic forgetting. To fix this, implement a memory replay buffer\n",
+ "# and use this buffer to counteract the catastrophic forgetting.\n",
+ "#\n",
+ "# Implementation requires adaptation of the cells above, plus the addition of a new\n",
+ "# experiment setup below.\n",
+ "\n",
+ "\n",
+ "### START SOLUTION ###\n",
+ "# The experiment setup is identical to the one above, but now we initialize and use\n",
+ "# a replay buffer. Right now, the buffer only stores a single image for each class.\n",
+ "# As a bonus exercise, explore how different variants of the replay buffer affect the\n",
+ "# resulting model performance.\n",
+ "\n",
+ "def train_memory_replay():\n",
+ " model = Net()\n",
+ " history = []\n",
+ " buffer = Buffer()\n",
+ "\n",
+ " config.train_classes = [0, 1]\n",
+ " train_model(config, model, phase = \"0_1\", history = history, buffer = buffer)\n",
+ " \n",
+ " config.train_classes = [2, 3]\n",
+ " train_model(config, model, phase = \"2_3\", history = history, buffer = buffer)\n",
+ "\n",
+ " config.train_classes = [4, 5]\n",
+ " train_model(config, model, phase = \"4_5\", history = history, buffer = buffer)\n",
+ "\n",
+ " config.train_classes = [6, 7]\n",
+ " train_model(config, model, phase = \"6_7\", history = history, buffer = buffer)\n",
+ "\n",
+ " config.train_classes = [8, 9]\n",
+ " train_model(config, model, phase = \"8_9\", history = history, buffer = buffer)\n",
+ " \n",
+ " config.train_classes = [0, 1]\n",
+ " train_model(config, model, phase = \"0_1_again\", history = history, buffer = buffer)\n",
+ "\n",
+ " return history\n",
+ "\n",
+ "### END SOLUTION ###\n",
+ "\n",
+ "history_memory_replay = train_memory_replay()\n",
+ "history_memory_replay"
+ ],
+ "metadata": {
+ "id": "LtWvHl_tQCpg",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "6705c5fe-9838-4da5-a562-77d7c9a8c6d8"
+ },
+ "execution_count": 20,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Train Epoch: 1 [0/12665 (0%)]\tLoss: 2.196357\n",
+ "Train Epoch: 1 [640/12665 (5%)]\tLoss: 1.335955\n",
+ "Train Epoch: 1 [1280/12665 (10%)]\tLoss: 0.555878\n",
+ "Train Epoch: 1 [1920/12665 (15%)]\tLoss: 0.146986\n",
+ "Train Epoch: 1 [2560/12665 (20%)]\tLoss: 0.082966\n",
+ "Train Epoch: 1 [3200/12665 (25%)]\tLoss: 0.054255\n",
+ "Train Epoch: 1 [3840/12665 (30%)]\tLoss: 0.011100\n",
+ "Train Epoch: 1 [4480/12665 (35%)]\tLoss: 0.078007\n",
+ "Train Epoch: 1 [5120/12665 (40%)]\tLoss: 0.005687\n",
+ "Train Epoch: 1 [5760/12665 (45%)]\tLoss: 0.004260\n",
+ "Train Epoch: 1 [6400/12665 (51%)]\tLoss: 0.002415\n",
+ "Train Epoch: 1 [7040/12665 (56%)]\tLoss: 0.005845\n",
+ "Train Epoch: 1 [7680/12665 (61%)]\tLoss: 0.005518\n",
+ "Train Epoch: 1 [8320/12665 (66%)]\tLoss: 0.001475\n",
+ "Train Epoch: 1 [8960/12665 (71%)]\tLoss: 0.002208\n",
+ "Train Epoch: 1 [9600/12665 (76%)]\tLoss: 0.000701\n",
+ "Train Epoch: 1 [10240/12665 (81%)]\tLoss: 0.000352\n",
+ "Train Epoch: 1 [10880/12665 (86%)]\tLoss: 0.000833\n",
+ "Train Epoch: 1 [11520/12665 (91%)]\tLoss: 0.002580\n",
+ "Train Epoch: 1 [12160/12665 (96%)]\tLoss: 0.002562\n",
+ "\n",
+ "Test set: Average loss: 7.8370\n",
+ "Class 0 accuracy: 976.0/980(100%)\n",
+ "Class 1 accuracy: 1134.0/1135(100%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 2 [0/12665 (0%)]\tLoss: 0.221930\n",
+ "Train Epoch: 2 [640/12665 (5%)]\tLoss: 0.027592\n",
+ "Train Epoch: 2 [1280/12665 (10%)]\tLoss: 0.065451\n",
+ "Train Epoch: 2 [1920/12665 (15%)]\tLoss: 0.015034\n",
+ "Train Epoch: 2 [2560/12665 (20%)]\tLoss: 0.046994\n",
+ "Train Epoch: 2 [3200/12665 (25%)]\tLoss: 0.008316\n",
+ "Train Epoch: 2 [3840/12665 (30%)]\tLoss: 0.021686\n",
+ "Train Epoch: 2 [4480/12665 (35%)]\tLoss: 0.024137\n",
+ "Train Epoch: 2 [5120/12665 (40%)]\tLoss: 0.002627\n",
+ "Train Epoch: 2 [5760/12665 (45%)]\tLoss: 0.042552\n",
+ "Train Epoch: 2 [6400/12665 (51%)]\tLoss: 0.026033\n",
+ "Train Epoch: 2 [7040/12665 (56%)]\tLoss: 0.002244\n",
+ "Train Epoch: 2 [7680/12665 (61%)]\tLoss: 0.003471\n",
+ "Train Epoch: 2 [8320/12665 (66%)]\tLoss: 0.041315\n",
+ "Train Epoch: 2 [8960/12665 (71%)]\tLoss: 0.008438\n",
+ "Train Epoch: 2 [9600/12665 (76%)]\tLoss: 0.000700\n",
+ "Train Epoch: 2 [10240/12665 (81%)]\tLoss: 0.000417\n",
+ "Train Epoch: 2 [10880/12665 (86%)]\tLoss: 0.002115\n",
+ "Train Epoch: 2 [11520/12665 (91%)]\tLoss: 0.008961\n",
+ "Train Epoch: 2 [12160/12665 (96%)]\tLoss: 0.002133\n",
+ "\n",
+ "Test set: Average loss: 8.7829\n",
+ "Class 0 accuracy: 976.0/980(100%)\n",
+ "Class 1 accuracy: 1135.0/1135(100%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 3 [0/12665 (0%)]\tLoss: 0.000570\n",
+ "Train Epoch: 3 [640/12665 (5%)]\tLoss: 0.011992\n",
+ "Train Epoch: 3 [1280/12665 (10%)]\tLoss: 0.000725\n",
+ "Train Epoch: 3 [1920/12665 (15%)]\tLoss: 0.000835\n",
+ "Train Epoch: 3 [2560/12665 (20%)]\tLoss: 0.000185\n",
+ "Train Epoch: 3 [3200/12665 (25%)]\tLoss: 0.001194\n",
+ "Train Epoch: 3 [3840/12665 (30%)]\tLoss: 0.000802\n",
+ "Train Epoch: 3 [4480/12665 (35%)]\tLoss: 0.001044\n",
+ "Train Epoch: 3 [5120/12665 (40%)]\tLoss: 0.001886\n",
+ "Train Epoch: 3 [5760/12665 (45%)]\tLoss: 0.001518\n",
+ "Train Epoch: 3 [6400/12665 (51%)]\tLoss: 0.001240\n",
+ "Train Epoch: 3 [7040/12665 (56%)]\tLoss: 0.005406\n",
+ "Train Epoch: 3 [7680/12665 (61%)]\tLoss: 0.004714\n",
+ "Train Epoch: 3 [8320/12665 (66%)]\tLoss: 0.000870\n",
+ "Train Epoch: 3 [8960/12665 (71%)]\tLoss: 0.001174\n",
+ "Train Epoch: 3 [9600/12665 (76%)]\tLoss: 0.000833\n",
+ "Train Epoch: 3 [10240/12665 (81%)]\tLoss: 0.006711\n",
+ "Train Epoch: 3 [10880/12665 (86%)]\tLoss: 0.006384\n",
+ "Train Epoch: 3 [11520/12665 (91%)]\tLoss: 0.006488\n",
+ "Train Epoch: 3 [12160/12665 (96%)]\tLoss: 0.000901\n",
+ "\n",
+ "Test set: Average loss: 9.2011\n",
+ "Class 0 accuracy: 977.0/980(100%)\n",
+ "Class 1 accuracy: 1135.0/1135(100%)\n",
+ "Class 2 accuracy: 0.0/1032(0%)\n",
+ "Class 3 accuracy: 0.0/1010(0%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 1 [0/12089 (0%)]\tLoss: 2.805099\n",
+ "Train Epoch: 1 [640/12089 (5%)]\tLoss: 1.660801\n",
+ "Train Epoch: 1 [1280/12089 (11%)]\tLoss: 0.437137\n",
+ "Train Epoch: 1 [1920/12089 (16%)]\tLoss: 0.808282\n",
+ "Train Epoch: 1 [2560/12089 (21%)]\tLoss: 0.285198\n",
+ "Train Epoch: 1 [3200/12089 (26%)]\tLoss: 0.129722\n",
+ "Train Epoch: 1 [3840/12089 (32%)]\tLoss: 0.145888\n",
+ "Train Epoch: 1 [4480/12089 (37%)]\tLoss: 0.343596\n",
+ "Train Epoch: 1 [5120/12089 (42%)]\tLoss: 0.078071\n",
+ "Train Epoch: 1 [5760/12089 (48%)]\tLoss: 0.098522\n",
+ "Train Epoch: 1 [6400/12089 (53%)]\tLoss: 0.089895\n",
+ "Train Epoch: 1 [7040/12089 (58%)]\tLoss: 0.067733\n",
+ "Train Epoch: 1 [7680/12089 (63%)]\tLoss: 0.046314\n",
+ "Train Epoch: 1 [8320/12089 (69%)]\tLoss: 0.028943\n",
+ "Train Epoch: 1 [8960/12089 (74%)]\tLoss: 0.082591\n",
+ "Train Epoch: 1 [9600/12089 (79%)]\tLoss: 0.024760\n",
+ "Train Epoch: 1 [10240/12089 (85%)]\tLoss: 0.050374\n",
+ "Train Epoch: 1 [10880/12089 (90%)]\tLoss: 0.020318\n",
+ "Train Epoch: 1 [11520/12089 (95%)]\tLoss: 0.025210\n",
+ "\n",
+ "Test set: Average loss: 5.2845\n",
+ "Class 0 accuracy: 829.0/980(85%)\n",
+ "Class 1 accuracy: 1046.0/1135(92%)\n",
+ "Class 2 accuracy: 991.0/1032(96%)\n",
+ "Class 3 accuracy: 983.0/1010(97%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 2 [0/12089 (0%)]\tLoss: 0.091748\n",
+ "Train Epoch: 2 [640/12089 (5%)]\tLoss: 0.038480\n",
+ "Train Epoch: 2 [1280/12089 (11%)]\tLoss: 0.027258\n",
+ "Train Epoch: 2 [1920/12089 (16%)]\tLoss: 0.040072\n",
+ "Train Epoch: 2 [2560/12089 (21%)]\tLoss: 0.024616\n",
+ "Train Epoch: 2 [3200/12089 (26%)]\tLoss: 0.021542\n",
+ "Train Epoch: 2 [3840/12089 (32%)]\tLoss: 0.016757\n",
+ "Train Epoch: 2 [4480/12089 (37%)]\tLoss: 0.029563\n",
+ "Train Epoch: 2 [5120/12089 (42%)]\tLoss: 0.010378\n",
+ "Train Epoch: 2 [5760/12089 (48%)]\tLoss: 0.020319\n",
+ "Train Epoch: 2 [6400/12089 (53%)]\tLoss: 0.019660\n",
+ "Train Epoch: 2 [7040/12089 (58%)]\tLoss: 0.006761\n",
+ "Train Epoch: 2 [7680/12089 (63%)]\tLoss: 0.026263\n",
+ "Train Epoch: 2 [8320/12089 (69%)]\tLoss: 0.008666\n",
+ "Train Epoch: 2 [8960/12089 (74%)]\tLoss: 0.013218\n",
+ "Train Epoch: 2 [9600/12089 (79%)]\tLoss: 0.002491\n",
+ "Train Epoch: 2 [10240/12089 (85%)]\tLoss: 0.010161\n",
+ "Train Epoch: 2 [10880/12089 (90%)]\tLoss: 0.019670\n",
+ "Train Epoch: 2 [11520/12089 (95%)]\tLoss: 0.012468\n",
+ "\n",
+ "Test set: Average loss: 6.1435\n",
+ "Class 0 accuracy: 764.0/980(78%)\n",
+ "Class 1 accuracy: 999.0/1135(88%)\n",
+ "Class 2 accuracy: 1000.0/1032(97%)\n",
+ "Class 3 accuracy: 988.0/1010(98%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 3 [0/12089 (0%)]\tLoss: 0.028908\n",
+ "Train Epoch: 3 [640/12089 (5%)]\tLoss: 0.017523\n",
+ "Train Epoch: 3 [1280/12089 (11%)]\tLoss: 0.012286\n",
+ "Train Epoch: 3 [1920/12089 (16%)]\tLoss: 0.029993\n",
+ "Train Epoch: 3 [2560/12089 (21%)]\tLoss: 0.014532\n",
+ "Train Epoch: 3 [3200/12089 (26%)]\tLoss: 0.023068\n",
+ "Train Epoch: 3 [3840/12089 (32%)]\tLoss: 0.007645\n",
+ "Train Epoch: 3 [4480/12089 (37%)]\tLoss: 0.016837\n",
+ "Train Epoch: 3 [5120/12089 (42%)]\tLoss: 0.004771\n",
+ "Train Epoch: 3 [5760/12089 (48%)]\tLoss: 0.004505\n",
+ "Train Epoch: 3 [6400/12089 (53%)]\tLoss: 0.085267\n",
+ "Train Epoch: 3 [7040/12089 (58%)]\tLoss: 0.007249\n",
+ "Train Epoch: 3 [7680/12089 (63%)]\tLoss: 0.005857\n",
+ "Train Epoch: 3 [8320/12089 (69%)]\tLoss: 0.036750\n",
+ "Train Epoch: 3 [8960/12089 (74%)]\tLoss: 0.035684\n",
+ "Train Epoch: 3 [9600/12089 (79%)]\tLoss: 0.005433\n",
+ "Train Epoch: 3 [10240/12089 (85%)]\tLoss: 0.006491\n",
+ "Train Epoch: 3 [10880/12089 (90%)]\tLoss: 0.026140\n",
+ "Train Epoch: 3 [11520/12089 (95%)]\tLoss: 0.007269\n",
+ "\n",
+ "Test set: Average loss: 6.5537\n",
+ "Class 0 accuracy: 723.0/980(74%)\n",
+ "Class 1 accuracy: 968.0/1135(85%)\n",
+ "Class 2 accuracy: 999.0/1032(97%)\n",
+ "Class 3 accuracy: 992.0/1010(98%)\n",
+ "Class 4 accuracy: 0.0/982(0%)\n",
+ "Class 5 accuracy: 0.0/892(0%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 1 [0/11263 (0%)]\tLoss: 1.635322\n",
+ "Train Epoch: 1 [640/11263 (6%)]\tLoss: 0.442154\n",
+ "Train Epoch: 1 [1280/11263 (11%)]\tLoss: 0.249078\n",
+ "Train Epoch: 1 [1920/11263 (17%)]\tLoss: 0.270407\n",
+ "Train Epoch: 1 [2560/11263 (23%)]\tLoss: 0.180369\n",
+ "Train Epoch: 1 [3200/11263 (28%)]\tLoss: 0.088140\n",
+ "Train Epoch: 1 [3840/11263 (34%)]\tLoss: 0.171120\n",
+ "Train Epoch: 1 [4480/11263 (40%)]\tLoss: 0.263797\n",
+ "Train Epoch: 1 [5120/11263 (45%)]\tLoss: 0.111911\n",
+ "Train Epoch: 1 [5760/11263 (51%)]\tLoss: 0.117878\n",
+ "Train Epoch: 1 [6400/11263 (57%)]\tLoss: 0.033867\n",
+ "Train Epoch: 1 [7040/11263 (62%)]\tLoss: 0.058040\n",
+ "Train Epoch: 1 [7680/11263 (68%)]\tLoss: 0.034263\n",
+ "Train Epoch: 1 [8320/11263 (74%)]\tLoss: 0.082758\n",
+ "Train Epoch: 1 [8960/11263 (80%)]\tLoss: 0.041592\n",
+ "Train Epoch: 1 [9600/11263 (85%)]\tLoss: 0.035437\n",
+ "Train Epoch: 1 [10240/11263 (91%)]\tLoss: 0.015268\n",
+ "Train Epoch: 1 [10880/11263 (97%)]\tLoss: 0.029453\n",
+ "\n",
+ "Test set: Average loss: 4.5940\n",
+ "Class 0 accuracy: 760.0/980(78%)\n",
+ "Class 1 accuracy: 986.0/1135(87%)\n",
+ "Class 2 accuracy: 664.0/1032(64%)\n",
+ "Class 3 accuracy: 367.0/1010(36%)\n",
+ "Class 4 accuracy: 979.0/982(100%)\n",
+ "Class 5 accuracy: 864.0/892(97%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 2 [0/11263 (0%)]\tLoss: 0.008444\n",
+ "Train Epoch: 2 [640/11263 (6%)]\tLoss: 0.020623\n",
+ "Train Epoch: 2 [1280/11263 (11%)]\tLoss: 0.028893\n",
+ "Train Epoch: 2 [1920/11263 (17%)]\tLoss: 0.025949\n",
+ "Train Epoch: 2 [2560/11263 (23%)]\tLoss: 0.027629\n",
+ "Train Epoch: 2 [3200/11263 (28%)]\tLoss: 0.012673\n",
+ "Train Epoch: 2 [3840/11263 (34%)]\tLoss: 0.132138\n",
+ "Train Epoch: 2 [4480/11263 (40%)]\tLoss: 0.019918\n",
+ "Train Epoch: 2 [5120/11263 (45%)]\tLoss: 0.040663\n",
+ "Train Epoch: 2 [5760/11263 (51%)]\tLoss: 0.072586\n",
+ "Train Epoch: 2 [6400/11263 (57%)]\tLoss: 0.027573\n",
+ "Train Epoch: 2 [7040/11263 (62%)]\tLoss: 0.010655\n",
+ "Train Epoch: 2 [7680/11263 (68%)]\tLoss: 0.007804\n",
+ "Train Epoch: 2 [8320/11263 (74%)]\tLoss: 0.021280\n",
+ "Train Epoch: 2 [8960/11263 (80%)]\tLoss: 0.038740\n",
+ "Train Epoch: 2 [9600/11263 (85%)]\tLoss: 0.005795\n",
+ "Train Epoch: 2 [10240/11263 (91%)]\tLoss: 0.005783\n",
+ "Train Epoch: 2 [10880/11263 (97%)]\tLoss: 0.049235\n",
+ "\n",
+ "Test set: Average loss: 5.5267\n",
+ "Class 0 accuracy: 647.0/980(66%)\n",
+ "Class 1 accuracy: 970.0/1135(85%)\n",
+ "Class 2 accuracy: 616.0/1032(60%)\n",
+ "Class 3 accuracy: 240.0/1010(24%)\n",
+ "Class 4 accuracy: 979.0/982(100%)\n",
+ "Class 5 accuracy: 876.0/892(98%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 3 [0/11263 (0%)]\tLoss: 0.060683\n",
+ "Train Epoch: 3 [640/11263 (6%)]\tLoss: 0.049512\n",
+ "Train Epoch: 3 [1280/11263 (11%)]\tLoss: 0.034698\n",
+ "Train Epoch: 3 [1920/11263 (17%)]\tLoss: 0.015581\n",
+ "Train Epoch: 3 [2560/11263 (23%)]\tLoss: 0.097252\n",
+ "Train Epoch: 3 [3200/11263 (28%)]\tLoss: 0.022429\n",
+ "Train Epoch: 3 [3840/11263 (34%)]\tLoss: 0.021434\n",
+ "Train Epoch: 3 [4480/11263 (40%)]\tLoss: 0.030935\n",
+ "Train Epoch: 3 [5120/11263 (45%)]\tLoss: 0.055945\n",
+ "Train Epoch: 3 [5760/11263 (51%)]\tLoss: 0.032630\n",
+ "Train Epoch: 3 [6400/11263 (57%)]\tLoss: 0.010632\n",
+ "Train Epoch: 3 [7040/11263 (62%)]\tLoss: 0.014181\n",
+ "Train Epoch: 3 [7680/11263 (68%)]\tLoss: 0.011640\n",
+ "Train Epoch: 3 [8320/11263 (74%)]\tLoss: 0.022853\n",
+ "Train Epoch: 3 [8960/11263 (80%)]\tLoss: 0.011952\n",
+ "Train Epoch: 3 [9600/11263 (85%)]\tLoss: 0.029158\n",
+ "Train Epoch: 3 [10240/11263 (91%)]\tLoss: 0.005775\n",
+ "Train Epoch: 3 [10880/11263 (97%)]\tLoss: 0.044570\n",
+ "\n",
+ "Test set: Average loss: 6.0346\n",
+ "Class 0 accuracy: 662.0/980(68%)\n",
+ "Class 1 accuracy: 972.0/1135(86%)\n",
+ "Class 2 accuracy: 558.0/1032(54%)\n",
+ "Class 3 accuracy: 228.0/1010(23%)\n",
+ "Class 4 accuracy: 978.0/982(100%)\n",
+ "Class 5 accuracy: 883.0/892(99%)\n",
+ "Class 6 accuracy: 0.0/958(0%)\n",
+ "Class 7 accuracy: 0.0/1028(0%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 1 [0/12183 (0%)]\tLoss: 1.379911\n",
+ "Train Epoch: 1 [640/12183 (5%)]\tLoss: 0.125989\n",
+ "Train Epoch: 1 [1280/12183 (10%)]\tLoss: 0.938824\n",
+ "Train Epoch: 1 [1920/12183 (16%)]\tLoss: 0.376103\n",
+ "Train Epoch: 1 [2560/12183 (21%)]\tLoss: 0.235958\n",
+ "Train Epoch: 1 [3200/12183 (26%)]\tLoss: 0.146971\n",
+ "Train Epoch: 1 [3840/12183 (31%)]\tLoss: 0.261188\n",
+ "Train Epoch: 1 [4480/12183 (37%)]\tLoss: 0.162076\n",
+ "Train Epoch: 1 [5120/12183 (42%)]\tLoss: 0.064259\n",
+ "Train Epoch: 1 [5760/12183 (47%)]\tLoss: 0.127866\n",
+ "Train Epoch: 1 [6400/12183 (52%)]\tLoss: 0.153409\n",
+ "Train Epoch: 1 [7040/12183 (58%)]\tLoss: 0.117668\n",
+ "Train Epoch: 1 [7680/12183 (63%)]\tLoss: 0.058572\n",
+ "Train Epoch: 1 [8320/12183 (68%)]\tLoss: 0.084697\n",
+ "Train Epoch: 1 [8960/12183 (73%)]\tLoss: 0.034420\n",
+ "Train Epoch: 1 [9600/12183 (79%)]\tLoss: 0.017252\n",
+ "Train Epoch: 1 [10240/12183 (84%)]\tLoss: 0.020837\n",
+ "Train Epoch: 1 [10880/12183 (89%)]\tLoss: 0.022506\n",
+ "Train Epoch: 1 [11520/12183 (94%)]\tLoss: 0.073931\n",
+ "Train Epoch: 1 [4370/12183 (99%)]\tLoss: 0.032285\n",
+ "\n",
+ "Test set: Average loss: 2.6226\n",
+ "Class 0 accuracy: 668.0/980(68%)\n",
+ "Class 1 accuracy: 975.0/1135(86%)\n",
+ "Class 2 accuracy: 475.0/1032(46%)\n",
+ "Class 3 accuracy: 489.0/1010(48%)\n",
+ "Class 4 accuracy: 691.0/982(70%)\n",
+ "Class 5 accuracy: 564.0/892(63%)\n",
+ "Class 6 accuracy: 948.0/958(99%)\n",
+ "Class 7 accuracy: 1003.0/1028(98%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 2 [0/12183 (0%)]\tLoss: 0.045590\n",
+ "Train Epoch: 2 [640/12183 (5%)]\tLoss: 0.091010\n",
+ "Train Epoch: 2 [1280/12183 (10%)]\tLoss: 0.021004\n",
+ "Train Epoch: 2 [1920/12183 (16%)]\tLoss: 0.036341\n",
+ "Train Epoch: 2 [2560/12183 (21%)]\tLoss: 0.021020\n",
+ "Train Epoch: 2 [3200/12183 (26%)]\tLoss: 0.036480\n",
+ "Train Epoch: 2 [3840/12183 (31%)]\tLoss: 0.043549\n",
+ "Train Epoch: 2 [4480/12183 (37%)]\tLoss: 0.029598\n",
+ "Train Epoch: 2 [5120/12183 (42%)]\tLoss: 0.053440\n",
+ "Train Epoch: 2 [5760/12183 (47%)]\tLoss: 0.016025\n",
+ "Train Epoch: 2 [6400/12183 (52%)]\tLoss: 0.022595\n",
+ "Train Epoch: 2 [7040/12183 (58%)]\tLoss: 0.008219\n",
+ "Train Epoch: 2 [7680/12183 (63%)]\tLoss: 0.047070\n",
+ "Train Epoch: 2 [8320/12183 (68%)]\tLoss: 0.021060\n",
+ "Train Epoch: 2 [8960/12183 (73%)]\tLoss: 0.024627\n",
+ "Train Epoch: 2 [9600/12183 (79%)]\tLoss: 0.031155\n",
+ "Train Epoch: 2 [10240/12183 (84%)]\tLoss: 0.045095\n",
+ "Train Epoch: 2 [10880/12183 (89%)]\tLoss: 0.015965\n",
+ "Train Epoch: 2 [11520/12183 (94%)]\tLoss: 0.025968\n",
+ "Train Epoch: 2 [4370/12183 (99%)]\tLoss: 0.043935\n",
+ "\n",
+ "Test set: Average loss: 3.1688\n",
+ "Class 0 accuracy: 682.0/980(70%)\n",
+ "Class 1 accuracy: 969.0/1135(85%)\n",
+ "Class 2 accuracy: 480.0/1032(47%)\n",
+ "Class 3 accuracy: 471.0/1010(47%)\n",
+ "Class 4 accuracy: 635.0/982(65%)\n",
+ "Class 5 accuracy: 484.0/892(54%)\n",
+ "Class 6 accuracy: 950.0/958(99%)\n",
+ "Class 7 accuracy: 1013.0/1028(99%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 3 [0/12183 (0%)]\tLoss: 0.019614\n",
+ "Train Epoch: 3 [640/12183 (5%)]\tLoss: 0.059921\n",
+ "Train Epoch: 3 [1280/12183 (10%)]\tLoss: 0.018621\n",
+ "Train Epoch: 3 [1920/12183 (16%)]\tLoss: 0.011714\n",
+ "Train Epoch: 3 [2560/12183 (21%)]\tLoss: 0.014679\n",
+ "Train Epoch: 3 [3200/12183 (26%)]\tLoss: 0.012920\n",
+ "Train Epoch: 3 [3840/12183 (31%)]\tLoss: 0.022788\n",
+ "Train Epoch: 3 [4480/12183 (37%)]\tLoss: 0.016748\n",
+ "Train Epoch: 3 [5120/12183 (42%)]\tLoss: 0.028296\n",
+ "Train Epoch: 3 [5760/12183 (47%)]\tLoss: 0.021148\n",
+ "Train Epoch: 3 [6400/12183 (52%)]\tLoss: 0.008148\n",
+ "Train Epoch: 3 [7040/12183 (58%)]\tLoss: 0.011808\n",
+ "Train Epoch: 3 [7680/12183 (63%)]\tLoss: 0.009274\n",
+ "Train Epoch: 3 [8320/12183 (68%)]\tLoss: 0.011613\n",
+ "Train Epoch: 3 [8960/12183 (73%)]\tLoss: 0.013502\n",
+ "Train Epoch: 3 [9600/12183 (79%)]\tLoss: 0.022795\n",
+ "Train Epoch: 3 [10240/12183 (84%)]\tLoss: 0.013261\n",
+ "Train Epoch: 3 [10880/12183 (89%)]\tLoss: 0.026117\n",
+ "Train Epoch: 3 [11520/12183 (94%)]\tLoss: 0.004542\n",
+ "Train Epoch: 3 [4370/12183 (99%)]\tLoss: 0.002898\n",
+ "\n",
+ "Test set: Average loss: 3.5287\n",
+ "Class 0 accuracy: 636.0/980(65%)\n",
+ "Class 1 accuracy: 955.0/1135(84%)\n",
+ "Class 2 accuracy: 451.0/1032(44%)\n",
+ "Class 3 accuracy: 432.0/1010(43%)\n",
+ "Class 4 accuracy: 562.0/982(57%)\n",
+ "Class 5 accuracy: 458.0/892(51%)\n",
+ "Class 6 accuracy: 952.0/958(99%)\n",
+ "Class 7 accuracy: 1018.0/1028(99%)\n",
+ "Class 8 accuracy: 0.0/974(0%)\n",
+ "Class 9 accuracy: 0.0/1009(0%)\n",
+ "Train Epoch: 1 [0/11800 (0%)]\tLoss: 0.473740\n",
+ "Train Epoch: 1 [640/11800 (5%)]\tLoss: 0.210103\n",
+ "Train Epoch: 1 [1280/11800 (11%)]\tLoss: 0.418402\n",
+ "Train Epoch: 1 [1920/11800 (16%)]\tLoss: 0.530582\n",
+ "Train Epoch: 1 [2560/11800 (22%)]\tLoss: 0.304206\n",
+ "Train Epoch: 1 [3200/11800 (27%)]\tLoss: 0.122092\n",
+ "Train Epoch: 1 [3840/11800 (32%)]\tLoss: 0.183575\n",
+ "Train Epoch: 1 [4480/11800 (38%)]\tLoss: 0.260152\n",
+ "Train Epoch: 1 [5120/11800 (43%)]\tLoss: 0.277376\n",
+ "Train Epoch: 1 [5760/11800 (49%)]\tLoss: 0.212683\n",
+ "Train Epoch: 1 [6400/11800 (54%)]\tLoss: 0.163834\n",
+ "Train Epoch: 1 [7040/11800 (59%)]\tLoss: 0.088070\n",
+ "Train Epoch: 1 [7680/11800 (65%)]\tLoss: 0.092493\n",
+ "Train Epoch: 1 [8320/11800 (70%)]\tLoss: 0.169504\n",
+ "Train Epoch: 1 [8960/11800 (76%)]\tLoss: 0.130592\n",
+ "Train Epoch: 1 [9600/11800 (81%)]\tLoss: 0.043165\n",
+ "Train Epoch: 1 [10240/11800 (86%)]\tLoss: 0.051642\n",
+ "Train Epoch: 1 [10880/11800 (92%)]\tLoss: 0.081378\n",
+ "Train Epoch: 1 [11520/11800 (97%)]\tLoss: 0.073074\n",
+ "\n",
+ "Test set: Average loss: 1.1909\n",
+ "Class 0 accuracy: 710.0/980(72%)\n",
+ "Class 1 accuracy: 937.0/1135(83%)\n",
+ "Class 2 accuracy: 525.0/1032(51%)\n",
+ "Class 3 accuracy: 409.0/1010(40%)\n",
+ "Class 4 accuracy: 368.0/982(37%)\n",
+ "Class 5 accuracy: 359.0/892(40%)\n",
+ "Class 6 accuracy: 782.0/958(82%)\n",
+ "Class 7 accuracy: 587.0/1028(57%)\n",
+ "Class 8 accuracy: 945.0/974(97%)\n",
+ "Class 9 accuracy: 974.0/1009(97%)\n",
+ "Train Epoch: 2 [0/11800 (0%)]\tLoss: 0.106367\n",
+ "Train Epoch: 2 [640/11800 (5%)]\tLoss: 0.110765\n",
+ "Train Epoch: 2 [1280/11800 (11%)]\tLoss: 0.168906\n",
+ "Train Epoch: 2 [1920/11800 (16%)]\tLoss: 0.130259\n",
+ "Train Epoch: 2 [2560/11800 (22%)]\tLoss: 0.099700\n",
+ "Train Epoch: 2 [3200/11800 (27%)]\tLoss: 0.045566\n",
+ "Train Epoch: 2 [3840/11800 (32%)]\tLoss: 0.032498\n",
+ "Train Epoch: 2 [4480/11800 (38%)]\tLoss: 0.127420\n",
+ "Train Epoch: 2 [5120/11800 (43%)]\tLoss: 0.127244\n",
+ "Train Epoch: 2 [5760/11800 (49%)]\tLoss: 0.123305\n",
+ "Train Epoch: 2 [6400/11800 (54%)]\tLoss: 0.033301\n",
+ "Train Epoch: 2 [7040/11800 (59%)]\tLoss: 0.050663\n",
+ "Train Epoch: 2 [7680/11800 (65%)]\tLoss: 0.029956\n",
+ "Train Epoch: 2 [8320/11800 (70%)]\tLoss: 0.016801\n",
+ "Train Epoch: 2 [8960/11800 (76%)]\tLoss: 0.067689\n",
+ "Train Epoch: 2 [9600/11800 (81%)]\tLoss: 0.044518\n",
+ "Train Epoch: 2 [10240/11800 (86%)]\tLoss: 0.073524\n",
+ "Train Epoch: 2 [10880/11800 (92%)]\tLoss: 0.035689\n",
+ "Train Epoch: 2 [11520/11800 (97%)]\tLoss: 0.084134\n",
+ "\n",
+ "Test set: Average loss: 1.5479\n",
+ "Class 0 accuracy: 684.0/980(70%)\n",
+ "Class 1 accuracy: 893.0/1135(79%)\n",
+ "Class 2 accuracy: 470.0/1032(46%)\n",
+ "Class 3 accuracy: 346.0/1010(34%)\n",
+ "Class 4 accuracy: 312.0/982(32%)\n",
+ "Class 5 accuracy: 298.0/892(33%)\n",
+ "Class 6 accuracy: 758.0/958(79%)\n",
+ "Class 7 accuracy: 583.0/1028(57%)\n",
+ "Class 8 accuracy: 953.0/974(98%)\n",
+ "Class 9 accuracy: 977.0/1009(97%)\n",
+ "Train Epoch: 3 [0/11800 (0%)]\tLoss: 0.106828\n",
+ "Train Epoch: 3 [640/11800 (5%)]\tLoss: 0.064151\n",
+ "Train Epoch: 3 [1280/11800 (11%)]\tLoss: 0.064141\n",
+ "Train Epoch: 3 [1920/11800 (16%)]\tLoss: 0.046738\n",
+ "Train Epoch: 3 [2560/11800 (22%)]\tLoss: 0.060965\n",
+ "Train Epoch: 3 [3200/11800 (27%)]\tLoss: 0.041177\n",
+ "Train Epoch: 3 [3840/11800 (32%)]\tLoss: 0.015125\n",
+ "Train Epoch: 3 [4480/11800 (38%)]\tLoss: 0.036628\n",
+ "Train Epoch: 3 [5120/11800 (43%)]\tLoss: 0.040213\n",
+ "Train Epoch: 3 [5760/11800 (49%)]\tLoss: 0.038948\n",
+ "Train Epoch: 3 [6400/11800 (54%)]\tLoss: 0.065069\n",
+ "Train Epoch: 3 [7040/11800 (59%)]\tLoss: 0.060684\n",
+ "Train Epoch: 3 [7680/11800 (65%)]\tLoss: 0.051811\n",
+ "Train Epoch: 3 [8320/11800 (70%)]\tLoss: 0.096974\n",
+ "Train Epoch: 3 [8960/11800 (76%)]\tLoss: 0.014142\n",
+ "Train Epoch: 3 [9600/11800 (81%)]\tLoss: 0.068916\n",
+ "Train Epoch: 3 [10240/11800 (86%)]\tLoss: 0.047880\n",
+ "Train Epoch: 3 [10880/11800 (92%)]\tLoss: 0.007572\n",
+ "Train Epoch: 3 [11520/11800 (97%)]\tLoss: 0.025779\n",
+ "\n",
+ "Test set: Average loss: 1.6151\n",
+ "Class 0 accuracy: 647.0/980(66%)\n",
+ "Class 1 accuracy: 885.0/1135(78%)\n",
+ "Class 2 accuracy: 492.0/1032(48%)\n",
+ "Class 3 accuracy: 379.0/1010(38%)\n",
+ "Class 4 accuracy: 323.0/982(33%)\n",
+ "Class 5 accuracy: 310.0/892(35%)\n",
+ "Class 6 accuracy: 736.0/958(77%)\n",
+ "Class 7 accuracy: 546.0/1028(53%)\n",
+ "Class 8 accuracy: 961.0/974(99%)\n",
+ "Class 9 accuracy: 976.0/1009(97%)\n",
+ "Train Epoch: 1 [0/12665 (0%)]\tLoss: 0.030852\n",
+ "Train Epoch: 1 [640/12665 (5%)]\tLoss: 0.051628\n",
+ "Train Epoch: 1 [1280/12665 (10%)]\tLoss: 0.097404\n",
+ "Train Epoch: 1 [1920/12665 (15%)]\tLoss: 0.066156\n",
+ "Train Epoch: 1 [2560/12665 (20%)]\tLoss: 0.085003\n",
+ "Train Epoch: 1 [3200/12665 (25%)]\tLoss: 0.022351\n",
+ "Train Epoch: 1 [3840/12665 (30%)]\tLoss: 0.047033\n",
+ "Train Epoch: 1 [4480/12665 (35%)]\tLoss: 0.021085\n",
+ "Train Epoch: 1 [5120/12665 (40%)]\tLoss: 0.038875\n",
+ "Train Epoch: 1 [5760/12665 (45%)]\tLoss: 0.035716\n",
+ "Train Epoch: 1 [6400/12665 (51%)]\tLoss: 0.033066\n",
+ "Train Epoch: 1 [7040/12665 (56%)]\tLoss: 0.020775\n",
+ "Train Epoch: 1 [7680/12665 (61%)]\tLoss: 0.027079\n",
+ "Train Epoch: 1 [8320/12665 (66%)]\tLoss: 0.071041\n",
+ "Train Epoch: 1 [8960/12665 (71%)]\tLoss: 0.030695\n",
+ "Train Epoch: 1 [9600/12665 (76%)]\tLoss: 0.038757\n",
+ "Train Epoch: 1 [10240/12665 (81%)]\tLoss: 0.003501\n",
+ "Train Epoch: 1 [10880/12665 (86%)]\tLoss: 0.011310\n",
+ "Train Epoch: 1 [11520/12665 (91%)]\tLoss: 0.057821\n",
+ "Train Epoch: 1 [12160/12665 (96%)]\tLoss: 0.016945\n",
+ "\n",
+ "Test set: Average loss: 1.3637\n",
+ "Class 0 accuracy: 977.0/980(100%)\n",
+ "Class 1 accuracy: 1133.0/1135(100%)\n",
+ "Class 2 accuracy: 574.0/1032(56%)\n",
+ "Class 3 accuracy: 561.0/1010(56%)\n",
+ "Class 4 accuracy: 494.0/982(50%)\n",
+ "Class 5 accuracy: 328.0/892(37%)\n",
+ "Class 6 accuracy: 656.0/958(68%)\n",
+ "Class 7 accuracy: 753.0/1028(73%)\n",
+ "Class 8 accuracy: 701.0/974(72%)\n",
+ "Class 9 accuracy: 866.0/1009(86%)\n",
+ "Train Epoch: 2 [0/12665 (0%)]\tLoss: 0.016473\n",
+ "Train Epoch: 2 [640/12665 (5%)]\tLoss: 0.007508\n",
+ "Train Epoch: 2 [1280/12665 (10%)]\tLoss: 0.037633\n",
+ "Train Epoch: 2 [1920/12665 (15%)]\tLoss: 0.026398\n",
+ "Train Epoch: 2 [2560/12665 (20%)]\tLoss: 0.004491\n",
+ "Train Epoch: 2 [3200/12665 (25%)]\tLoss: 0.002447\n",
+ "Train Epoch: 2 [3840/12665 (30%)]\tLoss: 0.007396\n",
+ "Train Epoch: 2 [4480/12665 (35%)]\tLoss: 0.022243\n",
+ "Train Epoch: 2 [5120/12665 (40%)]\tLoss: 0.032471\n",
+ "Train Epoch: 2 [5760/12665 (45%)]\tLoss: 0.012960\n",
+ "Train Epoch: 2 [6400/12665 (51%)]\tLoss: 0.009143\n",
+ "Train Epoch: 2 [7040/12665 (56%)]\tLoss: 0.005618\n",
+ "Train Epoch: 2 [7680/12665 (61%)]\tLoss: 0.016858\n",
+ "Train Epoch: 2 [8320/12665 (66%)]\tLoss: 0.015019\n",
+ "Train Epoch: 2 [8960/12665 (71%)]\tLoss: 0.042326\n",
+ "Train Epoch: 2 [9600/12665 (76%)]\tLoss: 0.005700\n",
+ "Train Epoch: 2 [10240/12665 (81%)]\tLoss: 0.020611\n",
+ "Train Epoch: 2 [10880/12665 (86%)]\tLoss: 0.011131\n",
+ "Train Epoch: 2 [11520/12665 (91%)]\tLoss: 0.019555\n",
+ "Train Epoch: 2 [12160/12665 (96%)]\tLoss: 0.015460\n",
+ "\n",
+ "Test set: Average loss: 1.5375\n",
+ "Class 0 accuracy: 977.0/980(100%)\n",
+ "Class 1 accuracy: 1133.0/1135(100%)\n",
+ "Class 2 accuracy: 500.0/1032(48%)\n",
+ "Class 3 accuracy: 503.0/1010(50%)\n",
+ "Class 4 accuracy: 489.0/982(50%)\n",
+ "Class 5 accuracy: 362.0/892(41%)\n",
+ "Class 6 accuracy: 639.0/958(67%)\n",
+ "Class 7 accuracy: 746.0/1028(73%)\n",
+ "Class 8 accuracy: 657.0/974(67%)\n",
+ "Class 9 accuracy: 847.0/1009(84%)\n",
+ "Train Epoch: 3 [0/12665 (0%)]\tLoss: 0.003884\n",
+ "Train Epoch: 3 [640/12665 (5%)]\tLoss: 0.024282\n",
+ "Train Epoch: 3 [1280/12665 (10%)]\tLoss: 0.007995\n",
+ "Train Epoch: 3 [1920/12665 (15%)]\tLoss: 0.006038\n",
+ "Train Epoch: 3 [2560/12665 (20%)]\tLoss: 0.007882\n",
+ "Train Epoch: 3 [3200/12665 (25%)]\tLoss: 0.008968\n",
+ "Train Epoch: 3 [3840/12665 (30%)]\tLoss: 0.007985\n",
+ "Train Epoch: 3 [4480/12665 (35%)]\tLoss: 0.017382\n",
+ "Train Epoch: 3 [5120/12665 (40%)]\tLoss: 0.006549\n",
+ "Train Epoch: 3 [5760/12665 (45%)]\tLoss: 0.006111\n",
+ "Train Epoch: 3 [6400/12665 (51%)]\tLoss: 0.016712\n",
+ "Train Epoch: 3 [7040/12665 (56%)]\tLoss: 0.009508\n",
+ "Train Epoch: 3 [7680/12665 (61%)]\tLoss: 0.013591\n",
+ "Train Epoch: 3 [8320/12665 (66%)]\tLoss: 0.002323\n",
+ "Train Epoch: 3 [8960/12665 (71%)]\tLoss: 0.003972\n",
+ "Train Epoch: 3 [9600/12665 (76%)]\tLoss: 0.002240\n",
+ "Train Epoch: 3 [10240/12665 (81%)]\tLoss: 0.012233\n",
+ "Train Epoch: 3 [10880/12665 (86%)]\tLoss: 0.046749\n",
+ "Train Epoch: 3 [11520/12665 (91%)]\tLoss: 0.006683\n",
+ "Train Epoch: 3 [12160/12665 (96%)]\tLoss: 0.006137\n",
+ "\n",
+ "Test set: Average loss: 1.5670\n",
+ "Class 0 accuracy: 976.0/980(100%)\n",
+ "Class 1 accuracy: 1134.0/1135(100%)\n",
+ "Class 2 accuracy: 509.0/1032(49%)\n",
+ "Class 3 accuracy: 520.0/1010(51%)\n",
+ "Class 4 accuracy: 496.0/982(51%)\n",
+ "Class 5 accuracy: 355.0/892(40%)\n",
+ "Class 6 accuracy: 625.0/958(65%)\n",
+ "Class 7 accuracy: 743.0/1028(72%)\n",
+ "Class 8 accuracy: 654.0/974(67%)\n",
+ "Class 9 accuracy: 854.0/1009(85%)\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[{0: 0.9959183673469387,\n",
+ " 1: 0.9991189427312775,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '0_1'},\n",
+ " {0: 0.9959183673469387,\n",
+ " 1: 1.0,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '0_1'},\n",
+ " {0: 0.996938775510204,\n",
+ " 1: 1.0,\n",
+ " 2: 0.0,\n",
+ " 3: 0.0,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '0_1'},\n",
+ " {0: 0.8459183673469388,\n",
+ " 1: 0.9215859030837005,\n",
+ " 2: 0.9602713178294574,\n",
+ " 3: 0.9732673267326732,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '2_3'},\n",
+ " {0: 0.7795918367346939,\n",
+ " 1: 0.8801762114537445,\n",
+ " 2: 0.9689922480620154,\n",
+ " 3: 0.9782178217821782,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '2_3'},\n",
+ " {0: 0.7377551020408163,\n",
+ " 1: 0.852863436123348,\n",
+ " 2: 0.9680232558139535,\n",
+ " 3: 0.9821782178217822,\n",
+ " 4: 0.0,\n",
+ " 5: 0.0,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '2_3'},\n",
+ " {0: 0.7755102040816326,\n",
+ " 1: 0.8687224669603524,\n",
+ " 2: 0.6434108527131783,\n",
+ " 3: 0.36336633663366336,\n",
+ " 4: 0.9969450101832994,\n",
+ " 5: 0.968609865470852,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '4_5'},\n",
+ " {0: 0.6602040816326531,\n",
+ " 1: 0.8546255506607929,\n",
+ " 2: 0.5968992248062015,\n",
+ " 3: 0.2376237623762376,\n",
+ " 4: 0.9969450101832994,\n",
+ " 5: 0.9820627802690582,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '4_5'},\n",
+ " {0: 0.6755102040816326,\n",
+ " 1: 0.8563876651982378,\n",
+ " 2: 0.5406976744186046,\n",
+ " 3: 0.22574257425742575,\n",
+ " 4: 0.9959266802443992,\n",
+ " 5: 0.9899103139013453,\n",
+ " 6: 0.0,\n",
+ " 7: 0.0,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '4_5'},\n",
+ " {0: 0.6816326530612244,\n",
+ " 1: 0.8590308370044053,\n",
+ " 2: 0.46027131782945735,\n",
+ " 3: 0.48415841584158414,\n",
+ " 4: 0.7036659877800407,\n",
+ " 5: 0.6322869955156951,\n",
+ " 6: 0.9895615866388309,\n",
+ " 7: 0.97568093385214,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '6_7'},\n",
+ " {0: 0.6959183673469388,\n",
+ " 1: 0.8537444933920705,\n",
+ " 2: 0.46511627906976744,\n",
+ " 3: 0.46633663366336636,\n",
+ " 4: 0.6466395112016293,\n",
+ " 5: 0.5426008968609866,\n",
+ " 6: 0.9916492693110647,\n",
+ " 7: 0.9854085603112841,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '6_7'},\n",
+ " {0: 0.6489795918367347,\n",
+ " 1: 0.8414096916299559,\n",
+ " 2: 0.437015503875969,\n",
+ " 3: 0.4277227722772277,\n",
+ " 4: 0.5723014256619144,\n",
+ " 5: 0.5134529147982063,\n",
+ " 6: 0.9937369519832986,\n",
+ " 7: 0.9902723735408561,\n",
+ " 8: 0.0,\n",
+ " 9: 0.0,\n",
+ " 'phase': '6_7'},\n",
+ " {0: 0.7244897959183674,\n",
+ " 1: 0.8255506607929516,\n",
+ " 2: 0.5087209302325582,\n",
+ " 3: 0.404950495049505,\n",
+ " 4: 0.37474541751527496,\n",
+ " 5: 0.4024663677130045,\n",
+ " 6: 0.8162839248434238,\n",
+ " 7: 0.5710116731517509,\n",
+ " 8: 0.9702258726899384,\n",
+ " 9: 0.9653121902874133,\n",
+ " 'phase': '8_9'},\n",
+ " {0: 0.6979591836734694,\n",
+ " 1: 0.786784140969163,\n",
+ " 2: 0.45542635658914726,\n",
+ " 3: 0.3425742574257426,\n",
+ " 4: 0.31771894093686354,\n",
+ " 5: 0.33408071748878926,\n",
+ " 6: 0.791231732776618,\n",
+ " 7: 0.5671206225680934,\n",
+ " 8: 0.9784394250513347,\n",
+ " 9: 0.9682854311199207,\n",
+ " 'phase': '8_9'},\n",
+ " {0: 0.6602040816326531,\n",
+ " 1: 0.7797356828193832,\n",
+ " 2: 0.47674418604651164,\n",
+ " 3: 0.37524752475247525,\n",
+ " 4: 0.3289205702647658,\n",
+ " 5: 0.3475336322869955,\n",
+ " 6: 0.7682672233820459,\n",
+ " 7: 0.5311284046692607,\n",
+ " 8: 0.9866529774127311,\n",
+ " 9: 0.9672943508424182,\n",
+ " 'phase': '8_9'},\n",
+ " {0: 0.996938775510204,\n",
+ " 1: 0.9982378854625551,\n",
+ " 2: 0.5562015503875969,\n",
+ " 3: 0.5554455445544555,\n",
+ " 4: 0.5030549898167006,\n",
+ " 5: 0.36771300448430494,\n",
+ " 6: 0.6847599164926931,\n",
+ " 7: 0.7324902723735408,\n",
+ " 8: 0.7197125256673511,\n",
+ " 9: 0.8582755203171457,\n",
+ " 'phase': '0_1_again'},\n",
+ " {0: 0.996938775510204,\n",
+ " 1: 0.9982378854625551,\n",
+ " 2: 0.4844961240310077,\n",
+ " 3: 0.498019801980198,\n",
+ " 4: 0.4979633401221996,\n",
+ " 5: 0.40582959641255606,\n",
+ " 6: 0.6670146137787056,\n",
+ " 7: 0.72568093385214,\n",
+ " 8: 0.6745379876796714,\n",
+ " 9: 0.8394449950445986,\n",
+ " 'phase': '0_1_again'},\n",
+ " {0: 0.9959183673469387,\n",
+ " 1: 0.9991189427312775,\n",
+ " 2: 0.4932170542635659,\n",
+ " 3: 0.5148514851485149,\n",
+ " 4: 0.505091649694501,\n",
+ " 5: 0.39798206278026904,\n",
+ " 6: 0.6524008350730689,\n",
+ " 7: 0.7227626459143969,\n",
+ " 8: 0.6714579055441479,\n",
+ " 9: 0.846382556987116,\n",
+ " 'phase': '0_1_again'}]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 20
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8Yad6wstNUpx"
+ },
+ "source": [
+ "# Analysis of the results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "id": "gimHfUXdW4_K",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 459
+ },
+ "outputId": "3a07ff95-27df-468d-f571-414c99f1ca1e"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ " epoch 0 1 2 3 4 5 6 7 8 9 phase\n",
+ "0 0 0.96 0.97 0.82 0.85 0.59 0.35 0.86 0.85 0.52 0.82 baseline"
+ ],
+ "text/html": [
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " epoch | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 5 | \n",
+ " 6 | \n",
+ " 7 | \n",
+ " 8 | \n",
+ " 9 | \n",
+ " phase | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0.96 | \n",
+ " 0.97 | \n",
+ " 0.82 | \n",
+ " 0.85 | \n",
+ " 0.59 | \n",
+ " 0.35 | \n",
+ " 0.86 | \n",
+ " 0.85 | \n",
+ " 0.52 | \n",
+ " 0.82 | \n",
+ " baseline | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ " "
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ " epoch 0 1 2 3 4 5 6 7 8 9 phase\n",
+ "0 0 1.0 1.0 0.00 0.0 0.00 0.00 0.0 0.00 0.0 0.00 0_1\n",
+ "1 1 0.0 0.0 0.14 1.0 0.00 0.00 0.0 0.00 0.0 0.00 2_3\n",
+ "2 2 0.0 0.0 0.00 0.0 0.99 0.89 0.0 0.00 0.0 0.00 4_5\n",
+ "3 3 0.0 0.0 0.00 0.0 0.00 0.00 1.0 0.93 0.0 0.00 6_7\n",
+ "4 4 0.0 0.0 0.00 0.0 0.00 0.00 0.0 0.00 1.0 0.52 8_9"
+ ],
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " epoch | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 5 | \n",
+ " 6 | \n",
+ " 7 | \n",
+ " 8 | \n",
+ " 9 | \n",
+ " phase | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0_1 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.14 | \n",
+ " 1.0 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 2_3 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.99 | \n",
+ " 0.89 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 4_5 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 3 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 1.0 | \n",
+ " 0.93 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 6_7 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 4 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.0 | \n",
+ " 0.00 | \n",
+ " 1.0 | \n",
+ " 0.52 | \n",
+ " 8_9 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ " "
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ " epoch 0 1 2 3 4 5 6 7 8 9 phase\n",
+ "0 0 1.00 0.99 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0_1\n",
+ "1 1 0.04 0.51 0.23 1.00 0.00 0.00 0.00 0.00 0.00 0.00 2_3\n",
+ "2 2 0.71 0.64 0.06 0.00 0.98 0.85 0.00 0.00 0.00 0.00 4_5\n",
+ "3 3 0.23 0.53 0.00 0.01 0.00 0.00 0.99 0.96 0.00 0.00 6_7\n",
+ "4 4 0.44 0.55 0.00 0.00 0.00 0.00 0.19 0.00 0.98 0.88 8_9"
+ ],
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " epoch | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 2 | \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 5 | \n",
+ " 6 | \n",
+ " 7 | \n",
+ " 8 | \n",
+ " 9 | \n",
+ " phase | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 1.00 | \n",
+ " 0.99 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0_1 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 0.04 | \n",
+ " 0.51 | \n",
+ " 0.23 | \n",
+ " 1.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 2_3 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 2 | \n",
+ " 0.71 | \n",
+ " 0.64 | \n",
+ " 0.06 | \n",
+ " 0.00 | \n",
+ " 0.98 | \n",
+ " 0.85 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 4_5 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 3 | \n",
+ " 0.23 | \n",
+ " 0.53 | \n",
+ " 0.00 | \n",
+ " 0.01 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.99 | \n",
+ " 0.96 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 6_7 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 4 | \n",
+ " 0.44 | \n",
+ " 0.55 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.00 | \n",
+ " 0.19 | \n",
+ " 0.00 | \n",
+ " 0.98 | \n",
+ " 0.88 | \n",
+ " 8_9 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ " "
+ ]
+ },
+ "metadata": {}
+ }
+ ],
+ "source": [
+ "# The following helper code takes the logs and converts them into a dataframe\n",
+ "# for easier reading. You can also store the result as a CSV or HDF file by\n",
+ "# using the .to_csv and .to_hdf methods from pandas for later reading.\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd\n",
+ "\n",
+ "def format_results(history):\n",
+ " logs = pd.DataFrame(history).round(2)\n",
+ " logs.index.name = 'epoch'\n",
+ " logs = logs.reset_index(drop = False)\n",
+ " return logs\n",
+ "\n",
+ "#display(format_results(history_regular_mnist).head())\n",
+ "display(format_results(history_catastrophic_forgetting).head())\n",
+ "display(format_results(history_memory_replay).head())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "id": "jR6eRKn4WguU",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 807
+ },
+ "outputId": "a20c835b-1f62-4b6c-cd45-745f122a0041"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "