fix embed_tokens for last layer in qwen models

This commit is contained in:
Alex Cheema
2025-01-28 23:09:45 +00:00
parent af171f06fa
commit 9c1bea97e8

View File

@@ -31,7 +31,7 @@ class Qwen2Model(nn.Module):
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
if self.args.shard.is_first_layer():
if self.args.shard.is_first_layer() or (self.args.shard.is_last_layer() and args.tie_word_embeddings):
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = []