Transformers responses API (#1)

This commit is contained in:
Lysandre Debut
2025-08-05 19:02:16 +02:00
committed by GitHub
parent 0106ce5ba3
commit a601a63cdc
3 changed files with 59 additions and 0 deletions

View File

@@ -273,6 +273,7 @@ You can start this server with the following inference backends:
- `metal` — uses the metal implementation on Apple Silicon only
- `ollama` — uses the Ollama /api/generate API as a inference solution
- `vllm` — uses your installed vllm version to perform inference
- `transformers` — uses your installed transformers version to perform local inference
```bash
usage: python -m gpt_oss.responses_api.serve [-h] [--checkpoint FILE] [--port PORT] [--inference-backend BACKEND]

View File

@@ -0,0 +1,56 @@
"""
NOTE: this is not the most efficient way to use transformers. It's a simple implementation that infers
one token at a time to mimic the behavior of the Triton implementation.
"""
import os
from typing import Callable, List
# Transformers imports
from transformers import AutoModelForCausalLM, PreTrainedModel
import torch
DEFAULT_TEMPERATURE = 0.0
TP = os.environ.get("TP", 2)
def load_model(checkpoint: str):
"""
Serve the model directly with the Auto API.
"""
model = AutoModelForCausalLM.from_pretrained(
checkpoint,
torch_dtype=torch.bfloat16,
device_map="auto",
)
return model
def get_infer_next_token(model: PreTrainedModel):
"""
Return a callable with the same shape as the original triton implementation:
infer_next_token(tokens: List[int], temperature: float, new_request: bool) -> int
Implementation detail:
- We issue a single-token generation with using model.generate
- generate handles sampling (temperature=0 => greedy, otherwise, sampling).
"""
def infer_next_token(
tokens: List[int],
temperature: float = DEFAULT_TEMPERATURE,
new_request: bool = False, # kept for interface compatibility; unused here
) -> int:
tokens = torch.tensor([tokens], dtype=torch.int64, device=model.device)
output = model.generate(tokens, max_new_tokens=1, do_sample=temperature != 0, temperature=temperature)
return output[0, -1].tolist()
return infer_next_token
def setup_model(checkpoint: str) -> Callable[[List[int], float, bool], int]:
model = load_model(checkpoint)
infer_next_token = get_infer_next_token(model)
return infer_next_token

View File

@@ -47,6 +47,8 @@ if __name__ == "__main__":
from .inference.ollama import setup_model
elif args.inference_backend == "vllm":
from .inference.vllm import setup_model
elif args.inference_backend == "transformers":
from .inference.transformers import setup_model
else:
raise ValueError(f"Invalid inference backend: {args.inference_backend}")