tutorial(training): Add a minimal example with trl (#473)

* v0

* 2 gpu setup

* improve parsing from yaml

* update yaml dataset example

* remove restriction on flash attn

* more comments

* first version of the readme

* pin torch

* simplify requirements

* just flash attn

* use set env instead

* simpler set env

* readme

* add wandb project to setup

* update template

* update model id

* post init to capture the config and weight

* extract metadata

* update config

* update dataset config

* move env for wandb project

* pre-commit

* remove qwen-math from training

* more instructions

* unused import

* remove trl old

* warmup ratio

* warmup ratio

* change model id

* change model_id

* add info about CUDA_VISIBLE_DEVICES
This commit is contained in:
Zafir Stojanovski
2025-06-21 00:01:31 +02:00
committed by GitHub
parent 49f3821098
commit 56ce2e79a7
59 changed files with 382 additions and 155340 deletions

View File

@@ -1,32 +1,56 @@
# TRL Examples
# Training with TRL
This directory contains examples using the [TRL (Transformer Reinforcement Learning) library](https://github.com/huggingface/trl) to fine-tune language models with reinforcement learning techniques.
Training stack:
- TRL for reinforcement learning training
- Accelerate (with DeepSpeed) for distributed training
- vLLM for rollouts
## GRPO Example
The main example demonstrates using GRPO (Group Relative Policy Optimization) to fine-tune a language model on reasoning tasks from reasoning-gym. It includes:
- Custom reward functions for answer accuracy and format compliance
- Integration with reasoning-gym datasets
- Configurable training parameters via YAML config
- Wandb logging and model checkpointing
- Evaluation on held-out test sets
## Setup
1. Install the required dependencies:
This tutorial uses CUDA 11.8, Python 3.10, and PyTorch 2.5.1
```bash
pip install -r requirements.txt
Moreover, we assume that you have 2 GPUs on your machine, the last of which is used for vLLM rollouts.
If you have more than 2 GPUs, adjust the `./config/grpo.yaml` file so that the `vllm_device` is set to the last index of your GPU. For example, if you have 4 GPUs, set it to 3:
```yaml
vllm_device: 3 # If you have 4 GPUs, set this to 3
```
## Usage
1. Configure the training parameters in `config/grpo.yaml`
2. Run the training script:
Moreover, you would need to update the `CUDA_VISIBLE_DEVICES` environment variable in the `train.sh` script to include all your available GPUs. For example, if you have 4 GPUs, set it to:
```bash
python main_grpo_reward.py
# ./train.sh
# ... beginning of the script
export CUDA_VISIBLE_DEVICES=0,1,2,3
# ... rest of the script
```
The model will be trained using GRPO with the specified reasoning-gym dataset and evaluation metrics will be logged to Weights & Biases.
1. Install the required packages:
```bash
# First, give execute permissions to the script
# chmod +x ./set_env.sh
# Then, run the setup script
./set_env.sh
```
2. (Optional) Log in to Weights & Biases for experiment tracking:
```bash
# First, set your WANDB_API_KEY as an environment variable
export WANDB_API_KEY=your_wandb_api_key
# Set the project name
export WANDB_PROJECT=your_wandb_project_name
```
3. Run the training script
```bash
# First, give execute permissions to the script
# chmod +x ./train.sh
# Then, run the training script
./train.sh
```

View File

@@ -1,37 +1,52 @@
#Model arguments
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
# Reasoning Gym configs
dataset_size: 20000
developer_prompt: DeepSeekZero
developer_role: system
datasets:
simple_equations:
weight: 1
complex_arithmetic:
weight: 1
config:
min_real: -20
max_real: 20
#script arguments
dataset_name: chain_sum
# Model configs from trl
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
attn_implementation: flash_attention_2
#training arguments
# GRPO trainer configs from trl
bf16: true
gradient_accumulation_steps: 16
use_vllm: true
vllm_device: cuda:1
vllm_gpu_memory_utilization: 0.9
log_level: info
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id:
seed: 42
eval_seed: 101
log_level: info
logging_steps: 10
use_reentrant: false
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine
learning_rate: 2.0e-05
learning_rate: 1e-06
lr_scheduler_type: constant_with_warmup
lr_scheduler_kwargs:
num_warmup_steps: 10
max_prompt_length: 512
max_completion_length: 1024
max_completion_length: 2048
max_steps: 100
num_generations: 8
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
overwrite_output_dir: true
output_dir: data/Qwen-1.5B-GRPO
train_size: 1000
eval_size: 100
num_train_epochs: 1
max_steps: -1
push_to_hub: true
report_to: ['wandb']
#do_eval: true
#eval_strategy: steps
#eval_steps: 100
overwrite_output_dir: true
per_device_train_batch_size: 8
report_to:
- wandb
save_strategy: steps
save_steps: 50
save_total_limit: 5
seed: 42
temperature: 0.6
warmup_ratio: 0.1

266
examples/trl/grpo.py Normal file
View File

@@ -0,0 +1,266 @@
import logging
import os
import re
import sys
from dataclasses import dataclass, field
from typing import Optional
import torch
import transformers
from torch.utils.data import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from transformers.trainer_utils import get_last_checkpoint
from trl import GRPOConfig, GRPOTrainer, ModelConfig, TrlParser
import reasoning_gym
from reasoning_gym.coaching.experiment import Experiment
from reasoning_gym.composite import DatasetSpec
from reasoning_gym.dataset import ProceduralDataset
from reasoning_gym.utils import SYSTEM_PROMPTS, extract_answer
@dataclass
class DatasetConfigItem:
weight: Optional[float] = field(default=1.0)
config: Optional[dict] = field(default_factory=dict)
@dataclass
class DatasetConfig:
dataset_size: int = field(default=1000)
developer_prompt: str = field(default="DeepSeekZero")
developer_role: str = field(default="system")
datasets: dict[str, DatasetConfigItem] = field(default=None)
def __post_init__(self):
# Convert dictionary items to DatasetConfigItem instances
if self.datasets:
converted_datasets = {}
for name, config_item in self.datasets.items():
if isinstance(config_item, dict):
converted_datasets[name] = DatasetConfigItem(**config_item)
else:
converted_datasets[name] = config_item
self.datasets = converted_datasets
class ReasoningGymDataset(Dataset):
def __init__(
self,
tokenizer,
procedural_dataset: Optional[ProceduralDataset] = None,
experiment: Optional[Experiment] = None,
developer_prompt: Optional[str] = None,
developer_role: Optional[str] = None,
):
self.tokenizer = tokenizer
self.data = procedural_dataset or experiment.composite
self.experiment = experiment
self.developer_prompt = developer_prompt
self.developer_role = developer_role
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> dict:
item = self.data[idx]
question = item["question"]
chat = []
if self.developer_role is not None:
chat.append({"role": self.developer_role, "content": self.developer_prompt})
chat.append({"role": "user", "content": question})
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
return {"prompt": prompt, "item": item}
class CustomGRPOTrainer(GRPOTrainer):
def __init__(
self,
model,
args: GRPOConfig,
tokenizer,
train_dataset: ReasoningGymDataset,
eval_dataset: ReasoningGymDataset,
):
super().__init__(
model=model,
reward_funcs=[
self._accuracy_reward,
self._format_reward,
],
args=args,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
def _accuracy_reward(self, completions: list[str], **kwargs) -> list[float]:
assert "item" in kwargs, "The 'item' argument must be provided to compute accuracy reward."
assert len(kwargs["item"]) == len(completions), "Items and completions must have the same length."
assert all(isinstance(item, dict) for item in kwargs["item"]), "Each item must be a dictionary."
answers = [extract_answer(c) for c in completions]
return [self.train_dataset.data.score_answer(answer, item) for answer, item in zip(answers, kwargs["item"])]
def _format_reward(self, completions: list[str], **kwargs) -> list[float]:
def count_tags(text: str) -> float:
count = 0.0
if re.search(r"\s*<think>\s*", text):
count += 0.25
if re.search(r"\s*</think>\s*", text):
count += 0.25
if re.search(r"\s*<answer>\s*", text):
count += 0.25
if re.search(r"\s*</answer>\s*", text):
count += 0.25
return count
return [count_tags(c) for c in completions]
def make_dataset(
tokenizer,
data_source: Experiment | ProceduralDataset,
developer_prompt: str,
developer_role: Optional[str] = None,
) -> ReasoningGymDataset:
"""Create a ReasoningGymDataset from an Experiment or ProceduralDataset."""
if isinstance(data_source, Experiment):
return ReasoningGymDataset(
tokenizer=tokenizer,
experiment=data_source,
developer_prompt=developer_prompt,
developer_role=developer_role,
)
else:
return ReasoningGymDataset(
tokenizer=tokenizer,
procedural_dataset=data_source,
developer_prompt=developer_prompt,
developer_role=developer_role,
)
def prepare_datasets(
config: DatasetConfig,
tokenizer,
) -> tuple[ReasoningGymDataset, ReasoningGymDataset]:
"""Prepare the training and eval datasets."""
developer_prompt = SYSTEM_PROMPTS[config.developer_prompt]
dataset_specs = [
DatasetSpec(
name=name,
weight=ds_config.weight,
config=ds_config.config,
)
for name, ds_config in config.datasets.items()
]
train_data_source = reasoning_gym.create_dataset(
"composite", seed=1, size=config.dataset_size, datasets=dataset_specs
)
val_data_source = reasoning_gym.create_dataset(
"composite", seed=2, size=config.dataset_size, datasets=dataset_specs
)
train_dataset = make_dataset(
tokenizer=tokenizer,
data_source=train_data_source,
developer_prompt=developer_prompt,
developer_role=config.developer_role,
)
eval_dataset = make_dataset(
tokenizer=tokenizer,
data_source=val_data_source,
developer_prompt=developer_prompt,
developer_role=config.developer_role,
)
return train_dataset, eval_dataset
def main():
# -----------
# Parse args
# -----------
parser = TrlParser((DatasetConfig, GRPOConfig, ModelConfig))
reasoning_gym_args, training_args, model_args = parser.parse_args_and_config()
set_seed(training_args.seed)
# ---------------
# Set up logging
# ---------------
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Training parameters {training_args}")
# -----------
# Load model
# -----------
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
torch_dtype=torch.bfloat16,
attn_implementation=model_args.attn_implementation,
use_cache=False if training_args.gradient_checkpointing else True,
)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
# --------------------
# Instantiate trainer
# --------------------
training_args.reasoning_gym = reasoning_gym_args
train_dataset, eval_dataset = prepare_datasets(reasoning_gym_args, tokenizer)
trainer = CustomGRPOTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# ------------------------------
# See if we can resume training
# ------------------------------
logger.info("Starting training...")
# Check for last checkpoint
ckpt = None
if training_args.resume_from_checkpoint is not None:
ckpt = training_args.resume_from_checkpoint
elif os.path.isdir(training_args.output_dir):
ckpt = get_last_checkpoint(training_args.output_dir)
if ckpt:
logger.info(f"\nCheckpoint detected, resuming training at {ckpt=}.")
else:
logger.info("\nNo checkpoint detected, starting training from scratch.")
# ---------------
# Start training
# ---------------
train_result = trainer.train(resume_from_checkpoint=ckpt)
train_metrics = train_result.metrics
trainer.log_metrics("train", train_metrics)
trainer.save_metrics("train", train_metrics)
trainer.save_state()
# ---------
# Clean up
# ---------
del trainer
torch.cuda.empty_cache()
if __name__ == "__main__":
main()

View File

@@ -1,18 +0,0 @@
from dataclasses import dataclass
from typing import Optional
@dataclass
class ScriptArguments:
"""
Arguments for the training script.
"""
dataset_name: str
dataset_config: Optional[str] = None
dataset_train_split: str = "train"
dataset_test_split: str = "test"
gradient_checkpointing_use_reentrant: bool = False
ignore_bias_buffers: bool = False
train_size: int = 1000
eval_size: int = 100

View File

@@ -1,217 +0,0 @@
# This example is an adapted version of HuggingFace trl GRPO code:
# link : https://github.com/huggingface/open-r1/blob/main/src/open_r1/grpo.py
import logging
import os
import re
import sys
from dataclasses import dataclass
from typing import Optional
import datasets
import torch
import transformers
from grpo_config import ScriptArguments
from peft import LoraConfig
from torch.utils.data import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from transformers.trainer_utils import get_last_checkpoint
from trl import GRPOConfig, GRPOTrainer, ModelConfig, TrlParser
import reasoning_gym
from reasoning_gym.utils import extract_answer
class ReasoningGymDataset(Dataset):
def __init__(self, dataset_name, seed, size, tokenizer, developer_prompt, developer_role="system") -> None:
super().__init__()
self.data = reasoning_gym.create_dataset(dataset_name, seed=seed, size=size)
self.tokenizer = tokenizer
self.developer_role = developer_role
self.developer_prompt = developer_prompt
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
question = item["question"]
chat = []
if self.developer_role is not None:
chat.append({"role": self.developer_role, "content": self.developer_prompt})
chat.append({"role": "user", "content": question})
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
return {"prompt": prompt, "metadata": item}
class GRPOTrainerCustom(GRPOTrainer):
def __init__(
self,
model,
dataset_name,
args: GRPOConfig,
tokenizer,
peft_config,
seed,
size,
developer_role="system",
):
super().__init__(
model,
reward_funcs=[self._accuracy_reward, self._format_reward],
args=args,
processing_class=tokenizer,
peft_config=peft_config,
)
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"]
self.train_dataset = ReasoningGymDataset(dataset_name, seed, size, tokenizer, developer_prompt, developer_role)
def _format_reward(self, completions, **kwargs):
regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
matches = [re.match(regex, completion, flags=re.DOTALL) for completion in completions]
return [1.0 if match else 0.0 for match in matches]
def _accuracy_reward(self, completions, metadata, **kwargs):
answers = [extract_answer(completion) for completion in completions]
return [self.train_dataset.data.score_answer(answer, entry=obj) for (answer, obj) in zip(answers, metadata)]
def main(script_args, training_args, model_args):
set_seed(training_args.seed)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level) # Set for module-level logger
# Configure third-party library log levels
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.info(f"Training arguments: {training_args}")
logger.info(f"Model arguments: {model_args}")
logger.info(f"Script arguments: {script_args}")
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
peft_config = LoraConfig(
r=16,
lora_alpha=64,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
task_type="CAUSAL_LM",
lora_dropout=0.05,
)
trainer = GRPOTrainerCustom(
model,
dataset_name=script_args.dataset_name,
args=training_args,
tokenizer=tokenizer,
peft_config=peft_config,
seed=training_args.seed,
size=script_args.train_size,
)
# Training loop
logger.info("Training model...")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is None:
checkpoint = model.save_pretrained(training_args.output_dir)
train_results = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_results.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
kwargs = {
"finetuned_from": model_args.model_name_or_path,
"dataset": list(script_args.dataset_name),
"dataset_tags": list(script_args.dataset_name),
"tags": ["reasoning-gym"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
def evaluate_model(model, tokenizer, dataset, *args, **kwargs):
model.eval()
correct_preds = 0
total_preds = 0
for i in range(len(dataset)):
item = dataset[i]
prompt = item["prompt"]
metadata = item["metadata"]
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model.generate(
inputs,
max_new_tokens=training_args.max_completion_length,
pad_token_id=tokenizer.eos_token_id,
*args,
**kwargs,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = reasoning_gym.utils.extract_answer(generated_text)
score = dataset.data.score_answer(answer, entry=metadata)
correct_preds += score
total_preds += 1
return correct_preds / total_preds
## Evaluate model
logger.info("Evaluating model...")
eval_dataset = ReasoningGymDataset(
script_args.dataset_name,
training_args.eval_seed,
script_args.eval_size,
tokenizer,
reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"],
)
eval_results = evaluate_model(model, tokenizer, eval_dataset)
trainer.log_metrics("eval", {"accuracy": eval_results})
trainer.save_metrics("eval", {"accuracy": eval_results})
logger.info(f"Evaluation results: {eval_results}")
if training_args.push_to_hub:
logging.info("Pushing model to hub...")
trainer.push_to_hub()
if __name__ == "__main__":
parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)

View File

@@ -1,8 +0,0 @@
torch>=2.6.0
datasets
peft
transformers
trl
wandb
huggingface_hub
flash-attn --no-build-isolation

View File

@@ -2,19 +2,11 @@
# python 3.10 + cuda 11.8.0
# the execution order the following commands matter
export MKL_NUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export OMP_NUM_THREADS=1
conda clean -a -y
mamba clean -a -y
pip install --upgrade pip
pip cache purge
# cuda, gcc/g++, torch
# mamba install cuda -c nvidia/label/cuda-11.8.0 -y
# mamba install gcc gxx -c conda-forge -y
# torch
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu118
# xformers
@@ -25,15 +17,7 @@ pip install https://github.com/vllm-project/vllm/releases/download/v0.7.2/vllm-0
pip install deepspeed
pip install flash-attn==2.7.3 --no-build-isolation
pip install peft
pip install "trl==0.15.2"
pip install latex2sympy2_extended
pip install "math_verify==0.5.2"
pip install word2number
pip install scipy
pip install "transformers==4.49.0"
pip install wandb
pip install plotly
pip install matplotlib
pip install seaborn
pip install reasoning-gym

26
examples/trl/train.sh Executable file
View File

@@ -0,0 +1,26 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
NUM_PROCESSES_TRAINING=$((GPU_COUNT - 1))
echo ""
echo "Number of GPUs: ${GPU_COUNT}"
echo "Number of processes for training: ${NUM_PROCESSES_TRAINING}"
echo ""
PY_SCRIPT="./grpo.py"
PY_CONFIG="./config/grpo.yaml"
ACCELERATE_DS_CONFIG="./config/ds_zero2.yaml"
echo "START TIME: $(date)"
export WANDB_PROJECT="reasoning-gym-trl"
accelerate launch \
--config_file "${ACCELERATE_DS_CONFIG}" \
--main_process_port=29500 \
--num_processes="${NUM_PROCESSES_TRAINING}" "${PY_SCRIPT}" --config "${PY_CONFIG}"
echo "END TIME: $(date)"
echo "DONE"

View File

@@ -2,8 +2,6 @@
Training codebase for training LLMs using Reasoning Gym procedural dataset generators.
**Note**: `qwen-math/` directory contains the code from the Tina project, used for the Qwen2.5 3B RG-Math training. This is separate from the rest of our training/evaluation codebase.
This readme documents:
- Training environment setup and usage example

View File

@@ -1,180 +0,0 @@
# Byte-compiled / optimized / DLL files
.DS_Store
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
ckpts/
outputs/

View File

@@ -1,201 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -1,179 +0,0 @@
<div align="center">
<h1 style="font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; margin-bottom: 10px;">
Tina: Tiny Reasoning Models via LoRA
</h1>
<hr style="width: 60%; border: none; border-top: 2px solid #ccc; margin: 0 auto 20px auto;">
<a href="https://github.com/shangshang-wang/Tina">
<img src="./assets/Avatar-Tina.png" style="
width: 200px;
border-radius: 20px;
box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2);
border: 3px solid #f18f01;
transition: transform 0.3s ease;
"
onmouseover="this.style.transform='scale(1.05)'"
onmouseout="this.style.transform='scale(1)'">
</a>
</div>
<div align="center">
[![Github](https://img.shields.io/badge/Tina-000000?style=for-the-badge&logo=github&logoColor=000&logoColor=white)](https://github.com/shangshang-wang/Tina)
[![Website](https://img.shields.io/badge/Notion-%23000000.svg?style=for-the-badge&logo=semanticweb&logoColor=white)](https://shangshangwang.notion.site/tina)
[![Hugging Face Collection](https://img.shields.io/badge/Tina_Yi-fcd022?style=for-the-badge&logo=huggingface&logoColor=000&labelColor)](https://huggingface.co/Tina-Yi)
[![Weights and Biases](https://img.shields.io/badge/Tina-fcd022?style=for-the-badge&logo=weightsandbiases&logoColor=000&labelColor)](https://wandb.ai/upup-ashton-wang-usc/Tina)
</div>
## Overview
This repository contains the code for the Tina project, accompanying the paper [Tina: Tiny Reasoning Models via LoRA](https://arxiv.org/abs/2504.15777).
We in this project try to answer the question "How cost-effectively can one perform reinforcement learning to efficiently instill reasoning abilities in language models?"
Specifically, we explore enhancing reasoning capabilities in tiny language models with low-rank adaptation during reinforcement learning.
<div style="text-align: center;">
<img
src="assets/overall_comparison.png"
alt="Overall Comparison"
width="1000"
style="max-width: 100%; height: auto;">
</div>
We show that our Tina models achieve performance competitive with, and in some cases even superior to, SOTA baseline models built on the same base model with full-parameter training.
In particular, the best Tina model achieves a >20% performance increase and 43.33% Pass@1 accuracy on AIME24.
Notably, the cost of reproducing the best Tina checkpoint stands at only \$9, and of reproducing all our experiments from scratch at \$526.
<div style="text-align: center;">
<img
src="assets/cost.png"
alt="Cost Breakdown"
style="max-width: 50%; height: auto;">
</div>
## Quick Start
### File Setup
* `./scripts/set/set_vars.sh`: contain the main env vars we use. Change the paths (e.g. `PROJECT_PREFIX`, `SCRATCH_PREFIX`) to align with your own setting. Also make sure to add the `WANDB_API_KEY` and `HF_TOKEN` in your `~/.bashrc` file.
* `./recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/`: contain the recipes for each experiment in this project, change the HF hub id to align with your own setting.
* `./tina/config.py`: contain the main configurations for this project, set default values here.
* `./tina/utils/constant.py`: contain the main datasets for each experiment in this project.
### Env Setup
First, install Miniconda:
```bash
mkdir -p ~/miniconda3
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
source ~/miniconda3/bin/activate
conda init --all
```
Then, run the following commands to install the dependencies.
```bash
conda update -n base -c defaults conda -y
conda install -n base -c conda-forge mamba -y
mamba create -n tina python=3.10 -y && mamba activate tina
./scripts/set/set_env.sh && mamba deactivate
mamba create -n tina_eval python=3.11 -y && mamba activate tina_eval
./scripts/set/set_env_eval.sh && mamba deactivate
# download the pre-trained models to the `CKPT_DIR` directory.
./scripts/set/prepare.sh
```
>[!IMPORTANT]
> For **Reasoning Gym** you need to install `lighteval` from source with a particular branch because of a known issue with evaluating on low-sample datasets such as AIME24.
> When the branch is merged into the main branch, we will update the instructions accordingly.
```bash
cd /path/to/installation/folder # e.g. /root/projects
git clone git@github.com:huggingface/lighteval.git
cd lighteval
git checkout remotes/origin/tune-pass-at-k
pip install -e .
```
### Training & Evaluation
* LoRA-based RL with GRPO: `./scripts/training/post_train_grpo.sh`
<div style="text-align: center;">
<img
src="assets/ablation.png"
alt="Ablation"
style="max-width: 50%; height: auto;">
</div>
After that, we have the following file structure in the `CKPT_DIR` directory.
```bash
CKPT_DIR/
├── models/
│ ├── DeepSeek-R1-Distill-Qwen-1.5B/
│ │ └── base/ # pre-trained models
│ │ └── grpo_PT_DATASET_I/ # post-trained models via GRPO using PT_DATASET_I
│ │ │ └── checkpoint-i/ # we should keep checkpoints during post-training in a stepwise manner
│ │ │ └── ...
│ │ └── grpo_PT_DATASET_II/ # post-trained models via GRPO using PT_DATASET_II
│ │ │ └── checkpoint-i/
│ │ │ └── ...
│ │ └── ...
```
* Re-evaluate baseline models: `./scripts/training/post_train_eval_baselines.sh`
<div style="text-align: center;">
<img
src="assets/baseline_eval.png"
alt="Baseline Re-evaluation"
style="max-width: 30%; height: auto;">
</div>
* Evaluate post-trained models: `./scripts/training/post_train_eval_local.sh`
<div style="text-align: center;">
<img
src="assets/tina_eval.png"
alt="Tina Evaluation"
style="max-width: 40%; height: auto;">
</div>
## Acknowledgements
We thank Huggingface to open source the amazing [open-r1](https://github.com/huggingface/open-r1/tree/7041fbc9d65b6f1832db727961e8282243f8f82a) project, which is the starting codebase of our Tina project.
We also appreciate all researchers releasing their open-source reasoning datasets, including [open-r1/OpenR1-Math-220k](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k), [bethgelab/CuratedThoughts](https://huggingface.co/datasets/bethgelab/CuratedThoughts), [agentica-org/DeepScaleR-Preview-Dataset](https://huggingface.co/datasets/agentica-org/DeepScaleR-Preview-Dataset), [RUC-AIBOX/STILL-3-Preview-RL-Data](https://huggingface.co/datasets/RUC-AIBOX/STILL-3-Preview-RL-Data), [knoveleng/open-rs](https://huggingface.co/datasets/knoveleng/open-rs), [knoveleng/open-s1](https://huggingface.co/datasets/knoveleng/open-s1), and [GAIR/LIMR](https://huggingface.co/datasets/GAIR/LIMR), which are used for our training.
*Tina's avatar is generated by GPT-4o based on [KYNE](https://www.artsy.net/artist/kyne)'s girls and the following prompt.*
*Hi, Im Tina — an INTJ whos all about getting to the essence of things. I study reasoning models because Im fascinated by how structured thinking and logic can emerge from data. Outside of that, I recharge with movies, music, and the occasional gaming session. I believe in strategic effort: minimal input, maximum impact — whether its in research or everyday life, Im always looking for the most efficient path to meaningful results.*
## Citation
```cite
@misc{wang2025tinatinyreasoningmodels,
title={Tina: Tiny Reasoning Models via LoRA},
author={Shangshang Wang and Julian Asilis and Ömer Faruk Akgül and Enes Burak Bilgin and Ollie Liu and Willie Neiswanger},
year={2025},
eprint={2504.15777},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2504.15777},
}
```

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.9 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 332 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 151 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 225 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 103 KiB

File diff suppressed because it is too large Load Diff

View File

@@ -1,68 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_deepscaler
model_post_train_type: grpo
rl_post_train_reward_funcs:
- format
- accuracy
rl_post_train_reward_weights:
- 1.0
- 2.0
# Model configs from trl
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
attn_implementation: flash_attention_2
use_peft: true
lora_r: 32
lora_alpha: 128
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
bf16: true
use_vllm: true
vllm_device: cuda:0
vllm_gpu_memory_utilization: 0.4
vllm_max_model_len: 4608
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
push_to_hub: false
hub_strategy: every_save
hub_private_repo: true
hub_model_id: TODO
learning_rate: 1e-06
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 3584
max_steps: 11000 # use 11000 for lr scheduler but stop at 5500 steps
num_generations: 4
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 4
report_to:
- wandb
save_strategy: steps
save_steps: 100
save_total_limit: 100
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View File

@@ -1,68 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_limr
model_post_train_type: grpo
rl_post_train_reward_funcs:
- format
- accuracy
rl_post_train_reward_weights:
- 1.0
- 2.0
# Model configs from trl
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
attn_implementation: flash_attention_2
use_peft: true
lora_r: 32
lora_alpha: 128
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
bf16: true
use_vllm: true
vllm_device: cuda:0
vllm_gpu_memory_utilization: 0.4
vllm_max_model_len: 4608
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
push_to_hub: false
hub_strategy: every_save
hub_private_repo: true
hub_model_id: TODO
learning_rate: 1e-06
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 3584
max_steps: 360 # use 360 for lr scheduler but stop at 180 steps
num_generations: 4
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 4
report_to:
- wandb
save_strategy: steps
save_steps: 10
save_total_limit: 100
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View File

@@ -1,68 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_limr_large_lr_ablation
model_post_train_type: grpo
rl_post_train_reward_funcs:
- format
- accuracy
rl_post_train_reward_weights:
- 1.0
- 2.0
# Model configs from trl
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
attn_implementation: flash_attention_2
use_peft: true
lora_r: 32
lora_alpha: 128
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
bf16: true
use_vllm: true
vllm_device: cuda:0
vllm_gpu_memory_utilization: 0.4
vllm_max_model_len: 4608
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
push_to_hub: false
hub_strategy: every_save
hub_private_repo: true
hub_model_id: TODO
learning_rate: 5e-06
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 3584
max_steps: 360 # use 360 for lr scheduler but stop at 180 steps
num_generations: 4
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 4
report_to:
- wandb
save_strategy: steps
save_steps: 10
save_total_limit: 100
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View File

@@ -1,68 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_limr_large_rank_ablation
model_post_train_type: grpo
rl_post_train_reward_funcs:
- format
- accuracy
rl_post_train_reward_weights:
- 1.0
- 2.0
# Model configs from trl
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
attn_implementation: flash_attention_2
use_peft: true
lora_r: 128
lora_alpha: 512
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
bf16: true
use_vllm: true
vllm_device: cuda:0
vllm_gpu_memory_utilization: 0.4
vllm_max_model_len: 4608
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
push_to_hub: false
hub_strategy: every_save
hub_private_repo: true
hub_model_id: TODO
learning_rate: 1e-06
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 3584
max_steps: 360 # use 360 for lr scheduler but stop at 180 steps
num_generations: 4
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 4
report_to:
- wandb
save_strategy: steps
save_steps: 50
save_total_limit: 100
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View File

@@ -1,68 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_limr_medium_rank_ablation
model_post_train_type: grpo
rl_post_train_reward_funcs:
- format
- accuracy
rl_post_train_reward_weights:
- 1.0
- 2.0
# Model configs from trl
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
attn_implementation: flash_attention_2
use_peft: true
lora_r: 16
lora_alpha: 64
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
bf16: true
use_vllm: true
vllm_device: cuda:0
vllm_gpu_memory_utilization: 0.4
vllm_max_model_len: 4608
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
push_to_hub: false
hub_strategy: every_save
hub_private_repo: true
hub_model_id: TODO
learning_rate: 1e-06
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 3584
max_steps: 360 # use 360 for lr scheduler but stop at 180 steps
num_generations: 4
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 4
report_to:
- wandb
save_strategy: steps
save_steps: 50
save_total_limit: 100
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View File

@@ -1,68 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_limr_small_lr_ablation
model_post_train_type: grpo
rl_post_train_reward_funcs:
- format
- accuracy
rl_post_train_reward_weights:
- 1.0
- 2.0
# Model configs from trl
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
attn_implementation: flash_attention_2
use_peft: true
lora_r: 32
lora_alpha: 128
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
bf16: true
use_vllm: true
vllm_device: cuda:0
vllm_gpu_memory_utilization: 0.4
vllm_max_model_len: 4608
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
push_to_hub: false
hub_strategy: every_save
hub_private_repo: true
hub_model_id: TODO
learning_rate: 5e-07
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 3584
max_steps: 360 # use 360 for lr scheduler but stop at 180 steps
num_generations: 4
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 4
report_to:
- wandb
save_strategy: steps
save_steps: 10
save_total_limit: 100
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View File

@@ -1,68 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_limr_small_rank_ablation
model_post_train_type: grpo
rl_post_train_reward_funcs:
- format
- accuracy
rl_post_train_reward_weights:
- 1.0
- 2.0
# Model configs from trl
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
attn_implementation: flash_attention_2
use_peft: true
lora_r: 8
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
bf16: true
use_vllm: true
vllm_device: cuda:0
vllm_gpu_memory_utilization: 0.4
vllm_max_model_len: 4608
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
push_to_hub: false
hub_strategy: every_save
hub_private_repo: true
hub_model_id: TODO
learning_rate: 1e-06
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 3584
max_steps: 360 # use 360 for lr scheduler but stop at 180 steps
num_generations: 4
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 4
report_to:
- wandb
save_strategy: steps
save_steps: 50
save_total_limit: 100
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View File

@@ -1,68 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_limr_tiny_rank_ablation
model_post_train_type: grpo
rl_post_train_reward_funcs:
- format
- accuracy
rl_post_train_reward_weights:
- 1.0
- 2.0
# Model configs from trl
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
attn_implementation: flash_attention_2
use_peft: true
lora_r: 4
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
bf16: true
use_vllm: true
vllm_device: cuda:0
vllm_gpu_memory_utilization: 0.4
vllm_max_model_len: 4608
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
push_to_hub: false
hub_strategy: every_save
hub_private_repo: true
hub_model_id: TODO
learning_rate: 1e-06
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 3584
max_steps: 360 # use 360 for lr scheduler but stop at 180 steps
num_generations: 4
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 4
report_to:
- wandb
save_strategy: steps
save_steps: 50
save_total_limit: 100
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View File

@@ -1,76 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_open_r1
model_post_train_type: grpo
rl_post_train_reward_funcs:
- accuracy
- format
- tag_count
- length
- reasoning_steps
- cosine
- repetition_penalty
rl_post_train_reward_weights:
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
# Model configs from trl
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
attn_implementation: flash_attention_2
use_peft: true
lora_r: 32
lora_alpha: 128
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
bf16: true
use_vllm: true
vllm_device: cuda:0
vllm_gpu_memory_utilization: 0.4
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
push_to_hub: false
hub_strategy: every_save
hub_private_repo: true
hub_model_id: TODO
learning_rate: 1.0e-06
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 3584
num_generations: 8
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 8
report_to:
- wandb
save_strategy: steps
save_steps: 100
save_total_limit: 100
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View File

@@ -1,68 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_open_rs1
model_post_train_type: grpo
rl_post_train_reward_funcs:
- format
- accuracy
rl_post_train_reward_weights:
- 1.0
- 2.0
# Model configs from trl
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
attn_implementation: flash_attention_2
use_peft: true
lora_r: 32
lora_alpha: 128
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
bf16: true
use_vllm: true
vllm_device: cuda:0
vllm_gpu_memory_utilization: 0.4
vllm_max_model_len: 4608
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
push_to_hub: false
hub_strategy: every_save
hub_private_repo: true
hub_model_id: TODO
learning_rate: 1.0e-06
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 3584
max_steps: 5000 # use 5000 for lr scheduler but stop at 2400 steps
num_generations: 4
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 4
report_to:
- wandb
save_strategy: steps
save_steps: 100
save_total_limit: 100
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View File

@@ -1,68 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_open_rs2
model_post_train_type: grpo
rl_post_train_reward_funcs:
- format
- accuracy
rl_post_train_reward_weights:
- 1.0
- 2.0
# Model configs from trl
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
attn_implementation: flash_attention_2
use_peft: true
lora_r: 32
lora_alpha: 128
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
bf16: true
use_vllm: true
vllm_device: cuda:0
vllm_gpu_memory_utilization: 0.4
vllm_max_model_len: 4608
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
push_to_hub: false
hub_strategy: every_save
hub_private_repo: true
hub_model_id: TODO
learning_rate: 1.0e-06
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 3584
max_steps: 1500 # use 1500 for lr scheduler but stop at 850 steps
num_generations: 6
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 6
report_to:
- wandb
save_strategy: steps
save_steps: 50
save_total_limit: 100
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View File

@@ -1,68 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_open_rs3
model_post_train_type: grpo
rl_post_train_reward_funcs:
- format
- cosine
rl_post_train_reward_weights:
- 1.0
- 2.0
# Model configs from trl
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
attn_implementation: flash_attention_2
use_peft: true
lora_r: 32
lora_alpha: 128
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
bf16: true
use_vllm: true
vllm_device: cuda:0
vllm_gpu_memory_utilization: 0.4
vllm_max_model_len: 4608
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
push_to_hub: false
hub_strategy: every_save
hub_private_repo: true
hub_model_id: TODO
learning_rate: 1.0e-06
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 3584
max_steps: 1500 # use 1500 for lr scheduler but stop at 850 steps
num_generations: 6
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 6
report_to:
- wandb
save_strategy: steps
save_steps: 50
save_total_limit: 100
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View File

@@ -1,68 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_open_rs3_drgrpo_ablation
model_post_train_type: grpo
rl_post_train_reward_funcs:
- format
- cosine
rl_post_train_reward_weights:
- 1.0
- 2.0
# Model configs from trl
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
attn_implementation: flash_attention_2
use_peft: true
lora_r: 32
lora_alpha: 128
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
scale_reward: false # use Dr. GRPO's normalization
bf16: true
use_vllm: true
vllm_device: cuda:0
vllm_gpu_memory_utilization: 0.4
vllm_max_model_len: 4608
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
push_to_hub: false
hub_strategy: every_save
hub_private_repo: true
hub_model_id: TODO
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
learning_rate: 1.0e-06
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 3584
max_steps: 1500 # use 1500 for lr scheduler but stop at 850 steps
num_generations: 6
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 6
report_to:
- wandb
save_strategy: steps
save_steps: 50
save_total_limit: 100
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View File

@@ -1,68 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_rg_math
model_post_train_type: grpo
rl_post_train_reward_funcs:
- format
- accuracy
rl_post_train_reward_weights:
- 1.0
- 2.0
# Model configs from trl
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
attn_implementation: flash_attention_2
use_peft: true
lora_r: 32
lora_alpha: 128
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
bf16: true
use_vllm: true
vllm_device: cuda:0
vllm_gpu_memory_utilization: 0.4
vllm_max_model_len: 4608
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
push_to_hub: true
hub_strategy: every_save
hub_private_repo: true
hub_model_id: starzmustdie
learning_rate: 1e-06
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 3584
max_steps: 2500 # use 2500 for lr scheduler but stop at 1250 steps
num_generations: 4
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 4
report_to:
- wandb
save_strategy: steps
save_steps: 100
save_total_limit: 100
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View File

@@ -1,68 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_still
model_post_train_type: grpo
rl_post_train_reward_funcs:
- length
- accuracy
rl_post_train_reward_weights:
- 1.0
- 2.0
# Model configs from trl
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
attn_implementation: flash_attention_2
use_peft: true
lora_r: 32
lora_alpha: 128
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
bf16: true
use_vllm: true
vllm_device: cuda:0
vllm_gpu_memory_utilization: 0.4
vllm_max_model_len: 4608
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
push_to_hub: false
hub_strategy: every_save
hub_private_repo: true
hub_model_id: TODO
learning_rate: 1e-06
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 3584
max_steps: 7500 # use 7500 for lr scheduler but stop at 3750 steps
num_generations: 4
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 4
report_to:
- wandb
save_strategy: steps
save_steps: 100
save_total_limit: 100
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View File

@@ -1,76 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_thoughts
model_post_train_dataset_config: OpenThoughts-114k-math-default
model_post_train_type: grpo
rl_post_train_reward_funcs:
- accuracy
- format
- tag_count
- length
- reasoning_steps
- cosine
- repetition_penalty
rl_post_train_reward_weights:
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
- 1.0
# Model configs from trl
model_name_or_path: DeepSeek-R1-Distill-Qwen-1.5B
attn_implementation: flash_attention_2
use_peft: true
lora_r: 32
lora_alpha: 128
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
bf16: true
use_vllm: true
vllm_device: cuda:0
vllm_gpu_memory_utilization: 0.6
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
push_to_hub: false
hub_strategy: every_save
hub_private_repo: true
hub_model_id: TODO
learning_rate: 1.0e-06
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 3584
num_generations: 8
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 8
report_to:
- wandb
save_strategy: steps
save_steps: 100
save_total_limit: 100
seed: 42
temperature: 0.7
warmup_ratio: 0.1

View File

@@ -1,68 +0,0 @@
# check ./tina/utils/constant.py
model_post_train_dataset_name: curated_rg_math
model_post_train_type: grpo
rl_post_train_reward_funcs:
- tag_count_reward
- length
- accuracy
rl_post_train_reward_weights:
- 1.0
- 1.0
- 1.0
# Model configs from trl
model_name_or_path: Qwen2.5-3B-Instruct
attn_implementation: flash_attention_2
use_peft: false # full fine tune
lora_r: 32
lora_alpha: 128
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
- gate_proj
# GRPO trainer configs from trl
bf16: true
use_vllm: true
vllm_device: cuda:2
vllm_gpu_memory_utilization: 0.9
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
push_to_hub: true
hub_strategy: every_save
hub_private_repo: true
hub_model_id: starzmustdie
learning_rate: 1e-06
lr_scheduler_type: constant_with_warmup
lr_scheduler_kwargs:
num_warmup_steps: 60
max_prompt_length: 512
max_completion_length: 2048
max_steps: 600
num_generations: 8
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 8
report_to:
- wandb
save_strategy: steps
save_steps: 200
save_total_limit: 100
seed: 42
temperature: 0.6

View File

@@ -1,20 +0,0 @@
#!/bin/bash
MAMBA_ENV="tina"
eval "$(mamba shell hook --shell bash)" && mamba activate "${MAMBA_ENV}"
echo "START TIME: $(date)"
echo "PYTHON ENV: $(which python)"
source "./scripts/set/set_vars.sh"
PY_SCRIPT="./scripts/set/run_download_model.py"
echo ""
echo "Running script: ${PY_SCRIPT}"
echo ""
python "${PY_SCRIPT}"
echo "END TIME: $(date)"
echo "DONE"

View File

@@ -1,9 +0,0 @@
import os
from huggingface_hub import snapshot_download
if __name__ == "__main__":
CKPT_DIR = os.environ["CKPT_DIR"]
print("Downloading model ...")
snapshot_download(repo_id="Qwen/Qwen2.5-3B-Instruct", local_dir=f"{CKPT_DIR}/models/Qwen2.5-3B-Instruct/base")

View File

@@ -1,46 +0,0 @@
#!/bin/bash
# python 3.11 & cuda 11.8
export MKL_NUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export OMP_NUM_THREADS=1
conda clean -a -y
mamba clean -a -y
pip install --upgrade pip
pip cache purge
# mamba install cuda -c nvidia/label/cuda-11.8.0 -y
# mamba install gcc gxx -c conda-forge -y
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install xformers --index-url https://download.pytorch.org/whl/cu118
pip install vllm
pip install flash-attn --no-build-isolation
pip install accelerate
pip install datasets
pip install deepspeed
pip install distilabel[vllm,ray,openai]
pip install e2b-code-interpreter
pip install einops
pip install flake8
pip install huggingface_hub
pip install hf_transfer
pip install isort
pip install langdetect
pip install latex2sympy2_extended
pip install liger_kernel
pip install "math_verify==0.5.2"
pip install packaging
pip install parameterized
pip install peft
pip install pytest
pip install python-dotenv
pip install ruff
pip install safetensors
pip install sentencepiece
pip install transformers
pip install trl@git+https://github.com/huggingface/trl.git
pip install wandb

View File

@@ -1,53 +0,0 @@
#!/bin/bash
export CUDA_LAUNCH_BLOCKING=1
export DS_LOG_LEVEL=error
export TOKENIZERS_PARALLELISM=false
export NCCL_P2P_DISABLE=1
export NCCL_SHM_DISABLE=1
export NCCL_IB_DISABLE=1
export MKL_THREADING_LAYER=GNU
export MKL_NUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export OMP_NUM_THREADS=1
## basic setup for the env
export PROJECT_PREFIX="/root/projects" # e.g. /home/username/projects
export SCRATCH_PREFIX="/root/scratch" # e.g. /home/username/scratch
mkdir -p "${PROJECT_PREFIX}" "${SCRATCH_PREFIX}"
export PROJECT_NAME="rg-math"
export CORE_POSTFIX="tina"
export PROJECT_DIR="${PROJECT_PREFIX}/${PROJECT_NAME}"
export PYTHONPATH="${PROJECT_DIR}":$PYTHONPATH
export PYTHONPATH="${PROJECT_DIR}/${CORE_POSTFIX}":$PYTHONPATH
mkdir -p "${PROJECT_PREFIX}/${PROJECT_NAME}"
export CKPT_DIR="${PROJECT_DIR}/ckpts"
export DATA_DIR="${PROJECT_DIR}/datasets"
export OUTPUT_DIR="${PROJECT_DIR}/outputs"
export LOGGING_DIR="${PROJECT_DIR}/logs"
mkdir -p "${CKPT_DIR}" "${DATA_DIR}" "${OUTPUT_DIR}" "${LOGGING_DIR}"
## wandb setup
# export WANDB_API_KEY="TODO"
export WANDB_PROJECT="${PROJECT_NAME}"
export WANDB_DIR="${OUTPUT_DIR}"
wandb login $WANDB_API_KEY
export CACHE_DIR="${PROJECT_DIR}/.cache"
export WANDB_CACHE_DIR="${CACHE_DIR}"
export TRITON_CACHE_DIR="${CACHE_DIR}/triton_cache"
## huggingface setup
# export HF_TOKEN="TODO"
git config --global credential.helper store
huggingface-cli login --token $HF_TOKEN --add-to-git-credential
export HF_HOME="${CACHE_DIR}/huggingface"
export HUGGINGFACE_HUB_CACHE="${HF_HOME}/hub"
export HF_DATASETS_CACHE="${HF_HOME}/datasets"

View File

@@ -1,54 +0,0 @@
#!/bin/bash
#SBATCH --job-name=grpo_multinode
#SBATCH -D .
#SBATCH --partition=TODO
#SBATCH --account=TODO
#SBATCH --output=output-%x.%j
#SBATCH --error=error-%x.%j
#SBATCH --nodes=2 # number of nodes
#SBATCH --ntasks-per-node=1 # number of MP tasks
#SBATCH --gres=gpu:2 # number of GPUs per node
#SBATCH --cpus-per-task=8 # number of cores per tasks
#SBATCH --mem=128G
#SBATCH --time=48:00:00 # maximum execution time (HH:MM:SS)
#SBATCH --comment "Key=Monitoring,Value=ON"
#SBATCH --exclusive
######################
### Set environment ##
######################
ulimit -s unlimited
MAMBA_ENV="tina"
eval "$(mamba shell hook --shell bash)" && mamba activate "${MAMBA_ENV}"
echo "START TIME: $(date)"
echo "PYTHON ENV: $(which python)"
source "./scripts/set/set_vars.sh"
export GPUS_PER_NODE=2
######################
######################
#### Set network #####
######################
head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
######################
export LAUNCHER="accelerate launch \
--num_processes $((SLURM_NNODES * GPUS_PER_NODE)) \
--num_machines $SLURM_NNODES \
--machine_rank $SLURM_NODEID \
--rdzv_backend c10d \
--main_process_ip $head_node_ip \
--main_process_port 29500 \
"
PY_SCRIPT="./tina/post_train_hf/grpo.py"
PY_CONFIG="./recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/model_curated_deepscaler.yaml"
# This step is necessary because accelerate launch does not handle multiline arguments properly
export CMD="$LAUNCHER $PY_SCRIPT --config $PY_CONFIG"
srun $CMD

View File

@@ -1,37 +0,0 @@
#!/bin/bash
# use by running `bash sbatch_launch.sh <script.slurm>`
cleanup() {
echo "Script interrupted. Cleaning up..."
scancel "$job_id" 2>/dev/null
echo "Job $job_id has been canceled."
exit 1
}
trap cleanup SIGINT
# launch the slurm script
SLURM_FILE=$1
echo "Launching $SLURM_FILE ..."
job_id=$(sbatch $SLURM_FILE | awk '{print $4}')
echo "Submitted job with ID: $job_id"
# Wait until the job is running
while true; do
job_status=$(squeue -j "$job_id" -h -o "%T")
if [ "$job_status" == "RUNNING" ]; then
echo "Job $job_id is now running."
sleep 5
break
elif [ -z "$job_status" ]; then
echo "Job $job_id has finished or failed before reaching running state."
exit 1
else
echo "Job $job_id is still in $job_status state. Checking again in 10 seconds..."
sleep 10
fi
done
# Plot the real-time output
output_file=$(scontrol show job "$job_id" | awk -F= '/StdOut/ {print $2}' | sed "s/%A/${job_id}/g" | sed "s/%a/1/g")
echo "Tailing output file: $output_file"
tail -f "$output_file"

View File

@@ -1,32 +0,0 @@
#!/bin/bash
MAMBA_ENV="tina_eval"
eval "$(mamba shell hook --shell bash)" && mamba activate "${MAMBA_ENV}"
echo "START TIME: $(date)"
echo "PYTHON ENV: $(which python)"
source "./scripts/set/set_vars.sh"
export CUDA_VISIBLE_DEVICES="0" # NOTE: update this if you have more than 1 GPU
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
echo ""
echo "GPU_COUNT: $GPU_COUNT"
echo ""
# MODEL_LIST=("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" "starzmustdie/DeepSeek-R1-Distill-Qwen-1.5B-rg-math" "agentica-org/DeepScaleR-1.5B-Preview" "knoveleng/Open-RS3" "RUC-AIBOX/STILL-3-1.5B-preview")
MODEL_LIST=("Qwen/Qwen2.5-3B-Instruct" "starzmustdie/Qwen2.5-3B-Instruct")
TASKS=("aime24" "aime25" "amc23" "minerva" "math_500")
for MODEL_NAME in "${MODEL_LIST[@]}"; do
for TASK in "${TASKS[@]}"; do
MODEL_ARGS="model_name=$MODEL_NAME,dtype=bfloat16,data_parallel_size=$GPU_COUNT,max_model_length=32768,gpu_memory_utilization=0.7,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
--custom-tasks ./scripts/training/run_post_train_eval.py \
--use-chat-template \
--output-dir "${OUTPUT_DIR}/${TASK}"
done
done
echo "END TIME: $(date)"
echo "DONE"

View File

@@ -1,73 +0,0 @@
#!/bin/bash
MAMBA_ENV="tina_eval"
eval "$(mamba shell hook --shell bash)" && mamba activate "${MAMBA_ENV}"
echo "START TIME: $(date)"
echo "PYTHON ENV: $(which python)"
source "./scripts/set/set_vars.sh"
export CUDA_VISIBLE_DEVICES=0,1 # make sure all evaluation run on 2 GPUs
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
echo ""
echo "GPU_COUNT: $GPU_COUNT, make sure using 2 GPUs."
echo ""
MODEL_NAME="DeepSeek-R1-Distill-Qwen-1.5B"
PT_TYPE="grpo"
## Main datasets
DATASET_NAME="curated_deepscaler"
#DATASET_NAME="curated_still"
#DATASET_NAME="curated_open_rs3"
#DATASET_NAME="curated_open_rs2"
#DATASET_NAME="curated_open_rs1"
## Extra datasets
#DATASET_NAME="curated_limr"
#DATASET_NAME="curated_open_r1"
#DATASET_NAME="curated_thoughts"
## Ablation
#DATASET_NAME="curated_limr_large_lr_ablation"
#DATASET_NAME="curated_limr_small_lr_ablation"
#DATASET_NAME="curated_limr_large_rank_ablation"
#DATASET_NAME="curated_limr_medium_rank_ablation"
#DATASET_NAME="curated_limr_small_rank_ablation"
#DATASET_NAME="curated_limr_tiny_rank_ablation"
#DATASET_NAME="curated_open_rs3_drgrpo_ablation"
CKPT_LIST=$(ls "${CKPT_DIR}/models/${MODEL_NAME}/${PT_TYPE}_${DATASET_NAME}" | grep -E "^checkpoint-[0-9]+$")
#CKPT_LIST=("checkpoint-XXX")
# loop over all the checkpoints in the list
for CKPT in "${CKPT_LIST[@]}"; do
echo "Running model post train merging base and adapter for checkpoint: ${CKPT}"
python ./scripts/training/run_post_train_merge.py \
--model_name "${MODEL_NAME}" \
--adapter_type "${PT_TYPE}_${DATASET_NAME}" \
--ckpt "${CKPT}"
MODEL_PATH="${CKPT_DIR}/models/${MODEL_NAME}/${PT_TYPE}_${DATASET_NAME}/${CKPT}-merged"
# Set model arguments (ensure that MODEL_PATH, GPU_COUNT, OUTPUT_DIR, and MODEL are defined)
MODEL_ARGS="pretrained=$MODEL_PATH,dtype=bfloat16,data_parallel_size=$GPU_COUNT,max_model_length=32768,gpu_memory_utilization=0.5,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
# Define an array of tasks to evaluate
tasks=("aime24" "math_500" "gpqa:diamond" "aime25" "amc23" "minerva")
# Loop over each task and evaluate
for TASK in "${tasks[@]}"; do
echo "Evaluating task: $TASK"
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
--custom-tasks ./scripts/training/run_post_train_eval.py \
--use-chat-template \
--output-dir "${OUTPUT_DIR}/${MODEL}/${TASK}"
done
done
echo "END TIME: $(date)"
echo "DONE"

View File

@@ -1,68 +0,0 @@
#!/bin/bash
MAMBA_ENV="tina"
eval "$(mamba shell hook --shell bash)" && mamba activate "${MAMBA_ENV}"
echo "START TIME: $(date)"
echo "PYTHON ENV: $(which python)"
source "./scripts/set/set_vars.sh"
export CUDA_VISIBLE_DEVICES=0,1,2 # Set the GPUs you want to use
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
NUM_PROCESSES_TRAINING=$((GPU_COUNT - 1))
echo ""
echo "Number of GPUs: ${GPU_COUNT}"
echo "Number of processes for training: ${NUM_PROCESSES_TRAINING}"
echo ""
MODEL_NAME="Qwen2.5-3B-Instruct"
## Main datasets
#DATASET_NAME="curated_deepscaler"
#DATASET_NAME="curated_still"
#DATASET_NAME="curated_open_rs3"
#DATASET_NAME="curated_open_rs2"
#DATASET_NAME="curated_open_rs1"
## Extra datasets
#DATASET_NAME="curated_limr"
#DATASET_NAME="curated_open_r1"
#DATASET_NAME="curated_thoughts"
## Ablation
#DATASET_NAME="curated_limr_large_lr_ablation"
#DATASET_NAME="curated_limr_small_lr_ablation"
#DATASET_NAME="curated_limr_large_rank_ablation"
#DATASET_NAME="curated_limr_medium_rank_ablation"
#DATASET_NAME="curated_limr_small_rank_ablation"
#DATASET_NAME="curated_limr_tiny_rank_ablation"
#DATASET_NAME="curated_open_rs3_drgrpo_ablation"
## Reasoning Gym
DATASET_NAME="curated_rg_math"
PY_SCRIPT="./tina/post_train_hf/grpo.py"
PY_CONFIG="./recipes/${MODEL_NAME}/grpo/model_${DATASET_NAME}.yaml"
ACCELERATE_DS_CONFIG="./recipes/accelerate_ds_cfgs/ds_zero2.yaml"
echo ""
echo "Running ${PY_SCRIPT} on model ${MODEL_NAME} with dataset ${DATASET_NAME}"
echo ""
if [[ "${DATASET_NAME}" == "curated_thoughts" || "${DATASET_NAME}" == "curated_open_r1" || "${DATASET_NAME}" == "curated_open_rs3" || "${DATASET_NAME}" == "curated_open_rs3_drgrpo_ablation" ]]; then
ACCELERATE_LOG_LEVEL=info accelerate launch \
--config_file "${ACCELERATE_DS_CONFIG}" \
--main_process_port=29500 \
--num_processes="${NUM_PROCESSES_TRAINING}" "${PY_SCRIPT}" --config "${PY_CONFIG}" --cosine_max_len 3584
else
ACCELERATE_LOG_LEVEL=info accelerate launch \
--config_file "${ACCELERATE_DS_CONFIG}" \
--main_process_port=29500 \
--num_processes="${NUM_PROCESSES_TRAINING}" "${PY_SCRIPT}" --config "${PY_CONFIG}"
fi
echo "END TIME: $(date)"
echo "DONE"

View File

@@ -1,229 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom evaluation tasks for LightEval."""
import numpy as np
from lighteval.metrics.dynamic_metrics import (
ExprExtractionConfig,
LatexExtractionConfig,
compare_gold_target,
extract_target_from_pred,
get_extraction_regexes,
)
from lighteval.metrics.metrics import Metrics
from lighteval.metrics.metrics_sample import PassAtK
from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase, SampleLevelMetric
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language
# Prompt template adapted from
# - simple-evals: https://github.com/openai/simple-evals/blob/6e84f4e2aed6b60f6a0c7b8f06bbbf4bfde72e58/math_eval.py#L17
# - Llama 3: https://huggingface.co/datasets/meta-llama/Llama-3.2-1B-Instruct-evals/viewer/Llama-3.2-1B-Instruct-evals__math__details?views%5B%5D=llama_32_1b_instruct_evals__math__details
# Note that it is important to have the final answer in a box for math-verify to work correctly
MATH_QUERY_TEMPLATE = """
Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering.
{Question}
""".strip()
math_pass_at_1_2n = SampleLevelMetric(
metric_name="math_pass@1:2_samples",
sample_level_fn=PassAtK(
k=1,
n=2,
strip_strings=True,
# Extracting mathematical expressions and latex expressions
normalize_gold=lambda k: extract_target_from_pred(
k,
get_extraction_regexes(
formatted_doc=None,
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
language=Language.ENGLISH,
),
),
# Extracting mathematical expressions and latex expressions
normalize_pred=lambda k: extract_target_from_pred(
k,
get_extraction_regexes(
formatted_doc=None,
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
language=Language.ENGLISH,
),
),
# Uses sympy for comparision
sample_scoring_function=compare_gold_target,
).compute,
category=MetricCategory.GENERATIVE_SAMPLING,
use_case=MetricUseCase.REASONING,
corpus_level_fn=np.mean,
higher_is_better=True,
)
def math_prompt_fn(line, task_name: str = None):
return Doc(
task_name=task_name,
query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
choices=[line["solution"]],
gold_index=0,
)
def aime_prompt_fn(line, task_name: str = None):
return Doc(
task_name=task_name,
query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
choices=[line["answer"]],
gold_index=0,
)
def amc_prompt_fn(line, task_name: str = None):
return Doc(
task_name=task_name,
query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
choices=[line["answer"]],
gold_index=0,
)
def minerva_prompt_fn(line, task_name: str = None):
return Doc(
task_name=task_name,
query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
choices=[line["solution"]],
gold_index=0,
)
# Define tasks
aime24 = LightevalTaskConfig(
name="aime24",
suite=["custom"],
prompt_function=aime_prompt_fn,
hf_repo="HuggingFaceH4/aime_2024",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[
# Metrics.math_pass_at_1_1n,
# math_pass_at_1_2n,
# Metrics.math_pass_at_1_4n,
# Metrics.math_pass_at_1_16n,
Metrics.math_pass_at_1_32n,
# Metrics.math_pass_at_1_64n,
],
version=1,
)
aime25 = LightevalTaskConfig(
name="aime25",
suite=["custom"],
prompt_function=aime_prompt_fn,
hf_repo="yentinglin/aime_2025",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[
# Metrics.math_pass_at_1_1n,
# math_pass_at_1_2n,
# Metrics.math_pass_at_1_4n,
# Metrics.math_pass_at_1_16n,
Metrics.math_pass_at_1_32n,
# Metrics.math_pass_at_1_64n,
],
version=1,
)
amc23 = LightevalTaskConfig(
name="amc23",
suite=["custom"],
prompt_function=amc_prompt_fn,
hf_repo="knoveleng/AMC-23",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[
# Metrics.math_pass_at_1_1n,
# math_pass_at_1_2n,
# Metrics.math_pass_at_1_4n,
# Metrics.math_pass_at_1_16n,
Metrics.math_pass_at_1_32n,
# Metrics.math_pass_at_1_64n,
],
version=1,
)
math_500 = LightevalTaskConfig(
name="math_500",
suite=["custom"],
prompt_function=math_prompt_fn,
hf_repo="HuggingFaceH4/MATH-500",
hf_subset="default",
hf_avail_splits=["test"],
evaluation_splits=["test"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[
# Metrics.math_pass_at_1_1n,
math_pass_at_1_2n,
],
version=1,
)
minerva = LightevalTaskConfig(
name="minerva",
suite=["custom"],
prompt_function=minerva_prompt_fn,
hf_repo="knoveleng/Minerva-Math",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[
# Metrics.math_pass_at_1_1n,
# math_pass_at_1_2n,
Metrics.math_pass_at_1_4n,
],
version=1,
)
# Add tasks to the table
TASKS_TABLE = []
TASKS_TABLE.append(aime24)
TASKS_TABLE.append(aime25)
TASKS_TABLE.append(amc23)
TASKS_TABLE.append(math_500)
TASKS_TABLE.append(minerva)
# MODULE LOGIC
if __name__ == "__main__":
print([t["name"] for t in TASKS_TABLE])
print(len(TASKS_TABLE))

View File

@@ -1,40 +0,0 @@
import argparse
import os
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
def argparser():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="DeepSeek-R1-Distill-Qwen-1.5B")
parser.add_argument("--adapter_type", type=str, default="grpo_curated_open_r1")
parser.add_argument("--ckpt", type=str, default="checkpoint-500")
return parser.parse_args()
if __name__ == "__main__":
args = argparser()
ckpt_dir = os.environ["CKPT_DIR"]
ckpt = args.ckpt
adapter_type = args.adapter_type
model_name = args.model_name
base_model_name_or_path = f"{ckpt_dir}/models/{model_name}/base"
adapter_model_name_or_path = f"{ckpt_dir}/models/{model_name}/{adapter_type}/{ckpt}"
merged_model_name_or_path = f"{ckpt_dir}/models/{model_name}/{adapter_type}/{ckpt}-merged"
print("Merged model will be saved to: ", merged_model_name_or_path)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path, torch_dtype=torch.bfloat16, device_map="auto"
) # Automatically distributes across available GPUs
model = PeftModel.from_pretrained(base_model, adapter_model_name_or_path)
model = model.merge_and_unload()
model.save_pretrained(merged_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)
tokenizer.save_pretrained(merged_model_name_or_path)

View File

@@ -1,21 +0,0 @@
from dataclasses import dataclass, field
from typing import Literal
# check ./recipes/MODEL_NAME/PT_METHOD/model_DATASET.yaml
@dataclass
class ModelPTConfig:
# //*******Model post-training configs*******//
model_post_train_type: Literal["grpo", "sft"] = field(default="grpo")
model_post_train_dataset_name: str = field(default="curated_deepscaler")
model_post_train_dataset_config: str | None = field(default=None)
rl_post_train_reward_funcs: list[str] = field(default_factory=lambda: ["accuracy", "format"])
rl_post_train_reward_weights: list[str] = field(default_factory=lambda: [2.0, 1.0])
cosine_min_value_wrong: float = field(default=0.0)
cosine_max_value_wrong: float = field(default=-0.5)
cosine_min_value_correct: float = field(default=0.5)
cosine_max_value_correct: float = field(default=1.0)
cosine_max_len: int = field(default=1000)
repetition_n_grams: int = field(default=3)
repetition_max_penalty: float = field(default=-1.0)

View File

@@ -1,167 +0,0 @@
import copy
import logging
import shutil
import numpy as np
import pandas as pd
import torch
import wandb
from tina.post_train_hf.hub import push_to_hub_revision
from tina.utils.prompt import FIXED_PROMPT_FOR_EVALUATION, OPEN_R1_SYSTEM_PROMPT
from transformers import TrainerCallback
logger = logging.getLogger(__name__)
class FixedPromptEvaluationCallback(TrainerCallback):
def __init__(
self,
system_prompt=OPEN_R1_SYSTEM_PROMPT,
prompt=FIXED_PROMPT_FOR_EVALUATION,
max_generation_length=4096,
eval_steps=100,
):
self.system_prompt = system_prompt
self.prompt = prompt
self.max_generation_length = max_generation_length
self.eval_steps = eval_steps
self.completion_table = {
"step": [],
"prompt": [],
"completion": [],
}
def on_init_end(self, args, state, control, processing_class=None, **kwargs):
tokenizer = processing_class
messages = [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": self.prompt}]
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
self.tokenized_prompt = tokenizer(input_text, return_tensors="pt")
def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):
if state.global_step % self.eval_steps == 0:
if state.is_world_process_zero:
completion = self.eval_prompt(model, processing_class)
self.completion_table["step"].append(str(state.global_step))
self.completion_table["prompt"].append(self.prompt)
self.completion_table["completion"].append(completion)
df = pd.DataFrame(self.completion_table)
wandb.log({"completions": wandb.Table(dataframe=df)})
def eval_prompt(self, model, tokenizer):
if hasattr(model, "peft_config"):
model.peft_config["default"].inference_mode = True
self.tokenized_prompt.to(model.device)
outputs = model.generate(
**self.tokenized_prompt,
max_length=self.max_generation_length,
temperature=0.01, # Very low temperature
top_k=1, # Only consider the most likely token
top_p=1.0, # Disable nucleus sampling or set to a high value
)
completion = tokenizer.decode(outputs[0], skip_special_tokens=True)
if hasattr(model, "peft_config"):
model.peft_config["default"].inference_mode = False
return completion
class GradientClippingLoggerCallback(TrainerCallback):
def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):
self.clipped_grad_norm = np.sqrt(
sum(p.grad.data.norm(2).item() ** 2 for p in model.parameters() if p.grad is not None)
)
if state.is_world_process_zero:
wandb.log({"clipped_grad_norm": self.clipped_grad_norm})
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is not None:
logs["clipped_grad_norm"] = self.clipped_grad_norm
class DummyConfig:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class PushToHubRevisionCallback(TrainerCallback):
def __init__(self, dataset_name, use_peft):
self.dataset_name = dataset_name
self.use_peft = use_peft
self.pending_futures = [] # Track pending push operations
def on_save(self, args, state, control, **kwargs):
if state.is_world_process_zero:
global_step = state.global_step
# Create merged model directory
if self.use_peft:
ckpt_model_dir = f"{args.output_dir}/checkpoint-{global_step}-merged"
original_model = kwargs["model"] # Don't pop it, keep it intact
model_to_save = copy.deepcopy(original_model).merge_and_unload()
model_to_save.save_pretrained(ckpt_model_dir)
else:
# this dir is already created by the HF Trainer, no need to manually save
ckpt_model_dir = f"{args.output_dir}/checkpoint-{global_step}"
tokenizer = kwargs.get("tokenizer") or kwargs.get("processing_class")
if tokenizer is None:
raise ValueError("Tokenizer or processing_class must be provided.")
tokenizer.save_pretrained(ckpt_model_dir)
dummy_config = DummyConfig(
hub_model_id=args.hub_model_id,
hub_model_revision=self.dataset_name,
checkpoint=f"checkpoint-{global_step}",
output_dir=ckpt_model_dir,
dataset_name=self.dataset_name,
)
# Start the push operation
future = push_to_hub_revision(dummy_config, extra_ignore_patterns=["*.pt"])
# Store the future and directory path for cleanup later
self.pending_futures.append((future, ckpt_model_dir))
# Check and clean up any completed pushes
if self.use_peft:
self._cleanup_completed_pushes()
return control
def _cleanup_completed_pushes(self):
"""Check pending futures and remove directories for completed pushes."""
still_pending = []
for future, dir_path in self.pending_futures:
if future.done():
if self.use_peft:
# The push is complete, safe to delete the directory
try:
shutil.rmtree(dir_path)
logger.info(f"\nCleaned up merged model directory: {dir_path}\n")
except Exception as e:
logger.error(f"\nFailed to clean up directory {dir_path}: {e}\n")
else:
# Push is still in progress, keep in pending list
still_pending.append((future, dir_path))
self.pending_futures = still_pending
def on_train_end(self, args, state, control, **kwargs):
"""Make sure to clean up any remaining directories at the end of training."""
if state.is_world_process_zero and self.use_peft:
# Wait for all pending pushes to complete
logger.info(f"\nCleaned up for lora models.")
for future, dir_path in self.pending_futures:
future.result() # Wait for completion
try:
shutil.rmtree(dir_path)
logger.info(f"\nCleaned up merged model directory: {dir_path}\n")
except Exception as e:
logger.error(f"\nFailed to clean up directory {dir_path}: {e}\n")
self.pending_futures = []

View File

@@ -1,245 +0,0 @@
import logging
import os
import sys
from datetime import datetime
import datasets
import torch
import transformers
from datasets import Dataset, load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from tina.config import ModelPTConfig
from tina.post_train_hf.callback import (
FixedPromptEvaluationCallback,
GradientClippingLoggerCallback,
PushToHubRevisionCallback,
)
from tina.post_train_hf.grpo_config import GRPOConfig # use this new one for Dr.GRPO
from tina.post_train_hf.grpo_trainer import GRPOTrainer # use this new one for Dr.GRPO
from tina.post_train_hf.preprocess import make_conv_for_grpo
from tina.post_train_hf.rewards import (
accuracy_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
len_reward,
reasoning_steps_reward,
tag_count_reward,
)
from tina.utils.chat_template import DEFAULT_CHAT_TEMPLATE, REASON_CHAT_TEMPLATE
from tina.utils.constant import RL_POST_TRAIN_DATASET_MAP
from tina.utils.prompt import OPEN_R1_SYSTEM_PROMPT, OPEN_RS_SYSTEM_PROMPT
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from transformers.trainer_utils import get_last_checkpoint
from trl import ModelConfig, TrlParser # GRPOTrainer, GRPOConfig
def main():
parser = TrlParser((ModelPTConfig, GRPOConfig, ModelConfig))
pt_args, training_args, model_args = parser.parse_args_and_config()
set_seed(training_args.seed)
os.environ["WANDB_PROJECT"] = "tina_model_post_training"
################
# Set up logging
################
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process a small summary
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Post training parameters {pt_args}")
logger.info(f"Training parameters {training_args}")
#####################
# Set up output paths
#####################
current_time = datetime.now()
formatted_datetime = current_time.strftime("%Y_%m_%d_%H_%M_%S")
model_name_or_path = model_args.model_name_or_path
ckpt_dir = os.environ["CKPT_DIR"]
ckpt_prefix = f"{ckpt_dir}/models/{model_name_or_path}"
if model_args.use_peft:
ckpt_postfix = f"{pt_args.model_post_train_type}_{pt_args.model_post_train_dataset_name}"
else:
ckpt_postfix = f"full_{pt_args.model_post_train_type}_{pt_args.model_post_train_dataset_name}"
model_args.model_name_or_path = f"{ckpt_prefix}/base"
training_args.output_dir = f"{ckpt_prefix}/{ckpt_postfix}"
# training_args.hub_model_id = f"{training_args.hub_model_id}_{ckpt_postfix}"
training_args.run_name = f"{model_name_or_path}_{ckpt_postfix}_{formatted_datetime}"
training_args.hub_model_id = f"{training_args.hub_model_id}/{model_name_or_path}"
#######################################################################
# Load and preprocess dataset (tokenization is handled by GRPO Trainer)
#######################################################################
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
if "Llama" in model_args.model_name_or_path:
tokenizer.pad_token = "<|finetune_right_pad_id|>"
elif "Qwen" in model_args.model_name_or_path:
tokenizer.pad_token = "<|fim_pad|>"
tokenizer.chat_template = REASON_CHAT_TEMPLATE
model_post_train_dataset_name = RL_POST_TRAIN_DATASET_MAP[pt_args.model_post_train_dataset_name]
if pt_args.model_post_train_dataset_config is not None:
train_dataset = load_dataset(
model_post_train_dataset_name, split="train", name=pt_args.model_post_train_dataset_config
)
else:
train_dataset = load_dataset(model_post_train_dataset_name, split="train")
# required by GRPOTrainer: (prompt, solution) columns
if "solution" not in train_dataset.column_names and "answer" in train_dataset.column_names:
train_dataset = train_dataset.rename_column("answer", "solution")
# Wrap the 'solution' values in $...$
def wrap_in_math(example):
return {"solution": f"${example['solution']}$"}
# Apply the transformation to the entire dataset
train_dataset = train_dataset.map(wrap_in_math)
if "problem" not in train_dataset.column_names and "question" in train_dataset.column_names:
train_dataset = train_dataset.rename_column("question", "problem")
if "problem" not in train_dataset.column_names and "prompt" in train_dataset.column_names:
train_dataset = train_dataset.rename_column("prompt", "problem")
if "messages" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("messages")
# handle deepscaler separately
if "deepscaler" in pt_args.model_post_train_dataset_name:
train_dataset = train_dataset.rename_column("solution", "solution_archive")
train_dataset = train_dataset.rename_column("answer", "solution")
# Wrap the 'solution' values in $...$
def wrap_in_math(example):
return {"solution": f"${example['solution']}$"}
# Apply the transformation to the entire dataset
train_dataset = train_dataset.map(wrap_in_math)
SYSTEM_PROMPT = OPEN_RS_SYSTEM_PROMPT if "open-rs" in model_post_train_dataset_name else OPEN_R1_SYSTEM_PROMPT
train_dataset = train_dataset.map(make_conv_for_grpo, fn_kwargs={"system_prompt": SYSTEM_PROMPT})
######################
# Initialize the model
######################
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
torch_dtype=torch.bfloat16,
attn_implementation=model_args.attn_implementation,
use_cache=False if training_args.gradient_checkpointing else True,
)
if model_args.use_peft:
logger.info(
f"\n Using PEFT with {model_args.lora_r} rank, {model_args.lora_alpha} alpha, {model_args.lora_dropout} dropout."
)
peft_config = LoraConfig(
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
target_modules=model_args.lora_target_modules,
inference_mode=False,
bias="none",
task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, peft_config)
#############################
# Initialize the GRPO trainer
#############################
RL_POST_TRAIN_REWARD_MAP = {
"accuracy": accuracy_reward,
"format": format_reward,
"tag_count": tag_count_reward,
"length": len_reward,
"reasoning_steps": reasoning_steps_reward,
"cosine": get_cosine_scaled_reward(
min_value_wrong=pt_args.cosine_min_value_wrong,
max_value_wrong=pt_args.cosine_max_value_wrong,
min_value_correct=pt_args.cosine_min_value_correct,
max_value_correct=pt_args.cosine_max_value_correct,
max_len=pt_args.cosine_max_len,
),
"repetition_penalty": get_repetition_penalty_reward(
ngram_size=pt_args.repetition_n_grams,
max_penalty=pt_args.repetition_max_penalty,
),
}
rl_reward_funcs = [RL_POST_TRAIN_REWARD_MAP[func] for func in pt_args.rl_post_train_reward_funcs]
training_args.reward_weights = pt_args.rl_post_train_reward_weights
if model_args.use_peft:
callbacks = [
FixedPromptEvaluationCallback(system_prompt=OPEN_R1_SYSTEM_PROMPT, eval_steps=training_args.save_steps),
# PushToHubRevisionCallback(dataset_name=pt_args.model_post_train_dataset_name, use_peft=model_args.use_peft)
]
else:
callbacks = [
FixedPromptEvaluationCallback(system_prompt=OPEN_R1_SYSTEM_PROMPT, eval_steps=training_args.save_steps),
GradientClippingLoggerCallback(),
# PushToHubRevisionCallback(dataset_name=pt_args.model_post_train_dataset_name, use_peft=model_args.use_peft)
]
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=rl_reward_funcs,
args=training_args,
train_dataset=train_dataset,
callbacks=callbacks,
)
#########################
# Training and Evaluation
#########################
logger.info(f"\nStarting training for {training_args.num_train_epochs} epochs.")
# Check for last checkpoint
ckpt = None
if training_args.resume_from_checkpoint is not None:
ckpt = training_args.resume_from_checkpoint
elif os.path.isdir(training_args.output_dir):
ckpt = get_last_checkpoint(training_args.output_dir)
if ckpt:
logger.info(f"\nCheckpoint detected, resuming training at {ckpt=}.")
else:
logger.info("\nNo checkpoint detected, starting training from scratch.")
train_result = trainer.train(resume_from_checkpoint=ckpt)
train_metrics = train_result.metrics
trainer.log_metrics("train", train_metrics)
trainer.save_metrics("train", train_metrics)
trainer.save_state()
trainer.push_to_hub(
commit_message=f"Add checkpoint {training_args.max_steps} post-trained on {pt_args.model_post_train_dataset_name}"
)
del trainer
torch.cuda.empty_cache()
if __name__ == "__main__":
main()

View File

@@ -1,260 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
@dataclass
class GRPOConfig(TrainingArguments):
r"""
Configuration class for the [`GRPOTrainer`].
Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
[`~transformers.TrainingArguments`] documentation.
Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
> Parameters that control the model and reference model
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
argument of the [`GRPOTrainer`] is provided as a string.
> Parameters that control the data preprocessing
remove_unused_columns (`bool`, *optional*, defaults to `False`):
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
num_generations (`int` or `None`, *optional*, defaults to `8`):
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
must be divisible by this value.
temperature (`float`, *optional*, defaults to `0.9`):
Temperature for sampling. The higher the temperature, the more random the completions.
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
Maximum length of the generated completion.
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
improving generation speed. However, disabling this option allows training models that exceed the VRAM
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
with vLLM generation.
> Parameters that control generation acceleration powered by vLLM
use_vllm (`bool`, *optional*, defaults to `False`):
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
vllm_device (`str`, *optional*, defaults to `"auto"`):
Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
automatically select the next available GPU after the last one used for training. This assumes that
training has not already occupied all available GPUs. If only one device is available, the device will be
shared between both training and vLLM.
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
during initialization.
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
based on the model configuration. Find the supported values in the vLLM documentation.
vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
context size, which might be much larger than the KV cache, leading to inefficiencies.
> Parameters that control the training
learning_rate (`float`, *optional*, defaults to `1e-6`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient.
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
weighted equally with weight `1.0`.
sync_ref_model (`bool`, *optional*, defaults to `False`):
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`):
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
between the current policy and the previous reference policy during updates. The reference policy is
updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
must set `sync_ref_model=True`.
ref_model_sync_steps (`int`, *optional*, defaults to `64`):
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
set `sync_ref_model=True`.
> Parameters that control the logging
log_completions (`bool`, *optional*, defaults to `False`):
Whether to log the completions during training.
"""
# Parameters that control the model and reference model
model_init_kwargs: Optional[dict] = field(
default=None,
metadata={
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
"argument of the `GRPOTrainer` is provided as a string."
},
)
# Parameters that control the data preprocessing
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
# additional columns to compute the reward
remove_unused_columns: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
"that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
},
)
max_prompt_length: Optional[int] = field(
default=512,
metadata={
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
},
)
num_generations: Optional[int] = field(
default=8,
metadata={
"help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
"must be divisible by this value."
},
)
temperature: Optional[float] = field(
default=0.9,
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
)
max_completion_length: Optional[int] = field(
default=256,
metadata={"help": "Maximum length of the generated completion."},
)
ds3_gather_for_generation: bool = field(
default=True,
metadata={
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
"generation, improving generation speed. However, disabling this option allows training models that "
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
"is not compatible with vLLM generation."
},
)
# Parameters that control generation acceleration powered by vLLM
use_vllm: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
"unused for training, as vLLM will require one for generation. vLLM must be installed "
"(`pip install vllm`)."
},
)
vllm_device: Optional[str] = field(
default="auto",
metadata={
"help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
"will automatically select the next available GPU after the last one used for training. This assumes "
"that training has not already occupied all available GPUs."
},
)
vllm_gpu_memory_utilization: float = field(
default=0.9,
metadata={
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
"out-of-memory (OOM) errors during initialization."
},
)
vllm_dtype: Optional[str] = field(
default="auto",
metadata={
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
"determined based on the model configuration. Find the supported values in the vLLM documentation."
},
)
vllm_max_model_len: Optional[int] = field(
default=None,
metadata={
"help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
"context size, which might be much larger than the KV cache, leading to inefficiencies."
},
)
# Parameters that control the training
learning_rate: float = field(
default=1e-6,
metadata={
"help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
"`transformers.TrainingArguments`."
},
)
beta: float = field(
default=0.04,
metadata={"help": "KL coefficient."},
)
scale_rewards: bool = field(
default=True,
metadata={
"help": "Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), "
"the rewards are normalized by the standard deviation, ensuring they have unit variance. If `False`, no "
"scaling is applied. The Dr. GRPO paper recommends not scaling the rewards, as scaling by the standard "
"deviation introduces a question-level difficulty bias."
},
)
reward_weights: Optional[list[float]] = field(
default=None,
metadata={
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
"rewards are weighted equally with weight `1.0`."
},
)
sync_ref_model: bool = field(
default=False,
metadata={
"help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
"steps, using the `ref_model_mixup_alpha` parameter."
},
)
ref_model_mixup_alpha: float = field(
default=0.9,
metadata={
"help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
"previous reference policy during updates. The reference policy is updated according to the equation: "
"`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
},
)
ref_model_sync_steps: int = field(
default=64,
metadata={
"help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
"synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
},
)
# Parameters that control the logging
log_completions: bool = field(
default=False,
metadata={"help": "Whether to log the completions during training."},
)

View File

@@ -1,823 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import textwrap
import warnings
from collections import defaultdict
from typing import Any, Callable, Optional, Sized, Union
from unittest.mock import patch
import torch
import torch.utils.data
import transformers
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
from accelerate.utils.other import is_compiled_module
from datasets import Dataset, IterableDataset
from packaging import version
from tina.post_train_hf.grpo_config import GRPOConfig
from torch import nn
from torch.utils.data import Sampler
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from trl.import_utils import is_vllm_available
from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from trl.trainer.callbacks import SyncRefModelCallback
from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad, selective_log_softmax
if is_peft_available():
from peft import PeftConfig, get_peft_model
if is_vllm_available():
from vllm import LLM, SamplingParams
if is_wandb_available():
import wandb
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
class RepeatRandomSampler(Sampler):
"""
Sampler that repeats the indices of a dataset N times.
Args:
data_source (`Sized`):
Dataset to sample from.
repeat_count (`int`):
Number of times to repeat each index.
seed (`Optional[int]`):
Random seed for reproducibility (only affects this sampler).
Example:
```python
>>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
>>> list(sampler)
[2, 2, 0, 0, 3, 3, 1, 1]
```
"""
def __init__(self, data_source: Sized, repeat_count: int, seed: Optional[int] = None):
self.data_source = data_source
self.repeat_count = repeat_count
self.num_samples = len(data_source)
self.seed = seed
self.generator = torch.Generator() # Create a local random generator
if seed is not None:
self.generator.manual_seed(seed)
def __iter__(self):
indexes = [
idx
for idx in torch.randperm(self.num_samples, generator=self.generator).tolist()
for _ in range(self.repeat_count)
]
return iter(indexes)
def __len__(self):
return self.num_samples * self.repeat_count
class GRPOTrainer(Trainer):
"""
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
Example:
```python
from datasets import load_dataset
from trl import GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
def reward_func(completions, **kwargs):
# Dummy reward function that rewards completions with more unique letters.
return [float(len(set(completion))) for completion in completions]
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_func,
train_dataset=dataset,
)
trainer.train()
```
Args:
model (`Union[str, PreTrainedModel]`):
Model to be trained. Can be either:
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
a path to a *directory* containing model weights saved using
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
in `args.model_init_kwargs`.
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
functions with the prompts and completions and sum the rewards. Can be either:
- A single reward function, such as:
- A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
path to a *directory* containing model weights saved using
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
keyword arguments in `args.model_init_kwargs`.
- A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
- A custom reward function: The function is provided with the prompts and the generated completions,
plus any additional columns in the dataset. It should return a list of rewards. For more details, see
[Using a custom reward function](#using-a-custom-reward-function).
- A list of reward functions, where each item can independently be any of the above types. Mixing different
types within the list (e.g., a string model ID and a custom reward function) is allowed.
args ([`GRPOConfig`], *optional*, defaults to `None`):
Configuration for this trainer. If `None`, a default configuration is used.
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
ignored. The format of the samples can be either:
- [Standard](dataset_formats#standard): Each sample contains plain text.
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
and content).
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
Processing class used to process the data. The padding side must be set to "left". If `None`, the
processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
- A single processing class: Used when `reward_funcs` contains only one reward function.
- A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
`None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
the corresponding entries in `reward_processing_classes` are ignored.
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
List of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
method.
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
"""
_tag_names = ["trl", "grpo"]
def __init__(
self,
model: Union[str, PreTrainedModel],
reward_funcs: Union[RewardFunc, list[RewardFunc]],
args: GRPOConfig = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
processing_class: Optional[PreTrainedTokenizerBase] = None,
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
peft_config: Optional["PeftConfig"] = None,
):
# Args
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
model_name = model_name.split("/")[-1]
args = GRPOConfig(f"{model_name}-GRPO")
# Models
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
if isinstance(model, str):
model_id = model
torch_dtype = model_init_kwargs.get("torch_dtype")
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
pass # torch_dtype is already a torch.dtype or "auto" or None
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
torch_dtype = getattr(torch, torch_dtype)
model_init_kwargs["torch_dtype"] = torch_dtype
else:
raise ValueError(
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
model_init_kwargs["use_cache"] = (
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
else:
model_id = model.config._name_or_path
if args.model_init_kwargs is not None:
raise ValueError(
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
"This argument can only be used when the `model` argument is a string."
)
if peft_config is not None:
model = get_peft_model(model, peft_config)
# Reference model
if is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
elif not is_peft_model(model):
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
else:
# If PEFT is used, the reference model is not needed since the adapter can be disabled
# to revert to the initial model.
self.ref_model = None
# Processing class
if processing_class is None:
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
# Reward functions
if not isinstance(reward_funcs, list):
reward_funcs = [reward_funcs]
for i, reward_func in enumerate(reward_funcs):
if isinstance(reward_func, str):
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
reward_func, num_labels=1, **model_init_kwargs
)
self.reward_funcs = reward_funcs
# Reward weights
if args.reward_weights is not None:
if len(args.reward_weights) != len(reward_funcs):
raise ValueError(
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
f"functions ({len(reward_funcs)})"
)
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
else:
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
# Reward processing class
if reward_processing_classes is None:
reward_processing_classes = [None] * len(reward_funcs)
elif not isinstance(reward_processing_classes, list):
reward_processing_classes = [reward_processing_classes]
else:
if len(reward_processing_classes) != len(reward_funcs):
raise ValueError("The number of reward processing classes must match the number of reward functions.")
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
if isinstance(reward_func, PreTrainedModel):
if reward_processing_class is None:
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
if reward_processing_class.pad_token_id is None:
reward_processing_class.pad_token = reward_processing_class.eos_token
# The reward model computes the reward for the latest non-padded token in the input sequence.
# So it's important to set the pad token ID to the padding token ID of the processing class.
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
reward_processing_classes[i] = reward_processing_class
self.reward_processing_classes = reward_processing_classes
# Data collator
def data_collator(features): # No data collation is needed in GRPO
return features
# Training arguments
self.max_prompt_length = args.max_prompt_length
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
self.num_generations = args.num_generations # = G in the GRPO paper
self.use_vllm = args.use_vllm
self.beta = args.beta
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
# This acts as a flag to indicate that the warning has already been issued.
model.warnings_issued["estimate_tokens"] = True
# Initialize the metrics
self._metrics = defaultdict(list)
self.log_completions = args.log_completions
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
callbacks=callbacks,
optimizers=optimizers,
)
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
num_processes = self.accelerator.num_processes
global_batch_size = args.per_device_train_batch_size * num_processes
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
if self.num_generations not in possible_values:
raise ValueError(
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
f"batch size, the valid values for the number of generations are: {possible_values}."
)
if self.args.eval_strategy != "no":
global_batch_size = args.per_device_eval_batch_size * num_processes
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
if self.num_generations not in possible_values:
raise ValueError(
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
f"eval batch size, the valid values for the number of generations are: {possible_values}."
)
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
# it's safer to set it in all cases.
set_seed(args.seed, device_specific=True)
if self.use_vllm:
if not is_vllm_available():
raise ImportError(
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
"`pip install vllm` to use it."
)
if self.accelerator.is_main_process:
vllm_device = self.args.vllm_device
if vllm_device == "auto":
if torch.cuda.device_count() == 1:
vllm_device = "cuda:0" # particular case when training with onyl 1 GPU: share it
else:
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
# Check that the requested device is available
if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
raise ValueError(
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
)
# Check that the requested device is not also used for training
if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}:
warnings.warn(
f"The requested device {vllm_device} is also being used for training. For higher throughput "
"and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. "
"If this is intentional, you may ignore this warning but should adjust "
"`vllm_gpu_memory_utilization` accordingly."
)
# vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
# model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
# setting (profiling_patch).
world_size_patch = patch("torch.distributed.get_world_size", return_value=1)
profiling_patch = patch(
"vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None
)
with world_size_patch, profiling_patch:
self.llm = LLM(
model=model.name_or_path,
device=vllm_device,
gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
dtype=self.args.vllm_dtype,
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=True,
max_model_len=self.args.vllm_max_model_len,
)
self.sampling_params = SamplingParams(
temperature=args.temperature,
max_tokens=self.max_completion_length,
)
self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
# synchronize all processes after vLLM has been fully initialized.
self.accelerator.wait_for_everyone()
else:
self.generation_config = GenerationConfig(
max_new_tokens=self.max_completion_length,
do_sample=True,
temperature=args.temperature,
pad_token_id=processing_class.pad_token_id,
)
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
# self.model_accepts_loss_kwargs to False to enable scaling.
self.model_accepts_loss_kwargs = False
# Add tags to the model
self.model.add_model_tags(self._tag_names)
if self.ref_model is not None:
if self.is_deepspeed_enabled:
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
if args.sync_ref_model:
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
if self._signature_columns is None:
self._signature_columns = ["prompt"]
def _get_train_sampler(self) -> Sampler:
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
# preventing discrepancies in group formation.
return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
def _get_eval_sampler(self, eval_dataset) -> Sampler:
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
# preventing discrepancies in group formation.
return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
# Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
def _move_model_to_vllm(self):
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
unwrapped_model = unwrapped_model._orig_mod
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items()
}
# Remove values with adapter prefix (example: "_lora")
state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
# When module to save, remove its prefix and discard the original module
state_dict = {
k.replace("modules_to_save.default.", ""): v
for k, v in state_dict.items()
if "original_module" not in k
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
# Unmerge the adapter to restore the model to its original state.
# This must be done after loading weights to ensure they correspond to the merged state.
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
if self.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
else:
completion_ids = [None] * len(all_prompts_text)
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]
# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
)
# Compute prompt length and extract completion ids
prompt_length = prompt_ids.size(1)
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:]
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps = self._get_per_token_logps(
self.model, prompt_completion_ids, attention_mask, logits_to_keep
)
# Decode the generated completions
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text):
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
completions.append([{"role": "assistant", "content": bootstrap + completion}])
else:
completions = completions_text
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
# Convert None values to NaN
output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
# If all reward functions return None for a given row, issue a detailed warning
if torch.isnan(rewards_per_func).all(dim=1).any():
nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()}
row_reward_kwargs["prompt"] = prompts[nan_row_idx]
row_reward_kwargs["completion"] = completions[nan_row_idx]
warnings.warn(
f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
"Please ensure that at least one reward function returns a valid reward."
)
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
# completions may be distributed across processes
rewards_per_func = gather(rewards_per_func)
# Apply weights to each reward function's output and sum
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = rewards - mean_grouped_rewards
if self.args.scale_rewards:
advantages = advantages / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
advantages = advantages[process_slice]
# Log the metrics
reward_per_func = rewards_per_func.mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
self._metrics["reward"].append(rewards.mean().item())
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
if (
self.log_completions
and self.state.global_step % self.args.logging_steps == 0
and "wandb" in self.args.report_to
):
import pandas as pd
# For logging
table = {
"step": [str(self.state.global_step)] * len(rewards),
"prompt": gather_object(prompts_text),
"completion": gather_object(completions_text),
"reward": rewards.tolist(),
}
df = pd.DataFrame(table)
if wandb.run is not None and self.accelerator.is_main_process:
wandb.log({"completions": wandb.Table(dataframe=df)})
return {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"ref_per_token_logps": ref_per_token_logps,
"advantages": advantages,
}
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
# Compute the per-token log probabilities for the model
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
# Compute the KL divergence between the model and the reference model
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# x - x.detach() allows for preserving gradients from x
advantages = inputs["advantages"]
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
if self.args.scale_rewards:
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
else:
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
return loss
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
inputs = self._prepare_inputs(inputs)
with torch.no_grad():
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
loss = loss.mean().detach()
return loss, None, None
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
if next(iter(logs.keys())).startswith("eval_"):
metrics = {f"eval_{key}": val for key, val in metrics.items()}
logs = {**logs, **metrics}
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
super().log(logs, start_time)
else: # transformers<=4.46
super().log(logs)
self._metrics.clear()
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str` or `None`, *optional*, defaults to `None`):
Name of the model.
dataset_name (`str` or `None`, *optional*, defaults to `None`):
Name of the dataset used for training.
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
tags = tags or []
if isinstance(tags, str):
tags = [tags]
if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")
citation = textwrap.dedent(
"""\
@article{zhihong2024deepseekmath,
title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
year = 2024,
eprint = {arXiv:2402.03300},
}
"""
)
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="GRPO",
trainer_citation=citation,
paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
paper_id="2402.03300",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))

