Files
gpt-oss/gpt_oss/torch/model.py
Dominik Kundel 243a1b0276 Initial commit
Co-authored-by: Zhuohan Li <zhuohan@openai.com>
Co-authored-by: Maratyszcza <marat@openai.com>
Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
2025-08-05 08:19:49 -07:00

478 lines
16 KiB
Python

import json
import math
import os
from dataclasses import dataclass
import torch
import torch.distributed as dist
from gpt_oss.torch.weights import Checkpoint
@dataclass
class ModelConfig:
num_hidden_layers: int = 36
num_experts: int = 128
experts_per_token: int = 4
vocab_size: int = 201088
hidden_size: int = 2880
intermediate_size: int = 2880
swiglu_limit: float = 7.0
head_dim: int = 64
num_attention_heads: int = 64
num_key_value_heads: int = 8
sliding_window: int = 128
initial_context_length: int = 4096
rope_theta: float = 150000.0
rope_scaling_factor: float = 32.0
rope_ntk_alpha: float = 1.0
rope_ntk_beta: float = 32.0
class RMSNorm(torch.nn.Module):
def __init__(
self, num_features: int, eps: float = 1e-05, device: torch.device | None = None
):
super().__init__()
self.num_features = num_features
self.eps = eps
self.scale = torch.nn.Parameter(
torch.ones(num_features, device=device, dtype=torch.float32)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.shape[-1] == self.num_features
t, dtype = x.float(), x.dtype
t = t * torch.rsqrt(torch.mean(t**2, dim=-1, keepdim=True) + self.eps)
return (t * self.scale).to(dtype)
def _apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
x1, x2 = torch.chunk(x, 2, dim=-1)
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
return torch.cat((o1, o2), dim=-1)
class RotaryEmbedding(torch.nn.Module):
def __init__(
self,
head_dim: int,
base: int,
dtype: torch.dtype,
initial_context_length: int = 4096,
scaling_factor: float = 1.0,
ntk_alpha: float = 1.0,
ntk_beta: float = 32.0,
device: torch.device | None = None,
) -> None:
super().__init__()
self.head_dim = head_dim
self.base = base
self.dtype = dtype
self.initial_context_length = initial_context_length
self.scaling_factor = scaling_factor
self.ntk_alpha = ntk_alpha
self.ntk_beta = ntk_beta
self.device = device
def _compute_concentration_and_inv_freq(self) -> torch.Tensor:
"""See YaRN paper: https://arxiv.org/abs/2309.00071"""
freq = self.base ** (
torch.arange(0, self.head_dim, 2, dtype=torch.float, device=self.device)
/ self.head_dim
)
if self.scaling_factor > 1.0:
concentration = (
0.1 * math.log(self.scaling_factor) + 1.0
) # YaRN concentration
d_half = self.head_dim / 2
# NTK by parts
low = (
d_half
* math.log(self.initial_context_length / (self.ntk_beta * 2 * math.pi))
/ math.log(self.base)
)
high = (
d_half
* math.log(self.initial_context_length / (self.ntk_alpha * 2 * math.pi))
/ math.log(self.base)
)
assert 0 < low < high < d_half - 1
interpolation = 1.0 / (self.scaling_factor * freq)
extrapolation = 1.0 / freq
ramp = (
torch.arange(d_half, dtype=torch.float32, device=freq.device) - low
) / (high - low)
mask = 1 - ramp.clamp(0, 1)
inv_freq = interpolation * (1 - mask) + extrapolation * mask
else:
concentration = 1.0
inv_freq = 1.0 / freq
return concentration, inv_freq
def _compute_cos_sin(self, num_tokens: int):
concentration, inv_freq = self._compute_concentration_and_inv_freq()
t = torch.arange(num_tokens, dtype=torch.float32, device=self.device)
freqs = torch.einsum("i,j->ij", t, inv_freq)
cos = freqs.cos() * concentration
sin = freqs.sin() * concentration
return cos, sin
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
num_tokens = query.shape[0]
cos, sin = self._compute_cos_sin(num_tokens)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_dim)
query = _apply_rotary_emb(query, cos, sin)
query = query.reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_dim)
key = _apply_rotary_emb(key, cos, sin)
key = key.reshape(key_shape)
return query, key
def sdpa(Q, K, V, S, sm_scale, sliding_window=0):
# sliding_window == 0 means no sliding window
n_tokens, n_heads, q_mult, d_head = Q.shape
assert K.shape == (n_tokens, n_heads, d_head)
assert V.shape == (n_tokens, n_heads, d_head)
K = K[:, :, None, :].expand(-1, -1, q_mult, -1)
V = V[:, :, None, :].expand(-1, -1, q_mult, -1)
S = S.reshape(n_heads, q_mult, 1, 1).expand(-1, -1, n_tokens, -1)
mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
if sliding_window > 0:
mask += torch.tril(
mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window
)
QK = torch.einsum("qhmd,khmd->hmqk", Q, K)
QK *= sm_scale
QK += mask[None, None, :, :]
QK = torch.cat([QK, S], dim=-1)
W = torch.softmax(QK, dim=-1)
W = W[..., :-1]
attn = torch.einsum("hmqk,khmd->qhmd", W, V)
return attn.reshape(n_tokens, -1)
class AttentionBlock(torch.nn.Module):
def __init__(
self,
config: ModelConfig,
layer_idx: int = 0,
device: torch.device | None = None,
):
super().__init__()
self.head_dim = config.head_dim
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
# Only apply sliding window to every other layer
self.sliding_window = config.sliding_window if layer_idx % 2 == 0 else 0
self.sinks = torch.nn.Parameter(
torch.empty(config.num_attention_heads, device=device, dtype=torch.bfloat16)
)
self.norm = RMSNorm(config.hidden_size, device=device)
qkv_dim = config.head_dim * (
config.num_attention_heads + 2 * config.num_key_value_heads
)
self.qkv = torch.nn.Linear(
config.hidden_size, qkv_dim, device=device, dtype=torch.bfloat16
)
self.out = torch.nn.Linear(
config.head_dim * config.num_attention_heads,
config.hidden_size,
device=device,
dtype=torch.bfloat16,
)
self.sm_scale = 1 / math.sqrt(config.head_dim)
self.rope = RotaryEmbedding(
config.head_dim,
config.rope_theta,
torch.float32,
initial_context_length=config.initial_context_length,
scaling_factor=config.rope_scaling_factor,
ntk_alpha=config.rope_ntk_alpha,
ntk_beta=config.rope_ntk_beta,
device=device,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
t = self.norm(x)
qkv = self.qkv(t)
q = qkv[:, : self.num_attention_heads * self.head_dim].contiguous()
k = qkv[
:,
self.num_attention_heads
* self.head_dim : (self.num_attention_heads + self.num_key_value_heads)
* self.head_dim,
].contiguous()
v = qkv[
:,
(self.num_attention_heads + self.num_key_value_heads)
* self.head_dim : (self.num_attention_heads + 2 * self.num_key_value_heads)
* self.head_dim,
].contiguous()
q = q.view(
-1,
self.num_key_value_heads,
self.num_attention_heads // self.num_key_value_heads,
self.head_dim,
)
k = k.view(-1, self.num_key_value_heads, self.head_dim)
v = v.view(-1, self.num_key_value_heads, self.head_dim)
q, k = self.rope(q, k)
t = sdpa(q, k, v, self.sinks, self.sm_scale, self.sliding_window)
t = self.out(t)
t = x + t
return t
def swiglu(x, alpha: float = 1.702, limit: float = 7.0):
x_glu, x_linear = x[..., ::2], x[..., 1::2]
# Clamp the input values
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
# Note we add an extra bias of 1 to the linear layer
return out_glu * (x_linear + 1)
class MLPBlock(torch.nn.Module):
def __init__(
self,
config: ModelConfig,
device: torch.device | None = None,
):
super().__init__()
self.num_experts = config.num_experts
self.experts_per_token = config.experts_per_token
self.swiglu_limit = config.swiglu_limit
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.norm = RMSNorm(config.hidden_size, device=device)
self.gate = torch.nn.Linear(
config.hidden_size, config.num_experts, device=device, dtype=torch.bfloat16
)
assert config.intermediate_size % self.world_size == 0
self.mlp1_weight = torch.nn.Parameter(
torch.empty(
(
config.num_experts,
config.intermediate_size * 2 // self.world_size,
config.hidden_size,
),
device=device,
dtype=torch.bfloat16,
)
)
self.mlp1_bias = torch.nn.Parameter(
torch.empty(
(config.num_experts, config.intermediate_size * 2 // self.world_size),
device=device,
dtype=torch.bfloat16,
)
)
self.mlp2_weight = torch.nn.Parameter(
torch.empty(
(
config.num_experts,
config.hidden_size,
config.intermediate_size // self.world_size,
),
device=device,
dtype=torch.bfloat16,
)
)
self.mlp2_bias = torch.nn.Parameter(
torch.empty(
(config.num_experts, config.hidden_size),
device=device,
dtype=torch.bfloat16,
)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
t = self.norm(x)
g = self.gate(t)
experts = torch.topk(g, k=self.experts_per_token, dim=-1, sorted=True)
expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
expert_indices = experts.indices
# MLP #1
mlp1_weight = self.mlp1_weight[expert_indices, ...]
mlp1_bias = self.mlp1_bias[expert_indices, ...]
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
t = swiglu(t, limit=self.swiglu_limit)
# MLP #2
mlp2_weight = self.mlp2_weight[expert_indices, ...]
mlp2_bias = self.mlp2_bias[expert_indices, ...]
t = torch.einsum("beck,bek->bec", mlp2_weight, t)
if self.world_size > 1:
dist.all_reduce(t, op=dist.ReduceOp.SUM)
t += mlp2_bias
# Weighted sum of experts
t = torch.einsum("bec,be->bc", t, expert_weights)
return x + t
class TransformerBlock(torch.nn.Module):
def __init__(
self,
config: ModelConfig,
layer_idx: int,
device: torch.device | None = None,
):
super().__init__()
self.layer_idx = layer_idx
self.attn = AttentionBlock(config, layer_idx, device)
self.mlp = MLPBlock(config, device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.attn(x)
x = self.mlp(x)
return x
class Transformer(torch.nn.Module):
def __init__(
self,
config: ModelConfig,
device: torch.device | None = None,
):
super().__init__()
self.embedding = torch.nn.Embedding(
config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16
)
self.block = torch.nn.ModuleList(
[
TransformerBlock(config, layer_idx, device)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, device=device)
self.unembedding = torch.nn.Linear(
config.hidden_size,
config.vocab_size,
bias=False,
device=device,
dtype=torch.bfloat16,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.embedding(x)
for block in self.block:
x = block(x)
x = self.norm(x)
x = self.unembedding(x)
return x
@staticmethod
def from_checkpoint(
path: str, device: str | torch.device = "cuda"
) -> "Transformer":
if not isinstance(device, torch.device):
device = torch.device(device)
config_path = os.path.join(path, "config.json")
with open(config_path, "r") as f:
json_config = json.load(f)
config = ModelConfig(**json_config)
model = Transformer(
config=config,
device=device,
)
model.eval()
# Load weights
my_rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
per_rank_intermediate_size = config.intermediate_size // world_size
checkpoint = Checkpoint(path, device)
for name, param in model.named_parameters():
loaded_tensor = checkpoint.get(name)
# Note: it would be more efficient to do sharding before upcasting from MXFP4,
# but for simplicity we do it after.
if "mlp1" in name: # both weight and bias
loaded_tensor = loaded_tensor[
:,
my_rank * 2
* per_rank_intermediate_size : (my_rank + 1) * 2
* per_rank_intermediate_size,
...,
]
elif "mlp2_weight" in name: # only weight
loaded_tensor = loaded_tensor[
...,
my_rank
* per_rank_intermediate_size : (my_rank + 1)
* per_rank_intermediate_size,
]
try:
param.data.copy_(loaded_tensor)
except:
print(f"{name=} {param.data.shape=} {loaded_tensor.shape=}")
raise
return model
class TokenGenerator:
@torch.inference_mode()
def __init__(self, checkpoint: str, device: torch.device):
self.device = device
self.model = Transformer.from_checkpoint(checkpoint, device=self.device)
@torch.inference_mode()
def generate(self,
prompt_tokens: list[int],
stop_tokens: list[int],
temperature: float = 1.0,
max_tokens: int = 0,
return_logprobs: bool = False):
tokens = list(prompt_tokens)
num_generated_tokens = 0
while max_tokens == 0 or num_generated_tokens < max_tokens:
logits = self.model(torch.as_tensor(tokens, dtype=torch.int32, device=self.device))[-1]
if temperature == 0.0:
predicted_token = torch.argmax(logits, dim=-1).item()
else:
probs = torch.softmax(logits * (1.0 / temperature), dim=-1)
predicted_token = torch.multinomial(probs, num_samples=1).item()
tokens.append(predicted_token)
num_generated_tokens += 1
if return_logprobs:
logprobs = torch.log_softmax(logits, dim=-1)
selected_logprobs = logprobs[predicted_token].item()
yield predicted_token, selected_logprobs
else:
yield predicted_token
if predicted_token in stop_tokens:
break