mirror of
				https://github.com/dragen1860/TensorFlow-2.x-Tutorials.git
				synced 2021-05-12 18:32:23 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			87 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			87 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import  tensorflow as tf
 | |
| 
 | |
| def scaled_dot_product_attention(q, k, v, mask):
 | |
|     matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)
 | |
| 
 | |
|     # scale matmul_qk
 | |
|     dk = tf.cast(tf.shape(k)[-1], tf.float32)
 | |
|     scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
 | |
| 
 | |
|     # add the mask to the scaled tensor.
 | |
|     if mask is not None:
 | |
|         scaled_attention_logits += (mask * -1e9)
 | |
| 
 | |
|         # softmax is normalized on the last axis (seq_len_k) so that the scores
 | |
|     # add up to 1.
 | |
|     attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)
 | |
| 
 | |
|     output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)
 | |
| 
 | |
|     return output, attention_weights
 | |
| 
 | |
| 
 | |
| # ## Multi-head attention
 | |
| 
 | |
| # In[ ]:
 | |
| 
 | |
| 
 | |
| class MultiHeadAttention(tf.keras.layers.Layer):
 | |
|     def __init__(self, d_model, num_heads):
 | |
|         super(MultiHeadAttention, self).__init__()
 | |
|         self.num_heads = num_heads
 | |
|         self.d_model = d_model
 | |
| 
 | |
|         assert d_model % self.num_heads == 0
 | |
| 
 | |
|         self.depth = d_model // self.num_heads
 | |
| 
 | |
|         self.wq = tf.keras.layers.Dense(d_model)
 | |
|         self.wk = tf.keras.layers.Dense(d_model)
 | |
|         self.wv = tf.keras.layers.Dense(d_model)
 | |
| 
 | |
|         self.dense = tf.keras.layers.Dense(d_model)
 | |
| 
 | |
|     def split_heads(self, x, batch_size):
 | |
|         """Split the last dimension into (num_heads, depth).
 | |
|         Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
 | |
|         """
 | |
|         x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
 | |
|         return tf.transpose(x, perm=[0, 2, 1, 3])
 | |
| 
 | |
|     def call(self, v, k, q, mask):
 | |
|         batch_size = tf.shape(q)[0]
 | |
| 
 | |
|         q = self.wq(q)  # (batch_size, seq_len, d_model)
 | |
|         k = self.wk(k)  # (batch_size, seq_len, d_model)
 | |
|         v = self.wv(v)  # (batch_size, seq_len, d_model)
 | |
| 
 | |
|         q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
 | |
|         k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
 | |
|         v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
 | |
| 
 | |
|         # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
 | |
|         # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
 | |
|         scaled_attention, attention_weights = scaled_dot_product_attention(
 | |
|             q, k, v, mask)
 | |
| 
 | |
|         scaled_attention = tf.transpose(scaled_attention,
 | |
|                                         perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
 | |
| 
 | |
|         concat_attention = tf.reshape(scaled_attention,
 | |
|                                       (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)
 | |
| 
 | |
|         output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
 | |
| 
 | |
|         return output, attention_weights
 | |
| 
 | |
| def main():
 | |
|     temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
 | |
|     y = tf.random.uniform((1, 60, 768))  # (batch_size, encoder_sequence, d_model)
 | |
|     q = tf.random.uniform((1, 60, 512))  # (batch_size, encoder_sequence, d_model)
 | |
|     out, attn = temp_mha(y, k=y, q=q, mask=None)
 | |
|     out.shape, attn.shape
 | |
| 
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     main() | 
