mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2023-08-15 01:09:35 +03:00
@@ -496,11 +496,6 @@ class CausalLM(Model):
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
self.has_position_ids = (
|
||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||
is not None
|
||||
)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import torch
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -29,6 +30,12 @@ class Model(ABC):
|
||||
self.device = device
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
|
||||
self.has_position_ids = (
|
||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||
is not None
|
||||
)
|
||||
|
||||
self.check_initialized()
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user