mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
fix embed_tokens for last layer in qwen models
This commit is contained in:
@@ -31,7 +31,7 @@ class Qwen2Model(nn.Module):
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
|
||||
if self.args.shard.is_first_layer():
|
||||
if self.args.shard.is_first_layer() or (self.args.shard.is_last_layer() and args.tie_word_embeddings):
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
|
||||
self.layers = []
|
||||
|
||||
Reference in New Issue
Block a user