mirror of
https://github.com/mmaithani/data-science.git
synced 2022-04-24 02:56:41 +03:00
697 lines
210 KiB
Plaintext
697 lines
210 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"name": "PyTorch loss function.ipynb",
|
|
"provenance": [],
|
|
"authorship_tag": "ABX9TyOHWuVYQKJLQBi0Q39BkxNs",
|
|
"include_colab_link": true
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
}
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "view-in-github",
|
|
"colab_type": "text"
|
|
},
|
|
"source": [
|
|
"<a href=\"https://colab.research.google.com/github/mmaithani/data-science/blob/main/PyTorch_ALL_loss_function.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "E6FUMb7jjKID"
|
|
},
|
|
"source": [
|
|
"# PyTorch loss functions\r\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "IWYesbBj2P2_"
|
|
},
|
|
"source": [
|
|
""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "r3SUVFgDXUzn"
|
|
},
|
|
"source": [
|
|
"## Mean Absolute Error (L1 Loss)\r\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "HLdeIDMTaMja"
|
|
},
|
|
"source": [
|
|
"### Algorthmic way of find loss\r\n",
|
|
"### without pytorch module "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "H_IA8ZaKaRLk",
|
|
"outputId": "cfc6f35c-c6d7-4534-9874-68cea8a8a3d5"
|
|
},
|
|
"source": [
|
|
"import numpy as np\r\n",
|
|
"y_pred = np.array([0.000, 0.100, 0.200])\r\n",
|
|
"y_true = np.array([0.000, 0.200, 0.250])\r\n",
|
|
"# Defining Mean Absolute Error loss function\r\n",
|
|
"def mae(pred, true):\r\n",
|
|
" # Find absolute difference\r\n",
|
|
" differences = pred - true\r\n",
|
|
" absolute_differences = np.absolute(differences)\r\n",
|
|
" # find the absoute mean\r\n",
|
|
" mean_absolute_error = absolute_differences.mean()\r\n",
|
|
" return mean_absolute_error\r\n",
|
|
"mae_value = mae(y_pred, y_true)\r\n",
|
|
"print (\"MAE error is: \" + str(mae_value))"
|
|
],
|
|
"execution_count": 9,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"MAE error is: 0.049999999999999996\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "U8PmW9F8aTlx"
|
|
},
|
|
"source": [
|
|
"### with pytirch module"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "0l45I7YOaRtc"
|
|
},
|
|
"source": [
|
|
"mae_loss = nn.L1Loss()"
|
|
],
|
|
"execution_count": 10,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "8hFmTQetaqpx",
|
|
"outputId": "2d6f56a3-3ac5-4988-8e8d-d44d06ad5593"
|
|
},
|
|
"source": [
|
|
"import torch\r\n",
|
|
"mae_loss = torch.nn.L1Loss()\r\n",
|
|
"input = torch.tensor(y_pred)\r\n",
|
|
"target = torch.tensor(y_true)\r\n",
|
|
"output = mae_loss(input, target)\r\n",
|
|
"print(output)"
|
|
],
|
|
"execution_count": 11,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"tensor(0.0500, dtype=torch.float64)\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "hswYi06WXY_Y"
|
|
},
|
|
"source": [
|
|
"## Mean-Squared Error (L2 Loss)\r\n",
|
|
"*italicized text*"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "o8b2cA7uZ65T",
|
|
"outputId": "c32316ec-0a91-4150-b336-c3b0379451d4"
|
|
},
|
|
"source": [
|
|
"input = torch.randn(3, 4, requires_grad=True)\r\n",
|
|
"target = torch.randn(3, 4)\r\n",
|
|
"\r\n",
|
|
"mse_loss = nn.MSELoss()\r\n",
|
|
"\r\n",
|
|
"output = mse_loss(input, target)\r\n",
|
|
"output.backward()\r\n",
|
|
"\r\n",
|
|
"print('input -: ', input)\r\n",
|
|
"print('target -: ', target)\r\n",
|
|
"print('output -: ', output)"
|
|
],
|
|
"execution_count": 12,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"input -: tensor([[-0.8130, -0.1722, 2.1259, 0.9007],\n",
|
|
" [ 0.4301, 0.2543, -0.3947, -1.8088],\n",
|
|
" [ 0.0463, 1.6871, -0.4065, -0.5540]], requires_grad=True)\n",
|
|
"target -: tensor([[ 0.9704, -0.1731, 0.0868, 0.8792],\n",
|
|
" [-0.6950, -1.9831, -0.0518, -0.1137],\n",
|
|
" [ 0.5052, 0.6071, 1.5943, -0.8278]])\n",
|
|
"output -: tensor(1.8380, grad_fn=<MseLossBackward>)\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "GjwtuBsEXZTh"
|
|
},
|
|
"source": [
|
|
"## Binary Cross Entropy\r\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "yw35tlCZfy66"
|
|
},
|
|
"source": [
|
|
"### algorithm"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "fb6Ev6VMffV-",
|
|
"outputId": "cf989d97-7ce9-46d4-aa16-1b3efb1e0090"
|
|
},
|
|
"source": [
|
|
"import numpy as np\r\n",
|
|
"y_pred = np.array([0.1580, 0.4137, 0.2285])\r\n",
|
|
"y_true = np.array([0.0, 1.0, 0.0]) #2 labels: (0,1)\r\n",
|
|
"def BCE(y_pred, y_true):\r\n",
|
|
" total_bce_loss = np.sum(-y_true * np.log(y_pred) - (1 - y_true) * np.log(1 - y_pred))\r\n",
|
|
" # Getting the mean BCE loss\r\n",
|
|
" num_of_samples = y_pred.shape[0]\r\n",
|
|
" mean_bce_loss = total_bce_loss / num_of_samples\r\n",
|
|
" \r\n",
|
|
" return mean_bce_loss\r\n",
|
|
"\r\n",
|
|
"bce_value = BCE(y_pred, y_true)\r\n",
|
|
"print (\"BCE error is: \" + str(bce_value))"
|
|
],
|
|
"execution_count": 13,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"BCE error is: 0.43800269247783435\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "g8QbY_A0f1e2"
|
|
},
|
|
"source": [
|
|
"### pytorch implemenation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "HxylWO3Bf48h",
|
|
"outputId": "a59ec09f-813a-4ea7-cdcd-0246b12ffad8"
|
|
},
|
|
"source": [
|
|
"bce_loss = torch.nn.BCELoss()\r\n",
|
|
"sigmoid = torch.nn.Sigmoid() # Ensuring inputs are between 0 and 1\r\n",
|
|
"input = torch.tensor(y_pred)\r\n",
|
|
"target = torch.tensor(y_true)\r\n",
|
|
"output = bce_loss(input, target)\r\n",
|
|
"output"
|
|
],
|
|
"execution_count": 14,
|
|
"outputs": [
|
|
{
|
|
"output_type": "execute_result",
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor(0.4380, dtype=torch.float64)"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"execution_count": 14
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "kMqXnUB8tg_w"
|
|
},
|
|
"source": [
|
|
"## BCEWithLogitsLoss(nn.BCEWithLogitsLoss)\r\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "3zfpYHGjjHqP",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"outputId": "76415ba8-ec05-4ecf-a9be-04a82d1313c9"
|
|
},
|
|
"source": [
|
|
"import torch\r\n",
|
|
"target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10\r\n",
|
|
"output = torch.full([10, 64], 1.5) # A prediction (logit)\r\n",
|
|
"pos_weight = torch.ones([64]) # All weights are equal to 1\r\n",
|
|
"criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)\r\n",
|
|
"criterion(output, target) # -log(sigmoid(1.5))"
|
|
],
|
|
"execution_count": 15,
|
|
"outputs": [
|
|
{
|
|
"output_type": "execute_result",
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor(0.2014)"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"execution_count": 15
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "i1lc02NDYZOs"
|
|
},
|
|
"source": [
|
|
"## Negative Log-Likelihood Loss\r\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "HpPl1N89fI4Y",
|
|
"outputId": "766fde07-a229-45ce-8202-eba0721859e6"
|
|
},
|
|
"source": [
|
|
"input = torch.randn(3, 5, requires_grad=True)\r\n",
|
|
"# every element in target should have value(0 <= value < C)\r\n",
|
|
"target = torch.tensor([1, 0, 4])\r\n",
|
|
"\r\n",
|
|
"m = nn.LogSoftmax(dim=1)\r\n",
|
|
"nll_loss = nn.NLLLoss()\r\n",
|
|
"output = nll_loss(m(input), target)\r\n",
|
|
"output.backward()\r\n",
|
|
"\r\n",
|
|
"print('input -: ', input)\r\n",
|
|
"print('target -: ', target)\r\n",
|
|
"print('output -: ', output)"
|
|
],
|
|
"execution_count": 16,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"input -: tensor([[ 0.2100, -0.2934, -0.2368, -0.9120, -0.8677],\n",
|
|
" [ 0.1300, -0.4204, 0.5999, 0.2263, -0.0318],\n",
|
|
" [ 1.0562, -0.6507, -2.2783, 0.2079, 0.2805]], requires_grad=True)\n",
|
|
"target -: tensor([1, 0, 4])\n",
|
|
"output -: tensor(1.5756, grad_fn=<NllLossBackward>)\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "OLIpmLghtjwe"
|
|
},
|
|
"source": [
|
|
"## PoissonNLLLoss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "t7MB8eCbtVC-",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"outputId": "733e1ec0-e091-4a05-8f7a-f61016e505ed"
|
|
},
|
|
"source": [
|
|
"import torch.nn as nn\r\n",
|
|
"loss = nn.PoissonNLLLoss()\r\n",
|
|
"log_input = torch.randn(5, 2, requires_grad=True)\r\n",
|
|
"target = torch.randn(5, 2)\r\n",
|
|
"output = loss(log_input, target)\r\n",
|
|
"output.backward()\r\n",
|
|
"output"
|
|
],
|
|
"execution_count": 17,
|
|
"outputs": [
|
|
{
|
|
"output_type": "execute_result",
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor(1.0439, grad_fn=<MeanBackward0>)"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"execution_count": 17
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "gXZXFzi5YgJM"
|
|
},
|
|
"source": [
|
|
"##Cross-Entropy Loss\r\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "2mYMGOXthk8p",
|
|
"outputId": "00dc8cf0-d433-4d2a-f734-0fd6f0076028"
|
|
},
|
|
"source": [
|
|
"input = torch.randn(3, 5, requires_grad=True)\r\n",
|
|
"target = torch.empty(3, dtype=torch.long).random_(5)\r\n",
|
|
"\r\n",
|
|
"cross_entropy_loss = nn.CrossEntropyLoss()\r\n",
|
|
"output = cross_entropy_loss(input, target)\r\n",
|
|
"output.backward()\r\n",
|
|
"\r\n",
|
|
"print('input: ', input)\r\n",
|
|
"print('target: ', target)\r\n",
|
|
"print('output: ', output)"
|
|
],
|
|
"execution_count": 18,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"input: tensor([[-0.5641, 2.0046, 0.4709, -1.3824, 0.3271],\n",
|
|
" [ 0.2807, -0.8588, -0.6625, 1.1710, -1.1822],\n",
|
|
" [-0.3820, 0.2075, 0.6264, -0.5623, -0.6328]], requires_grad=True)\n",
|
|
"target: tensor([3, 2, 3])\n",
|
|
"output: tensor(2.7897, grad_fn=<NllLossBackward>)\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "O9dASCuufNvP"
|
|
},
|
|
"source": [
|
|
"## Hinge Embedding Loss\r\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "xfStGspphvwy",
|
|
"outputId": "6dc99ac5-10f6-4e74-a3ac-55d624547c89"
|
|
},
|
|
"source": [
|
|
"input = torch.randn(3, 5, requires_grad=True)\r\n",
|
|
"target = torch.randn(3, 5)\r\n",
|
|
"\r\n",
|
|
"hinge_loss = nn.HingeEmbeddingLoss()\r\n",
|
|
"output = hinge_loss(input, target)\r\n",
|
|
"output.backward()\r\n",
|
|
"\r\n",
|
|
"print('input -: ', input)\r\n",
|
|
"print('target -: ', target)\r\n",
|
|
"print('output -: ', output)"
|
|
],
|
|
"execution_count": 19,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"input -: tensor([[-1.3021, 0.1646, -0.6010, 1.1430, 0.0962],\n",
|
|
" [ 0.2079, 1.8048, -0.9333, 1.1201, -1.1432],\n",
|
|
" [ 0.1606, -0.0297, 0.6047, 0.1355, 0.5362]], requires_grad=True)\n",
|
|
"target -: tensor([[-1.5627, -1.5915, 0.5986, 0.4758, -1.1109],\n",
|
|
" [-1.9742, 1.1048, -0.5299, -0.0454, 0.2371],\n",
|
|
" [-0.0415, -0.3526, 0.9375, 0.6387, 0.6531]])\n",
|
|
"output -: tensor(1.0712, grad_fn=<MeanBackward0>)\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "tKJ72jDeYonF"
|
|
},
|
|
"source": [
|
|
"##Margin Ranking Loss\r\n",
|
|
"\r\n",
|
|
"\r\n",
|
|
"```\r\n",
|
|
"torch.nn.MarginRankingLoss\r\n",
|
|
"```\r\n",
|
|
"\r\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "JyN0cc1GiEZE",
|
|
"outputId": "7fdba3d7-00ac-49af-dd0f-c36f1f829cf1"
|
|
},
|
|
"source": [
|
|
"first_input = torch.randn(3, requires_grad=True)\r\n",
|
|
"Second_input = torch.randn(3, requires_grad=True)\r\n",
|
|
"target = torch.randn(3).sign()\r\n",
|
|
"\r\n",
|
|
"ranking_loss = nn.MarginRankingLoss()\r\n",
|
|
"output = ranking_loss(first_input, Second_input, target)\r\n",
|
|
"output.backward()\r\n",
|
|
"\r\n",
|
|
"print('input one: ', first_input)\r\n",
|
|
"print('input two: ', Second_input)\r\n",
|
|
"print('target: ', target)\r\n",
|
|
"print('output: ', output)"
|
|
],
|
|
"execution_count": 20,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"input one: tensor([ 0.0219, -0.7005, -1.2160], requires_grad=True)\n",
|
|
"input two: tensor([-0.4255, 0.3859, -0.7394], requires_grad=True)\n",
|
|
"target: tensor([-1., -1., 1.])\n",
|
|
"output: tensor(0.3080, grad_fn=<MeanBackward0>)\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "c0--iv1UidaY"
|
|
},
|
|
"source": [
|
|
"## Triplet Margin Loss Function"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "H-VV1XSNihPT",
|
|
"outputId": "d5add53c-177b-4e43-d88f-087a938dd4f6"
|
|
},
|
|
"source": [
|
|
"anchor = torch.randn(100, 128, requires_grad=True)\r\n",
|
|
"positive = torch.randn(100, 128, requires_grad=True)\r\n",
|
|
"negative = torch.randn(100, 128, requires_grad=True)\r\n",
|
|
"\r\n",
|
|
"triplet_margin_loss = nn.TripletMarginLoss(margin=1.0, p=2)\r\n",
|
|
"output = triplet_margin_loss(anchor, positive, negative)\r\n",
|
|
"output.backward()\r\n",
|
|
"\r\n",
|
|
"print('anchors -: ', anchor)\r\n",
|
|
"print('positive -: ', positive)\r\n",
|
|
"print('negative -: ', negative)\r\n",
|
|
"print('output -: ', output)"
|
|
],
|
|
"execution_count": 21,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"anchors -: tensor([[-0.0458, -0.2617, 0.8472, ..., 1.7588, -1.3604, 0.7182],\n",
|
|
" [ 1.2925, 1.7845, -1.0670, ..., -0.4664, -0.6562, -0.5562],\n",
|
|
" [-0.4970, -0.5046, -1.8275, ..., -1.7325, -0.0448, 0.2471],\n",
|
|
" ...,\n",
|
|
" [ 0.5414, 0.1786, 0.8064, ..., -0.6623, -0.2556, -0.2867],\n",
|
|
" [-1.1563, -1.7103, 2.2845, ..., 1.0123, -0.3839, -1.3699],\n",
|
|
" [ 0.7297, -0.1043, -0.1768, ..., 0.3457, -0.8843, -0.0626]],\n",
|
|
" requires_grad=True)\n",
|
|
"positive -: tensor([[ 7.4869e-01, 1.3499e+00, -1.4480e+00, ..., -7.6688e-01,\n",
|
|
" -1.7461e-03, -7.0950e-01],\n",
|
|
" [-1.0364e+00, 1.0784e+00, 1.4848e+00, ..., -6.4932e-01,\n",
|
|
" -2.4223e-01, 4.4354e-01],\n",
|
|
" [ 1.9670e-01, -8.3027e-01, 2.5105e-01, ..., -9.0814e-01,\n",
|
|
" -2.3587e-01, 1.3626e+00],\n",
|
|
" ...,\n",
|
|
" [-2.3753e+00, -1.0636e+00, 3.1268e+00, ..., 1.5887e-01,\n",
|
|
" 6.0285e-02, 3.2817e-01],\n",
|
|
" [ 1.0319e+00, -9.9035e-01, -9.8707e-01, ..., 1.2975e+00,\n",
|
|
" 6.1644e-01, 1.2362e+00],\n",
|
|
" [ 4.4754e-01, 1.7472e+00, -1.0116e+00, ..., 5.9146e-01,\n",
|
|
" -3.1294e-01, -1.2864e-01]], requires_grad=True)\n",
|
|
"negative -: tensor([[-1.5941, 0.7201, -0.8380, ..., 1.4464, 1.9402, 1.0685],\n",
|
|
" [-1.3552, 0.9982, -0.2235, ..., -0.6102, 0.4565, -0.7907],\n",
|
|
" [ 1.1297, -0.0303, 1.2934, ..., 1.0800, 1.0632, 2.3885],\n",
|
|
" ...,\n",
|
|
" [-1.2044, -0.2218, -1.7082, ..., 0.7270, -1.3822, 0.9942],\n",
|
|
" [-0.4531, -1.3416, -0.9141, ..., 1.0345, -0.5356, -0.9907],\n",
|
|
" [-0.1342, -1.6461, 1.2896, ..., 0.1225, 1.3387, -0.8353]],\n",
|
|
" requires_grad=True)\n",
|
|
"output -: tensor(1.0123, grad_fn=<MeanBackward0>)\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "6KX_8Ym8YjJl"
|
|
},
|
|
"source": [
|
|
"##Kullback-Leibler divergence\r\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "RpnNQ0TpXHLz",
|
|
"outputId": "2d7b6d1a-1e3f-4e7a-b277-deade7439f72"
|
|
},
|
|
"source": [
|
|
"input = torch.randn(2, 3, requires_grad=True)\r\n",
|
|
"target = torch.randn(2, 3)\r\n",
|
|
"\r\n",
|
|
"kld_loss = nn.KLDivLoss(reduction = 'batchmean')\r\n",
|
|
"output = kld_loss(input, target)\r\n",
|
|
"output.backward()\r\n",
|
|
"\r\n",
|
|
"print('input tensor: ', input)\r\n",
|
|
"print('target tensor: ', target)\r\n",
|
|
"print('Loss: ', output)"
|
|
],
|
|
"execution_count": 22,
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"text": [
|
|
"input tensor: tensor([[-0.8168, 1.6042, -0.7599],\n",
|
|
" [-0.6457, 0.3002, -0.7881]], requires_grad=True)\n",
|
|
"target tensor: tensor([[ 1.1857, 0.6820, -0.5791],\n",
|
|
" [-0.3623, -0.7202, -0.0946]])\n",
|
|
"Loss: tensor(-0.0923, grad_fn=<DivBackward0>)\n"
|
|
],
|
|
"name": "stdout"
|
|
}
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "swOodmepx_Su"
|
|
},
|
|
"source": [
|
|
""
|
|
],
|
|
"execution_count": null,
|
|
"outputs": []
|
|
}
|
|
]
|
|
} |