diff --git a/examples/MNIST Addition - Video.ipynb b/examples/MNIST Addition - Video.ipynb index 62ffe92..4eaad31 100644 --- a/examples/MNIST Addition - Video.ipynb +++ b/examples/MNIST Addition - Video.ipynb @@ -16,16 +16,24 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/hanliying/opt/anaconda3/envs/pytorch/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 2, + "execution_count": 1, "metadata": {}, "output_type": "execute_result" } @@ -47,7 +55,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -86,9 +94,92 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/MNIST/raw/train-images-idx3-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "9913344it [00:01, 6405934.93it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./MNIST/MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw/train-labels-idx1-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "29696it [00:00, 7006866.09it/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "1649664it [00:00, 18619900.63it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "5120it [00:00, 2585772.00it/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])\n", "mnist_train_data = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True,transform=transform)\n", @@ -107,7 +198,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -145,7 +236,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -162,7 +253,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -189,7 +280,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -199,7 +290,7 @@ "\n", "from pylon.constraint import constraint\n", "\n", - "def enforce_sum(img1, img2, kwargs):\n", + "def enforce_sum(img1, img2, **kwargs):\n", " return img1 + img2 == kwargs['summation']\n", "\n", "\n", @@ -215,7 +306,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": { "scrolled": false }, @@ -224,14 +315,16 @@ "name": "stderr", "output_type": "stream", "text": [ - " 11%|█ | 1003/9000 [00:54<21:14, 6.27it/s]" + " 0%| | 0/9000 [00:00