View File

@@ -1,44 +0,0 @@
import logging
from concurrent.futures import Future
from huggingface_hub import create_branch, create_repo, list_repo_commits, upload_folder
logger = logging.getLogger(__name__)
def push_to_hub_revision(training_args, extra_ignore_patterns=[]) -> Future:
"""Pushes the model to branch on a Hub repo."""
# Create a repo if it doesn't exist yet
repo_url = create_repo(repo_id=training_args.hub_model_id, private=True, exist_ok=True)
# Get initial commit to branch from
initial_commit = list_repo_commits(training_args.hub_model_id)[-1]
# Now create the branch we'll be pushing to
create_branch(
repo_id=training_args.hub_model_id,
branch=training_args.hub_model_revision,
# checkpoint=training_args.checkpoint,
revision=initial_commit.commit_id,
exist_ok=True,
)
logger.info(f"Created target repo at {repo_url}")
logger.info(
f"Pushing to the Hub revision {training_args.hub_model_revision} with checkpoint {training_args.checkpoint}"
)
ignore_patterns = ["checkpoint-*", "*.pth"]
ignore_patterns.extend(extra_ignore_patterns)
future = upload_folder(
repo_id=training_args.hub_model_id,
folder_path=training_args.output_dir,
revision=training_args.hub_model_revision,
# commit_message=f"Add {training_args.hub_model_revision} checkpoint {training_args.dataset_name}",
commit_message=f"Add {training_args.checkpoint} checkpoint post-trained on {training_args.dataset_name}",
ignore_patterns=ignore_patterns,
run_as_future=True,
)
logger.info(
f"Pushed to {repo_url} revision {training_args.hub_model_revision} with checkpoint {training_args.checkpoint} successfully!"
)
return future

