Created using Colaboratory

This commit is contained in:
George Mihaila
2019-09-18 23:05:18 -05:00
parent 63180b365a
commit 35a3432c46

View File

@@ -684,10 +684,10 @@
"metadata": {
"id": "vfdi1HANvi7f",
"colab_type": "code",
"outputId": "277e02e6-88b8-4d58-cc6e-1bb35b3a65e3",
"outputId": "b0dec8ee-d09e-4d4f-f54d-3cf6f28c5c32",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 425
"height": 1000
}
},
"source": [
@@ -792,107 +792,175 @@
" \n",
" print(\"Epoch:\", epoch, \"Training Loss: \", np.mean(train_loss), \"Valid Loss: \", np.mean(valid_loss))"
],
"execution_count": 0,
"execution_count": 36,
"outputs": [
{
"output_type": "stream",
"text": [
" 0%| | 0/9912422 [00:00<?, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz\n"
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n",
"torch.Size([256, 10])\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"9920512it [00:01, 9805960.78it/s] \n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Extracting data/MNIST/raw/train-images-idx3-ubyte.gz\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
" 0%| | 0/28881 [00:00<?, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"32768it [00:00, 130353.83it/s] \n",
" 0%| | 0/1648877 [00:00<?, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"1654784it [00:00, 2125030.13it/s] \n",
"0it [00:00, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"8192it [00:00, 49184.55it/s] \n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n",
"Processing...\n",
"Done!\n",
"Epoch: 1 Training Loss: 1.4125140620038865 Valid Loss: 0.7013939695155367\n",
"Epoch: 2 Training Loss: 0.5760371469436808 Valid Loss: 0.447847954770352\n",
"Epoch: 3 Training Loss: 0.4329018018981244 Valid Loss: 0.37440624389242616\n",
"Epoch: 4 Training Loss: 0.37780846084686037 Valid Loss: 0.3380796021603523\n",
"Epoch: 5 Training Loss: 0.3474148580051483 Valid Loss: 0.3164374764929426\n",
"Epoch: 6 Training Loss: 0.32569479799651085 Valid Loss: 0.2996636686172891\n",
"Epoch: 7 Training Loss: 0.3101388563184028 Valid Loss: 0.2881964070999876\n",
"Epoch: 8 Training Loss: 0.2971356455632981 Valid Loss: 0.2755763013946249\n",
"Epoch: 9 Training Loss: 0.2857853964446707 Valid Loss: 0.2682326946486818\n",
"Epoch: 10 Training Loss: 0.27580255975431583 Valid Loss: 0.2596069701174472\n"
],
"name": "stdout"
"output_type": "error",
"ename": "KeyboardInterrupt",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-36-5fa9db7f8f35>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;31m## Training on 1 epoch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 72\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtrainloader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 73\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart_dim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 566\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrcvd_idx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreorder_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreorder_dict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrcvd_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 568\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_process_next_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 569\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 570\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatches_outstanding\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_process_next_batch\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 599\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_process_next_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 600\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrcvd_idx\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 601\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_put_indices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 602\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mExceptionWrapper\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 603\u001b[0m \u001b[0;31m# make multiline KeyError msg readable by working around\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_put_indices\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 589\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_put_indices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 590\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatches_outstanding\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 591\u001b[0;31m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_iter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 592\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 593\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/sampler.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__iter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 172\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 173\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/sampler.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__iter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 96\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandperm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 97\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__len__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
]
},
@@ -911,6 +979,150 @@
"metadata": {
"id": "uJUm3qjvBfFe",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 170
},
"outputId": "e4fdfb90-566e-43a4-e4c5-424dce012a09"
},
"source": [
"import os\n",
"import torch\n",
"from torch import optim\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torchvision import transforms\n",
"from torchvision.datasets import MNIST\n",
"from torch.utils.data import DataLoader\n",
"from torch.utils.data.sampler import SubsetRandomSampler\n",
"from torch.backends import cudnn\n",
"import numpy as np\n",
"import multiprocessing\n",
"\n",
"cudnn.benchmark = True\n",
"\n",
"num_cores = multiprocessing.cpu_count()\n",
"\n",
"# transform the raw dataset into tensors and normalize them in a fixed range\n",
"_tasks = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])\n",
"\n",
"## Load MNIST Dataset and apply transformations\n",
"mnist = MNIST(\"data\", download=True, train=True, transform=_tasks)\n",
"\n",
"## create training and validation split \n",
"split = int(0.8 * len(mnist))\n",
"index_list = list(range(len(mnist)))\n",
"train_idx, valid_idx = index_list[:split], index_list[split:]\n",
"\n",
"## create sampler objects using SubsetRandomSampler\n",
"tr_sampler = SubsetRandomSampler(train_idx)\n",
"val_sampled = SubsetRandomSampler(valid_idx)\n",
"\n",
"## create iterator objects for train and valid datasets\n",
"trainloader = DataLoader(mnist, batch_size=256, sampler=tr_sampler, num_workers=num_cores)\n",
"validloader = DataLoader(mnist, batch_size=256, sampler=val_sampler, num_workers=num_cores)\n",
"\n",
"## GPU\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"## Build class of model\n",
"class Model(nn.Module):\n",
" def __init__(self):\n",
" super(Model, self).__init__()\n",
" \n",
" ## define layers\n",
" self.conv1 = nn.Conv2d(1, 16, 3, padding=1)\n",
" self.conv2 = nn.Conv2d(16, 32, 3, padding=1)\n",
" self.conv3 = nn.Conv2d(32, 64, 3, padding=1)\n",
" self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n",
" self.linear1 = nn.Linear(64*3*3, 512)\n",
" self.linear2 = nn.Linear(512,10)\n",
" \n",
" return\n",
"\n",
" def forward(self, x):\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = self.pool(F.relu(self.conv3(x)))\n",
" x = x.view(-1,64*3*3) #torch.flatten(x, start_dim=1) ## reshaping\n",
" x = F.relu(self.linear1(x))\n",
" x = self.linear2(x)\n",
"\n",
" return x\n",
"\n",
"## create model\n",
"model = Model()\n",
"\n",
"## in case of multi gpu\n",
"if torch.cuda.device_count() > 1:\n",
" print(\"Using\", torch.cuda.device_count(), \"GPUs\")\n",
" model = nn.DataParallel(model, device_ids=[1]) # [0,1,2,3]\n",
"\n",
"## put model on gpu\n",
"model.to(device)\n",
"\n",
"## loss fucntion\n",
"loss_function = nn.CrossEntropyLoss()\n",
"## optimizer\n",
"optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-6, momentum=0.9, nesterov=True)\n",
"## run for n epochs\n",
"for epoch in range(1,11):\n",
" train_loss , valid_loss = [], []\n",
"\n",
" ## train part\n",
" model.train()\n",
" for data, target in trainloader:\n",
" ## gradients acumulate. need to clear them on each example\n",
" optimizer.zero_grad()\n",
" output = model(data.to(device))\n",
" loss = loss_function(output.to(device), target.to(device))\n",
" loss.backward()\n",
" optimizer.step()\n",
" train_loss.append(loss.item())\n",
"\n",
" ## evaluation part on validation\n",
" model.eval() ##set model in evaluation mode\n",
" for data, target in validloader:\n",
" output = model(data.to(device))\n",
" loss = loss_function(output.to(device), target.to(device))\n",
" valid_loss.append(loss.item())\n",
"\n",
" print(\"Epoch:\", epoch, \"Training Loss: \", np.mean(train_loss), \"Valid Loss: \", np.mean(valid_loss))"
],
"execution_count": 53,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch: 1 Training Loss: 1.3867575062557738 Valid Loss: 0.20830699516103623\n",
"Epoch: 2 Training Loss: 0.14508407129014425 Valid Loss: 0.10438944716402825\n",
"Epoch: 3 Training Loss: 0.0902247162773571 Valid Loss: 0.07995569214541862\n",
"Epoch: 4 Training Loss: 0.06947706440622185 Valid Loss: 0.08476778730115991\n",
"Epoch: 5 Training Loss: 0.05636777369146968 Valid Loss: 0.06578631869497452\n",
"Epoch: 6 Training Loss: 0.048184677641442485 Valid Loss: 0.05531598358078206\n",
"Epoch: 7 Training Loss: 0.04294469940694089 Valid Loss: 0.05248951709809455\n",
"Epoch: 8 Training Loss: 0.038696830844546254 Valid Loss: 0.048144756558727714\n",
"Epoch: 9 Training Loss: 0.03434643990043154 Valid Loss: 0.04841103820883213\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hJZJLvwgsUrq",
"colab_type": "text"
},
"source": [
"#### Evaluation"
]
},
{
"cell_type": "code",
"metadata": {
"id": "eSRiJuEPsW0x",
"colab_type": "code",
"colab": {}
},
"source": [