mirror of
https://github.com/dragen1860/TensorFlow-2.x-Tutorials.git
synced 2021-05-12 18:32:23 +03:00
GPT fixes
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import math
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from tensorflow import keras
|
||||
@@ -8,13 +9,19 @@ np.random.seed(22)
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
assert tf.__version__.startswith('2.')
|
||||
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1 + tf.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3))))
|
||||
|
||||
def swish(x):
|
||||
return x * tf.sigmoid(x)
|
||||
|
||||
class namespace():
|
||||
pass
|
||||
args = namespace()
|
||||
args.n_ctx = 512
|
||||
args.n_embd = 768
|
||||
args.n_head = 12
|
||||
args.n_lar = 12
|
||||
args.n_layer = 12
|
||||
args.embd_pdrop = 0.1
|
||||
args.attn_pdrop = 0.1
|
||||
args.resid_pdrop = 0.1
|
||||
@@ -26,12 +33,7 @@ args.b1 = 0.9
|
||||
args.b2 = 0.999
|
||||
args.e = 1e-8
|
||||
args.n_valid = 374
|
||||
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1 + tf.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3))))
|
||||
|
||||
def swish(x):
|
||||
return x * tf.sigmoid(x)
|
||||
args.afn = gelu
|
||||
|
||||
zeros_init = keras.initializers.Zeros()
|
||||
ones_init = keras.initializers.Ones()
|
||||
@@ -94,7 +96,7 @@ class Attention(keras.Model):
|
||||
def _attn(self, q, k, v):
|
||||
w = tf.matmul(q, k)
|
||||
if self.scale:
|
||||
w = w / tf.sqrt(v.shape[-1])
|
||||
w = w / tf.sqrt(tf.cast(v.shape[-1], tf.float32))
|
||||
# self.b may be larger than w, so we need to crop it
|
||||
b = self.b[:, :, :w.shape[-2], :w.shape[-1]]
|
||||
w = w * b + 1e-9 * (1 - b)
|
||||
@@ -103,7 +105,7 @@ class Attention(keras.Model):
|
||||
|
||||
def merge_heads(self, x):
|
||||
x = tf.transpose(x, [0,2,1,3])
|
||||
new_x_shape = list(x.shape[:-2]) + [x.shape[-2]*x.shape[1]]
|
||||
new_x_shape = list(x.shape[:-2]) + [x.shape[-2]*x.shape[-1]]
|
||||
return tf.reshape(x, new_x_shape) # in openai implem: fct merge_states
|
||||
|
||||
def split_heads(self, x, k=False):
|
||||
@@ -129,12 +131,12 @@ class Attention(keras.Model):
|
||||
|
||||
class MLP(keras.Model):
|
||||
|
||||
def __init__(self, n_state=3072, cfg): # n_state=3072 (4*n_embd)
|
||||
def __init__(self, n_state=3072, cfg=args): # n_state=3072 (4*n_embd)
|
||||
super(MLP, self).__init__()
|
||||
nx = cfg.n_embd
|
||||
self.c_fc = Conv1D(n_state, 1, nx)
|
||||
self.c_proj = Conv1D(nx, 1, n_state)
|
||||
self.act = ACT_FNS[cfg.afn]
|
||||
self.act = cfg.afn
|
||||
self.dropout = keras.layers.Dropout(cfg.resid_pdrop)
|
||||
|
||||
def call(self, x):
|
||||
@@ -145,8 +147,7 @@ class MLP(keras.Model):
|
||||
|
||||
class Block(keras.Model):
|
||||
|
||||
def __init__(self, n_ctx, cfg, scale=False):
|
||||
def __init__(self, n_ctx, cfg, scale=False):
|
||||
def __init__(self, n_ctx=512, cfg=args, scale=False):
|
||||
super(Block, self).__init__()
|
||||
nx = cfg.n_embd
|
||||
self.attn = Attention(nx, n_ctx, cfg, scale)
|
||||
@@ -159,12 +160,12 @@ class Block(keras.Model):
|
||||
n = self.ln_1(x + a)
|
||||
m = self.mlp(n)
|
||||
h = self.ln_2(n + m)
|
||||
return
|
||||
return h
|
||||
|
||||
|
||||
class TransformerModel(keras.Model):
|
||||
|
||||
def __init__(self, cfg, vocab=40990, n_ctx=512):
|
||||
def __init__(self, cfg=args, vocab=40990, n_ctx=512):
|
||||
super(TransformerModel, self).__init__()
|
||||
self.vocab = vocab
|
||||
self.embed = keras.layers.Embedding(vocab, cfg.n_embd)
|
||||
@@ -174,7 +175,7 @@ class TransformerModel(keras.Model):
|
||||
def call(self, x):
|
||||
x = tf.reshape(x, [-1,x.shape[-2],x.shape[-1]])
|
||||
e = self.drop(self.embed(x))
|
||||
# Add the position information to input embeddings
|
||||
# add the position information to input embeddings
|
||||
h = tf.reduce_sum(e, 2)
|
||||
for block in self.h:
|
||||
h = block(h)
|
||||
|
||||
Reference in New Issue
Block a user