View File

@@ -1,7 +0,0 @@
def make_conv_for_grpo(example, system_prompt):
return {
"prompt": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": example["problem"]},
]
}

View File

@@ -1,302 +0,0 @@
import math
import re
from typing import Callable, Dict, Optional
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]:
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(
sol,
extraction_mode="first_match",
)
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Compute binary rewards if verifiable, `None` otherwise to skip this example
try:
reward = float(verify(gold_parsed, answer_parsed))
except Exception as e:
print(f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}")
reward = None
else:
# If the gold solution is not parseable, we assign `None` to skip this example
reward = None
print("Failed to parse gold solution: ", sol)
rewards.append(reward)
return rewards
def format_reward(completions, **kwargs):
"""Reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags."""
def count_tags(text: str) -> float:
count = 0.0
# We only count </think> tag, because <think> tag is available in system prompt
if text.count("\n</think>\n") == 1:
count += 1.0
return count
contents = [completion[0]["content"] for completion in completions]
return [count_tags(c) for c in contents]
def tag_count_reward(completions, **kwargs) -> list[float]:
"""Reward function that checks if we produce the desired number of think and answer tags associated with `format_reward()`.
Adapted from: https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb#file-grpo_demo-py-L90
"""
def count_tags(text: str) -> float:
count = 0.0
if re.search(r"\s*<think>\s*", text):
count += 0.25
if re.search(r"\s*</think>\s*", text):
count += 0.25
if re.search(r"\s*<answer>\s*", text):
count += 0.25
if re.search(r"\s*</answer>\s*", text):
count += 0.25
return count
contents = [completion[0]["content"] for completion in completions]
return [count_tags(c) for c in contents]
def reasoning_steps_reward(completions, **kwargs):
r"""Reward function that checks for clear step-by-step reasoning.
Regex pattern:
Step \d+: - matches "Step 1:", "Step 2:", etc.
^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line
\n- - matches bullet points with hyphens
\n\* - matches bullet points with asterisks
First,|Second,|Next,|Finally, - matches transition words
"""
pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [len(re.findall(pattern, content)) for content in completion_contents]
# Magic number 3 to encourage 3 steps and more, otherwise partial reward
return [min(1.0, count / 3) for count in matches]
def len_reward(completions: list[Dict[str, str]], solution: list[str], **kwargs) -> float:
"""Compute length-based rewards to discourage overthinking and promote token efficiency.
Taken from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599
Args:
completions: List of model completions
solution: List of ground truth solutions
Returns:
List of rewards where:
- For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)
- For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))
"""
contents = [completion[0]["content"] for completion in completions]
# First check correctness of answers
correctness = []
for content, sol in zip(contents, solution):
gold_parsed = parse(
sol,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(gold_parsed) == 0:
# Skip unparseable examples
correctness.append(True) # Treat as correct to avoid penalizing
print("Failed to parse gold solution: ", sol)
continue
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
correctness.append(verify(answer_parsed, gold_parsed))
# Calculate lengths
lengths = [len(content) for content in contents]
min_len = min(lengths)
max_len = max(lengths)
# If all responses have the same length, return zero rewards
if max_len == min_len:
return [0.0] * len(completions)
rewards = []
for length, is_correct in zip(lengths, correctness):
lambda_val = 0.5 - (length - min_len) / (max_len - min_len)
if is_correct:
reward = lambda_val
else:
reward = min(0, lambda_val)
rewards.append(float(reward))
return rewards
def get_cosine_scaled_reward(
min_value_wrong: float = -1.0,
max_value_wrong: float = -0.5,
min_value_correct: float = 0.5,
max_value_correct: float = 1.0,
max_len: int = 1000,
):
def cosine_scaled_reward(completions, solution, **kwargs):
"""Reward function that scales based on completion length using a cosine schedule.
Shorter correct solutions are rewarded more than longer ones.
Longer incorrect solutions are penalized less than shorter ones.
Args:
completions: List of model completions
solution: List of ground truth solutions
This function is parameterized by the following arguments:
min_value_wrong: Minimum reward for wrong answers
max_value_wrong: Maximum reward for wrong answers
min_value_correct: Minimum reward for correct answers
max_value_correct: Maximum reward for correct answers
max_len: Maximum length for scaling
"""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) == 0:
rewards.append(1.0) # Skip unparseable examples
print("Failed to parse gold solution: ", sol)
continue
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
is_correct = verify(answer_parsed, gold_parsed)
gen_len = len(content)
# Apply cosine scaling based on length
progress = gen_len / max_len
cosine = math.cos(progress * math.pi)
if is_correct:
min_value = min_value_correct
max_value = max_value_correct
else:
# Swap min/max for incorrect answers
min_value = max_value_wrong
max_value = min_value_wrong
reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
rewards.append(float(reward))
return rewards
return cosine_scaled_reward
def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
"""
Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373.
Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
Args:
ngram_size: size of the n-grams
max_penalty: Maximum (negative) penalty for wrong answers
"""
if max_penalty > 0:
raise ValueError(f"max_penalty {max_penalty} should not be positive")
def zipngram(text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)])
def repetition_penalty_reward(completions, **kwargs) -> float:
"""
reward function the penalizes repetitions
ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
Args:
completions: List of model completions
"""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for completion in contents:
if completion == "":
rewards.append(0.0)
continue
if len(completion.split()) < ngram_size:
rewards.append(0.0)
continue
ngrams = set()
total = 0
for ng in zipngram(completion, ngram_size):
ngrams.add(ng)
total += 1
scaling = 1 - len(ngrams) / total
reward = scaling * max_penalty
rewards.append(reward)
return rewards
return repetition_penalty_reward

