mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2025-10-09 13:40:09 +03:00
tutorial(training): Add a minimal example with trl (#473)
* v0 * 2 gpu setup * improve parsing from yaml * update yaml dataset example * remove restriction on flash attn * more comments * first version of the readme * pin torch * simplify requirements * just flash attn * use set env instead * simpler set env * readme * add wandb project to setup * update template * update model id * post init to capture the config and weight * extract metadata * update config * update dataset config * move env for wandb project * pre-commit * remove qwen-math from training * more instructions * unused import * remove trl old * warmup ratio * warmup ratio * change model id * change model_id * add info about CUDA_VISIBLE_DEVICES
This commit is contained in:
committed by
GitHub
parent
49f3821098
commit
56ce2e79a7
@@ -1,32 +1,56 @@
|
||||
# TRL Examples
|
||||
# Training with TRL
|
||||
|
||||
This directory contains examples using the [TRL (Transformer Reinforcement Learning) library](https://github.com/huggingface/trl) to fine-tune language models with reinforcement learning techniques.
|
||||
Training stack:
|
||||
- TRL for reinforcement learning training
|
||||
- Accelerate (with DeepSpeed) for distributed training
|
||||
- vLLM for rollouts
|
||||
|
||||
## GRPO Example
|
||||
|
||||
The main example demonstrates using GRPO (Group Relative Policy Optimization) to fine-tune a language model on reasoning tasks from reasoning-gym. It includes:
|
||||
|
||||
- Custom reward functions for answer accuracy and format compliance
|
||||
- Integration with reasoning-gym datasets
|
||||
- Configurable training parameters via YAML config
|
||||
- Wandb logging and model checkpointing
|
||||
- Evaluation on held-out test sets
|
||||
|
||||
## Setup
|
||||
|
||||
1. Install the required dependencies:
|
||||
This tutorial uses CUDA 11.8, Python 3.10, and PyTorch 2.5.1
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
Moreover, we assume that you have 2 GPUs on your machine, the last of which is used for vLLM rollouts.
|
||||
|
||||
If you have more than 2 GPUs, adjust the `./config/grpo.yaml` file so that the `vllm_device` is set to the last index of your GPU. For example, if you have 4 GPUs, set it to 3:
|
||||
```yaml
|
||||
vllm_device: 3 # If you have 4 GPUs, set this to 3
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
1. Configure the training parameters in `config/grpo.yaml`
|
||||
2. Run the training script:
|
||||
|
||||
Moreover, you would need to update the `CUDA_VISIBLE_DEVICES` environment variable in the `train.sh` script to include all your available GPUs. For example, if you have 4 GPUs, set it to:
|
||||
```bash
|
||||
python main_grpo_reward.py
|
||||
# ./train.sh
|
||||
|
||||
# ... beginning of the script
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
# ... rest of the script
|
||||
```
|
||||
|
||||
The model will be trained using GRPO with the specified reasoning-gym dataset and evaluation metrics will be logged to Weights & Biases.
|
||||
|
||||
|
||||
1. Install the required packages:
|
||||
```bash
|
||||
# First, give execute permissions to the script
|
||||
# chmod +x ./set_env.sh
|
||||
|
||||
# Then, run the setup script
|
||||
./set_env.sh
|
||||
```
|
||||
|
||||
2. (Optional) Log in to Weights & Biases for experiment tracking:
|
||||
```bash
|
||||
# First, set your WANDB_API_KEY as an environment variable
|
||||
export WANDB_API_KEY=your_wandb_api_key
|
||||
|
||||
# Set the project name
|
||||
export WANDB_PROJECT=your_wandb_project_name
|
||||
```
|
||||
|
||||
3. Run the training script
|
||||
```bash
|
||||
# First, give execute permissions to the script
|
||||
# chmod +x ./train.sh
|
||||
|
||||
# Then, run the training script
|
||||
./train.sh
|
||||
```
|
||||
|
||||
@@ -1,37 +1,52 @@
|
||||
#Model arguments
|
||||
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
||||
# Reasoning Gym configs
|
||||
dataset_size: 20000
|
||||
developer_prompt: DeepSeekZero
|
||||
developer_role: system
|
||||
datasets:
|
||||
simple_equations:
|
||||
weight: 1
|
||||
complex_arithmetic:
|
||||
weight: 1
|
||||
config:
|
||||
min_real: -20
|
||||
max_real: 20
|
||||
|
||||
|
||||
#script arguments
|
||||
dataset_name: chain_sum
|
||||
# Model configs from trl
|
||||
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
#training arguments
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
gradient_accumulation_steps: 16
|
||||
use_vllm: true
|
||||
vllm_device: cuda:1
|
||||
vllm_gpu_memory_utilization: 0.9
|
||||
log_level: info
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id:
|
||||
seed: 42
|
||||
eval_seed: 101
|
||||
log_level: info
|
||||
logging_steps: 10
|
||||
use_reentrant: false
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
learning_rate: 2.0e-05
|
||||
learning_rate: 1e-06
|
||||
lr_scheduler_type: constant_with_warmup
|
||||
lr_scheduler_kwargs:
|
||||
num_warmup_steps: 10
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 1024
|
||||
max_completion_length: 2048
|
||||
max_steps: 100
|
||||
num_generations: 8
|
||||
per_device_train_batch_size: 1
|
||||
per_device_eval_batch_size: 1
|
||||
overwrite_output_dir: true
|
||||
output_dir: data/Qwen-1.5B-GRPO
|
||||
train_size: 1000
|
||||
eval_size: 100
|
||||
num_train_epochs: 1
|
||||
max_steps: -1
|
||||
push_to_hub: true
|
||||
report_to: ['wandb']
|
||||
#do_eval: true
|
||||
#eval_strategy: steps
|
||||
#eval_steps: 100
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 8
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 50
|
||||
save_total_limit: 5
|
||||
|
||||
seed: 42
|
||||
temperature: 0.6
|
||||
warmup_ratio: 0.1
|
||||
|
||||
266
examples/trl/grpo.py
Normal file
266
examples/trl/grpo.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from trl import GRPOConfig, GRPOTrainer, ModelConfig, TrlParser
|
||||
|
||||
import reasoning_gym
|
||||
from reasoning_gym.coaching.experiment import Experiment
|
||||
from reasoning_gym.composite import DatasetSpec
|
||||
from reasoning_gym.dataset import ProceduralDataset
|
||||
from reasoning_gym.utils import SYSTEM_PROMPTS, extract_answer
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetConfigItem:
|
||||
weight: Optional[float] = field(default=1.0)
|
||||
config: Optional[dict] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
dataset_size: int = field(default=1000)
|
||||
developer_prompt: str = field(default="DeepSeekZero")
|
||||
developer_role: str = field(default="system")
|
||||
datasets: dict[str, DatasetConfigItem] = field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
# Convert dictionary items to DatasetConfigItem instances
|
||||
if self.datasets:
|
||||
converted_datasets = {}
|
||||
for name, config_item in self.datasets.items():
|
||||
if isinstance(config_item, dict):
|
||||
converted_datasets[name] = DatasetConfigItem(**config_item)
|
||||
else:
|
||||
converted_datasets[name] = config_item
|
||||
self.datasets = converted_datasets
|
||||
|
||||
|
||||
class ReasoningGymDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
procedural_dataset: Optional[ProceduralDataset] = None,
|
||||
experiment: Optional[Experiment] = None,
|
||||
developer_prompt: Optional[str] = None,
|
||||
developer_role: Optional[str] = None,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.data = procedural_dataset or experiment.composite
|
||||
self.experiment = experiment
|
||||
self.developer_prompt = developer_prompt
|
||||
self.developer_role = developer_role
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
item = self.data[idx]
|
||||
question = item["question"]
|
||||
|
||||
chat = []
|
||||
if self.developer_role is not None:
|
||||
chat.append({"role": self.developer_role, "content": self.developer_prompt})
|
||||
chat.append({"role": "user", "content": question})
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
||||
return {"prompt": prompt, "item": item}
|
||||
|
||||
|
||||
class CustomGRPOTrainer(GRPOTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
args: GRPOConfig,
|
||||
tokenizer,
|
||||
train_dataset: ReasoningGymDataset,
|
||||
eval_dataset: ReasoningGymDataset,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
reward_funcs=[
|
||||
self._accuracy_reward,
|
||||
self._format_reward,
|
||||
],
|
||||
args=args,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
)
|
||||
|
||||
def _accuracy_reward(self, completions: list[str], **kwargs) -> list[float]:
|
||||
assert "item" in kwargs, "The 'item' argument must be provided to compute accuracy reward."
|
||||
assert len(kwargs["item"]) == len(completions), "Items and completions must have the same length."
|
||||
assert all(isinstance(item, dict) for item in kwargs["item"]), "Each item must be a dictionary."
|
||||
answers = [extract_answer(c) for c in completions]
|
||||
return [self.train_dataset.data.score_answer(answer, item) for answer, item in zip(answers, kwargs["item"])]
|
||||
|
||||
def _format_reward(self, completions: list[str], **kwargs) -> list[float]:
|
||||
def count_tags(text: str) -> float:
|
||||
count = 0.0
|
||||
if re.search(r"\s*<think>\s*", text):
|
||||
count += 0.25
|
||||
if re.search(r"\s*</think>\s*", text):
|
||||
count += 0.25
|
||||
if re.search(r"\s*<answer>\s*", text):
|
||||
count += 0.25
|
||||
if re.search(r"\s*</answer>\s*", text):
|
||||
count += 0.25
|
||||
return count
|
||||
|
||||
return [count_tags(c) for c in completions]
|
||||
|
||||
|
||||
def make_dataset(
|
||||
tokenizer,
|
||||
data_source: Experiment | ProceduralDataset,
|
||||
developer_prompt: str,
|
||||
developer_role: Optional[str] = None,
|
||||
) -> ReasoningGymDataset:
|
||||
"""Create a ReasoningGymDataset from an Experiment or ProceduralDataset."""
|
||||
if isinstance(data_source, Experiment):
|
||||
return ReasoningGymDataset(
|
||||
tokenizer=tokenizer,
|
||||
experiment=data_source,
|
||||
developer_prompt=developer_prompt,
|
||||
developer_role=developer_role,
|
||||
)
|
||||
else:
|
||||
return ReasoningGymDataset(
|
||||
tokenizer=tokenizer,
|
||||
procedural_dataset=data_source,
|
||||
developer_prompt=developer_prompt,
|
||||
developer_role=developer_role,
|
||||
)
|
||||
|
||||
|
||||
def prepare_datasets(
|
||||
config: DatasetConfig,
|
||||
tokenizer,
|
||||
) -> tuple[ReasoningGymDataset, ReasoningGymDataset]:
|
||||
"""Prepare the training and eval datasets."""
|
||||
developer_prompt = SYSTEM_PROMPTS[config.developer_prompt]
|
||||
|
||||
dataset_specs = [
|
||||
DatasetSpec(
|
||||
name=name,
|
||||
weight=ds_config.weight,
|
||||
config=ds_config.config,
|
||||
)
|
||||
for name, ds_config in config.datasets.items()
|
||||
]
|
||||
train_data_source = reasoning_gym.create_dataset(
|
||||
"composite", seed=1, size=config.dataset_size, datasets=dataset_specs
|
||||
)
|
||||
val_data_source = reasoning_gym.create_dataset(
|
||||
"composite", seed=2, size=config.dataset_size, datasets=dataset_specs
|
||||
)
|
||||
train_dataset = make_dataset(
|
||||
tokenizer=tokenizer,
|
||||
data_source=train_data_source,
|
||||
developer_prompt=developer_prompt,
|
||||
developer_role=config.developer_role,
|
||||
)
|
||||
eval_dataset = make_dataset(
|
||||
tokenizer=tokenizer,
|
||||
data_source=val_data_source,
|
||||
developer_prompt=developer_prompt,
|
||||
developer_role=config.developer_role,
|
||||
)
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def main():
|
||||
# -----------
|
||||
# Parse args
|
||||
# -----------
|
||||
parser = TrlParser((DatasetConfig, GRPOConfig, ModelConfig))
|
||||
reasoning_gym_args, training_args, model_args = parser.parse_args_and_config()
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# ---------------
|
||||
# Set up logging
|
||||
# ---------------
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Model parameters {model_args}")
|
||||
logger.info(f"Training parameters {training_args}")
|
||||
|
||||
# -----------
|
||||
# Load model
|
||||
# -----------
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
|
||||
|
||||
# --------------------
|
||||
# Instantiate trainer
|
||||
# --------------------
|
||||
training_args.reasoning_gym = reasoning_gym_args
|
||||
train_dataset, eval_dataset = prepare_datasets(reasoning_gym_args, tokenizer)
|
||||
trainer = CustomGRPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
)
|
||||
|
||||
# ------------------------------
|
||||
# See if we can resume training
|
||||
# ------------------------------
|
||||
logger.info("Starting training...")
|
||||
# Check for last checkpoint
|
||||
ckpt = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
ckpt = training_args.resume_from_checkpoint
|
||||
elif os.path.isdir(training_args.output_dir):
|
||||
ckpt = get_last_checkpoint(training_args.output_dir)
|
||||
if ckpt:
|
||||
logger.info(f"\nCheckpoint detected, resuming training at {ckpt=}.")
|
||||
else:
|
||||
logger.info("\nNo checkpoint detected, starting training from scratch.")
|
||||
|
||||
# ---------------
|
||||
# Start training
|
||||
# ---------------
|
||||
train_result = trainer.train(resume_from_checkpoint=ckpt)
|
||||
train_metrics = train_result.metrics
|
||||
trainer.log_metrics("train", train_metrics)
|
||||
trainer.save_metrics("train", train_metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# ---------
|
||||
# Clean up
|
||||
# ---------
|
||||
del trainer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,18 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
Arguments for the training script.
|
||||
"""
|
||||
|
||||
dataset_name: str
|
||||
dataset_config: Optional[str] = None
|
||||
dataset_train_split: str = "train"
|
||||
dataset_test_split: str = "test"
|
||||
gradient_checkpointing_use_reentrant: bool = False
|
||||
ignore_bias_buffers: bool = False
|
||||
train_size: int = 1000
|
||||
eval_size: int = 100
|
||||
@@ -1,217 +0,0 @@
|
||||
# This example is an adapted version of HuggingFace trl GRPO code:
|
||||
# link : https://github.com/huggingface/open-r1/blob/main/src/open_r1/grpo.py
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from grpo_config import ScriptArguments
|
||||
from peft import LoraConfig
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from trl import GRPOConfig, GRPOTrainer, ModelConfig, TrlParser
|
||||
|
||||
import reasoning_gym
|
||||
from reasoning_gym.utils import extract_answer
|
||||
|
||||
|
||||
class ReasoningGymDataset(Dataset):
|
||||
def __init__(self, dataset_name, seed, size, tokenizer, developer_prompt, developer_role="system") -> None:
|
||||
super().__init__()
|
||||
self.data = reasoning_gym.create_dataset(dataset_name, seed=seed, size=size)
|
||||
self.tokenizer = tokenizer
|
||||
self.developer_role = developer_role
|
||||
self.developer_prompt = developer_prompt
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
question = item["question"]
|
||||
|
||||
chat = []
|
||||
|
||||
if self.developer_role is not None:
|
||||
chat.append({"role": self.developer_role, "content": self.developer_prompt})
|
||||
chat.append({"role": "user", "content": question})
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
||||
return {"prompt": prompt, "metadata": item}
|
||||
|
||||
|
||||
class GRPOTrainerCustom(GRPOTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
dataset_name,
|
||||
args: GRPOConfig,
|
||||
tokenizer,
|
||||
peft_config,
|
||||
seed,
|
||||
size,
|
||||
developer_role="system",
|
||||
):
|
||||
super().__init__(
|
||||
model,
|
||||
reward_funcs=[self._accuracy_reward, self._format_reward],
|
||||
args=args,
|
||||
processing_class=tokenizer,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"]
|
||||
self.train_dataset = ReasoningGymDataset(dataset_name, seed, size, tokenizer, developer_prompt, developer_role)
|
||||
|
||||
def _format_reward(self, completions, **kwargs):
|
||||
regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
|
||||
matches = [re.match(regex, completion, flags=re.DOTALL) for completion in completions]
|
||||
return [1.0 if match else 0.0 for match in matches]
|
||||
|
||||
def _accuracy_reward(self, completions, metadata, **kwargs):
|
||||
answers = [extract_answer(completion) for completion in completions]
|
||||
return [self.train_dataset.data.score_answer(answer, entry=obj) for (answer, obj) in zip(answers, metadata)]
|
||||
|
||||
|
||||
def main(script_args, training_args, model_args):
|
||||
set_seed(training_args.seed)
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level) # Set for module-level logger
|
||||
|
||||
# Configure third-party library log levels
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
logger.info(f"Training arguments: {training_args}")
|
||||
logger.info(f"Model arguments: {model_args}")
|
||||
logger.info(f"Script arguments: {script_args}")
|
||||
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
).to("cuda")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=64,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
trainer = GRPOTrainerCustom(
|
||||
model,
|
||||
dataset_name=script_args.dataset_name,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
peft_config=peft_config,
|
||||
seed=training_args.seed,
|
||||
size=script_args.train_size,
|
||||
)
|
||||
|
||||
# Training loop
|
||||
logger.info("Training model...")
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is None:
|
||||
checkpoint = model.save_pretrained(training_args.output_dir)
|
||||
|
||||
train_results = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
metrics = train_results.metrics
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
logger.info("*** Save model ***")
|
||||
trainer.save_model(training_args.output_dir)
|
||||
logger.info(f"Model saved to {training_args.output_dir}")
|
||||
|
||||
kwargs = {
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": list(script_args.dataset_name),
|
||||
"dataset_tags": list(script_args.dataset_name),
|
||||
"tags": ["reasoning-gym"],
|
||||
}
|
||||
|
||||
if trainer.accelerator.is_main_process:
|
||||
trainer.create_model_card(**kwargs)
|
||||
# Restore k,v cache for fast inference
|
||||
trainer.model.config.use_cache = True
|
||||
trainer.model.config.save_pretrained(training_args.output_dir)
|
||||
|
||||
def evaluate_model(model, tokenizer, dataset, *args, **kwargs):
|
||||
model.eval()
|
||||
correct_preds = 0
|
||||
total_preds = 0
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
prompt = item["prompt"]
|
||||
metadata = item["metadata"]
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
inputs,
|
||||
max_new_tokens=training_args.max_completion_length,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
answer = reasoning_gym.utils.extract_answer(generated_text)
|
||||
score = dataset.data.score_answer(answer, entry=metadata)
|
||||
correct_preds += score
|
||||
total_preds += 1
|
||||
|
||||
return correct_preds / total_preds
|
||||
|
||||
## Evaluate model
|
||||
logger.info("Evaluating model...")
|
||||
eval_dataset = ReasoningGymDataset(
|
||||
script_args.dataset_name,
|
||||
training_args.eval_seed,
|
||||
script_args.eval_size,
|
||||
tokenizer,
|
||||
reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"],
|
||||
)
|
||||
|
||||
eval_results = evaluate_model(model, tokenizer, eval_dataset)
|
||||
trainer.log_metrics("eval", {"accuracy": eval_results})
|
||||
trainer.save_metrics("eval", {"accuracy": eval_results})
|
||||
logger.info(f"Evaluation results: {eval_results}")
|
||||
|
||||
if training_args.push_to_hub:
|
||||
logging.info("Pushing model to hub...")
|
||||
trainer.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
main(script_args, training_args, model_args)
|
||||
@@ -1,8 +0,0 @@
|
||||
torch>=2.6.0
|
||||
datasets
|
||||
peft
|
||||
transformers
|
||||
trl
|
||||
wandb
|
||||
huggingface_hub
|
||||
flash-attn --no-build-isolation
|
||||
@@ -2,19 +2,11 @@
|
||||
# python 3.10 + cuda 11.8.0
|
||||
# the execution order the following commands matter
|
||||
|
||||
export MKL_NUM_THREADS=1
|
||||
export NUMEXPR_NUM_THREADS=1
|
||||
export OPENBLAS_NUM_THREADS=1
|
||||
export OMP_NUM_THREADS=1
|
||||
|
||||
conda clean -a -y
|
||||
mamba clean -a -y
|
||||
pip install --upgrade pip
|
||||
pip cache purge
|
||||
|
||||
# cuda, gcc/g++, torch
|
||||
# mamba install cuda -c nvidia/label/cuda-11.8.0 -y
|
||||
# mamba install gcc gxx -c conda-forge -y
|
||||
# torch
|
||||
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
# xformers
|
||||
@@ -25,15 +17,7 @@ pip install https://github.com/vllm-project/vllm/releases/download/v0.7.2/vllm-0
|
||||
|
||||
pip install deepspeed
|
||||
pip install flash-attn==2.7.3 --no-build-isolation
|
||||
pip install peft
|
||||
|
||||
pip install "trl==0.15.2"
|
||||
pip install latex2sympy2_extended
|
||||
pip install "math_verify==0.5.2"
|
||||
pip install word2number
|
||||
pip install scipy
|
||||
|
||||
pip install "transformers==4.49.0"
|
||||
pip install wandb
|
||||
pip install plotly
|
||||
pip install matplotlib
|
||||
pip install seaborn
|
||||
pip install reasoning-gym
|
||||
26
examples/trl/train.sh
Executable file
26
examples/trl/train.sh
Executable file
@@ -0,0 +1,26 @@
|
||||
#!/bin/bash
|
||||
|
||||
export CUDA_VISIBLE_DEVICES=0,1
|
||||
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
|
||||
NUM_PROCESSES_TRAINING=$((GPU_COUNT - 1))
|
||||
|
||||
echo ""
|
||||
echo "Number of GPUs: ${GPU_COUNT}"
|
||||
echo "Number of processes for training: ${NUM_PROCESSES_TRAINING}"
|
||||
echo ""
|
||||
|
||||
PY_SCRIPT="./grpo.py"
|
||||
PY_CONFIG="./config/grpo.yaml"
|
||||
ACCELERATE_DS_CONFIG="./config/ds_zero2.yaml"
|
||||
|
||||
echo "START TIME: $(date)"
|
||||
|
||||
export WANDB_PROJECT="reasoning-gym-trl"
|
||||
|
||||
accelerate launch \
|
||||
--config_file "${ACCELERATE_DS_CONFIG}" \
|
||||
--main_process_port=29500 \
|
||||
--num_processes="${NUM_PROCESSES_TRAINING}" "${PY_SCRIPT}" --config "${PY_CONFIG}"
|
||||
|
||||
echo "END TIME: $(date)"
|
||||
echo "DONE"
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
Training codebase for training LLMs using Reasoning Gym procedural dataset generators.
|
||||
|
||||
**Note**: `qwen-math/` directory contains the code from the Tina project, used for the Qwen2.5 3B RG-Math training. This is separate from the rest of our training/evaluation codebase.
|
||||
|
||||
This readme documents:
|
||||
|
||||
- Training environment setup and usage example
|
||||
|
||||
180
training/qwen-math/.gitignore
vendored
180
training/qwen-math/.gitignore
vendored
@@ -1,180 +0,0 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
.DS_Store
|
||||
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
|
||||
ckpts/
|
||||
outputs/
|
||||
@@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -1,179 +0,0 @@
|
||||
<div align="center">
|
||||
<h1 style="font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; margin-bottom: 10px;">
|
||||
Tina: Tiny Reasoning Models via LoRA
|
||||
</h1>
|
||||
|
||||
<hr style="width: 60%; border: none; border-top: 2px solid #ccc; margin: 0 auto 20px auto;">
|
||||
|
||||
<a href="https://github.com/shangshang-wang/Tina">
|
||||
<img src="./assets/Avatar-Tina.png" style="
|
||||
width: 200px;
|
||||
border-radius: 20px;
|
||||
box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2);
|
||||
border: 3px solid #f18f01;
|
||||
transition: transform 0.3s ease;
|
||||
"
|
||||
onmouseover="this.style.transform='scale(1.05)'"
|
||||
onmouseout="this.style.transform='scale(1)'">
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://github.com/shangshang-wang/Tina)
|
||||
[](https://shangshangwang.notion.site/tina)
|
||||
[](https://huggingface.co/Tina-Yi)
|
||||
[](https://wandb.ai/upup-ashton-wang-usc/Tina)
|
||||
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
This repository contains the code for the Tina project, accompanying the paper [Tina: Tiny Reasoning Models via LoRA](https://arxiv.org/abs/2504.15777).
|
||||
We in this project try to answer the question "How cost-effectively can one perform reinforcement learning to efficiently instill reasoning abilities in language models?"
|
||||
Specifically, we explore enhancing reasoning capabilities in tiny language models with low-rank adaptation during reinforcement learning.
|
||||
|
||||
<div style="text-align: center;">
|
||||
<img
|
||||
src="assets/overall_comparison.png"
|
||||
alt="Overall Comparison"
|
||||
width="1000"
|
||||
style="max-width: 100%; height: auto;">
|
||||
</div>
|
||||
|
||||
We show that our Tina models achieve performance competitive with, and in some cases even superior to, SOTA baseline models built on the same base model with full-parameter training.
|
||||
In particular, the best Tina model achieves a >20% performance increase and 43.33% Pass@1 accuracy on AIME24.
|
||||
Notably, the cost of reproducing the best Tina checkpoint stands at only \$9, and of reproducing all our experiments from scratch at \$526.
|
||||
|
||||
<div style="text-align: center;">
|
||||
<img
|
||||
src="assets/cost.png"
|
||||
alt="Cost Breakdown"
|
||||
style="max-width: 50%; height: auto;">
|
||||
</div>
|
||||
|
||||
|
||||
## Quick Start
|
||||
|
||||
### File Setup
|
||||
|
||||
* `./scripts/set/set_vars.sh`: contain the main env vars we use. Change the paths (e.g. `PROJECT_PREFIX`, `SCRATCH_PREFIX`) to align with your own setting. Also make sure to add the `WANDB_API_KEY` and `HF_TOKEN` in your `~/.bashrc` file.
|
||||
* `./recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/`: contain the recipes for each experiment in this project, change the HF hub id to align with your own setting.
|
||||
* `./tina/config.py`: contain the main configurations for this project, set default values here.
|
||||
* `./tina/utils/constant.py`: contain the main datasets for each experiment in this project.
|
||||
|
||||
### Env Setup
|
||||
|
||||
First, install Miniconda:
|
||||
```bash
|
||||
mkdir -p ~/miniconda3
|
||||
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
||||
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
|
||||
source ~/miniconda3/bin/activate
|
||||
|
||||
conda init --all
|
||||
```
|
||||
|
||||
Then, run the following commands to install the dependencies.
|
||||
```bash
|
||||
conda update -n base -c defaults conda -y
|
||||
conda install -n base -c conda-forge mamba -y
|
||||
|
||||
mamba create -n tina python=3.10 -y && mamba activate tina
|
||||
./scripts/set/set_env.sh && mamba deactivate
|
||||
|
||||
mamba create -n tina_eval python=3.11 -y && mamba activate tina_eval
|
||||
./scripts/set/set_env_eval.sh && mamba deactivate
|
||||
|
||||
# download the pre-trained models to the `CKPT_DIR` directory.
|
||||
./scripts/set/prepare.sh
|
||||
```
|
||||
|
||||
>[!IMPORTANT]
|
||||
> For **Reasoning Gym** you need to install `lighteval` from source with a particular branch because of a known issue with evaluating on low-sample datasets such as AIME24.
|
||||
> When the branch is merged into the main branch, we will update the instructions accordingly.
|
||||
|
||||
```bash
|
||||
cd /path/to/installation/folder # e.g. /root/projects
|
||||
|
||||
git clone git@github.com:huggingface/lighteval.git
|
||||
|
||||
cd lighteval
|
||||
|
||||
git checkout remotes/origin/tune-pass-at-k
|
||||
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Training & Evaluation
|
||||
|
||||
* LoRA-based RL with GRPO: `./scripts/training/post_train_grpo.sh`
|
||||
|
||||
<div style="text-align: center;">
|
||||
<img
|
||||
src="assets/ablation.png"
|
||||
alt="Ablation"
|
||||
style="max-width: 50%; height: auto;">
|
||||
</div>
|
||||
|
||||
After that, we have the following file structure in the `CKPT_DIR` directory.
|
||||
```bash
|
||||
CKPT_DIR/
|
||||
│
|
||||
├── models/
|
||||
│ ├── DeepSeek-R1-Distill-Qwen-1.5B/
|
||||
│ │ └── base/ # pre-trained models
|
||||
│ │ └── grpo_PT_DATASET_I/ # post-trained models via GRPO using PT_DATASET_I
|
||||
│ │ │ └── checkpoint-i/ # we should keep checkpoints during post-training in a stepwise manner
|
||||
│ │ │ └── ...
|
||||
│ │ └── grpo_PT_DATASET_II/ # post-trained models via GRPO using PT_DATASET_II
|
||||
│ │ │ └── checkpoint-i/
|
||||
│ │ │ └── ...
|
||||
│ │ └── ...
|
||||
```
|
||||
|
||||
* Re-evaluate baseline models: `./scripts/training/post_train_eval_baselines.sh`
|
||||
|
||||
<div style="text-align: center;">
|
||||
<img
|
||||
src="assets/baseline_eval.png"
|
||||
alt="Baseline Re-evaluation"
|
||||
style="max-width: 30%; height: auto;">
|
||||
</div>
|
||||
|
||||
* Evaluate post-trained models: `./scripts/training/post_train_eval_local.sh`
|
||||
|
||||
<div style="text-align: center;">
|
||||
<img
|
||||
src="assets/tina_eval.png"
|
||||
alt="Tina Evaluation"
|
||||
style="max-width: 40%; height: auto;">
|
||||
</div>
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
We thank Huggingface to open source the amazing [open-r1](https://github.com/huggingface/open-r1/tree/7041fbc9d65b6f1832db727961e8282243f8f82a) project, which is the starting codebase of our Tina project.
|
||||
We also appreciate all researchers releasing their open-source reasoning datasets, including [open-r1/OpenR1-Math-220k](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k), [bethgelab/CuratedThoughts](https://huggingface.co/datasets/bethgelab/CuratedThoughts), [agentica-org/DeepScaleR-Preview-Dataset](https://huggingface.co/datasets/agentica-org/DeepScaleR-Preview-Dataset), [RUC-AIBOX/STILL-3-Preview-RL-Data](https://huggingface.co/datasets/RUC-AIBOX/STILL-3-Preview-RL-Data), [knoveleng/open-rs](https://huggingface.co/datasets/knoveleng/open-rs), [knoveleng/open-s1](https://huggingface.co/datasets/knoveleng/open-s1), and [GAIR/LIMR](https://huggingface.co/datasets/GAIR/LIMR), which are used for our training.
|
||||
|
||||
*Tina's avatar is generated by GPT-4o based on [KYNE](https://www.artsy.net/artist/kyne)'s girls and the following prompt.*
|
||||
|
||||
*Hi, I’m Tina — an INTJ who’s all about getting to the essence of things. I study reasoning models because I’m fascinated by how structured thinking and logic can emerge from data. Outside of that, I recharge with movies, music, and the occasional gaming session. I believe in strategic effort: minimal input, maximum impact — whether it’s in research or everyday life, I’m always looking for the most efficient path to meaningful results.*
|
||||
|
||||
## Citation
|
||||
|
||||
```cite
|
||||
@misc{wang2025tinatinyreasoningmodels,
|
||||
title={Tina: Tiny Reasoning Models via LoRA},
|
||||
author={Shangshang Wang and Julian Asilis and Ömer Faruk Akgül and Enes Burak Bilgin and Ollie Liu and Willie Neiswanger},
|
||||
year={2025},
|
||||
eprint={2504.15777},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL},
|
||||
url={https://arxiv.org/abs/2504.15777},
|
||||
}
|
||||
```
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.9 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 332 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 151 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 130 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 225 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 103 KiB |
File diff suppressed because it is too large
Load Diff
@@ -1,68 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_deepscaler
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- format
|
||||
- accuracy
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 2.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: true
|
||||
lora_r: 32
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:0
|
||||
vllm_gpu_memory_utilization: 0.4
|
||||
vllm_max_model_len: 4608
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
push_to_hub: false
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: TODO
|
||||
|
||||
learning_rate: 1e-06
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 3584
|
||||
max_steps: 11000 # use 11000 for lr scheduler but stop at 5500 steps
|
||||
num_generations: 4
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 100
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
||||
@@ -1,68 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_limr
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- format
|
||||
- accuracy
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 2.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: true
|
||||
lora_r: 32
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:0
|
||||
vllm_gpu_memory_utilization: 0.4
|
||||
vllm_max_model_len: 4608
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
push_to_hub: false
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: TODO
|
||||
|
||||
learning_rate: 1e-06
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 3584
|
||||
max_steps: 360 # use 360 for lr scheduler but stop at 180 steps
|
||||
num_generations: 4
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 10
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
||||
@@ -1,68 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_limr_large_lr_ablation
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- format
|
||||
- accuracy
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 2.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: true
|
||||
lora_r: 32
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:0
|
||||
vllm_gpu_memory_utilization: 0.4
|
||||
vllm_max_model_len: 4608
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
push_to_hub: false
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: TODO
|
||||
|
||||
learning_rate: 5e-06
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 3584
|
||||
max_steps: 360 # use 360 for lr scheduler but stop at 180 steps
|
||||
num_generations: 4
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 10
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
||||
@@ -1,68 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_limr_large_rank_ablation
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- format
|
||||
- accuracy
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 2.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: true
|
||||
lora_r: 128
|
||||
lora_alpha: 512
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:0
|
||||
vllm_gpu_memory_utilization: 0.4
|
||||
vllm_max_model_len: 4608
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
push_to_hub: false
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: TODO
|
||||
|
||||
learning_rate: 1e-06
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 3584
|
||||
max_steps: 360 # use 360 for lr scheduler but stop at 180 steps
|
||||
num_generations: 4
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 50
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
||||
@@ -1,68 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_limr_medium_rank_ablation
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- format
|
||||
- accuracy
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 2.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: true
|
||||
lora_r: 16
|
||||
lora_alpha: 64
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:0
|
||||
vllm_gpu_memory_utilization: 0.4
|
||||
vllm_max_model_len: 4608
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
push_to_hub: false
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: TODO
|
||||
|
||||
learning_rate: 1e-06
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 3584
|
||||
max_steps: 360 # use 360 for lr scheduler but stop at 180 steps
|
||||
num_generations: 4
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 50
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
||||
@@ -1,68 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_limr_small_lr_ablation
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- format
|
||||
- accuracy
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 2.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: true
|
||||
lora_r: 32
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:0
|
||||
vllm_gpu_memory_utilization: 0.4
|
||||
vllm_max_model_len: 4608
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
push_to_hub: false
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: TODO
|
||||
|
||||
learning_rate: 5e-07
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 3584
|
||||
max_steps: 360 # use 360 for lr scheduler but stop at 180 steps
|
||||
num_generations: 4
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 10
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
||||
@@ -1,68 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_limr_small_rank_ablation
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- format
|
||||
- accuracy
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 2.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: true
|
||||
lora_r: 8
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:0
|
||||
vllm_gpu_memory_utilization: 0.4
|
||||
vllm_max_model_len: 4608
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
push_to_hub: false
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: TODO
|
||||
|
||||
learning_rate: 1e-06
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 3584
|
||||
max_steps: 360 # use 360 for lr scheduler but stop at 180 steps
|
||||
num_generations: 4
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 50
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
||||
@@ -1,68 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_limr_tiny_rank_ablation
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- format
|
||||
- accuracy
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 2.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: true
|
||||
lora_r: 4
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:0
|
||||
vllm_gpu_memory_utilization: 0.4
|
||||
vllm_max_model_len: 4608
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
push_to_hub: false
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: TODO
|
||||
|
||||
learning_rate: 1e-06
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 3584
|
||||
max_steps: 360 # use 360 for lr scheduler but stop at 180 steps
|
||||
num_generations: 4
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 50
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
||||
@@ -1,76 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_open_r1
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
- tag_count
|
||||
- length
|
||||
- reasoning_steps
|
||||
- cosine
|
||||
- repetition_penalty
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 1.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: true
|
||||
lora_r: 32
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:0
|
||||
vllm_gpu_memory_utilization: 0.4
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
push_to_hub: false
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: TODO
|
||||
|
||||
learning_rate: 1.0e-06
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 3584
|
||||
num_generations: 8
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 8
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 100
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
||||
@@ -1,68 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_open_rs1
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- format
|
||||
- accuracy
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 2.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: true
|
||||
lora_r: 32
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:0
|
||||
vllm_gpu_memory_utilization: 0.4
|
||||
vllm_max_model_len: 4608
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
push_to_hub: false
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: TODO
|
||||
|
||||
learning_rate: 1.0e-06
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 3584
|
||||
max_steps: 5000 # use 5000 for lr scheduler but stop at 2400 steps
|
||||
num_generations: 4
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 100
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
||||
@@ -1,68 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_open_rs2
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- format
|
||||
- accuracy
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 2.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: true
|
||||
lora_r: 32
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:0
|
||||
vllm_gpu_memory_utilization: 0.4
|
||||
vllm_max_model_len: 4608
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
push_to_hub: false
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: TODO
|
||||
|
||||
learning_rate: 1.0e-06
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 3584
|
||||
max_steps: 1500 # use 1500 for lr scheduler but stop at 850 steps
|
||||
num_generations: 6
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 6
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 50
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
||||
@@ -1,68 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_open_rs3
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- format
|
||||
- cosine
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 2.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: true
|
||||
lora_r: 32
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:0
|
||||
vllm_gpu_memory_utilization: 0.4
|
||||
vllm_max_model_len: 4608
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
push_to_hub: false
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: TODO
|
||||
|
||||
learning_rate: 1.0e-06
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 3584
|
||||
max_steps: 1500 # use 1500 for lr scheduler but stop at 850 steps
|
||||
num_generations: 6
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 6
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 50
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
||||
@@ -1,68 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_open_rs3_drgrpo_ablation
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- format
|
||||
- cosine
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 2.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: true
|
||||
lora_r: 32
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
scale_reward: false # use Dr. GRPO's normalization
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:0
|
||||
vllm_gpu_memory_utilization: 0.4
|
||||
vllm_max_model_len: 4608
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
push_to_hub: false
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: TODO
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
learning_rate: 1.0e-06
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 3584
|
||||
max_steps: 1500 # use 1500 for lr scheduler but stop at 850 steps
|
||||
num_generations: 6
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 6
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 50
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
||||
@@ -1,68 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_rg_math
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- format
|
||||
- accuracy
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 2.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: true
|
||||
lora_r: 32
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:0
|
||||
vllm_gpu_memory_utilization: 0.4
|
||||
vllm_max_model_len: 4608
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
push_to_hub: true
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: starzmustdie
|
||||
|
||||
learning_rate: 1e-06
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 3584
|
||||
max_steps: 2500 # use 2500 for lr scheduler but stop at 1250 steps
|
||||
num_generations: 4
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 100
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
||||
@@ -1,68 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_still
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- length
|
||||
- accuracy
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 2.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: true
|
||||
lora_r: 32
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:0
|
||||
vllm_gpu_memory_utilization: 0.4
|
||||
vllm_max_model_len: 4608
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
push_to_hub: false
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: TODO
|
||||
|
||||
learning_rate: 1e-06
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 3584
|
||||
max_steps: 7500 # use 7500 for lr scheduler but stop at 3750 steps
|
||||
num_generations: 4
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 100
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
||||
@@ -1,76 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_thoughts
|
||||
model_post_train_dataset_config: OpenThoughts-114k-math-default
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
- tag_count
|
||||
- length
|
||||
- reasoning_steps
|
||||
- cosine
|
||||
- repetition_penalty
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 1.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: true
|
||||
lora_r: 32
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:0
|
||||
vllm_gpu_memory_utilization: 0.6
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
push_to_hub: false
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: TODO
|
||||
|
||||
learning_rate: 1.0e-06
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 3584
|
||||
num_generations: 8
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 8
|
||||
report_to:
|
||||
- wandb
|
||||
save_strategy: steps
|
||||
save_steps: 100
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
||||
@@ -1,68 +0,0 @@
|
||||
# check ./tina/utils/constant.py
|
||||
model_post_train_dataset_name: curated_rg_math
|
||||
model_post_train_type: grpo
|
||||
rl_post_train_reward_funcs:
|
||||
- tag_count_reward
|
||||
- length
|
||||
- accuracy
|
||||
rl_post_train_reward_weights:
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 1.0
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
model_name_or_path: Qwen2.5-3B-Instruct
|
||||
attn_implementation: flash_attention_2
|
||||
use_peft: false # full fine tune
|
||||
lora_r: 32
|
||||
lora_alpha: 128
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- gate_proj
|
||||
|
||||
|
||||
# GRPO trainer configs from trl
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: cuda:2
|
||||
vllm_gpu_memory_utilization: 0.9
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
push_to_hub: true
|
||||
hub_strategy: every_save
|
||||
hub_private_repo: true
|
||||
hub_model_id: starzmustdie
|
||||
|
||||
learning_rate: 1e-06
|
||||
lr_scheduler_type: constant_with_warmup
|
||||
lr_scheduler_kwargs:
|
||||
num_warmup_steps: 60
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 2048
|
||||
max_steps: 600
|
||||
num_generations: 8
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 8
|
||||
report_to:
|
||||
- wandb
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 200
|
||||
save_total_limit: 100
|
||||
|
||||
seed: 42
|
||||
temperature: 0.6
|
||||
@@ -1,20 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
MAMBA_ENV="tina"
|
||||
eval "$(mamba shell hook --shell bash)" && mamba activate "${MAMBA_ENV}"
|
||||
echo "START TIME: $(date)"
|
||||
echo "PYTHON ENV: $(which python)"
|
||||
|
||||
source "./scripts/set/set_vars.sh"
|
||||
|
||||
PY_SCRIPT="./scripts/set/run_download_model.py"
|
||||
|
||||
echo ""
|
||||
echo "Running script: ${PY_SCRIPT}"
|
||||
echo ""
|
||||
|
||||
python "${PY_SCRIPT}"
|
||||
|
||||
echo "END TIME: $(date)"
|
||||
echo "DONE"
|
||||
@@ -1,9 +0,0 @@
|
||||
import os
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
if __name__ == "__main__":
|
||||
CKPT_DIR = os.environ["CKPT_DIR"]
|
||||
|
||||
print("Downloading model ...")
|
||||
snapshot_download(repo_id="Qwen/Qwen2.5-3B-Instruct", local_dir=f"{CKPT_DIR}/models/Qwen2.5-3B-Instruct/base")
|
||||
@@ -1,46 +0,0 @@
|
||||
#!/bin/bash
|
||||
# python 3.11 & cuda 11.8
|
||||
|
||||
export MKL_NUM_THREADS=1
|
||||
export NUMEXPR_NUM_THREADS=1
|
||||
export OPENBLAS_NUM_THREADS=1
|
||||
export OMP_NUM_THREADS=1
|
||||
|
||||
conda clean -a -y
|
||||
mamba clean -a -y
|
||||
pip install --upgrade pip
|
||||
pip cache purge
|
||||
|
||||
# mamba install cuda -c nvidia/label/cuda-11.8.0 -y
|
||||
# mamba install gcc gxx -c conda-forge -y
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
pip install xformers --index-url https://download.pytorch.org/whl/cu118
|
||||
pip install vllm
|
||||
pip install flash-attn --no-build-isolation
|
||||
|
||||
pip install accelerate
|
||||
pip install datasets
|
||||
pip install deepspeed
|
||||
pip install distilabel[vllm,ray,openai]
|
||||
pip install e2b-code-interpreter
|
||||
pip install einops
|
||||
pip install flake8
|
||||
pip install huggingface_hub
|
||||
pip install hf_transfer
|
||||
pip install isort
|
||||
pip install langdetect
|
||||
pip install latex2sympy2_extended
|
||||
pip install liger_kernel
|
||||
pip install "math_verify==0.5.2"
|
||||
pip install packaging
|
||||
pip install parameterized
|
||||
pip install peft
|
||||
pip install pytest
|
||||
pip install python-dotenv
|
||||
pip install ruff
|
||||
pip install safetensors
|
||||
pip install sentencepiece
|
||||
pip install transformers
|
||||
pip install trl@git+https://github.com/huggingface/trl.git
|
||||
pip install wandb
|
||||
@@ -1,53 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
export CUDA_LAUNCH_BLOCKING=1
|
||||
export DS_LOG_LEVEL=error
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
export NCCL_P2P_DISABLE=1
|
||||
export NCCL_SHM_DISABLE=1
|
||||
export NCCL_IB_DISABLE=1
|
||||
|
||||
export MKL_THREADING_LAYER=GNU
|
||||
export MKL_NUM_THREADS=1
|
||||
export NUMEXPR_NUM_THREADS=1
|
||||
export OPENBLAS_NUM_THREADS=1
|
||||
export OMP_NUM_THREADS=1
|
||||
|
||||
## basic setup for the env
|
||||
export PROJECT_PREFIX="/root/projects" # e.g. /home/username/projects
|
||||
export SCRATCH_PREFIX="/root/scratch" # e.g. /home/username/scratch
|
||||
mkdir -p "${PROJECT_PREFIX}" "${SCRATCH_PREFIX}"
|
||||
|
||||
export PROJECT_NAME="rg-math"
|
||||
export CORE_POSTFIX="tina"
|
||||
export PROJECT_DIR="${PROJECT_PREFIX}/${PROJECT_NAME}"
|
||||
export PYTHONPATH="${PROJECT_DIR}":$PYTHONPATH
|
||||
export PYTHONPATH="${PROJECT_DIR}/${CORE_POSTFIX}":$PYTHONPATH
|
||||
mkdir -p "${PROJECT_PREFIX}/${PROJECT_NAME}"
|
||||
|
||||
export CKPT_DIR="${PROJECT_DIR}/ckpts"
|
||||
export DATA_DIR="${PROJECT_DIR}/datasets"
|
||||
export OUTPUT_DIR="${PROJECT_DIR}/outputs"
|
||||
export LOGGING_DIR="${PROJECT_DIR}/logs"
|
||||
mkdir -p "${CKPT_DIR}" "${DATA_DIR}" "${OUTPUT_DIR}" "${LOGGING_DIR}"
|
||||
|
||||
## wandb setup
|
||||
# export WANDB_API_KEY="TODO"
|
||||
export WANDB_PROJECT="${PROJECT_NAME}"
|
||||
export WANDB_DIR="${OUTPUT_DIR}"
|
||||
|
||||
wandb login $WANDB_API_KEY
|
||||
|
||||
export CACHE_DIR="${PROJECT_DIR}/.cache"
|
||||
export WANDB_CACHE_DIR="${CACHE_DIR}"
|
||||
export TRITON_CACHE_DIR="${CACHE_DIR}/triton_cache"
|
||||
|
||||
## huggingface setup
|
||||
# export HF_TOKEN="TODO"
|
||||
git config --global credential.helper store
|
||||
huggingface-cli login --token $HF_TOKEN --add-to-git-credential
|
||||
|
||||
export HF_HOME="${CACHE_DIR}/huggingface"
|
||||
export HUGGINGFACE_HUB_CACHE="${HF_HOME}/hub"
|
||||
export HF_DATASETS_CACHE="${HF_HOME}/datasets"
|
||||
@@ -1,54 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
#SBATCH --job-name=grpo_multinode
|
||||
#SBATCH -D .
|
||||
#SBATCH --partition=TODO
|
||||
#SBATCH --account=TODO
|
||||
#SBATCH --output=output-%x.%j
|
||||
#SBATCH --error=error-%x.%j
|
||||
#SBATCH --nodes=2 # number of nodes
|
||||
#SBATCH --ntasks-per-node=1 # number of MP tasks
|
||||
#SBATCH --gres=gpu:2 # number of GPUs per node
|
||||
#SBATCH --cpus-per-task=8 # number of cores per tasks
|
||||
#SBATCH --mem=128G
|
||||
#SBATCH --time=48:00:00 # maximum execution time (HH:MM:SS)
|
||||
#SBATCH --comment "Key=Monitoring,Value=ON"
|
||||
#SBATCH --exclusive
|
||||
|
||||
######################
|
||||
### Set environment ##
|
||||
######################
|
||||
|
||||
ulimit -s unlimited
|
||||
|
||||
MAMBA_ENV="tina"
|
||||
eval "$(mamba shell hook --shell bash)" && mamba activate "${MAMBA_ENV}"
|
||||
echo "START TIME: $(date)"
|
||||
echo "PYTHON ENV: $(which python)"
|
||||
|
||||
source "./scripts/set/set_vars.sh"
|
||||
export GPUS_PER_NODE=2
|
||||
######################
|
||||
|
||||
######################
|
||||
#### Set network #####
|
||||
######################
|
||||
head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
||||
######################
|
||||
|
||||
export LAUNCHER="accelerate launch \
|
||||
--num_processes $((SLURM_NNODES * GPUS_PER_NODE)) \
|
||||
--num_machines $SLURM_NNODES \
|
||||
--machine_rank $SLURM_NODEID \
|
||||
--rdzv_backend c10d \
|
||||
--main_process_ip $head_node_ip \
|
||||
--main_process_port 29500 \
|
||||
"
|
||||
|
||||
PY_SCRIPT="./tina/post_train_hf/grpo.py"
|
||||
PY_CONFIG="./recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/model_curated_deepscaler.yaml"
|
||||
|
||||
# This step is necessary because accelerate launch does not handle multiline arguments properly
|
||||
export CMD="$LAUNCHER $PY_SCRIPT --config $PY_CONFIG"
|
||||
srun $CMD
|
||||
@@ -1,37 +0,0 @@
|
||||
#!/bin/bash
|
||||
# use by running `bash sbatch_launch.sh <script.slurm>`
|
||||
|
||||
cleanup() {
|
||||
echo "Script interrupted. Cleaning up..."
|
||||
scancel "$job_id" 2>/dev/null
|
||||
echo "Job $job_id has been canceled."
|
||||
exit 1
|
||||
}
|
||||
trap cleanup SIGINT
|
||||
|
||||
# launch the slurm script
|
||||
SLURM_FILE=$1
|
||||
echo "Launching $SLURM_FILE ..."
|
||||
job_id=$(sbatch $SLURM_FILE | awk '{print $4}')
|
||||
echo "Submitted job with ID: $job_id"
|
||||
|
||||
# Wait until the job is running
|
||||
while true; do
|
||||
job_status=$(squeue -j "$job_id" -h -o "%T")
|
||||
if [ "$job_status" == "RUNNING" ]; then
|
||||
echo "Job $job_id is now running."
|
||||
sleep 5
|
||||
break
|
||||
elif [ -z "$job_status" ]; then
|
||||
echo "Job $job_id has finished or failed before reaching running state."
|
||||
exit 1
|
||||
else
|
||||
echo "Job $job_id is still in $job_status state. Checking again in 10 seconds..."
|
||||
sleep 10
|
||||
fi
|
||||
done
|
||||
|
||||
# Plot the real-time output
|
||||
output_file=$(scontrol show job "$job_id" | awk -F= '/StdOut/ {print $2}' | sed "s/%A/${job_id}/g" | sed "s/%a/1/g")
|
||||
echo "Tailing output file: $output_file"
|
||||
tail -f "$output_file"
|
||||
@@ -1,32 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
MAMBA_ENV="tina_eval"
|
||||
eval "$(mamba shell hook --shell bash)" && mamba activate "${MAMBA_ENV}"
|
||||
echo "START TIME: $(date)"
|
||||
echo "PYTHON ENV: $(which python)"
|
||||
|
||||
source "./scripts/set/set_vars.sh"
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0" # NOTE: update this if you have more than 1 GPU
|
||||
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
|
||||
echo ""
|
||||
echo "GPU_COUNT: $GPU_COUNT"
|
||||
echo ""
|
||||
|
||||
# MODEL_LIST=("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" "starzmustdie/DeepSeek-R1-Distill-Qwen-1.5B-rg-math" "agentica-org/DeepScaleR-1.5B-Preview" "knoveleng/Open-RS3" "RUC-AIBOX/STILL-3-1.5B-preview")
|
||||
MODEL_LIST=("Qwen/Qwen2.5-3B-Instruct" "starzmustdie/Qwen2.5-3B-Instruct")
|
||||
TASKS=("aime24" "aime25" "amc23" "minerva" "math_500")
|
||||
|
||||
for MODEL_NAME in "${MODEL_LIST[@]}"; do
|
||||
for TASK in "${TASKS[@]}"; do
|
||||
MODEL_ARGS="model_name=$MODEL_NAME,dtype=bfloat16,data_parallel_size=$GPU_COUNT,max_model_length=32768,gpu_memory_utilization=0.7,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
|
||||
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
|
||||
--custom-tasks ./scripts/training/run_post_train_eval.py \
|
||||
--use-chat-template \
|
||||
--output-dir "${OUTPUT_DIR}/${TASK}"
|
||||
done
|
||||
done
|
||||
|
||||
echo "END TIME: $(date)"
|
||||
echo "DONE"
|
||||
@@ -1,73 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
MAMBA_ENV="tina_eval"
|
||||
eval "$(mamba shell hook --shell bash)" && mamba activate "${MAMBA_ENV}"
|
||||
echo "START TIME: $(date)"
|
||||
echo "PYTHON ENV: $(which python)"
|
||||
|
||||
source "./scripts/set/set_vars.sh"
|
||||
|
||||
export CUDA_VISIBLE_DEVICES=0,1 # make sure all evaluation run on 2 GPUs
|
||||
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
|
||||
|
||||
echo ""
|
||||
echo "GPU_COUNT: $GPU_COUNT, make sure using 2 GPUs."
|
||||
echo ""
|
||||
|
||||
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
PT_TYPE="grpo"
|
||||
|
||||
## Main datasets
|
||||
DATASET_NAME="curated_deepscaler"
|
||||
#DATASET_NAME="curated_still"
|
||||
#DATASET_NAME="curated_open_rs3"
|
||||
#DATASET_NAME="curated_open_rs2"
|
||||
#DATASET_NAME="curated_open_rs1"
|
||||
|
||||
## Extra datasets
|
||||
#DATASET_NAME="curated_limr"
|
||||
#DATASET_NAME="curated_open_r1"
|
||||
#DATASET_NAME="curated_thoughts"
|
||||
|
||||
## Ablation
|
||||
#DATASET_NAME="curated_limr_large_lr_ablation"
|
||||
#DATASET_NAME="curated_limr_small_lr_ablation"
|
||||
#DATASET_NAME="curated_limr_large_rank_ablation"
|
||||
#DATASET_NAME="curated_limr_medium_rank_ablation"
|
||||
#DATASET_NAME="curated_limr_small_rank_ablation"
|
||||
#DATASET_NAME="curated_limr_tiny_rank_ablation"
|
||||
#DATASET_NAME="curated_open_rs3_drgrpo_ablation"
|
||||
|
||||
CKPT_LIST=$(ls "${CKPT_DIR}/models/${MODEL_NAME}/${PT_TYPE}_${DATASET_NAME}" | grep -E "^checkpoint-[0-9]+$")
|
||||
#CKPT_LIST=("checkpoint-XXX")
|
||||
|
||||
# loop over all the checkpoints in the list
|
||||
for CKPT in "${CKPT_LIST[@]}"; do
|
||||
echo "Running model post train merging base and adapter for checkpoint: ${CKPT}"
|
||||
python ./scripts/training/run_post_train_merge.py \
|
||||
--model_name "${MODEL_NAME}" \
|
||||
--adapter_type "${PT_TYPE}_${DATASET_NAME}" \
|
||||
--ckpt "${CKPT}"
|
||||
|
||||
MODEL_PATH="${CKPT_DIR}/models/${MODEL_NAME}/${PT_TYPE}_${DATASET_NAME}/${CKPT}-merged"
|
||||
|
||||
# Set model arguments (ensure that MODEL_PATH, GPU_COUNT, OUTPUT_DIR, and MODEL are defined)
|
||||
MODEL_ARGS="pretrained=$MODEL_PATH,dtype=bfloat16,data_parallel_size=$GPU_COUNT,max_model_length=32768,gpu_memory_utilization=0.5,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
|
||||
|
||||
# Define an array of tasks to evaluate
|
||||
tasks=("aime24" "math_500" "gpqa:diamond" "aime25" "amc23" "minerva")
|
||||
|
||||
# Loop over each task and evaluate
|
||||
for TASK in "${tasks[@]}"; do
|
||||
echo "Evaluating task: $TASK"
|
||||
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
|
||||
--custom-tasks ./scripts/training/run_post_train_eval.py \
|
||||
--use-chat-template \
|
||||
--output-dir "${OUTPUT_DIR}/${MODEL}/${TASK}"
|
||||
done
|
||||
|
||||
done
|
||||
|
||||
echo "END TIME: $(date)"
|
||||
echo "DONE"
|
||||
@@ -1,68 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
|
||||
MAMBA_ENV="tina"
|
||||
eval "$(mamba shell hook --shell bash)" && mamba activate "${MAMBA_ENV}"
|
||||
echo "START TIME: $(date)"
|
||||
echo "PYTHON ENV: $(which python)"
|
||||
|
||||
source "./scripts/set/set_vars.sh"
|
||||
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2 # Set the GPUs you want to use
|
||||
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
|
||||
|
||||
NUM_PROCESSES_TRAINING=$((GPU_COUNT - 1))
|
||||
|
||||
echo ""
|
||||
echo "Number of GPUs: ${GPU_COUNT}"
|
||||
echo "Number of processes for training: ${NUM_PROCESSES_TRAINING}"
|
||||
echo ""
|
||||
|
||||
MODEL_NAME="Qwen2.5-3B-Instruct"
|
||||
|
||||
## Main datasets
|
||||
#DATASET_NAME="curated_deepscaler"
|
||||
#DATASET_NAME="curated_still"
|
||||
#DATASET_NAME="curated_open_rs3"
|
||||
#DATASET_NAME="curated_open_rs2"
|
||||
#DATASET_NAME="curated_open_rs1"
|
||||
|
||||
## Extra datasets
|
||||
#DATASET_NAME="curated_limr"
|
||||
#DATASET_NAME="curated_open_r1"
|
||||
#DATASET_NAME="curated_thoughts"
|
||||
|
||||
## Ablation
|
||||
#DATASET_NAME="curated_limr_large_lr_ablation"
|
||||
#DATASET_NAME="curated_limr_small_lr_ablation"
|
||||
#DATASET_NAME="curated_limr_large_rank_ablation"
|
||||
#DATASET_NAME="curated_limr_medium_rank_ablation"
|
||||
#DATASET_NAME="curated_limr_small_rank_ablation"
|
||||
#DATASET_NAME="curated_limr_tiny_rank_ablation"
|
||||
#DATASET_NAME="curated_open_rs3_drgrpo_ablation"
|
||||
|
||||
## Reasoning Gym
|
||||
DATASET_NAME="curated_rg_math"
|
||||
|
||||
PY_SCRIPT="./tina/post_train_hf/grpo.py"
|
||||
PY_CONFIG="./recipes/${MODEL_NAME}/grpo/model_${DATASET_NAME}.yaml"
|
||||
ACCELERATE_DS_CONFIG="./recipes/accelerate_ds_cfgs/ds_zero2.yaml"
|
||||
|
||||
echo ""
|
||||
echo "Running ${PY_SCRIPT} on model ${MODEL_NAME} with dataset ${DATASET_NAME}"
|
||||
echo ""
|
||||
|
||||
if [[ "${DATASET_NAME}" == "curated_thoughts" || "${DATASET_NAME}" == "curated_open_r1" || "${DATASET_NAME}" == "curated_open_rs3" || "${DATASET_NAME}" == "curated_open_rs3_drgrpo_ablation" ]]; then
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch \
|
||||
--config_file "${ACCELERATE_DS_CONFIG}" \
|
||||
--main_process_port=29500 \
|
||||
--num_processes="${NUM_PROCESSES_TRAINING}" "${PY_SCRIPT}" --config "${PY_CONFIG}" --cosine_max_len 3584
|
||||
else
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch \
|
||||
--config_file "${ACCELERATE_DS_CONFIG}" \
|
||||
--main_process_port=29500 \
|
||||
--num_processes="${NUM_PROCESSES_TRAINING}" "${PY_SCRIPT}" --config "${PY_CONFIG}"
|
||||
fi
|
||||
|
||||
echo "END TIME: $(date)"
|
||||
echo "DONE"
|
||||
@@ -1,229 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom evaluation tasks for LightEval."""
|
||||
|
||||
import numpy as np
|
||||
from lighteval.metrics.dynamic_metrics import (
|
||||
ExprExtractionConfig,
|
||||
LatexExtractionConfig,
|
||||
compare_gold_target,
|
||||
extract_target_from_pred,
|
||||
get_extraction_regexes,
|
||||
)
|
||||
from lighteval.metrics.metrics import Metrics
|
||||
from lighteval.metrics.metrics_sample import PassAtK
|
||||
from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase, SampleLevelMetric
|
||||
from lighteval.tasks.lighteval_task import LightevalTaskConfig
|
||||
from lighteval.tasks.requests import Doc
|
||||
from lighteval.utils.language import Language
|
||||
|
||||
# Prompt template adapted from
|
||||
# - simple-evals: https://github.com/openai/simple-evals/blob/6e84f4e2aed6b60f6a0c7b8f06bbbf4bfde72e58/math_eval.py#L17
|
||||
# - Llama 3: https://huggingface.co/datasets/meta-llama/Llama-3.2-1B-Instruct-evals/viewer/Llama-3.2-1B-Instruct-evals__math__details?views%5B%5D=llama_32_1b_instruct_evals__math__details
|
||||
# Note that it is important to have the final answer in a box for math-verify to work correctly
|
||||
MATH_QUERY_TEMPLATE = """
|
||||
Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering.
|
||||
|
||||
{Question}
|
||||
""".strip()
|
||||
|
||||
|
||||
math_pass_at_1_2n = SampleLevelMetric(
|
||||
metric_name="math_pass@1:2_samples",
|
||||
sample_level_fn=PassAtK(
|
||||
k=1,
|
||||
n=2,
|
||||
strip_strings=True,
|
||||
# Extracting mathematical expressions and latex expressions
|
||||
normalize_gold=lambda k: extract_target_from_pred(
|
||||
k,
|
||||
get_extraction_regexes(
|
||||
formatted_doc=None,
|
||||
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
|
||||
language=Language.ENGLISH,
|
||||
),
|
||||
),
|
||||
# Extracting mathematical expressions and latex expressions
|
||||
normalize_pred=lambda k: extract_target_from_pred(
|
||||
k,
|
||||
get_extraction_regexes(
|
||||
formatted_doc=None,
|
||||
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
|
||||
language=Language.ENGLISH,
|
||||
),
|
||||
),
|
||||
# Uses sympy for comparision
|
||||
sample_scoring_function=compare_gold_target,
|
||||
).compute,
|
||||
category=MetricCategory.GENERATIVE_SAMPLING,
|
||||
use_case=MetricUseCase.REASONING,
|
||||
corpus_level_fn=np.mean,
|
||||
higher_is_better=True,
|
||||
)
|
||||
|
||||
|
||||
def math_prompt_fn(line, task_name: str = None):
|
||||
return Doc(
|
||||
task_name=task_name,
|
||||
query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
|
||||
choices=[line["solution"]],
|
||||
gold_index=0,
|
||||
)
|
||||
|
||||
|
||||
def aime_prompt_fn(line, task_name: str = None):
|
||||
return Doc(
|
||||
task_name=task_name,
|
||||
query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
|
||||
choices=[line["answer"]],
|
||||
gold_index=0,
|
||||
)
|
||||
|
||||
|
||||
def amc_prompt_fn(line, task_name: str = None):
|
||||
return Doc(
|
||||
task_name=task_name,
|
||||
query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
|
||||
choices=[line["answer"]],
|
||||
gold_index=0,
|
||||
)
|
||||
|
||||
|
||||
def minerva_prompt_fn(line, task_name: str = None):
|
||||
return Doc(
|
||||
task_name=task_name,
|
||||
query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
|
||||
choices=[line["solution"]],
|
||||
gold_index=0,
|
||||
)
|
||||
|
||||
|
||||
# Define tasks
|
||||
aime24 = LightevalTaskConfig(
|
||||
name="aime24",
|
||||
suite=["custom"],
|
||||
prompt_function=aime_prompt_fn,
|
||||
hf_repo="HuggingFaceH4/aime_2024",
|
||||
hf_subset="default",
|
||||
hf_avail_splits=["train"],
|
||||
evaluation_splits=["train"],
|
||||
few_shots_split=None,
|
||||
few_shots_select=None,
|
||||
generation_size=32768,
|
||||
metric=[
|
||||
# Metrics.math_pass_at_1_1n,
|
||||
# math_pass_at_1_2n,
|
||||
# Metrics.math_pass_at_1_4n,
|
||||
# Metrics.math_pass_at_1_16n,
|
||||
Metrics.math_pass_at_1_32n,
|
||||
# Metrics.math_pass_at_1_64n,
|
||||
],
|
||||
version=1,
|
||||
)
|
||||
|
||||
aime25 = LightevalTaskConfig(
|
||||
name="aime25",
|
||||
suite=["custom"],
|
||||
prompt_function=aime_prompt_fn,
|
||||
hf_repo="yentinglin/aime_2025",
|
||||
hf_subset="default",
|
||||
hf_avail_splits=["train"],
|
||||
evaluation_splits=["train"],
|
||||
few_shots_split=None,
|
||||
few_shots_select=None,
|
||||
generation_size=32768,
|
||||
metric=[
|
||||
# Metrics.math_pass_at_1_1n,
|
||||
# math_pass_at_1_2n,
|
||||
# Metrics.math_pass_at_1_4n,
|
||||
# Metrics.math_pass_at_1_16n,
|
||||
Metrics.math_pass_at_1_32n,
|
||||
# Metrics.math_pass_at_1_64n,
|
||||
],
|
||||
version=1,
|
||||
)
|
||||
|
||||
amc23 = LightevalTaskConfig(
|
||||
name="amc23",
|
||||
suite=["custom"],
|
||||
prompt_function=amc_prompt_fn,
|
||||
hf_repo="knoveleng/AMC-23",
|
||||
hf_subset="default",
|
||||
hf_avail_splits=["train"],
|
||||
evaluation_splits=["train"],
|
||||
few_shots_split=None,
|
||||
few_shots_select=None,
|
||||
generation_size=32768,
|
||||
metric=[
|
||||
# Metrics.math_pass_at_1_1n,
|
||||
# math_pass_at_1_2n,
|
||||
# Metrics.math_pass_at_1_4n,
|
||||
# Metrics.math_pass_at_1_16n,
|
||||
Metrics.math_pass_at_1_32n,
|
||||
# Metrics.math_pass_at_1_64n,
|
||||
],
|
||||
version=1,
|
||||
)
|
||||
|
||||
math_500 = LightevalTaskConfig(
|
||||
name="math_500",
|
||||
suite=["custom"],
|
||||
prompt_function=math_prompt_fn,
|
||||
hf_repo="HuggingFaceH4/MATH-500",
|
||||
hf_subset="default",
|
||||
hf_avail_splits=["test"],
|
||||
evaluation_splits=["test"],
|
||||
few_shots_split=None,
|
||||
few_shots_select=None,
|
||||
generation_size=32768,
|
||||
metric=[
|
||||
# Metrics.math_pass_at_1_1n,
|
||||
math_pass_at_1_2n,
|
||||
],
|
||||
version=1,
|
||||
)
|
||||
|
||||
minerva = LightevalTaskConfig(
|
||||
name="minerva",
|
||||
suite=["custom"],
|
||||
prompt_function=minerva_prompt_fn,
|
||||
hf_repo="knoveleng/Minerva-Math",
|
||||
hf_subset="default",
|
||||
hf_avail_splits=["train"],
|
||||
evaluation_splits=["train"],
|
||||
few_shots_split=None,
|
||||
few_shots_select=None,
|
||||
generation_size=32768,
|
||||
metric=[
|
||||
# Metrics.math_pass_at_1_1n,
|
||||
# math_pass_at_1_2n,
|
||||
Metrics.math_pass_at_1_4n,
|
||||
],
|
||||
version=1,
|
||||
)
|
||||
|
||||
|
||||
# Add tasks to the table
|
||||
TASKS_TABLE = []
|
||||
TASKS_TABLE.append(aime24)
|
||||
TASKS_TABLE.append(aime25)
|
||||
TASKS_TABLE.append(amc23)
|
||||
TASKS_TABLE.append(math_500)
|
||||
TASKS_TABLE.append(minerva)
|
||||
|
||||
# MODULE LOGIC
|
||||
if __name__ == "__main__":
|
||||
print([t["name"] for t in TASKS_TABLE])
|
||||
print(len(TASKS_TABLE))
|
||||
@@ -1,40 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
def argparser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", type=str, default="DeepSeek-R1-Distill-Qwen-1.5B")
|
||||
parser.add_argument("--adapter_type", type=str, default="grpo_curated_open_r1")
|
||||
parser.add_argument("--ckpt", type=str, default="checkpoint-500")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = argparser()
|
||||
|
||||
ckpt_dir = os.environ["CKPT_DIR"]
|
||||
ckpt = args.ckpt
|
||||
adapter_type = args.adapter_type
|
||||
model_name = args.model_name
|
||||
|
||||
base_model_name_or_path = f"{ckpt_dir}/models/{model_name}/base"
|
||||
adapter_model_name_or_path = f"{ckpt_dir}/models/{model_name}/{adapter_type}/{ckpt}"
|
||||
merged_model_name_or_path = f"{ckpt_dir}/models/{model_name}/{adapter_type}/{ckpt}-merged"
|
||||
|
||||
print("Merged model will be saved to: ", merged_model_name_or_path)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model_name_or_path, torch_dtype=torch.bfloat16, device_map="auto"
|
||||
) # Automatically distributes across available GPUs
|
||||
|
||||
model = PeftModel.from_pretrained(base_model, adapter_model_name_or_path)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
model.save_pretrained(merged_model_name_or_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)
|
||||
tokenizer.save_pretrained(merged_model_name_or_path)
|
||||
@@ -1,21 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
|
||||
# check ./recipes/MODEL_NAME/PT_METHOD/model_DATASET.yaml
|
||||
@dataclass
|
||||
class ModelPTConfig:
|
||||
# //*******Model post-training configs*******//
|
||||
model_post_train_type: Literal["grpo", "sft"] = field(default="grpo")
|
||||
model_post_train_dataset_name: str = field(default="curated_deepscaler")
|
||||
model_post_train_dataset_config: str | None = field(default=None)
|
||||
|
||||
rl_post_train_reward_funcs: list[str] = field(default_factory=lambda: ["accuracy", "format"])
|
||||
rl_post_train_reward_weights: list[str] = field(default_factory=lambda: [2.0, 1.0])
|
||||
cosine_min_value_wrong: float = field(default=0.0)
|
||||
cosine_max_value_wrong: float = field(default=-0.5)
|
||||
cosine_min_value_correct: float = field(default=0.5)
|
||||
cosine_max_value_correct: float = field(default=1.0)
|
||||
cosine_max_len: int = field(default=1000)
|
||||
repetition_n_grams: int = field(default=3)
|
||||
repetition_max_penalty: float = field(default=-1.0)
|
||||
@@ -1,167 +0,0 @@
|
||||
import copy
|
||||
import logging
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import wandb
|
||||
from tina.post_train_hf.hub import push_to_hub_revision
|
||||
from tina.utils.prompt import FIXED_PROMPT_FOR_EVALUATION, OPEN_R1_SYSTEM_PROMPT
|
||||
from transformers import TrainerCallback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FixedPromptEvaluationCallback(TrainerCallback):
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt=OPEN_R1_SYSTEM_PROMPT,
|
||||
prompt=FIXED_PROMPT_FOR_EVALUATION,
|
||||
max_generation_length=4096,
|
||||
eval_steps=100,
|
||||
):
|
||||
|
||||
self.system_prompt = system_prompt
|
||||
self.prompt = prompt
|
||||
self.max_generation_length = max_generation_length
|
||||
self.eval_steps = eval_steps
|
||||
self.completion_table = {
|
||||
"step": [],
|
||||
"prompt": [],
|
||||
"completion": [],
|
||||
}
|
||||
|
||||
def on_init_end(self, args, state, control, processing_class=None, **kwargs):
|
||||
tokenizer = processing_class
|
||||
messages = [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": self.prompt}]
|
||||
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
self.tokenized_prompt = tokenizer(input_text, return_tensors="pt")
|
||||
|
||||
def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):
|
||||
if state.global_step % self.eval_steps == 0:
|
||||
if state.is_world_process_zero:
|
||||
completion = self.eval_prompt(model, processing_class)
|
||||
self.completion_table["step"].append(str(state.global_step))
|
||||
self.completion_table["prompt"].append(self.prompt)
|
||||
self.completion_table["completion"].append(completion)
|
||||
df = pd.DataFrame(self.completion_table)
|
||||
wandb.log({"completions": wandb.Table(dataframe=df)})
|
||||
|
||||
def eval_prompt(self, model, tokenizer):
|
||||
if hasattr(model, "peft_config"):
|
||||
model.peft_config["default"].inference_mode = True
|
||||
|
||||
self.tokenized_prompt.to(model.device)
|
||||
outputs = model.generate(
|
||||
**self.tokenized_prompt,
|
||||
max_length=self.max_generation_length,
|
||||
temperature=0.01, # Very low temperature
|
||||
top_k=1, # Only consider the most likely token
|
||||
top_p=1.0, # Disable nucleus sampling or set to a high value
|
||||
)
|
||||
completion = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
if hasattr(model, "peft_config"):
|
||||
model.peft_config["default"].inference_mode = False
|
||||
|
||||
return completion
|
||||
|
||||
|
||||
class GradientClippingLoggerCallback(TrainerCallback):
|
||||
def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):
|
||||
self.clipped_grad_norm = np.sqrt(
|
||||
sum(p.grad.data.norm(2).item() ** 2 for p in model.parameters() if p.grad is not None)
|
||||
)
|
||||
if state.is_world_process_zero:
|
||||
wandb.log({"clipped_grad_norm": self.clipped_grad_norm})
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if logs is not None:
|
||||
logs["clipped_grad_norm"] = self.clipped_grad_norm
|
||||
|
||||
|
||||
class DummyConfig:
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class PushToHubRevisionCallback(TrainerCallback):
|
||||
def __init__(self, dataset_name, use_peft):
|
||||
self.dataset_name = dataset_name
|
||||
self.use_peft = use_peft
|
||||
|
||||
self.pending_futures = [] # Track pending push operations
|
||||
|
||||
def on_save(self, args, state, control, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
global_step = state.global_step
|
||||
|
||||
# Create merged model directory
|
||||
if self.use_peft:
|
||||
ckpt_model_dir = f"{args.output_dir}/checkpoint-{global_step}-merged"
|
||||
original_model = kwargs["model"] # Don't pop it, keep it intact
|
||||
model_to_save = copy.deepcopy(original_model).merge_and_unload()
|
||||
model_to_save.save_pretrained(ckpt_model_dir)
|
||||
else:
|
||||
# this dir is already created by the HF Trainer, no need to manually save
|
||||
ckpt_model_dir = f"{args.output_dir}/checkpoint-{global_step}"
|
||||
|
||||
tokenizer = kwargs.get("tokenizer") or kwargs.get("processing_class")
|
||||
if tokenizer is None:
|
||||
raise ValueError("Tokenizer or processing_class must be provided.")
|
||||
tokenizer.save_pretrained(ckpt_model_dir)
|
||||
|
||||
dummy_config = DummyConfig(
|
||||
hub_model_id=args.hub_model_id,
|
||||
hub_model_revision=self.dataset_name,
|
||||
checkpoint=f"checkpoint-{global_step}",
|
||||
output_dir=ckpt_model_dir,
|
||||
dataset_name=self.dataset_name,
|
||||
)
|
||||
|
||||
# Start the push operation
|
||||
future = push_to_hub_revision(dummy_config, extra_ignore_patterns=["*.pt"])
|
||||
|
||||
# Store the future and directory path for cleanup later
|
||||
self.pending_futures.append((future, ckpt_model_dir))
|
||||
|
||||
# Check and clean up any completed pushes
|
||||
if self.use_peft:
|
||||
self._cleanup_completed_pushes()
|
||||
|
||||
return control
|
||||
|
||||
def _cleanup_completed_pushes(self):
|
||||
"""Check pending futures and remove directories for completed pushes."""
|
||||
still_pending = []
|
||||
for future, dir_path in self.pending_futures:
|
||||
if future.done():
|
||||
if self.use_peft:
|
||||
# The push is complete, safe to delete the directory
|
||||
try:
|
||||
shutil.rmtree(dir_path)
|
||||
logger.info(f"\nCleaned up merged model directory: {dir_path}\n")
|
||||
except Exception as e:
|
||||
logger.error(f"\nFailed to clean up directory {dir_path}: {e}\n")
|
||||
else:
|
||||
# Push is still in progress, keep in pending list
|
||||
still_pending.append((future, dir_path))
|
||||
|
||||
self.pending_futures = still_pending
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
"""Make sure to clean up any remaining directories at the end of training."""
|
||||
if state.is_world_process_zero and self.use_peft:
|
||||
# Wait for all pending pushes to complete
|
||||
logger.info(f"\nCleaned up for lora models.")
|
||||
for future, dir_path in self.pending_futures:
|
||||
future.result() # Wait for completion
|
||||
try:
|
||||
shutil.rmtree(dir_path)
|
||||
logger.info(f"\nCleaned up merged model directory: {dir_path}\n")
|
||||
except Exception as e:
|
||||
logger.error(f"\nFailed to clean up directory {dir_path}: {e}\n")
|
||||
|
||||
self.pending_futures = []
|
||||
@@ -1,245 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import Dataset, load_dataset
|
||||
from peft import LoraConfig, TaskType, get_peft_model
|
||||
from tina.config import ModelPTConfig
|
||||
from tina.post_train_hf.callback import (
|
||||
FixedPromptEvaluationCallback,
|
||||
GradientClippingLoggerCallback,
|
||||
PushToHubRevisionCallback,
|
||||
)
|
||||
from tina.post_train_hf.grpo_config import GRPOConfig # use this new one for Dr.GRPO
|
||||
from tina.post_train_hf.grpo_trainer import GRPOTrainer # use this new one for Dr.GRPO
|
||||
from tina.post_train_hf.preprocess import make_conv_for_grpo
|
||||
from tina.post_train_hf.rewards import (
|
||||
accuracy_reward,
|
||||
format_reward,
|
||||
get_cosine_scaled_reward,
|
||||
get_repetition_penalty_reward,
|
||||
len_reward,
|
||||
reasoning_steps_reward,
|
||||
tag_count_reward,
|
||||
)
|
||||
from tina.utils.chat_template import DEFAULT_CHAT_TEMPLATE, REASON_CHAT_TEMPLATE
|
||||
from tina.utils.constant import RL_POST_TRAIN_DATASET_MAP
|
||||
from tina.utils.prompt import OPEN_R1_SYSTEM_PROMPT, OPEN_RS_SYSTEM_PROMPT
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from trl import ModelConfig, TrlParser # GRPOTrainer, GRPOConfig
|
||||
|
||||
|
||||
def main():
|
||||
parser = TrlParser((ModelPTConfig, GRPOConfig, ModelConfig))
|
||||
pt_args, training_args, model_args = parser.parse_args_and_config()
|
||||
set_seed(training_args.seed)
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "tina_model_post_training"
|
||||
|
||||
################
|
||||
# Set up logging
|
||||
################
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Log on each process a small summary
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Model parameters {model_args}")
|
||||
logger.info(f"Post training parameters {pt_args}")
|
||||
logger.info(f"Training parameters {training_args}")
|
||||
|
||||
#####################
|
||||
# Set up output paths
|
||||
#####################
|
||||
|
||||
current_time = datetime.now()
|
||||
formatted_datetime = current_time.strftime("%Y_%m_%d_%H_%M_%S")
|
||||
|
||||
model_name_or_path = model_args.model_name_or_path
|
||||
ckpt_dir = os.environ["CKPT_DIR"]
|
||||
ckpt_prefix = f"{ckpt_dir}/models/{model_name_or_path}"
|
||||
if model_args.use_peft:
|
||||
ckpt_postfix = f"{pt_args.model_post_train_type}_{pt_args.model_post_train_dataset_name}"
|
||||
else:
|
||||
ckpt_postfix = f"full_{pt_args.model_post_train_type}_{pt_args.model_post_train_dataset_name}"
|
||||
|
||||
model_args.model_name_or_path = f"{ckpt_prefix}/base"
|
||||
training_args.output_dir = f"{ckpt_prefix}/{ckpt_postfix}"
|
||||
# training_args.hub_model_id = f"{training_args.hub_model_id}_{ckpt_postfix}"
|
||||
training_args.run_name = f"{model_name_or_path}_{ckpt_postfix}_{formatted_datetime}"
|
||||
|
||||
training_args.hub_model_id = f"{training_args.hub_model_id}/{model_name_or_path}"
|
||||
|
||||
#######################################################################
|
||||
# Load and preprocess dataset (tokenization is handled by GRPO Trainer)
|
||||
#######################################################################
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
|
||||
if "Llama" in model_args.model_name_or_path:
|
||||
tokenizer.pad_token = "<|finetune_right_pad_id|>"
|
||||
elif "Qwen" in model_args.model_name_or_path:
|
||||
tokenizer.pad_token = "<|fim_pad|>"
|
||||
tokenizer.chat_template = REASON_CHAT_TEMPLATE
|
||||
|
||||
model_post_train_dataset_name = RL_POST_TRAIN_DATASET_MAP[pt_args.model_post_train_dataset_name]
|
||||
if pt_args.model_post_train_dataset_config is not None:
|
||||
train_dataset = load_dataset(
|
||||
model_post_train_dataset_name, split="train", name=pt_args.model_post_train_dataset_config
|
||||
)
|
||||
else:
|
||||
train_dataset = load_dataset(model_post_train_dataset_name, split="train")
|
||||
# required by GRPOTrainer: (prompt, solution) columns
|
||||
if "solution" not in train_dataset.column_names and "answer" in train_dataset.column_names:
|
||||
train_dataset = train_dataset.rename_column("answer", "solution")
|
||||
|
||||
# Wrap the 'solution' values in $...$
|
||||
def wrap_in_math(example):
|
||||
return {"solution": f"${example['solution']}$"}
|
||||
|
||||
# Apply the transformation to the entire dataset
|
||||
train_dataset = train_dataset.map(wrap_in_math)
|
||||
if "problem" not in train_dataset.column_names and "question" in train_dataset.column_names:
|
||||
train_dataset = train_dataset.rename_column("question", "problem")
|
||||
if "problem" not in train_dataset.column_names and "prompt" in train_dataset.column_names:
|
||||
train_dataset = train_dataset.rename_column("prompt", "problem")
|
||||
if "messages" in train_dataset.column_names:
|
||||
train_dataset = train_dataset.remove_columns("messages")
|
||||
|
||||
# handle deepscaler separately
|
||||
if "deepscaler" in pt_args.model_post_train_dataset_name:
|
||||
train_dataset = train_dataset.rename_column("solution", "solution_archive")
|
||||
train_dataset = train_dataset.rename_column("answer", "solution")
|
||||
|
||||
# Wrap the 'solution' values in $...$
|
||||
def wrap_in_math(example):
|
||||
return {"solution": f"${example['solution']}$"}
|
||||
|
||||
# Apply the transformation to the entire dataset
|
||||
train_dataset = train_dataset.map(wrap_in_math)
|
||||
|
||||
SYSTEM_PROMPT = OPEN_RS_SYSTEM_PROMPT if "open-rs" in model_post_train_dataset_name else OPEN_R1_SYSTEM_PROMPT
|
||||
train_dataset = train_dataset.map(make_conv_for_grpo, fn_kwargs={"system_prompt": SYSTEM_PROMPT})
|
||||
|
||||
######################
|
||||
# Initialize the model
|
||||
######################
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
)
|
||||
|
||||
if model_args.use_peft:
|
||||
logger.info(
|
||||
f"\n Using PEFT with {model_args.lora_r} rank, {model_args.lora_alpha} alpha, {model_args.lora_dropout} dropout."
|
||||
)
|
||||
peft_config = LoraConfig(
|
||||
r=model_args.lora_r,
|
||||
lora_alpha=model_args.lora_alpha,
|
||||
lora_dropout=model_args.lora_dropout,
|
||||
target_modules=model_args.lora_target_modules,
|
||||
inference_mode=False,
|
||||
bias="none",
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
#############################
|
||||
# Initialize the GRPO trainer
|
||||
#############################
|
||||
|
||||
RL_POST_TRAIN_REWARD_MAP = {
|
||||
"accuracy": accuracy_reward,
|
||||
"format": format_reward,
|
||||
"tag_count": tag_count_reward,
|
||||
"length": len_reward,
|
||||
"reasoning_steps": reasoning_steps_reward,
|
||||
"cosine": get_cosine_scaled_reward(
|
||||
min_value_wrong=pt_args.cosine_min_value_wrong,
|
||||
max_value_wrong=pt_args.cosine_max_value_wrong,
|
||||
min_value_correct=pt_args.cosine_min_value_correct,
|
||||
max_value_correct=pt_args.cosine_max_value_correct,
|
||||
max_len=pt_args.cosine_max_len,
|
||||
),
|
||||
"repetition_penalty": get_repetition_penalty_reward(
|
||||
ngram_size=pt_args.repetition_n_grams,
|
||||
max_penalty=pt_args.repetition_max_penalty,
|
||||
),
|
||||
}
|
||||
rl_reward_funcs = [RL_POST_TRAIN_REWARD_MAP[func] for func in pt_args.rl_post_train_reward_funcs]
|
||||
training_args.reward_weights = pt_args.rl_post_train_reward_weights
|
||||
|
||||
if model_args.use_peft:
|
||||
callbacks = [
|
||||
FixedPromptEvaluationCallback(system_prompt=OPEN_R1_SYSTEM_PROMPT, eval_steps=training_args.save_steps),
|
||||
# PushToHubRevisionCallback(dataset_name=pt_args.model_post_train_dataset_name, use_peft=model_args.use_peft)
|
||||
]
|
||||
else:
|
||||
callbacks = [
|
||||
FixedPromptEvaluationCallback(system_prompt=OPEN_R1_SYSTEM_PROMPT, eval_steps=training_args.save_steps),
|
||||
GradientClippingLoggerCallback(),
|
||||
# PushToHubRevisionCallback(dataset_name=pt_args.model_post_train_dataset_name, use_peft=model_args.use_peft)
|
||||
]
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=rl_reward_funcs,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
#########################
|
||||
# Training and Evaluation
|
||||
#########################
|
||||
|
||||
logger.info(f"\nStarting training for {training_args.num_train_epochs} epochs.")
|
||||
|
||||
# Check for last checkpoint
|
||||
ckpt = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
ckpt = training_args.resume_from_checkpoint
|
||||
elif os.path.isdir(training_args.output_dir):
|
||||
ckpt = get_last_checkpoint(training_args.output_dir)
|
||||
if ckpt:
|
||||
logger.info(f"\nCheckpoint detected, resuming training at {ckpt=}.")
|
||||
else:
|
||||
logger.info("\nNo checkpoint detected, starting training from scratch.")
|
||||
|
||||
train_result = trainer.train(resume_from_checkpoint=ckpt)
|
||||
train_metrics = train_result.metrics
|
||||
trainer.log_metrics("train", train_metrics)
|
||||
trainer.save_metrics("train", train_metrics)
|
||||
trainer.save_state()
|
||||
trainer.push_to_hub(
|
||||
commit_message=f"Add checkpoint {training_args.max_steps} post-trained on {pt_args.model_post_train_dataset_name}"
|
||||
)
|
||||
|
||||
del trainer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,260 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class GRPOConfig(TrainingArguments):
|
||||
r"""
|
||||
Configuration class for the [`GRPOTrainer`].
|
||||
|
||||
Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
|
||||
[`~transformers.TrainingArguments`] documentation.
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
> Parameters that control the model and reference model
|
||||
|
||||
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
||||
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
||||
argument of the [`GRPOTrainer`] is provided as a string.
|
||||
|
||||
> Parameters that control the data preprocessing
|
||||
|
||||
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
|
||||
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
|
||||
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
||||
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
|
||||
num_generations (`int` or `None`, *optional*, defaults to `8`):
|
||||
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
|
||||
must be divisible by this value.
|
||||
temperature (`float`, *optional*, defaults to `0.9`):
|
||||
Temperature for sampling. The higher the temperature, the more random the completions.
|
||||
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
|
||||
Maximum length of the generated completion.
|
||||
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
||||
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
||||
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
||||
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
|
||||
with vLLM generation.
|
||||
|
||||
> Parameters that control generation acceleration powered by vLLM
|
||||
|
||||
use_vllm (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
|
||||
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
|
||||
vllm_device (`str`, *optional*, defaults to `"auto"`):
|
||||
Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
|
||||
automatically select the next available GPU after the last one used for training. This assumes that
|
||||
training has not already occupied all available GPUs. If only one device is available, the device will be
|
||||
shared between both training and vLLM.
|
||||
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
|
||||
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
|
||||
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
|
||||
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
|
||||
during initialization.
|
||||
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
|
||||
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
|
||||
based on the model configuration. Find the supported values in the vLLM documentation.
|
||||
vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
|
||||
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
|
||||
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
|
||||
context size, which might be much larger than the KV cache, leading to inefficiencies.
|
||||
|
||||
> Parameters that control the training
|
||||
|
||||
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
||||
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
||||
[`~transformers.TrainingArguments`].
|
||||
beta (`float`, *optional*, defaults to `0.04`):
|
||||
KL coefficient.
|
||||
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
|
||||
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
|
||||
weighted equally with weight `1.0`.
|
||||
sync_ref_model (`bool`, *optional*, defaults to `False`):
|
||||
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
|
||||
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
|
||||
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
|
||||
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`):
|
||||
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
|
||||
between the current policy and the previous reference policy during updates. The reference policy is
|
||||
updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
|
||||
must set `sync_ref_model=True`.
|
||||
ref_model_sync_steps (`int`, *optional*, defaults to `64`):
|
||||
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
|
||||
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
|
||||
set `sync_ref_model=True`.
|
||||
|
||||
> Parameters that control the logging
|
||||
|
||||
log_completions (`bool`, *optional*, defaults to `False`):
|
||||
Whether to log the completions during training.
|
||||
"""
|
||||
|
||||
# Parameters that control the model and reference model
|
||||
model_init_kwargs: Optional[dict] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
|
||||
"argument of the `GRPOTrainer` is provided as a string."
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters that control the data preprocessing
|
||||
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
|
||||
# additional columns to compute the reward
|
||||
remove_unused_columns: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
|
||||
"that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
|
||||
},
|
||||
)
|
||||
max_prompt_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
|
||||
},
|
||||
)
|
||||
num_generations: Optional[int] = field(
|
||||
default=8,
|
||||
metadata={
|
||||
"help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
|
||||
"must be divisible by this value."
|
||||
},
|
||||
)
|
||||
temperature: Optional[float] = field(
|
||||
default=0.9,
|
||||
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
|
||||
)
|
||||
max_completion_length: Optional[int] = field(
|
||||
default=256,
|
||||
metadata={"help": "Maximum length of the generated completion."},
|
||||
)
|
||||
ds3_gather_for_generation: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
|
||||
"generation, improving generation speed. However, disabling this option allows training models that "
|
||||
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
|
||||
"is not compatible with vLLM generation."
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters that control generation acceleration powered by vLLM
|
||||
use_vllm: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
|
||||
"unused for training, as vLLM will require one for generation. vLLM must be installed "
|
||||
"(`pip install vllm`)."
|
||||
},
|
||||
)
|
||||
vllm_device: Optional[str] = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
|
||||
"will automatically select the next available GPU after the last one used for training. This assumes "
|
||||
"that training has not already occupied all available GPUs."
|
||||
},
|
||||
)
|
||||
vllm_gpu_memory_utilization: float = field(
|
||||
default=0.9,
|
||||
metadata={
|
||||
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
|
||||
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
|
||||
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
|
||||
"out-of-memory (OOM) errors during initialization."
|
||||
},
|
||||
)
|
||||
vllm_dtype: Optional[str] = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
|
||||
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
||||
},
|
||||
)
|
||||
vllm_max_model_len: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
|
||||
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
|
||||
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters that control the training
|
||||
learning_rate: float = field(
|
||||
default=1e-6,
|
||||
metadata={
|
||||
"help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
|
||||
"`transformers.TrainingArguments`."
|
||||
},
|
||||
)
|
||||
beta: float = field(
|
||||
default=0.04,
|
||||
metadata={"help": "KL coefficient."},
|
||||
)
|
||||
scale_rewards: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), "
|
||||
"the rewards are normalized by the standard deviation, ensuring they have unit variance. If `False`, no "
|
||||
"scaling is applied. The Dr. GRPO paper recommends not scaling the rewards, as scaling by the standard "
|
||||
"deviation introduces a question-level difficulty bias."
|
||||
},
|
||||
)
|
||||
reward_weights: Optional[list[float]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
|
||||
"rewards are weighted equally with weight `1.0`."
|
||||
},
|
||||
)
|
||||
sync_ref_model: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
|
||||
"steps, using the `ref_model_mixup_alpha` parameter."
|
||||
},
|
||||
)
|
||||
ref_model_mixup_alpha: float = field(
|
||||
default=0.9,
|
||||
metadata={
|
||||
"help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
|
||||
"previous reference policy during updates. The reference policy is updated according to the equation: "
|
||||
"`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
|
||||
},
|
||||
)
|
||||
ref_model_sync_steps: int = field(
|
||||
default=64,
|
||||
metadata={
|
||||
"help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
|
||||
"synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters that control the logging
|
||||
log_completions: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to log the completions during training."},
|
||||
)
|
||||
@@ -1,823 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Optional, Sized, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import transformers
|
||||
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
|
||||
from accelerate.utils.other import is_compiled_module
|
||||
from datasets import Dataset, IterableDataset
|
||||
from packaging import version
|
||||
from tina.post_train_hf.grpo_config import GRPOConfig
|
||||
from torch import nn
|
||||
from torch.utils.data import Sampler
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
is_wandb_available,
|
||||
)
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.utils import is_peft_available
|
||||
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
||||
from trl.import_utils import is_vllm_available
|
||||
from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
|
||||
from trl.trainer.callbacks import SyncRefModelCallback
|
||||
from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad, selective_log_softmax
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftConfig, get_peft_model
|
||||
|
||||
if is_vllm_available():
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
|
||||
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
||||
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
||||
|
||||
|
||||
class RepeatRandomSampler(Sampler):
|
||||
"""
|
||||
Sampler that repeats the indices of a dataset N times.
|
||||
|
||||
Args:
|
||||
data_source (`Sized`):
|
||||
Dataset to sample from.
|
||||
repeat_count (`int`):
|
||||
Number of times to repeat each index.
|
||||
seed (`Optional[int]`):
|
||||
Random seed for reproducibility (only affects this sampler).
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
|
||||
>>> list(sampler)
|
||||
[2, 2, 0, 0, 3, 3, 1, 1]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, data_source: Sized, repeat_count: int, seed: Optional[int] = None):
|
||||
self.data_source = data_source
|
||||
self.repeat_count = repeat_count
|
||||
self.num_samples = len(data_source)
|
||||
self.seed = seed
|
||||
self.generator = torch.Generator() # Create a local random generator
|
||||
if seed is not None:
|
||||
self.generator.manual_seed(seed)
|
||||
|
||||
def __iter__(self):
|
||||
indexes = [
|
||||
idx
|
||||
for idx in torch.randperm(self.num_samples, generator=self.generator).tolist()
|
||||
for _ in range(self.repeat_count)
|
||||
]
|
||||
return iter(indexes)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples * self.repeat_count
|
||||
|
||||
|
||||
class GRPOTrainer(Trainer):
|
||||
"""
|
||||
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
|
||||
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
def reward_func(completions, **kwargs):
|
||||
# Dummy reward function that rewards completions with more unique letters.
|
||||
return [float(len(set(completion))) for completion in completions]
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_func,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Args:
|
||||
model (`Union[str, PreTrainedModel]`):
|
||||
Model to be trained. Can be either:
|
||||
|
||||
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
||||
a path to a *directory* containing model weights saved using
|
||||
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
||||
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
||||
in `args.model_init_kwargs`.
|
||||
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
||||
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
|
||||
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
|
||||
functions with the prompts and completions and sum the rewards. Can be either:
|
||||
|
||||
- A single reward function, such as:
|
||||
- A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
||||
path to a *directory* containing model weights saved using
|
||||
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
|
||||
using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
|
||||
keyword arguments in `args.model_init_kwargs`.
|
||||
- A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
|
||||
- A custom reward function: The function is provided with the prompts and the generated completions,
|
||||
plus any additional columns in the dataset. It should return a list of rewards. For more details, see
|
||||
[Using a custom reward function](#using-a-custom-reward-function).
|
||||
- A list of reward functions, where each item can independently be any of the above types. Mixing different
|
||||
types within the list (e.g., a string model ID and a custom reward function) is allowed.
|
||||
args ([`GRPOConfig`], *optional*, defaults to `None`):
|
||||
Configuration for this trainer. If `None`, a default configuration is used.
|
||||
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
||||
Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
|
||||
ignored. The format of the samples can be either:
|
||||
|
||||
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
||||
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
||||
and content).
|
||||
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
||||
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
||||
Processing class used to process the data. The padding side must be set to "left". If `None`, the
|
||||
processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
|
||||
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
|
||||
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
|
||||
|
||||
- A single processing class: Used when `reward_funcs` contains only one reward function.
|
||||
- A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
|
||||
If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
|
||||
`None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
|
||||
For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
|
||||
the corresponding entries in `reward_processing_classes` are ignored.
|
||||
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
|
||||
List of callbacks to customize the training loop. Will add those to the list of default callbacks
|
||||
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
||||
|
||||
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
||||
method.
|
||||
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
|
||||
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
|
||||
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
|
||||
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
||||
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
||||
"""
|
||||
|
||||
_tag_names = ["trl", "grpo"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[str, PreTrainedModel],
|
||||
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
||||
args: GRPOConfig = None,
|
||||
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
||||
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
||||
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
||||
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
||||
peft_config: Optional["PeftConfig"] = None,
|
||||
):
|
||||
# Args
|
||||
if args is None:
|
||||
model_name = model if isinstance(model, str) else model.config._name_or_path
|
||||
model_name = model_name.split("/")[-1]
|
||||
args = GRPOConfig(f"{model_name}-GRPO")
|
||||
|
||||
# Models
|
||||
# Trained model
|
||||
model_init_kwargs = args.model_init_kwargs or {}
|
||||
if isinstance(model, str):
|
||||
model_id = model
|
||||
torch_dtype = model_init_kwargs.get("torch_dtype")
|
||||
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
||||
pass # torch_dtype is already a torch.dtype or "auto" or None
|
||||
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
model_init_kwargs["torch_dtype"] = torch_dtype
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
||||
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
||||
)
|
||||
# Disable caching if gradient checkpointing is enabled (not supported)
|
||||
model_init_kwargs["use_cache"] = (
|
||||
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
||||
else:
|
||||
model_id = model.config._name_or_path
|
||||
if args.model_init_kwargs is not None:
|
||||
raise ValueError(
|
||||
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
|
||||
"This argument can only be used when the `model` argument is a string."
|
||||
)
|
||||
|
||||
if peft_config is not None:
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
# Reference model
|
||||
if is_deepspeed_zero3_enabled():
|
||||
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
|
||||
elif not is_peft_model(model):
|
||||
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
||||
self.ref_model = create_reference_model(model)
|
||||
else:
|
||||
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
||||
# to revert to the initial model.
|
||||
self.ref_model = None
|
||||
|
||||
# Processing class
|
||||
if processing_class is None:
|
||||
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
|
||||
|
||||
# Reward functions
|
||||
if not isinstance(reward_funcs, list):
|
||||
reward_funcs = [reward_funcs]
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
if isinstance(reward_func, str):
|
||||
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
||||
reward_func, num_labels=1, **model_init_kwargs
|
||||
)
|
||||
self.reward_funcs = reward_funcs
|
||||
|
||||
# Reward weights
|
||||
if args.reward_weights is not None:
|
||||
if len(args.reward_weights) != len(reward_funcs):
|
||||
raise ValueError(
|
||||
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
|
||||
f"functions ({len(reward_funcs)})"
|
||||
)
|
||||
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
|
||||
else:
|
||||
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
|
||||
|
||||
# Reward processing class
|
||||
if reward_processing_classes is None:
|
||||
reward_processing_classes = [None] * len(reward_funcs)
|
||||
elif not isinstance(reward_processing_classes, list):
|
||||
reward_processing_classes = [reward_processing_classes]
|
||||
else:
|
||||
if len(reward_processing_classes) != len(reward_funcs):
|
||||
raise ValueError("The number of reward processing classes must match the number of reward functions.")
|
||||
|
||||
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
if reward_processing_class is None:
|
||||
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
|
||||
if reward_processing_class.pad_token_id is None:
|
||||
reward_processing_class.pad_token = reward_processing_class.eos_token
|
||||
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
||||
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
||||
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
||||
reward_processing_classes[i] = reward_processing_class
|
||||
self.reward_processing_classes = reward_processing_classes
|
||||
|
||||
# Data collator
|
||||
def data_collator(features): # No data collation is needed in GRPO
|
||||
return features
|
||||
|
||||
# Training arguments
|
||||
self.max_prompt_length = args.max_prompt_length
|
||||
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
|
||||
self.num_generations = args.num_generations # = G in the GRPO paper
|
||||
self.use_vllm = args.use_vllm
|
||||
|
||||
self.beta = args.beta
|
||||
|
||||
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
||||
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
||||
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
||||
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
||||
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
||||
# This acts as a flag to indicate that the warning has already been issued.
|
||||
model.warnings_issued["estimate_tokens"] = True
|
||||
|
||||
# Initialize the metrics
|
||||
self._metrics = defaultdict(list)
|
||||
self.log_completions = args.log_completions
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
)
|
||||
|
||||
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
||||
num_processes = self.accelerator.num_processes
|
||||
global_batch_size = args.per_device_train_batch_size * num_processes
|
||||
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
||||
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
|
||||
f"batch size, the valid values for the number of generations are: {possible_values}."
|
||||
)
|
||||
if self.args.eval_strategy != "no":
|
||||
global_batch_size = args.per_device_eval_batch_size * num_processes
|
||||
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
||||
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
|
||||
f"eval batch size, the valid values for the number of generations are: {possible_values}."
|
||||
)
|
||||
|
||||
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
|
||||
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
|
||||
# it's safer to set it in all cases.
|
||||
set_seed(args.seed, device_specific=True)
|
||||
|
||||
if self.use_vllm:
|
||||
if not is_vllm_available():
|
||||
raise ImportError(
|
||||
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
|
||||
"`pip install vllm` to use it."
|
||||
)
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
vllm_device = self.args.vllm_device
|
||||
if vllm_device == "auto":
|
||||
if torch.cuda.device_count() == 1:
|
||||
vllm_device = "cuda:0" # particular case when training with onyl 1 GPU: share it
|
||||
else:
|
||||
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
|
||||
# Check that the requested device is available
|
||||
if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
|
||||
raise ValueError(
|
||||
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
|
||||
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
|
||||
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
|
||||
f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
|
||||
)
|
||||
# Check that the requested device is not also used for training
|
||||
if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}:
|
||||
warnings.warn(
|
||||
f"The requested device {vllm_device} is also being used for training. For higher throughput "
|
||||
"and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. "
|
||||
"If this is intentional, you may ignore this warning but should adjust "
|
||||
"`vllm_gpu_memory_utilization` accordingly."
|
||||
)
|
||||
# vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
|
||||
# model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
|
||||
# setting (profiling_patch).
|
||||
world_size_patch = patch("torch.distributed.get_world_size", return_value=1)
|
||||
profiling_patch = patch(
|
||||
"vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None
|
||||
)
|
||||
with world_size_patch, profiling_patch:
|
||||
self.llm = LLM(
|
||||
model=model.name_or_path,
|
||||
device=vllm_device,
|
||||
gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
|
||||
dtype=self.args.vllm_dtype,
|
||||
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
|
||||
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
||||
# This is particularly useful here because we generate completions from the same prompts.
|
||||
enable_prefix_caching=True,
|
||||
max_model_len=self.args.vllm_max_model_len,
|
||||
)
|
||||
self.sampling_params = SamplingParams(
|
||||
temperature=args.temperature,
|
||||
max_tokens=self.max_completion_length,
|
||||
)
|
||||
|
||||
self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
|
||||
|
||||
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
|
||||
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
|
||||
# synchronize all processes after vLLM has been fully initialized.
|
||||
self.accelerator.wait_for_everyone()
|
||||
else:
|
||||
self.generation_config = GenerationConfig(
|
||||
max_new_tokens=self.max_completion_length,
|
||||
do_sample=True,
|
||||
temperature=args.temperature,
|
||||
pad_token_id=processing_class.pad_token_id,
|
||||
)
|
||||
|
||||
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
||||
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
||||
# self.model_accepts_loss_kwargs to False to enable scaling.
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
# Add tags to the model
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
|
||||
if self.ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
if args.sync_ref_model:
|
||||
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
|
||||
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
||||
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
||||
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
||||
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
||||
if self._signature_columns is None:
|
||||
self._signature_columns = ["prompt"]
|
||||
|
||||
def _get_train_sampler(self) -> Sampler:
|
||||
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
|
||||
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
||||
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
||||
# preventing discrepancies in group formation.
|
||||
return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
|
||||
|
||||
def _get_eval_sampler(self, eval_dataset) -> Sampler:
|
||||
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
|
||||
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
||||
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
||||
# preventing discrepancies in group formation.
|
||||
return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
|
||||
|
||||
# Get the per-token log probabilities for the completions for the model and the reference model
|
||||
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
||||
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
||||
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
||||
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
||||
|
||||
input_ids = input_ids[:, -logits_to_keep:]
|
||||
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
|
||||
# See https://github.com/huggingface/trl/issues/2770
|
||||
logits = logits[:, -logits_to_keep:]
|
||||
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
|
||||
|
||||
def _move_model_to_vllm(self):
|
||||
with unwrap_model_for_generation(
|
||||
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
||||
) as unwrapped_model:
|
||||
if is_compiled_module(unwrapped_model):
|
||||
unwrapped_model = unwrapped_model._orig_mod
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.merge_adapter()
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
# Remove base_model and base_layer prefixes
|
||||
state_dict = {
|
||||
k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items()
|
||||
}
|
||||
# Remove values with adapter prefix (example: "_lora")
|
||||
state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
state_dict = {
|
||||
k.replace("modules_to_save.default.", ""): v
|
||||
for k, v in state_dict.items()
|
||||
if "original_module" not in k
|
||||
}
|
||||
else:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
llm_model.load_weights(state_dict.items())
|
||||
# Unmerge the adapter to restore the model to its original state.
|
||||
# This must be done after loading weights to ensure they correspond to the merged state.
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.unmerge_adapter()
|
||||
|
||||
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
device = self.accelerator.device
|
||||
prompts = [x["prompt"] for x in inputs]
|
||||
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
||||
prompt_inputs = self.processing_class(
|
||||
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
|
||||
)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
|
||||
if self.max_prompt_length is not None:
|
||||
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
||||
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
||||
|
||||
# Generate completions using either vLLM or regular generation
|
||||
if self.args.use_vllm:
|
||||
# First, have main process load weights if needed
|
||||
if self.state.global_step != self._last_loaded_step:
|
||||
self._move_model_to_vllm()
|
||||
self._last_loaded_step = self.state.global_step
|
||||
|
||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if self.accelerator.is_main_process:
|
||||
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
|
||||
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
|
||||
else:
|
||||
completion_ids = [None] * len(all_prompts_text)
|
||||
# Broadcast the completions from the main process to all processes, ensuring each process receives its
|
||||
# corresponding slice.
|
||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts),
|
||||
(self.accelerator.process_index + 1) * len(prompts),
|
||||
)
|
||||
completion_ids = completion_ids[process_slice]
|
||||
|
||||
# Pad the completions, and concatenate them with the prompts
|
||||
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
|
||||
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
|
||||
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
||||
else:
|
||||
# Regular generation path
|
||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
||||
prompt_completion_ids = unwrapped_model.generate(
|
||||
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
|
||||
)
|
||||
|
||||
# Compute prompt length and extract completion ids
|
||||
prompt_length = prompt_ids.size(1)
|
||||
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
||||
completion_ids = prompt_completion_ids[:, prompt_length:]
|
||||
|
||||
# Mask everything after the first EOS token
|
||||
is_eos = completion_ids == self.processing_class.eos_token_id
|
||||
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
||||
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||||
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
||||
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
||||
|
||||
# Concatenate prompt_mask with completion_mask for logit computation
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
|
||||
|
||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||
|
||||
with torch.inference_mode():
|
||||
if self.ref_model is not None:
|
||||
ref_per_token_logps = self._get_per_token_logps(
|
||||
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
|
||||
)
|
||||
else:
|
||||
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
||||
ref_per_token_logps = self._get_per_token_logps(
|
||||
self.model, prompt_completion_ids, attention_mask, logits_to_keep
|
||||
)
|
||||
|
||||
# Decode the generated completions
|
||||
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
||||
if is_conversational(inputs[0]):
|
||||
completions = []
|
||||
for prompt, completion in zip(prompts, completions_text):
|
||||
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
|
||||
completions.append([{"role": "assistant", "content": bootstrap + completion}])
|
||||
else:
|
||||
completions = completions_text
|
||||
|
||||
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
|
||||
for i, (reward_func, reward_processing_class) in enumerate(
|
||||
zip(self.reward_funcs, self.reward_processing_classes)
|
||||
):
|
||||
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
||||
if is_conversational(inputs[0]):
|
||||
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
||||
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
||||
else:
|
||||
texts = [p + c for p, c in zip(prompts, completions)]
|
||||
reward_inputs = reward_processing_class(
|
||||
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
|
||||
)
|
||||
reward_inputs = super()._prepare_inputs(reward_inputs)
|
||||
with torch.inference_mode():
|
||||
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
|
||||
else:
|
||||
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
||||
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
|
||||
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
|
||||
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
|
||||
# Convert None values to NaN
|
||||
output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
|
||||
|
||||
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
|
||||
|
||||
# If all reward functions return None for a given row, issue a detailed warning
|
||||
if torch.isnan(rewards_per_func).all(dim=1).any():
|
||||
nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
|
||||
row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()}
|
||||
row_reward_kwargs["prompt"] = prompts[nan_row_idx]
|
||||
row_reward_kwargs["completion"] = completions[nan_row_idx]
|
||||
warnings.warn(
|
||||
f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
|
||||
"Please ensure that at least one reward function returns a valid reward."
|
||||
)
|
||||
|
||||
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
|
||||
# completions may be distributed across processes
|
||||
rewards_per_func = gather(rewards_per_func)
|
||||
|
||||
# Apply weights to each reward function's output and sum
|
||||
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
|
||||
|
||||
# Compute grouped-wise rewards
|
||||
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
||||
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
||||
|
||||
# Normalize the rewards to compute the advantages
|
||||
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
||||
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
||||
advantages = rewards - mean_grouped_rewards
|
||||
if self.args.scale_rewards:
|
||||
advantages = advantages / (std_grouped_rewards + 1e-4)
|
||||
|
||||
# Slice to keep only the local part of the data
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts),
|
||||
(self.accelerator.process_index + 1) * len(prompts),
|
||||
)
|
||||
advantages = advantages[process_slice]
|
||||
|
||||
# Log the metrics
|
||||
reward_per_func = rewards_per_func.mean(0)
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
|
||||
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
||||
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
||||
else:
|
||||
reward_func_name = reward_func.__name__
|
||||
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
||||
|
||||
self._metrics["reward"].append(rewards.mean().item())
|
||||
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
|
||||
|
||||
if (
|
||||
self.log_completions
|
||||
and self.state.global_step % self.args.logging_steps == 0
|
||||
and "wandb" in self.args.report_to
|
||||
):
|
||||
import pandas as pd
|
||||
|
||||
# For logging
|
||||
table = {
|
||||
"step": [str(self.state.global_step)] * len(rewards),
|
||||
"prompt": gather_object(prompts_text),
|
||||
"completion": gather_object(completions_text),
|
||||
"reward": rewards.tolist(),
|
||||
}
|
||||
df = pd.DataFrame(table)
|
||||
|
||||
if wandb.run is not None and self.accelerator.is_main_process:
|
||||
wandb.log({"completions": wandb.Table(dataframe=df)})
|
||||
|
||||
return {
|
||||
"prompt_ids": prompt_ids,
|
||||
"prompt_mask": prompt_mask,
|
||||
"completion_ids": completion_ids,
|
||||
"completion_mask": completion_mask,
|
||||
"ref_per_token_logps": ref_per_token_logps,
|
||||
"advantages": advantages,
|
||||
}
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
if return_outputs:
|
||||
raise ValueError("The GRPOTrainer does not support returning outputs")
|
||||
# Compute the per-token log probabilities for the model
|
||||
|
||||
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
||||
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
||||
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||
|
||||
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
||||
|
||||
# Compute the KL divergence between the model and the reference model
|
||||
ref_per_token_logps = inputs["ref_per_token_logps"]
|
||||
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
||||
|
||||
# x - x.detach() allows for preserving gradients from x
|
||||
advantages = inputs["advantages"]
|
||||
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
|
||||
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
|
||||
|
||||
if self.args.scale_rewards:
|
||||
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
||||
else:
|
||||
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
|
||||
|
||||
# Log the metrics
|
||||
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
||||
self._metrics["completion_length"].append(completion_length)
|
||||
|
||||
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
||||
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
||||
|
||||
return loss
|
||||
|
||||
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
with torch.no_grad():
|
||||
with self.compute_loss_context_manager():
|
||||
loss = self.compute_loss(model, inputs)
|
||||
loss = loss.mean().detach()
|
||||
return loss, None, None
|
||||
|
||||
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
||||
|
||||
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
||||
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
||||
if next(iter(logs.keys())).startswith("eval_"):
|
||||
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
||||
|
||||
logs = {**logs, **metrics}
|
||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||
super().log(logs, start_time)
|
||||
else: # transformers<=4.46
|
||||
super().log(logs)
|
||||
self._metrics.clear()
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, list[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
||||
Args:
|
||||
model_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the model.
|
||||
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the dataset used for training.
|
||||
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
tags = tags or []
|
||||
if isinstance(tags, str):
|
||||
tags = [tags]
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.append("unsloth")
|
||||
|
||||
citation = textwrap.dedent(
|
||||
"""\
|
||||
@article{zhihong2024deepseekmath,
|
||||
title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
|
||||
author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
|
||||
year = 2024,
|
||||
eprint = {arXiv:2402.03300},
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
||||
comet_url=get_comet_experiment_url(),
|
||||
trainer_name="GRPO",
|
||||
trainer_citation=citation,
|
||||
paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
|
||||
paper_id="2402.03300",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
||||
@@ -1,44 +0,0 @@
|
||||
import logging
|
||||
from concurrent.futures import Future
|
||||
|
||||
from huggingface_hub import create_branch, create_repo, list_repo_commits, upload_folder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def push_to_hub_revision(training_args, extra_ignore_patterns=[]) -> Future:
|
||||
"""Pushes the model to branch on a Hub repo."""
|
||||
|
||||
# Create a repo if it doesn't exist yet
|
||||
repo_url = create_repo(repo_id=training_args.hub_model_id, private=True, exist_ok=True)
|
||||
# Get initial commit to branch from
|
||||
initial_commit = list_repo_commits(training_args.hub_model_id)[-1]
|
||||
# Now create the branch we'll be pushing to
|
||||
create_branch(
|
||||
repo_id=training_args.hub_model_id,
|
||||
branch=training_args.hub_model_revision,
|
||||
# checkpoint=training_args.checkpoint,
|
||||
revision=initial_commit.commit_id,
|
||||
exist_ok=True,
|
||||
)
|
||||
logger.info(f"Created target repo at {repo_url}")
|
||||
logger.info(
|
||||
f"Pushing to the Hub revision {training_args.hub_model_revision} with checkpoint {training_args.checkpoint}"
|
||||
)
|
||||
ignore_patterns = ["checkpoint-*", "*.pth"]
|
||||
ignore_patterns.extend(extra_ignore_patterns)
|
||||
future = upload_folder(
|
||||
repo_id=training_args.hub_model_id,
|
||||
folder_path=training_args.output_dir,
|
||||
revision=training_args.hub_model_revision,
|
||||
# commit_message=f"Add {training_args.hub_model_revision} checkpoint {training_args.dataset_name}",
|
||||
commit_message=f"Add {training_args.checkpoint} checkpoint post-trained on {training_args.dataset_name}",
|
||||
ignore_patterns=ignore_patterns,
|
||||
run_as_future=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Pushed to {repo_url} revision {training_args.hub_model_revision} with checkpoint {training_args.checkpoint} successfully!"
|
||||
)
|
||||
|
||||
return future
|
||||
@@ -1,7 +0,0 @@
|
||||
def make_conv_for_grpo(example, system_prompt):
|
||||
return {
|
||||
"prompt": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": example["problem"]},
|
||||
]
|
||||
}
|
||||
@@ -1,302 +0,0 @@
|
||||
import math
|
||||
import re
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
|
||||
def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]:
|
||||
"""Reward function that checks if the completion is the same as the ground truth."""
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
rewards = []
|
||||
for content, sol in zip(contents, solution):
|
||||
gold_parsed = parse(
|
||||
sol,
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
if len(gold_parsed) != 0:
|
||||
# We require the answer to be provided in correct latex (no malformed operators)
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
# Compute binary rewards if verifiable, `None` otherwise to skip this example
|
||||
try:
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}")
|
||||
reward = None
|
||||
else:
|
||||
# If the gold solution is not parseable, we assign `None` to skip this example
|
||||
reward = None
|
||||
print("Failed to parse gold solution: ", sol)
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
def format_reward(completions, **kwargs):
|
||||
"""Reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags."""
|
||||
|
||||
def count_tags(text: str) -> float:
|
||||
count = 0.0
|
||||
# We only count </think> tag, because <think> tag is available in system prompt
|
||||
if text.count("\n</think>\n") == 1:
|
||||
count += 1.0
|
||||
return count
|
||||
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
return [count_tags(c) for c in contents]
|
||||
|
||||
|
||||
def tag_count_reward(completions, **kwargs) -> list[float]:
|
||||
"""Reward function that checks if we produce the desired number of think and answer tags associated with `format_reward()`.
|
||||
|
||||
Adapted from: https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb#file-grpo_demo-py-L90
|
||||
"""
|
||||
|
||||
def count_tags(text: str) -> float:
|
||||
count = 0.0
|
||||
if re.search(r"\s*<think>\s*", text):
|
||||
count += 0.25
|
||||
if re.search(r"\s*</think>\s*", text):
|
||||
count += 0.25
|
||||
if re.search(r"\s*<answer>\s*", text):
|
||||
count += 0.25
|
||||
if re.search(r"\s*</answer>\s*", text):
|
||||
count += 0.25
|
||||
return count
|
||||
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
return [count_tags(c) for c in contents]
|
||||
|
||||
|
||||
def reasoning_steps_reward(completions, **kwargs):
|
||||
r"""Reward function that checks for clear step-by-step reasoning.
|
||||
Regex pattern:
|
||||
Step \d+: - matches "Step 1:", "Step 2:", etc.
|
||||
^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line
|
||||
\n- - matches bullet points with hyphens
|
||||
\n\* - matches bullet points with asterisks
|
||||
First,|Second,|Next,|Finally, - matches transition words
|
||||
"""
|
||||
pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
|
||||
completion_contents = [completion[0]["content"] for completion in completions]
|
||||
matches = [len(re.findall(pattern, content)) for content in completion_contents]
|
||||
|
||||
# Magic number 3 to encourage 3 steps and more, otherwise partial reward
|
||||
return [min(1.0, count / 3) for count in matches]
|
||||
|
||||
|
||||
def len_reward(completions: list[Dict[str, str]], solution: list[str], **kwargs) -> float:
|
||||
"""Compute length-based rewards to discourage overthinking and promote token efficiency.
|
||||
|
||||
Taken from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599
|
||||
|
||||
Args:
|
||||
completions: List of model completions
|
||||
solution: List of ground truth solutions
|
||||
|
||||
Returns:
|
||||
List of rewards where:
|
||||
- For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)
|
||||
- For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))
|
||||
"""
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
|
||||
# First check correctness of answers
|
||||
correctness = []
|
||||
for content, sol in zip(contents, solution):
|
||||
gold_parsed = parse(
|
||||
sol,
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
if len(gold_parsed) == 0:
|
||||
# Skip unparseable examples
|
||||
correctness.append(True) # Treat as correct to avoid penalizing
|
||||
print("Failed to parse gold solution: ", sol)
|
||||
continue
|
||||
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed=True,
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
correctness.append(verify(answer_parsed, gold_parsed))
|
||||
|
||||
# Calculate lengths
|
||||
lengths = [len(content) for content in contents]
|
||||
min_len = min(lengths)
|
||||
max_len = max(lengths)
|
||||
|
||||
# If all responses have the same length, return zero rewards
|
||||
if max_len == min_len:
|
||||
return [0.0] * len(completions)
|
||||
|
||||
rewards = []
|
||||
for length, is_correct in zip(lengths, correctness):
|
||||
lambda_val = 0.5 - (length - min_len) / (max_len - min_len)
|
||||
|
||||
if is_correct:
|
||||
reward = lambda_val
|
||||
else:
|
||||
reward = min(0, lambda_val)
|
||||
|
||||
rewards.append(float(reward))
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
def get_cosine_scaled_reward(
|
||||
min_value_wrong: float = -1.0,
|
||||
max_value_wrong: float = -0.5,
|
||||
min_value_correct: float = 0.5,
|
||||
max_value_correct: float = 1.0,
|
||||
max_len: int = 1000,
|
||||
):
|
||||
def cosine_scaled_reward(completions, solution, **kwargs):
|
||||
"""Reward function that scales based on completion length using a cosine schedule.
|
||||
|
||||
Shorter correct solutions are rewarded more than longer ones.
|
||||
Longer incorrect solutions are penalized less than shorter ones.
|
||||
|
||||
Args:
|
||||
completions: List of model completions
|
||||
solution: List of ground truth solutions
|
||||
|
||||
This function is parameterized by the following arguments:
|
||||
min_value_wrong: Minimum reward for wrong answers
|
||||
max_value_wrong: Maximum reward for wrong answers
|
||||
min_value_correct: Minimum reward for correct answers
|
||||
max_value_correct: Maximum reward for correct answers
|
||||
max_len: Maximum length for scaling
|
||||
"""
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
rewards = []
|
||||
|
||||
for content, sol in zip(contents, solution):
|
||||
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
|
||||
if len(gold_parsed) == 0:
|
||||
rewards.append(1.0) # Skip unparseable examples
|
||||
print("Failed to parse gold solution: ", sol)
|
||||
continue
|
||||
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed=True,
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
|
||||
is_correct = verify(answer_parsed, gold_parsed)
|
||||
gen_len = len(content)
|
||||
|
||||
# Apply cosine scaling based on length
|
||||
progress = gen_len / max_len
|
||||
cosine = math.cos(progress * math.pi)
|
||||
|
||||
if is_correct:
|
||||
min_value = min_value_correct
|
||||
max_value = max_value_correct
|
||||
else:
|
||||
# Swap min/max for incorrect answers
|
||||
min_value = max_value_wrong
|
||||
max_value = min_value_wrong
|
||||
|
||||
reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
|
||||
rewards.append(float(reward))
|
||||
|
||||
return rewards
|
||||
|
||||
return cosine_scaled_reward
|
||||
|
||||
|
||||
def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
|
||||
"""
|
||||
Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373.
|
||||
Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
|
||||
|
||||
Args:
|
||||
ngram_size: size of the n-grams
|
||||
max_penalty: Maximum (negative) penalty for wrong answers
|
||||
"""
|
||||
if max_penalty > 0:
|
||||
raise ValueError(f"max_penalty {max_penalty} should not be positive")
|
||||
|
||||
def zipngram(text: str, ngram_size: int):
|
||||
words = text.lower().split()
|
||||
return zip(*[words[i:] for i in range(ngram_size)])
|
||||
|
||||
def repetition_penalty_reward(completions, **kwargs) -> float:
|
||||
"""
|
||||
reward function the penalizes repetitions
|
||||
ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
|
||||
|
||||
Args:
|
||||
completions: List of model completions
|
||||
"""
|
||||
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
rewards = []
|
||||
for completion in contents:
|
||||
if completion == "":
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
if len(completion.split()) < ngram_size:
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
|
||||
ngrams = set()
|
||||
total = 0
|
||||
for ng in zipngram(completion, ngram_size):
|
||||
ngrams.add(ng)
|
||||
total += 1
|
||||
|
||||
scaling = 1 - len(ngrams) / total
|
||||
reward = scaling * max_penalty
|
||||
rewards.append(reward)
|
||||
return rewards
|
||||
|
||||
return repetition_penalty_reward
|
||||
@@ -1,2 +0,0 @@
|
||||
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
|
||||
REASON_CHAT_TEMPLATE = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}"
|
||||
@@ -1,23 +0,0 @@
|
||||
# problem/question, (solution), answer
|
||||
RL_POST_TRAIN_DATASET_MAP = {
|
||||
# Main datasets
|
||||
"curated_deepscaler": "agentica-org/DeepScaleR-Preview-Dataset", # 40.3k
|
||||
"curated_still": "RUC-AIBOX/STILL-3-Preview-RL-Data", # 33k
|
||||
"curated_open_rs3": "knoveleng/open-rs", # 7k
|
||||
"curated_open_rs2": "knoveleng/open-rs", # 7k
|
||||
"curated_open_rs1": "knoveleng/open-s1", # 18.6k
|
||||
# Extra datasets
|
||||
"curated_limr": "GAIR/LIMR", # 1.39k
|
||||
"curated_open_r1": "open-r1/OpenR1-Math-220k", # default split 93.7k
|
||||
"curated_thoughts": "bethgelab/CuratedThoughts", # default split 66.1k
|
||||
# Ablation
|
||||
"curated_limr_large_lr_ablation": "GAIR/LIMR",
|
||||
"curated_limr_small_lr_ablation": "GAIR/LIMR",
|
||||
"curated_limr_large_rank_ablation": "GAIR/LIMR",
|
||||
"curated_limr_medium_rank_ablation": "GAIR/LIMR",
|
||||
"curated_limr_small_rank_ablation": "GAIR/LIMR",
|
||||
"curated_limr_tiny_rank_ablation": "GAIR/LIMR",
|
||||
"curated_open_rs3_drgrpo_ablation": "knoveleng/open-rs",
|
||||
# Reasoning gym
|
||||
"curated_rg_math": "starzmustdie/rg-math",
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
# borrowed from https://github.com/huggingface/open-r1/blob/main/recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml
|
||||
OPEN_R1_SYSTEM_PROMPT = """
|
||||
You are a helpful AI Assistant that provides well-reasoned and detailed responses.
|
||||
You first think about the reasoning process as an internal monologue and then provide the user with the answer.
|
||||
Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>
|
||||
"""
|
||||
|
||||
# borrowed from https://github.com/knoveleng/open-rs/blob/main/recipes/grpo.yaml
|
||||
OPEN_RS_SYSTEM_PROMPT = """
|
||||
A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer, and put your final answer within \\boxed{{}} .
|
||||
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively,
|
||||
i.e., <think> reasoning process here </think> <answer> answer here </answer>.
|
||||
Note that respond by English, NOT use other languages.
|
||||
"""
|
||||
|
||||
# the first question from aime 2024
|
||||
FIXED_PROMPT_FOR_EVALUATION = """
|
||||
Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards.
|
||||
When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop.
|
||||
When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop.
|
||||
Suppose Aya walks at $s+\frac{1}{2}$ kilometers per hour.
|
||||
Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop."""
|
||||
Reference in New Issue
Block a user