mirror of
https://github.com/openai/gpt-oss.git
synced 2025-08-06 00:55:46 +03:00
Transformers responses API (#1)
This commit is contained in:
@@ -273,6 +273,7 @@ You can start this server with the following inference backends:
|
||||
- `metal` — uses the metal implementation on Apple Silicon only
|
||||
- `ollama` — uses the Ollama /api/generate API as a inference solution
|
||||
- `vllm` — uses your installed vllm version to perform inference
|
||||
- `transformers` — uses your installed transformers version to perform local inference
|
||||
|
||||
```bash
|
||||
usage: python -m gpt_oss.responses_api.serve [-h] [--checkpoint FILE] [--port PORT] [--inference-backend BACKEND]
|
||||
|
||||
56
gpt_oss/responses_api/inference/transformers.py
Normal file
56
gpt_oss/responses_api/inference/transformers.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
NOTE: this is not the most efficient way to use transformers. It's a simple implementation that infers
|
||||
one token at a time to mimic the behavior of the Triton implementation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Callable, List
|
||||
|
||||
# Transformers imports
|
||||
from transformers import AutoModelForCausalLM, PreTrainedModel
|
||||
import torch
|
||||
|
||||
|
||||
DEFAULT_TEMPERATURE = 0.0
|
||||
TP = os.environ.get("TP", 2)
|
||||
|
||||
def load_model(checkpoint: str):
|
||||
"""
|
||||
Serve the model directly with the Auto API.
|
||||
"""
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
checkpoint,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_infer_next_token(model: PreTrainedModel):
|
||||
"""
|
||||
Return a callable with the same shape as the original triton implementation:
|
||||
infer_next_token(tokens: List[int], temperature: float, new_request: bool) -> int
|
||||
|
||||
Implementation detail:
|
||||
- We issue a single-token generation with using model.generate
|
||||
- generate handles sampling (temperature=0 => greedy, otherwise, sampling).
|
||||
"""
|
||||
|
||||
def infer_next_token(
|
||||
tokens: List[int],
|
||||
temperature: float = DEFAULT_TEMPERATURE,
|
||||
new_request: bool = False, # kept for interface compatibility; unused here
|
||||
) -> int:
|
||||
tokens = torch.tensor([tokens], dtype=torch.int64, device=model.device)
|
||||
output = model.generate(tokens, max_new_tokens=1, do_sample=temperature != 0, temperature=temperature)
|
||||
return output[0, -1].tolist()
|
||||
|
||||
return infer_next_token
|
||||
|
||||
|
||||
def setup_model(checkpoint: str) -> Callable[[List[int], float, bool], int]:
|
||||
model = load_model(checkpoint)
|
||||
infer_next_token = get_infer_next_token(model)
|
||||
return infer_next_token
|
||||
@@ -47,6 +47,8 @@ if __name__ == "__main__":
|
||||
from .inference.ollama import setup_model
|
||||
elif args.inference_backend == "vllm":
|
||||
from .inference.vllm import setup_model
|
||||
elif args.inference_backend == "transformers":
|
||||
from .inference.transformers import setup_model
|
||||
else:
|
||||
raise ValueError(f"Invalid inference backend: {args.inference_backend}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user