mirror of
https://github.com/gmihaila/ml_things.git
synced 2021-10-04 01:29:04 +03:00
Created using Colaboratory
This commit is contained in:
402
pytorch_nn.ipynb
402
pytorch_nn.ipynb
@@ -684,10 +684,10 @@
|
||||
"metadata": {
|
||||
"id": "vfdi1HANvi7f",
|
||||
"colab_type": "code",
|
||||
"outputId": "277e02e6-88b8-4d58-cc6e-1bb35b3a65e3",
|
||||
"outputId": "b0dec8ee-d09e-4d4f-f54d-3cf6f28c5c32",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 425
|
||||
"height": 1000
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
@@ -792,107 +792,175 @@
|
||||
" \n",
|
||||
" print(\"Epoch:\", epoch, \"Training Loss: \", np.mean(train_loss), \"Valid Loss: \", np.mean(valid_loss))"
|
||||
],
|
||||
"execution_count": 0,
|
||||
"execution_count": 36,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 0%| | 0/9912422 [00:00<?, ?it/s]"
|
||||
],
|
||||
"name": "stderr"
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz\n"
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n",
|
||||
"torch.Size([256, 10])\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"9920512it [00:01, 9805960.78it/s] \n"
|
||||
],
|
||||
"name": "stderr"
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Extracting data/MNIST/raw/train-images-idx3-ubyte.gz\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 0%| | 0/28881 [00:00<?, ?it/s]"
|
||||
],
|
||||
"name": "stderr"
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"32768it [00:00, 130353.83it/s] \n",
|
||||
" 0%| | 0/1648877 [00:00<?, ?it/s]"
|
||||
],
|
||||
"name": "stderr"
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz\n",
|
||||
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1654784it [00:00, 2125030.13it/s] \n",
|
||||
"0it [00:00, ?it/s]"
|
||||
],
|
||||
"name": "stderr"
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz\n",
|
||||
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"8192it [00:00, 49184.55it/s] \n"
|
||||
],
|
||||
"name": "stderr"
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n",
|
||||
"Processing...\n",
|
||||
"Done!\n",
|
||||
"Epoch: 1 Training Loss: 1.4125140620038865 Valid Loss: 0.7013939695155367\n",
|
||||
"Epoch: 2 Training Loss: 0.5760371469436808 Valid Loss: 0.447847954770352\n",
|
||||
"Epoch: 3 Training Loss: 0.4329018018981244 Valid Loss: 0.37440624389242616\n",
|
||||
"Epoch: 4 Training Loss: 0.37780846084686037 Valid Loss: 0.3380796021603523\n",
|
||||
"Epoch: 5 Training Loss: 0.3474148580051483 Valid Loss: 0.3164374764929426\n",
|
||||
"Epoch: 6 Training Loss: 0.32569479799651085 Valid Loss: 0.2996636686172891\n",
|
||||
"Epoch: 7 Training Loss: 0.3101388563184028 Valid Loss: 0.2881964070999876\n",
|
||||
"Epoch: 8 Training Loss: 0.2971356455632981 Valid Loss: 0.2755763013946249\n",
|
||||
"Epoch: 9 Training Loss: 0.2857853964446707 Valid Loss: 0.2682326946486818\n",
|
||||
"Epoch: 10 Training Loss: 0.27580255975431583 Valid Loss: 0.2596069701174472\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
"output_type": "error",
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "ignored",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m<ipython-input-36-5fa9db7f8f35>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;31m## Training on 1 epoch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 72\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtrainloader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 73\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart_dim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 566\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrcvd_idx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreorder_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreorder_dict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrcvd_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 568\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_process_next_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 569\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 570\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatches_outstanding\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_process_next_batch\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 599\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_process_next_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 600\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrcvd_idx\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 601\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_put_indices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 602\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mExceptionWrapper\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 603\u001b[0m \u001b[0;31m# make multiline KeyError msg readable by working around\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_put_indices\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 589\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_put_indices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 590\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatches_outstanding\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 591\u001b[0;31m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_iter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 592\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 593\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/sampler.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__iter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 172\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 173\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/sampler.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__iter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 96\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandperm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 97\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__len__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -911,6 +979,150 @@
|
||||
"metadata": {
|
||||
"id": "uJUm3qjvBfFe",
|
||||
"colab_type": "code",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 170
|
||||
},
|
||||
"outputId": "e4fdfb90-566e-43a4-e4c5-424dce012a09"
|
||||
},
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import torch\n",
|
||||
"from torch import optim\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from torchvision import transforms\n",
|
||||
"from torchvision.datasets import MNIST\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"from torch.utils.data.sampler import SubsetRandomSampler\n",
|
||||
"from torch.backends import cudnn\n",
|
||||
"import numpy as np\n",
|
||||
"import multiprocessing\n",
|
||||
"\n",
|
||||
"cudnn.benchmark = True\n",
|
||||
"\n",
|
||||
"num_cores = multiprocessing.cpu_count()\n",
|
||||
"\n",
|
||||
"# transform the raw dataset into tensors and normalize them in a fixed range\n",
|
||||
"_tasks = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])\n",
|
||||
"\n",
|
||||
"## Load MNIST Dataset and apply transformations\n",
|
||||
"mnist = MNIST(\"data\", download=True, train=True, transform=_tasks)\n",
|
||||
"\n",
|
||||
"## create training and validation split \n",
|
||||
"split = int(0.8 * len(mnist))\n",
|
||||
"index_list = list(range(len(mnist)))\n",
|
||||
"train_idx, valid_idx = index_list[:split], index_list[split:]\n",
|
||||
"\n",
|
||||
"## create sampler objects using SubsetRandomSampler\n",
|
||||
"tr_sampler = SubsetRandomSampler(train_idx)\n",
|
||||
"val_sampled = SubsetRandomSampler(valid_idx)\n",
|
||||
"\n",
|
||||
"## create iterator objects for train and valid datasets\n",
|
||||
"trainloader = DataLoader(mnist, batch_size=256, sampler=tr_sampler, num_workers=num_cores)\n",
|
||||
"validloader = DataLoader(mnist, batch_size=256, sampler=val_sampler, num_workers=num_cores)\n",
|
||||
"\n",
|
||||
"## GPU\n",
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"\n",
|
||||
"## Build class of model\n",
|
||||
"class Model(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Model, self).__init__()\n",
|
||||
" \n",
|
||||
" ## define layers\n",
|
||||
" self.conv1 = nn.Conv2d(1, 16, 3, padding=1)\n",
|
||||
" self.conv2 = nn.Conv2d(16, 32, 3, padding=1)\n",
|
||||
" self.conv3 = nn.Conv2d(32, 64, 3, padding=1)\n",
|
||||
" self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n",
|
||||
" self.linear1 = nn.Linear(64*3*3, 512)\n",
|
||||
" self.linear2 = nn.Linear(512,10)\n",
|
||||
" \n",
|
||||
" return\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.pool(F.relu(self.conv1(x)))\n",
|
||||
" x = self.pool(F.relu(self.conv2(x)))\n",
|
||||
" x = self.pool(F.relu(self.conv3(x)))\n",
|
||||
" x = x.view(-1,64*3*3) #torch.flatten(x, start_dim=1) ## reshaping\n",
|
||||
" x = F.relu(self.linear1(x))\n",
|
||||
" x = self.linear2(x)\n",
|
||||
"\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"## create model\n",
|
||||
"model = Model()\n",
|
||||
"\n",
|
||||
"## in case of multi gpu\n",
|
||||
"if torch.cuda.device_count() > 1:\n",
|
||||
" print(\"Using\", torch.cuda.device_count(), \"GPUs\")\n",
|
||||
" model = nn.DataParallel(model, device_ids=[1]) # [0,1,2,3]\n",
|
||||
"\n",
|
||||
"## put model on gpu\n",
|
||||
"model.to(device)\n",
|
||||
"\n",
|
||||
"## loss fucntion\n",
|
||||
"loss_function = nn.CrossEntropyLoss()\n",
|
||||
"## optimizer\n",
|
||||
"optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-6, momentum=0.9, nesterov=True)\n",
|
||||
"## run for n epochs\n",
|
||||
"for epoch in range(1,11):\n",
|
||||
" train_loss , valid_loss = [], []\n",
|
||||
"\n",
|
||||
" ## train part\n",
|
||||
" model.train()\n",
|
||||
" for data, target in trainloader:\n",
|
||||
" ## gradients acumulate. need to clear them on each example\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" output = model(data.to(device))\n",
|
||||
" loss = loss_function(output.to(device), target.to(device))\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" train_loss.append(loss.item())\n",
|
||||
"\n",
|
||||
" ## evaluation part on validation\n",
|
||||
" model.eval() ##set model in evaluation mode\n",
|
||||
" for data, target in validloader:\n",
|
||||
" output = model(data.to(device))\n",
|
||||
" loss = loss_function(output.to(device), target.to(device))\n",
|
||||
" valid_loss.append(loss.item())\n",
|
||||
"\n",
|
||||
" print(\"Epoch:\", epoch, \"Training Loss: \", np.mean(train_loss), \"Valid Loss: \", np.mean(valid_loss))"
|
||||
],
|
||||
"execution_count": 53,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 1 Training Loss: 1.3867575062557738 Valid Loss: 0.20830699516103623\n",
|
||||
"Epoch: 2 Training Loss: 0.14508407129014425 Valid Loss: 0.10438944716402825\n",
|
||||
"Epoch: 3 Training Loss: 0.0902247162773571 Valid Loss: 0.07995569214541862\n",
|
||||
"Epoch: 4 Training Loss: 0.06947706440622185 Valid Loss: 0.08476778730115991\n",
|
||||
"Epoch: 5 Training Loss: 0.05636777369146968 Valid Loss: 0.06578631869497452\n",
|
||||
"Epoch: 6 Training Loss: 0.048184677641442485 Valid Loss: 0.05531598358078206\n",
|
||||
"Epoch: 7 Training Loss: 0.04294469940694089 Valid Loss: 0.05248951709809455\n",
|
||||
"Epoch: 8 Training Loss: 0.038696830844546254 Valid Loss: 0.048144756558727714\n",
|
||||
"Epoch: 9 Training Loss: 0.03434643990043154 Valid Loss: 0.04841103820883213\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "hJZJLvwgsUrq",
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"#### Evaluation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
"id": "eSRiJuEPsW0x",
|
||||
"colab_type": "code",
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
|
||||
Reference in New Issue
Block a user