mirror of
https://github.com/inzva/inzpeech.git
synced 2021-06-01 09:25:07 +03:00
220 lines
7.8 KiB
Python
220 lines
7.8 KiB
Python
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class SelfAttention(nn.Module):
|
|
def __init__(self, embed_size, heads):
|
|
super(SelfAttention, self).__init__()
|
|
self.embed_size = embed_size
|
|
self.heads = heads
|
|
self.head_dim = embed_size // heads
|
|
|
|
assert (
|
|
self.head_dim * heads == embed_size
|
|
), "Embedding size needs to be divisible by heads"
|
|
|
|
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
|
|
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
|
|
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
|
|
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
|
|
|
|
def forward(self, values, keys, query, mask=None):
|
|
# Get number of training examples
|
|
N = query.shape[0]
|
|
|
|
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
|
|
|
|
# Split the embedding into self.heads different pieces
|
|
values = values.reshape(N, value_len, self.heads, self.head_dim)
|
|
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
|
|
query = query.reshape(N, query_len, self.heads, self.head_dim)
|
|
|
|
values = self.values(values) # (N, value_len, heads, head_dim)
|
|
keys = self.keys(keys) # (N, key_len, heads, head_dim)
|
|
queries = self.queries(query) # (N, query_len, heads, heads_dim)
|
|
|
|
# Einsum does matrix mult. for query*keys for each training example
|
|
# with every other training example, don't be confused by einsum
|
|
# it's just how I like doing matrix multiplication & bmm
|
|
|
|
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
|
|
# queries shape: (N, query_len, heads, heads_dim),
|
|
# keys shape: (N, key_len, heads, heads_dim)
|
|
# energy: (N, heads, query_len, key_len)
|
|
|
|
# Mask padded indices so their weights become 0
|
|
if mask is not None:
|
|
energy = energy.masked_fill(mask == 0, float("-1e20"))
|
|
|
|
# Normalize energy values similarly to seq2seq + attention
|
|
# so that they sum to 1. Also divide by scaling factor for
|
|
# better stability
|
|
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
|
|
# attention shape: (N, heads, query_len, key_len)
|
|
|
|
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
|
|
N, query_len, self.heads * self.head_dim
|
|
)
|
|
# attention shape: (N, heads, query_len, key_len)
|
|
# values shape: (N, value_len, heads, heads_dim)
|
|
# out after matrix multiply: (N, query_len, heads, head_dim), then
|
|
# we reshape and flatten the last two dimensions.
|
|
|
|
out = self.fc_out(out)
|
|
# Linear layer doesn't modify the shape, final shape will be
|
|
# (N, query_len, embed_size)
|
|
|
|
return out
|
|
|
|
class block(nn.Module):
|
|
def __init__(
|
|
self, in_channels, intermediate_channels, out_channels, identity_downsample=None, stride=1
|
|
):
|
|
super(block, self).__init__()
|
|
self.conv1 = nn.Conv2d(
|
|
in_channels, intermediate_channels, kernel_size=1, stride=1, padding=0
|
|
)
|
|
self.bn1 = nn.BatchNorm2d(intermediate_channels)
|
|
self.conv2 = nn.Conv2d(
|
|
intermediate_channels,
|
|
intermediate_channels,
|
|
kernel_size=3,
|
|
stride=stride,
|
|
padding=1,
|
|
)
|
|
self.bn2 = nn.BatchNorm2d(intermediate_channels)
|
|
self.conv3 = nn.Conv2d(
|
|
intermediate_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
)
|
|
self.bn3 = nn.BatchNorm2d(out_channels)
|
|
self.relu = nn.ReLU()
|
|
self.identity_downsample = identity_downsample
|
|
self.stride = stride
|
|
|
|
def forward(self, x):
|
|
identity = x.clone()
|
|
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
x = self.conv2(x)
|
|
x = self.bn2(x)
|
|
x = self.relu(x)
|
|
x = self.conv3(x)
|
|
x = self.bn3(x)
|
|
|
|
if self.identity_downsample is not None:
|
|
identity = self.identity_downsample(identity)
|
|
|
|
x += identity
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
|
|
class Net(nn.Module):
|
|
def __init__(self, block, layers, image_channels, num_classes, expansion):
|
|
super(Net, self).__init__()
|
|
self.in_channels = 64
|
|
self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)
|
|
self.bn1 = nn.BatchNorm2d(64)
|
|
self.relu = nn.ReLU()
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
|
|
# Essentially the entire ResNet architecture are in these 4 lines below
|
|
self.layer1 = self._make_layer(
|
|
block, layers[0], intermediate_channels=64, out_channels=64*expansion, stride=1
|
|
)
|
|
self.layer2 = self._make_layer(
|
|
block, layers[1], intermediate_channels=128, out_channels=128*expansion, stride=2
|
|
)
|
|
self.layer3 = self._make_layer(
|
|
block, layers[2], intermediate_channels=256, out_channels=256*expansion, stride=2
|
|
)
|
|
self.layer4 = self._make_layer(
|
|
block, layers[3], intermediate_channels=512, out_channels=512*expansion, stride=2
|
|
)
|
|
|
|
self.attention = SelfAttention(heads=4, embed_size=512*expansion)
|
|
|
|
self.avgpool = nn.AvgPool2d((20, 1))
|
|
|
|
self.fc1 = nn.Linear(512*expansion, 512*expansion//2)
|
|
self.fc2 = nn.Linear(512*expansion//2, 512*expansion//4)
|
|
self.fc3 = nn.Linear(512*expansion//4, num_classes)
|
|
|
|
def forward(self, x):
|
|
# ResNet layer
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
x = self.maxpool(x)
|
|
x = self.layer1(x)
|
|
x = self.layer2(x)
|
|
x = self.layer3(x)
|
|
x = self.layer4(x)
|
|
|
|
x = x.reshape(x.shape[0], x.shape[2] * x.shape[3], x.shape[1])
|
|
# Attenntion Layer
|
|
x = self.attention(x, x, x)
|
|
x = self.avgpool(x)
|
|
|
|
# FC Layer
|
|
x = x.reshape(x.shape[0], -1)
|
|
x = self.relu(self.fc1(x))
|
|
x = self.relu(self.fc2(x))
|
|
x = self.relu(self.fc3(x))
|
|
|
|
return x
|
|
|
|
def _make_layer(self, block, num_residual_blocks, intermediate_channels, out_channels, stride):
|
|
identity_downsample = None
|
|
layers = []
|
|
|
|
# Either if we half the input space for ex, 56x56 -> 28x28 (stride=2), or channels changes
|
|
# we need to adapt the Identity (skip connection) so it will be able to be added
|
|
# to the layer that's ahead
|
|
if stride != 1 or self.in_channels != out_channels:
|
|
identity_downsample = nn.Sequential(
|
|
nn.Conv2d(
|
|
self.in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=stride,
|
|
),
|
|
nn.BatchNorm2d(out_channels),
|
|
)
|
|
|
|
layers.append(
|
|
block(self.in_channels, intermediate_channels, out_channels, identity_downsample, stride)
|
|
)
|
|
|
|
self.in_channels = out_channels
|
|
|
|
# For example for first resnet layer: 256 will be mapped to 64 as intermediate layer,
|
|
# then finally back to 256. Hence no identity downsample is needed, since stride = 1,
|
|
# and also same amount of channels.
|
|
for i in range(num_residual_blocks - 1):
|
|
layers.append(block(self.in_channels, intermediate_channels, out_channels))
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
def Net_ResNet50(img_channel=3, num_classes=1000):
|
|
return Net(block, [3, 4, 6, 3], img_channel, num_classes, expansion=4)
|
|
|
|
|
|
def Net_ResNet101(img_channel=3, num_classes=1000):
|
|
return Net(block, [3, 4, 23, 3], img_channel, num_classes, expansion=4)
|
|
|
|
|
|
def Net_ResNet152(img_channel=3, num_classes=1000):
|
|
return Net(block, [3, 8, 36, 3], img_channel, num_classes, expansion=4)
|
|
|
|
|