fix(server): fix has_position_ids (#395)

Fix #389
This commit is contained in:
OlivierDehaene
2023-06-01 11:41:35 +02:00
committed by GitHub
parent db2ebe3947
commit d69a0633be
2 changed files with 7 additions and 5 deletions

View File

@@ -496,11 +496,6 @@ class CausalLM(Model):
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,

View File

@@ -1,3 +1,4 @@
import inspect
import torch
from abc import ABC, abstractmethod
@@ -29,6 +30,12 @@ class Model(ABC):
self.device = device
self.rank = rank
self.world_size = world_size
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
self.check_initialized()
@property