View File

@@ -1,2 +0,0 @@
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
REASON_CHAT_TEMPLATE = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<Assistant>' + content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<Assistant>'}}{% endif %}"

View File

@@ -1,23 +0,0 @@
# problem/question, (solution), answer
RL_POST_TRAIN_DATASET_MAP = {
# Main datasets
"curated_deepscaler": "agentica-org/DeepScaleR-Preview-Dataset", # 40.3k
"curated_still": "RUC-AIBOX/STILL-3-Preview-RL-Data", # 33k
"curated_open_rs3": "knoveleng/open-rs", # 7k
"curated_open_rs2": "knoveleng/open-rs", # 7k
"curated_open_rs1": "knoveleng/open-s1", # 18.6k
# Extra datasets
"curated_limr": "GAIR/LIMR", # 1.39k
"curated_open_r1": "open-r1/OpenR1-Math-220k", # default split 93.7k
"curated_thoughts": "bethgelab/CuratedThoughts", # default split 66.1k
# Ablation
"curated_limr_large_lr_ablation": "GAIR/LIMR",
"curated_limr_small_lr_ablation": "GAIR/LIMR",
"curated_limr_large_rank_ablation": "GAIR/LIMR",
"curated_limr_medium_rank_ablation": "GAIR/LIMR",
"curated_limr_small_rank_ablation": "GAIR/LIMR",
"curated_limr_tiny_rank_ablation": "GAIR/LIMR",
"curated_open_rs3_drgrpo_ablation": "knoveleng/open-rs",
# Reasoning gym
"curated_rg_math": "starzmustdie/rg-math",
}

View File

@@ -1,23 +0,0 @@
# borrowed from https://github.com/huggingface/open-r1/blob/main/recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml
OPEN_R1_SYSTEM_PROMPT = """
You are a helpful AI Assistant that provides well-reasoned and detailed responses.
You first think about the reasoning process as an internal monologue and then provide the user with the answer.
Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>
"""
# borrowed from https://github.com/knoveleng/open-rs/blob/main/recipes/grpo.yaml
OPEN_RS_SYSTEM_PROMPT = """
A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer, and put your final answer within \\boxed{{}} .
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively,
i.e., <think> reasoning process here </think> <answer> answer here </answer>.
Note that respond by English, NOT use other languages.
"""
# the first question from aime 2024
FIXED_PROMPT_FOR_EVALUATION = """
Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards.
When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop.
When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop.
Suppose Aya walks at $s+\frac{1}{2}$ kilometers per hour.
Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop."""