1
0
mirror of https://github.com/inzva/inzpeech.git synced 2021-06-01 09:25:07 +03:00
Files
inzpeech/ResNet/model.py
2020-10-08 23:32:08 +03:00

220 lines
7.8 KiB
Python

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query, mask=None):
# Get number of training examples
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# Split the embedding into self.heads different pieces
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
query = query.reshape(N, query_len, self.heads, self.head_dim)
values = self.values(values) # (N, value_len, heads, head_dim)
keys = self.keys(keys) # (N, key_len, heads, head_dim)
queries = self.queries(query) # (N, query_len, heads, heads_dim)
# Einsum does matrix mult. for query*keys for each training example
# with every other training example, don't be confused by einsum
# it's just how I like doing matrix multiplication & bmm
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# queries shape: (N, query_len, heads, heads_dim),
# keys shape: (N, key_len, heads, heads_dim)
# energy: (N, heads, query_len, key_len)
# Mask padded indices so their weights become 0
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
# Normalize energy values similarly to seq2seq + attention
# so that they sum to 1. Also divide by scaling factor for
# better stability
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
# attention shape: (N, heads, query_len, key_len)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.heads * self.head_dim
)
# attention shape: (N, heads, query_len, key_len)
# values shape: (N, value_len, heads, heads_dim)
# out after matrix multiply: (N, query_len, heads, head_dim), then
# we reshape and flatten the last two dimensions.
out = self.fc_out(out)
# Linear layer doesn't modify the shape, final shape will be
# (N, query_len, embed_size)
return out
class block(nn.Module):
def __init__(
self, in_channels, intermediate_channels, out_channels, identity_downsample=None, stride=1
):
super(block, self).__init__()
self.conv1 = nn.Conv2d(
in_channels, intermediate_channels, kernel_size=1, stride=1, padding=0
)
self.bn1 = nn.BatchNorm2d(intermediate_channels)
self.conv2 = nn.Conv2d(
intermediate_channels,
intermediate_channels,
kernel_size=3,
stride=stride,
padding=1,
)
self.bn2 = nn.BatchNorm2d(intermediate_channels)
self.conv3 = nn.Conv2d(
intermediate_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
)
self.bn3 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.identity_downsample = identity_downsample
self.stride = stride
def forward(self, x):
identity = x.clone()
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
if self.identity_downsample is not None:
identity = self.identity_downsample(identity)
x += identity
x = self.relu(x)
return x
class Net(nn.Module):
def __init__(self, block, layers, image_channels, num_classes, expansion):
super(Net, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Essentially the entire ResNet architecture are in these 4 lines below
self.layer1 = self._make_layer(
block, layers[0], intermediate_channels=64, out_channels=64*expansion, stride=1
)
self.layer2 = self._make_layer(
block, layers[1], intermediate_channels=128, out_channels=128*expansion, stride=2
)
self.layer3 = self._make_layer(
block, layers[2], intermediate_channels=256, out_channels=256*expansion, stride=2
)
self.layer4 = self._make_layer(
block, layers[3], intermediate_channels=512, out_channels=512*expansion, stride=2
)
self.attention = SelfAttention(heads=4, embed_size=512*expansion)
self.avgpool = nn.AvgPool2d((20, 1))
self.fc1 = nn.Linear(512*expansion, 512*expansion//2)
self.fc2 = nn.Linear(512*expansion//2, 512*expansion//4)
self.fc3 = nn.Linear(512*expansion//4, num_classes)
def forward(self, x):
# ResNet layer
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.reshape(x.shape[0], x.shape[2] * x.shape[3], x.shape[1])
# Attenntion Layer
x = self.attention(x, x, x)
x = self.avgpool(x)
# FC Layer
x = x.reshape(x.shape[0], -1)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.relu(self.fc3(x))
return x
def _make_layer(self, block, num_residual_blocks, intermediate_channels, out_channels, stride):
identity_downsample = None
layers = []
# Either if we half the input space for ex, 56x56 -> 28x28 (stride=2), or channels changes
# we need to adapt the Identity (skip connection) so it will be able to be added
# to the layer that's ahead
if stride != 1 or self.in_channels != out_channels:
identity_downsample = nn.Sequential(
nn.Conv2d(
self.in_channels,
out_channels,
kernel_size=1,
stride=stride,
),
nn.BatchNorm2d(out_channels),
)
layers.append(
block(self.in_channels, intermediate_channels, out_channels, identity_downsample, stride)
)
self.in_channels = out_channels
# For example for first resnet layer: 256 will be mapped to 64 as intermediate layer,
# then finally back to 256. Hence no identity downsample is needed, since stride = 1,
# and also same amount of channels.
for i in range(num_residual_blocks - 1):
layers.append(block(self.in_channels, intermediate_channels, out_channels))
return nn.Sequential(*layers)
def Net_ResNet50(img_channel=3, num_classes=1000):
return Net(block, [3, 4, 6, 3], img_channel, num_classes, expansion=4)
def Net_ResNet101(img_channel=3, num_classes=1000):
return Net(block, [3, 4, 23, 3], img_channel, num_classes, expansion=4)
def Net_ResNet152(img_channel=3, num_classes=1000):
return Net(block, [3, 8, 36, 3], img_channel, num_classes, expansion=4)