mirror of
https://github.com/openai/gpt-oss.git
synced 2025-08-06 00:55:46 +03:00
Co-authored-by: Zhuohan Li <zhuohan@openai.com> Co-authored-by: Maratyszcza <marat@openai.com> Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
478 lines
16 KiB
Python
478 lines
16 KiB
Python
import json
|
|
import math
|
|
import os
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from gpt_oss.torch.weights import Checkpoint
|
|
|
|
|
|
@dataclass
|
|
class ModelConfig:
|
|
num_hidden_layers: int = 36
|
|
num_experts: int = 128
|
|
experts_per_token: int = 4
|
|
vocab_size: int = 201088
|
|
hidden_size: int = 2880
|
|
intermediate_size: int = 2880
|
|
swiglu_limit: float = 7.0
|
|
head_dim: int = 64
|
|
num_attention_heads: int = 64
|
|
num_key_value_heads: int = 8
|
|
sliding_window: int = 128
|
|
initial_context_length: int = 4096
|
|
rope_theta: float = 150000.0
|
|
rope_scaling_factor: float = 32.0
|
|
rope_ntk_alpha: float = 1.0
|
|
rope_ntk_beta: float = 32.0
|
|
|
|
|
|
class RMSNorm(torch.nn.Module):
|
|
def __init__(
|
|
self, num_features: int, eps: float = 1e-05, device: torch.device | None = None
|
|
):
|
|
super().__init__()
|
|
self.num_features = num_features
|
|
self.eps = eps
|
|
self.scale = torch.nn.Parameter(
|
|
torch.ones(num_features, device=device, dtype=torch.float32)
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
assert x.shape[-1] == self.num_features
|
|
t, dtype = x.float(), x.dtype
|
|
t = t * torch.rsqrt(torch.mean(t**2, dim=-1, keepdim=True) + self.eps)
|
|
return (t * self.scale).to(dtype)
|
|
|
|
|
|
def _apply_rotary_emb(
|
|
x: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
cos = cos.unsqueeze(-2).to(x.dtype)
|
|
sin = sin.unsqueeze(-2).to(x.dtype)
|
|
x1, x2 = torch.chunk(x, 2, dim=-1)
|
|
o1 = x1 * cos - x2 * sin
|
|
o2 = x2 * cos + x1 * sin
|
|
return torch.cat((o1, o2), dim=-1)
|
|
|
|
|
|
class RotaryEmbedding(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
head_dim: int,
|
|
base: int,
|
|
dtype: torch.dtype,
|
|
initial_context_length: int = 4096,
|
|
scaling_factor: float = 1.0,
|
|
ntk_alpha: float = 1.0,
|
|
ntk_beta: float = 32.0,
|
|
device: torch.device | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.head_dim = head_dim
|
|
self.base = base
|
|
self.dtype = dtype
|
|
self.initial_context_length = initial_context_length
|
|
self.scaling_factor = scaling_factor
|
|
self.ntk_alpha = ntk_alpha
|
|
self.ntk_beta = ntk_beta
|
|
self.device = device
|
|
|
|
def _compute_concentration_and_inv_freq(self) -> torch.Tensor:
|
|
"""See YaRN paper: https://arxiv.org/abs/2309.00071"""
|
|
freq = self.base ** (
|
|
torch.arange(0, self.head_dim, 2, dtype=torch.float, device=self.device)
|
|
/ self.head_dim
|
|
)
|
|
if self.scaling_factor > 1.0:
|
|
concentration = (
|
|
0.1 * math.log(self.scaling_factor) + 1.0
|
|
) # YaRN concentration
|
|
|
|
d_half = self.head_dim / 2
|
|
# NTK by parts
|
|
low = (
|
|
d_half
|
|
* math.log(self.initial_context_length / (self.ntk_beta * 2 * math.pi))
|
|
/ math.log(self.base)
|
|
)
|
|
high = (
|
|
d_half
|
|
* math.log(self.initial_context_length / (self.ntk_alpha * 2 * math.pi))
|
|
/ math.log(self.base)
|
|
)
|
|
assert 0 < low < high < d_half - 1
|
|
|
|
interpolation = 1.0 / (self.scaling_factor * freq)
|
|
extrapolation = 1.0 / freq
|
|
|
|
ramp = (
|
|
torch.arange(d_half, dtype=torch.float32, device=freq.device) - low
|
|
) / (high - low)
|
|
mask = 1 - ramp.clamp(0, 1)
|
|
|
|
inv_freq = interpolation * (1 - mask) + extrapolation * mask
|
|
else:
|
|
concentration = 1.0
|
|
inv_freq = 1.0 / freq
|
|
|
|
return concentration, inv_freq
|
|
|
|
def _compute_cos_sin(self, num_tokens: int):
|
|
concentration, inv_freq = self._compute_concentration_and_inv_freq()
|
|
t = torch.arange(num_tokens, dtype=torch.float32, device=self.device)
|
|
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
|
cos = freqs.cos() * concentration
|
|
sin = freqs.sin() * concentration
|
|
return cos, sin
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
num_tokens = query.shape[0]
|
|
cos, sin = self._compute_cos_sin(num_tokens)
|
|
|
|
query_shape = query.shape
|
|
query = query.view(num_tokens, -1, self.head_dim)
|
|
query = _apply_rotary_emb(query, cos, sin)
|
|
query = query.reshape(query_shape)
|
|
|
|
key_shape = key.shape
|
|
key = key.view(num_tokens, -1, self.head_dim)
|
|
key = _apply_rotary_emb(key, cos, sin)
|
|
key = key.reshape(key_shape)
|
|
return query, key
|
|
|
|
|
|
def sdpa(Q, K, V, S, sm_scale, sliding_window=0):
|
|
# sliding_window == 0 means no sliding window
|
|
n_tokens, n_heads, q_mult, d_head = Q.shape
|
|
assert K.shape == (n_tokens, n_heads, d_head)
|
|
assert V.shape == (n_tokens, n_heads, d_head)
|
|
K = K[:, :, None, :].expand(-1, -1, q_mult, -1)
|
|
V = V[:, :, None, :].expand(-1, -1, q_mult, -1)
|
|
S = S.reshape(n_heads, q_mult, 1, 1).expand(-1, -1, n_tokens, -1)
|
|
mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
|
|
if sliding_window > 0:
|
|
mask += torch.tril(
|
|
mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window
|
|
)
|
|
QK = torch.einsum("qhmd,khmd->hmqk", Q, K)
|
|
QK *= sm_scale
|
|
QK += mask[None, None, :, :]
|
|
QK = torch.cat([QK, S], dim=-1)
|
|
W = torch.softmax(QK, dim=-1)
|
|
W = W[..., :-1]
|
|
attn = torch.einsum("hmqk,khmd->qhmd", W, V)
|
|
return attn.reshape(n_tokens, -1)
|
|
|
|
|
|
class AttentionBlock(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: ModelConfig,
|
|
layer_idx: int = 0,
|
|
device: torch.device | None = None,
|
|
):
|
|
super().__init__()
|
|
self.head_dim = config.head_dim
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.num_key_value_heads = config.num_key_value_heads
|
|
# Only apply sliding window to every other layer
|
|
self.sliding_window = config.sliding_window if layer_idx % 2 == 0 else 0
|
|
self.sinks = torch.nn.Parameter(
|
|
torch.empty(config.num_attention_heads, device=device, dtype=torch.bfloat16)
|
|
)
|
|
self.norm = RMSNorm(config.hidden_size, device=device)
|
|
qkv_dim = config.head_dim * (
|
|
config.num_attention_heads + 2 * config.num_key_value_heads
|
|
)
|
|
self.qkv = torch.nn.Linear(
|
|
config.hidden_size, qkv_dim, device=device, dtype=torch.bfloat16
|
|
)
|
|
self.out = torch.nn.Linear(
|
|
config.head_dim * config.num_attention_heads,
|
|
config.hidden_size,
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
self.sm_scale = 1 / math.sqrt(config.head_dim)
|
|
self.rope = RotaryEmbedding(
|
|
config.head_dim,
|
|
config.rope_theta,
|
|
torch.float32,
|
|
initial_context_length=config.initial_context_length,
|
|
scaling_factor=config.rope_scaling_factor,
|
|
ntk_alpha=config.rope_ntk_alpha,
|
|
ntk_beta=config.rope_ntk_beta,
|
|
device=device,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
t = self.norm(x)
|
|
qkv = self.qkv(t)
|
|
q = qkv[:, : self.num_attention_heads * self.head_dim].contiguous()
|
|
k = qkv[
|
|
:,
|
|
self.num_attention_heads
|
|
* self.head_dim : (self.num_attention_heads + self.num_key_value_heads)
|
|
* self.head_dim,
|
|
].contiguous()
|
|
v = qkv[
|
|
:,
|
|
(self.num_attention_heads + self.num_key_value_heads)
|
|
* self.head_dim : (self.num_attention_heads + 2 * self.num_key_value_heads)
|
|
* self.head_dim,
|
|
].contiguous()
|
|
|
|
q = q.view(
|
|
-1,
|
|
self.num_key_value_heads,
|
|
self.num_attention_heads // self.num_key_value_heads,
|
|
self.head_dim,
|
|
)
|
|
k = k.view(-1, self.num_key_value_heads, self.head_dim)
|
|
v = v.view(-1, self.num_key_value_heads, self.head_dim)
|
|
q, k = self.rope(q, k)
|
|
t = sdpa(q, k, v, self.sinks, self.sm_scale, self.sliding_window)
|
|
t = self.out(t)
|
|
t = x + t
|
|
return t
|
|
|
|
|
|
def swiglu(x, alpha: float = 1.702, limit: float = 7.0):
|
|
x_glu, x_linear = x[..., ::2], x[..., 1::2]
|
|
# Clamp the input values
|
|
x_glu = x_glu.clamp(min=None, max=limit)
|
|
x_linear = x_linear.clamp(min=-limit, max=limit)
|
|
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
|
# Note we add an extra bias of 1 to the linear layer
|
|
return out_glu * (x_linear + 1)
|
|
|
|
|
|
class MLPBlock(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: ModelConfig,
|
|
device: torch.device | None = None,
|
|
):
|
|
super().__init__()
|
|
self.num_experts = config.num_experts
|
|
self.experts_per_token = config.experts_per_token
|
|
self.swiglu_limit = config.swiglu_limit
|
|
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
|
self.norm = RMSNorm(config.hidden_size, device=device)
|
|
self.gate = torch.nn.Linear(
|
|
config.hidden_size, config.num_experts, device=device, dtype=torch.bfloat16
|
|
)
|
|
assert config.intermediate_size % self.world_size == 0
|
|
self.mlp1_weight = torch.nn.Parameter(
|
|
torch.empty(
|
|
(
|
|
config.num_experts,
|
|
config.intermediate_size * 2 // self.world_size,
|
|
config.hidden_size,
|
|
),
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
)
|
|
self.mlp1_bias = torch.nn.Parameter(
|
|
torch.empty(
|
|
(config.num_experts, config.intermediate_size * 2 // self.world_size),
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
)
|
|
self.mlp2_weight = torch.nn.Parameter(
|
|
torch.empty(
|
|
(
|
|
config.num_experts,
|
|
config.hidden_size,
|
|
config.intermediate_size // self.world_size,
|
|
),
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
)
|
|
self.mlp2_bias = torch.nn.Parameter(
|
|
torch.empty(
|
|
(config.num_experts, config.hidden_size),
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
t = self.norm(x)
|
|
g = self.gate(t)
|
|
experts = torch.topk(g, k=self.experts_per_token, dim=-1, sorted=True)
|
|
expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
|
|
expert_indices = experts.indices
|
|
|
|
# MLP #1
|
|
mlp1_weight = self.mlp1_weight[expert_indices, ...]
|
|
mlp1_bias = self.mlp1_bias[expert_indices, ...]
|
|
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
|
|
t = swiglu(t, limit=self.swiglu_limit)
|
|
|
|
# MLP #2
|
|
mlp2_weight = self.mlp2_weight[expert_indices, ...]
|
|
mlp2_bias = self.mlp2_bias[expert_indices, ...]
|
|
t = torch.einsum("beck,bek->bec", mlp2_weight, t)
|
|
if self.world_size > 1:
|
|
dist.all_reduce(t, op=dist.ReduceOp.SUM)
|
|
t += mlp2_bias
|
|
|
|
# Weighted sum of experts
|
|
t = torch.einsum("bec,be->bc", t, expert_weights)
|
|
|
|
return x + t
|
|
|
|
|
|
class TransformerBlock(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: ModelConfig,
|
|
layer_idx: int,
|
|
device: torch.device | None = None,
|
|
):
|
|
super().__init__()
|
|
self.layer_idx = layer_idx
|
|
self.attn = AttentionBlock(config, layer_idx, device)
|
|
self.mlp = MLPBlock(config, device)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.attn(x)
|
|
x = self.mlp(x)
|
|
return x
|
|
|
|
|
|
class Transformer(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: ModelConfig,
|
|
device: torch.device | None = None,
|
|
):
|
|
super().__init__()
|
|
self.embedding = torch.nn.Embedding(
|
|
config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16
|
|
)
|
|
self.block = torch.nn.ModuleList(
|
|
[
|
|
TransformerBlock(config, layer_idx, device)
|
|
for layer_idx in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
self.norm = RMSNorm(config.hidden_size, device=device)
|
|
self.unembedding = torch.nn.Linear(
|
|
config.hidden_size,
|
|
config.vocab_size,
|
|
bias=False,
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.embedding(x)
|
|
for block in self.block:
|
|
x = block(x)
|
|
x = self.norm(x)
|
|
x = self.unembedding(x)
|
|
return x
|
|
|
|
@staticmethod
|
|
def from_checkpoint(
|
|
path: str, device: str | torch.device = "cuda"
|
|
) -> "Transformer":
|
|
if not isinstance(device, torch.device):
|
|
device = torch.device(device)
|
|
|
|
config_path = os.path.join(path, "config.json")
|
|
with open(config_path, "r") as f:
|
|
json_config = json.load(f)
|
|
config = ModelConfig(**json_config)
|
|
|
|
model = Transformer(
|
|
config=config,
|
|
device=device,
|
|
)
|
|
model.eval()
|
|
|
|
# Load weights
|
|
my_rank = dist.get_rank() if dist.is_initialized() else 0
|
|
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
|
per_rank_intermediate_size = config.intermediate_size // world_size
|
|
|
|
checkpoint = Checkpoint(path, device)
|
|
|
|
for name, param in model.named_parameters():
|
|
loaded_tensor = checkpoint.get(name)
|
|
|
|
# Note: it would be more efficient to do sharding before upcasting from MXFP4,
|
|
# but for simplicity we do it after.
|
|
if "mlp1" in name: # both weight and bias
|
|
loaded_tensor = loaded_tensor[
|
|
:,
|
|
my_rank * 2
|
|
* per_rank_intermediate_size : (my_rank + 1) * 2
|
|
* per_rank_intermediate_size,
|
|
...,
|
|
]
|
|
elif "mlp2_weight" in name: # only weight
|
|
loaded_tensor = loaded_tensor[
|
|
...,
|
|
my_rank
|
|
* per_rank_intermediate_size : (my_rank + 1)
|
|
* per_rank_intermediate_size,
|
|
]
|
|
try:
|
|
param.data.copy_(loaded_tensor)
|
|
except:
|
|
print(f"{name=} {param.data.shape=} {loaded_tensor.shape=}")
|
|
raise
|
|
|
|
return model
|
|
|
|
|
|
class TokenGenerator:
|
|
@torch.inference_mode()
|
|
def __init__(self, checkpoint: str, device: torch.device):
|
|
self.device = device
|
|
self.model = Transformer.from_checkpoint(checkpoint, device=self.device)
|
|
|
|
@torch.inference_mode()
|
|
def generate(self,
|
|
prompt_tokens: list[int],
|
|
stop_tokens: list[int],
|
|
temperature: float = 1.0,
|
|
max_tokens: int = 0,
|
|
return_logprobs: bool = False):
|
|
tokens = list(prompt_tokens)
|
|
num_generated_tokens = 0
|
|
while max_tokens == 0 or num_generated_tokens < max_tokens:
|
|
logits = self.model(torch.as_tensor(tokens, dtype=torch.int32, device=self.device))[-1]
|
|
if temperature == 0.0:
|
|
predicted_token = torch.argmax(logits, dim=-1).item()
|
|
else:
|
|
probs = torch.softmax(logits * (1.0 / temperature), dim=-1)
|
|
predicted_token = torch.multinomial(probs, num_samples=1).item()
|
|
tokens.append(predicted_token)
|
|
num_generated_tokens += 1
|
|
|
|
if return_logprobs:
|
|
logprobs = torch.log_softmax(logits, dim=-1)
|
|
selected_logprobs = logprobs[predicted_token].item()
|
|
yield predicted_token, selected_logprobs
|
|
else:
|
|
yield predicted_token
|
|
|
|
if predicted_token in stop_tokens:
|
|
break
|