Feat/curr adj (#394)

This commit is contained in:
joesharratt1229
2025-04-02 06:39:14 +01:00
committed by GitHub
parent 2c52f33c3a
commit 43c739cb3e
26 changed files with 152390 additions and 453 deletions

View File

@@ -86,7 +86,7 @@ class LetterCountingCurriculum(BaseCurriculum):
self._define_attributes(
RangeAttributeDefinition(
name="words",
levels=[10, 50, 100, 1000],
levels=list(range(5, 20, 2)),
description="Number of words in the span",
lower_field_name="min_words",
upper_field_name="max_words",

View File

@@ -5,6 +5,8 @@ from dataclasses import dataclass
from random import Random
from typing import Any, Optional
import numpy as np
from ..coaching import BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset
@@ -44,12 +46,6 @@ Please follow the instruction below:
## 2. Convert all numbers in the square brackets as strings. For example, ['-69', '-13', '1', '7', '11', '43', '59', '61']
"""
def _format_number(self, num: float, decimals: int) -> str:
"""Format number with specified decimal places"""
formatted = f"{num:.{decimals}f}"
# Reparse to ensure exact decimal representation
return f"{float(formatted):.{decimals}f}"
def _generate_numbers(self, rng: Random, count: int) -> tuple[list[float], list[str]]:
"""Generate list of numbers and their string representations"""
numbers = []
@@ -58,11 +54,9 @@ Please follow the instruction below:
for _ in range(count):
num = rng.uniform(self.config.min_value, self.config.max_value)
decimals = rng.randint(self.config.min_decimals, self.config.max_decimals)
num_str = self._format_number(num, decimals)
# Reparse to ensure exact value
num = float(num_str)
num = np.round(num, decimals)
numbers.append(num)
number_strs.append(num_str)
number_strs.append(str(num))
return numbers, number_strs
@@ -78,9 +72,8 @@ Please follow the instruction below:
desc_numbers = sorted(numbers, reverse=True)
# Format answers as string lists
decimals = len(number_strs[0].split(".")[-1]) if "." in number_strs[0] else 0
asc_answer = [self._format_number(n, decimals) for n in asc_numbers]
desc_answer = [self._format_number(n, decimals) for n in desc_numbers]
asc_answer = [str(n) for n in asc_numbers]
desc_answer = [str(n) for n in desc_numbers]
# Randomly choose ascending or descending
is_ascending = rng.choice([True, False])
@@ -158,7 +151,7 @@ Please follow the instruction below:
return 0.0
# Check if the values are close enough (allowing for small rounding differences)
tolerance = 0.1 # Increased tolerance to handle decimal differences
tolerance = 1 # Increased tolerance to handle decimal differences
for i in range(len(user_floats)):
if abs(user_floats[i] - expected_floats[i]) > tolerance:
return 0.0
@@ -177,7 +170,7 @@ class NumberSortingCurriculum(BaseCurriculum):
self._define_attributes(
RangeAttributeDefinition(
name="numbers",
levels=[10, 100, 500, 1000],
levels=list(range(5, 20, 2)),
description="How many numbers to sort",
lower_field_name="min_numbers",
upper_field_name="max_numbers",
@@ -185,7 +178,7 @@ class NumberSortingCurriculum(BaseCurriculum):
),
RangeAttributeDefinition(
name="decimals",
levels=[0, 2, 4, 6],
levels=list(range(0, 8)),
description="Number of decimal places",
lower_field_name="min_decimals",
upper_field_name="max_decimals",

View File

@@ -17,8 +17,9 @@ class SpellBackwardConfig:
"""Configuration for spelling words backward task generation"""
min_word_len: int = 3 # Minimum word length
max_word_len: int = 20 # Maximum word length
max_word_len: int = 10 # Maximum word length
seed: Optional[int] = None
data_file: str = "words3to10.txt"
size: int = 500 # Virtual dataset size
def validate(self) -> None:
@@ -34,12 +35,11 @@ class SpellBackwardDataset(ProceduralDataset):
super().__init__(config=config, seed=config.seed, size=config.size)
# Load and preprocess text
text = read_data_file("in_the_year_2889.txt")
# Extract words and clean them to contain only alphanumeric characters
text = read_data_file(self.config.data_file)
self.words = [
word
for word in re.findall(r"\b\w+\b", text)
if word.isalnum() and config.min_word_len <= len(word) <= config.max_word_len
word.strip()
for word in text.splitlines()
if word.strip().isalnum() and config.min_word_len <= len(word.strip()) <= config.max_word_len
]
def __getitem__(self, idx: int) -> dict:
@@ -69,10 +69,22 @@ class SpellBackwardDataset(ProceduralDataset):
expected_answer = entry["answer"]
if isinstance(answer, str):
try:
if expected_answer.lower() == answer.lower():
reward = 1.0
expected_answer = expected_answer.lower()
answer = answer.lower()
if expected_answer == answer:
return 1.0
else:
reward = 0.05
answer_len = len(expected_answer)
for i in range(len(expected_answer)):
if i < len(expected_answer) and i < len(answer):
if expected_answer[i] == answer[i]:
reward += 1 / answer_len
else:
continue
else:
break
if reward == 1.0:
reward -= 0.2
except:
reward = 0.0
return reward
@@ -86,11 +98,11 @@ class SpellBackwardCurriculum(BaseCurriculum):
self._define_attributes(
RangeAttributeDefinition(
name="word_len",
levels=[5, 10, 20, 30],
levels=list(range(3, 11, 1)),
description="Word length",
lower_field_name="min_word_len",
upper_field_name="max_word_len",
ensure_interval=True,
ensure_interval=False,
),
)

View File

@@ -125,14 +125,25 @@ class WordSortingDataset(ProceduralDataset):
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
oracle_answer = entry["metadata"]["sorted_words"]
if answer is not None and len(answer) > 0:
parsed_answer = [word.strip() for word in re.split(r",\s*", answer)]
if parsed_answer == oracle_answer:
return 1.0
elif sorted(parsed_answer) == oracle_answer:
return 0.2
return 0.0
if not answer:
return 0.0
parsed_answer = [word.strip() for word in re.split(r",\s*", answer)]
if parsed_answer == oracle_answer:
return 1.0
correct_positions = sum(
1 for i, word in enumerate(parsed_answer) if i < len(oracle_answer) and word == oracle_answer[i]
)
partial_score = correct_positions / len(oracle_answer)
if sorted(parsed_answer) == sorted(oracle_answer):
partial_score = max(partial_score, 0.2)
return partial_score
class WordSortingCurriculum(BaseCurriculum):

View File

@@ -239,3 +239,16 @@ class BaseCurriculum:
self.set_attr_level(attr_name, target_level)
return True
return False
def get_global_level(self) -> Optional[int]:
"""Get the global level of the curriculum."""
attr_dict = {}
if not self._attributes:
return 0
for attr_name in self._attributes:
attr = self.get_attribute(attr_name)
if isinstance(attr, RangeAttributeDefinition):
attr_dict[attr.upper_field_name] = self.get_attr_value(attr_name)
elif isinstance(attr, ScalarAttributeDefinition):
attr_dict[attr.field_name] = self.get_attr_value(attr_name)
return attr_dict

View File

@@ -54,7 +54,6 @@ class CurriculumExperimentConfig:
if not isinstance(data, dict):
raise ValueError("YAML data must contain a dictionary")
if "curricula" not in data:
raise ValueError("YAML data must contain a 'curricula' key")

View File

@@ -1,6 +1,6 @@
"""Experiment class combining dataset, scoreboard and curriculum."""
from typing import Any, Optional
from typing import Any, Literal, Optional
from reasoning_gym.coaching.base_curriculum import CurriculumContext
@@ -27,7 +27,8 @@ class Experiment:
entry = dataset[index]
score = dataset.score_answer(answer, entry)
metadata = entry["metadata"]
self.score_board.add_score(score, metadata, conversation)
score_board_metadata = {"difficulty": metadata["difficulty"], "source_dataset": metadata["source_dataset"]}
self.score_board.add_score(dataset_name, score, score_board_metadata, conversation)
return score
@classmethod
@@ -97,7 +98,15 @@ class CurriculumExperiment(Experiment):
self.curriculum_config = config
self.context = context
def update_difficulty(self):
def update_difficulty(self, dataset_name: str, method: Literal["increment", "decrement"]):
"""Update difficulty levels based on performance metrics"""
# TODO: Implement difficulty adjustment logic
pass
if method not in ["increment", "decrement"]:
raise ValueError(f"Invalid method: {method}")
if method == "increment":
self.curricula[dataset_name].increment_global_level()
elif method == "decrement":
self.curricula[dataset_name].decrement_global_level()
config = self.curricula[dataset_name].get_global_level()
self.composite.update_dataset_config(dataset_name, config)

View File

@@ -114,11 +114,13 @@ class GroupedScores:
class ScoreBoard:
"""Tracks scores and metadata for coaching sessions"""
scores: list[float] = field(default_factory=list)
metadata: list[dict[str, Any]] = field(default_factory=list)
conversations: list[Optional[list[dict]]] = field(default_factory=list)
scores: dict[str, list[float]] = field(default_factory=dict)
metadata: dict[str, list[dict[str, Any]]] = field(default_factory=dict)
conversations: dict[str, list[Optional[list[dict]]]] = field(default_factory=dict)
def add_score(self, score: float, metadata: dict[str, Any], conversation: Optional[list[dict]] = None) -> None:
def add_score(
self, dataset_name: str, score: float, metadata: dict[str, Any], conversation: Optional[list[dict]] = None
) -> None:
"""Add a new score entry with associated metadata and optional conversation
Args:
@@ -126,15 +128,19 @@ class ScoreBoard:
metadata: Dictionary of metadata about the task/attempt
conversation: Optional list of conversation turns as dicts
"""
self.scores.append(score)
self.metadata.append(metadata)
self.conversations.append(conversation)
if dataset_name not in self.scores:
self.scores[dataset_name] = []
self.metadata[dataset_name] = []
self.conversations[dataset_name] = []
self.scores[dataset_name].append(score)
self.metadata[dataset_name].append(metadata)
self.conversations[dataset_name].append(conversation)
def clear(self) -> None:
def clear(self, dataset_name: str) -> None:
"""Clear all stored scores, metadata and conversations"""
self.scores.clear()
self.metadata.clear()
self.conversations.clear()
self.scores[dataset_name] = []
self.metadata[dataset_name] = []
self.conversations[dataset_name] = []
def __len__(self) -> int:
"""Return the number of stored scores"""
@@ -147,7 +153,7 @@ class ScoreBoard:
placed first in the tuple as ("source", dataset) and ("idx", index).
"""
# Start with empty list
key_items = [("source", metadata["source_dataset"]), ("idx", metadata["source_index"])]
key_items = [("source", metadata["source_dataset"])]
# Add difficulty parameters or other metadata
if "difficulty" in metadata:
@@ -155,39 +161,52 @@ class ScoreBoard:
items = metadata["difficulty"].items()
else:
# Use all metadata except source info
items = ((k, v) for k, v in metadata.items() if k not in ("source_dataset", "source_index"))
items = ((k, v) for k, v in metadata.items() if k not in ("source_dataset"))
# Add remaining items in sorted order
key_items.extend(sorted((str(k), v) for k, v in items))
return tuple(key_items)
def aggregate(self, last_n: Optional[int] = None) -> GroupedScores:
"""Aggregate scores by difficulty parameters or full metadata if no difficulty present
def aggregate(self, last_n: Optional[int] = None) -> dict[str, GroupedScores]:
"""Aggregate scores by dataset name and then by difficulty parameters
Args:
last_n: Optional number of most recent entries to consider
If None, use all entries
If None, use all entries
Returns:
OrderedDict mapping difficulty parameter combinations to lists of scores
Keys are tuples of (param_name, value) pairs, sorted by param_name
Dictionary mapping dataset names to their respective GroupedScores objects
Each GroupedScores contains scores grouped by difficulty parameters for that dataset
"""
if not self.scores:
return GroupedScores(scores=OrderedDict(), total_scores=0)
return {}
# Determine start index for iteration
start_idx = max(0, len(self.scores) - last_n) if last_n is not None else 0
# Create a nested structure: dataset -> parameter groups -> scores
result = {}
# Group scores by difficulty parameters without creating intermediate lists
result = OrderedDict()
for i in range(start_idx, len(self.scores)):
key = self._metadata_to_key(self.metadata[i])
if key not in result:
result[key] = []
result[key].append(self.scores[i])
# Process each dataset
for dataset_name, dataset_scores in self.scores.items():
# Determine start index for this dataset
dataset_len = len(dataset_scores)
start_idx = max(0, dataset_len - last_n) if last_n is not None else 0
# Count total scores
total_scores = sum(len(scores) for scores in result.values())
# Create OrderedDict for this dataset's parameter groupings
dataset_groups = OrderedDict()
return GroupedScores(scores=result, total_scores=total_scores)
# Process scores for this dataset
for i in range(start_idx, dataset_len):
# Get metadata for this score
metadata = self.metadata[dataset_name][i]
params = self._metadata_to_key(metadata)
if params not in dataset_groups:
dataset_groups[params] = []
dataset_groups[params].append(dataset_scores[i])
# Create a GroupedScores object for this dataset
total_scores = sum(len(scores) for scores in dataset_groups.values())
result[dataset_name] = GroupedScores(scores=dataset_groups, total_scores=total_scores)
return result

View File

@@ -47,6 +47,14 @@ class CompositeConfig:
for ds in self.datasets:
ds.validate()
def get_dataset_weight(self, dataset_name: str) -> float:
"""Get the weight for a specific dataset by name."""
for ds in self.datasets:
if ds.name == dataset_name:
return ds.weight
raise ValueError(f"Dataset '{dataset_name}' not found in composite configuration")
@classmethod
def from_yaml_stream(cls, stream) -> "CompositeConfig":
"""Load configuration from a YAML stream

File diff suppressed because it is too large Load Diff

View File

@@ -90,14 +90,14 @@ def test_letter_counting_curriculum():
base_cfg: LetterCountingConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_words == 10 and base_cfg.max_words == 50
assert base_cfg.min_words == 5 and base_cfg.max_words == 7
# test incrementing attribute levels
curriculum.increment_attr_level("words")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_words == 10 and increased_cfg.max_words == 100
assert increased_cfg.min_words == 5 and increased_cfg.max_words == 9
# test decrementing attribute level for words again
curriculum.decrement_attr_level("words")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_words == 10 and partially_decreased_cfg.max_words == 50
assert partially_decreased_cfg.min_words == 5 and partially_decreased_cfg.max_words == 7

View File

@@ -99,23 +99,23 @@ def test_number_sorting_curriculum():
base_cfg: NumberSortingConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_numbers == 10 and base_cfg.max_numbers == 100
assert base_cfg.min_decimals == 0 and base_cfg.max_decimals == 2
assert base_cfg.min_numbers == 5 and base_cfg.max_numbers == 7
assert base_cfg.min_decimals == 0 and base_cfg.max_decimals == 1
assert base_cfg.min_value == -10_000 and base_cfg.max_value == 10_000
# test incrementing some attribute levels
curriculum.increment_attr_level("numbers")
curriculum.increment_attr_level("decimals")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_numbers == 10 and increased_cfg.max_numbers == 500
assert increased_cfg.min_decimals == 0 and increased_cfg.max_decimals == 4
assert increased_cfg.min_numbers == 5 and increased_cfg.max_numbers == 9
assert increased_cfg.min_decimals == 0 and increased_cfg.max_decimals == 2
assert increased_cfg.min_value == -10_000 and increased_cfg.max_value == 10_000
# test decrementing attribute level for numbers again
curriculum.decrement_attr_level("numbers")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_numbers == 10 and partially_decreased_cfg.max_numbers == 100
assert partially_decreased_cfg.min_decimals == 0 and partially_decreased_cfg.max_decimals == 4
assert partially_decreased_cfg.min_numbers == 5 and partially_decreased_cfg.max_numbers == 7
assert partially_decreased_cfg.min_decimals == 0 and partially_decreased_cfg.max_decimals == 2
assert partially_decreased_cfg.min_value == -10_000 and partially_decreased_cfg.max_value == 10_000

View File

@@ -61,32 +61,32 @@ def test_score_aggregation():
aggregated = experiment.score_board.aggregate()
# Verify we have scores grouped by difficulty parameters
assert len(aggregated.scores) > 0
assert len(aggregated["leg_counting"].scores.keys()) > 0
# Each key should be a tuple of tuples containing difficulty parameters
for key in aggregated.scores:
for key in aggregated["leg_counting"].scores:
assert isinstance(key, tuple)
# Each inner tuple should be (param_name, value) or (param_name, (min_value, max_value))
for param in key:
assert isinstance(param, tuple)
assert param[0] in ("source", "idx", "num_animals", "num_instances")
assert param[0] in ("source", "num_animals", "num_instances")
# Test aggregation with last_n
last_3 = experiment.score_board.aggregate(last_n=3)
assert len(last_3.scores) > 0
assert len(last_3["leg_counting"].scores) > 0
# Verify total scores count
assert last_3.total_scores == 3
assert last_3["leg_counting"].total_scores == 3
# Verify conversation tracking
assert len(experiment.score_board.conversations) == 5
for conv in experiment.score_board.conversations:
assert len(experiment.score_board.conversations["leg_counting"]) == 5
for conv in experiment.score_board.conversations["leg_counting"]:
assert len(conv) == 2 # user question and assistant response
assert conv[0]["role"] == "user"
assert conv[1]["role"] == "assistant"
# Test stats calculation
stats = aggregated.stats()
stats = aggregated["leg_counting"].stats()
for key, values in stats.scores.items():
assert isinstance(values, tuple)
@@ -107,11 +107,11 @@ def test_score_aggregation():
assert all(math.isnan(v) for v in stats_tuple[1:]) # stats should be NaN
# Test clear functionality
experiment.score_board.clear()
assert len(experiment.score_board.scores) == 0
assert len(experiment.score_board.metadata) == 0
assert len(experiment.score_board.conversations) == 0
assert len(experiment.score_board.aggregate().scores) == 0
experiment.score_board.clear("leg_counting")
assert len(experiment.score_board.scores["leg_counting"]) == 0
assert len(experiment.score_board.metadata["leg_counting"]) == 0
assert len(experiment.score_board.conversations["leg_counting"]) == 0
assert len(experiment.score_board.aggregate()["leg_counting"].scores) == 0
def test_experiment_with_composite():
@@ -147,15 +147,14 @@ def test_experiment_with_composite():
# Test aggregation
aggregated = experiment.score_board.aggregate()
assert len(aggregated.scores) > 0
assert len(aggregated["leg_counting"].scores) > 0
# Verify source dataset info is first in keys
for key in aggregated.scores:
for key in aggregated["leg_counting"].scores:
assert key[0][0] == "source" # First tuple should be ("source", dataset_name)
assert key[1][0] == "idx" # Second tuple should be ("idx", index)
# Test stats
stats = aggregated.stats()
stats = aggregated["leg_counting"].stats()
for key, values in stats.scores.items():
assert isinstance(values, tuple)
assert len(values) == 5 # (count, mean, std, min, max)

View File

@@ -71,14 +71,14 @@ def test_spell_backward_curriculum():
base_cfg: SpellBackwardConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_word_len == 5 and base_cfg.max_word_len == 10
assert base_cfg.min_word_len == 3 and base_cfg.max_word_len == 3
# test incrementing attribute levels
curriculum.increment_attr_level("word_len")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_word_len == 5 and increased_cfg.max_word_len == 20
assert increased_cfg.min_word_len == 3 and increased_cfg.max_word_len == 4
# test decrementing attribute levels
curriculum.decrement_attr_level("word_len")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_word_len == 5 and partially_decreased_cfg.max_word_len == 10
assert partially_decreased_cfg.min_word_len == 3 and partially_decreased_cfg.max_word_len == 3

View File

@@ -17,21 +17,22 @@ pip install -e .
```bash
pip install ray wandb
pip install torch==2.6.0
pip install flash-attn --no-build-isolation
```
4. Install veRL (tested with HEAD c34206925e2a50fd452e474db857b4d488f8602d):
```bash
git clone https://github.com/volcengine/verl.git
cd verl
pip install -e .
pip install git+https://github.com/volcengine/verl.git@c6dc8b73cf011aa75b8c6a47b0322f50aed800ad#egg=verl
```
5. Install vLLM:
```bash
pip install -U vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
pip install vllm==0.6.3 transformers==4.50.3 fire==0.7.0
```
6. Install flash attention
```
pip install flash-attn --no-build-isolation
```
6. Log in to HF and W&B:
@@ -64,3 +65,33 @@ CUDA_VISIBLE_DEVICES=0,1 bash train.sh
CUDA_VISIBLE_DEVICES is set to 0,1 to use the first two GPUs on the machine (see `nvidia-smi` output). This can be adjusted as needed. `tensor_model_parallel_size` and `n_gpus_per_node` should also be set to the number of GPUs you are using.
You can change all configuration options by either modifying the config YAML (in this case, `config/llama3.1_1b_grpo.yaml`) or providing them as arguments to the Python script. Note that the batch sizes set in the Llama 1B and Qwen 1.5B configs are as high as it was possible for me to set them for the puzzles dataset mix on 2xA6000 GPUs without OOMs. Depending on the hardware you use and the datasets you train on, you may need to adjust these.
# Exporting from FSDP checkpoint to HF model checkpoint
After training your model the weights are saved across as a sharded checkpoints across several files. To faciliate simple evaluation of your trained model you may want to convert this into a HF model checkpoint. We have added a utility script to convert your sharded checkpoint into a hf checkpoint.
To run this script. Navigate to the training directory and run the following
```python
python load_fsdp_to_hf.py /path/to/fsdp/checkpoint/global_step_num/actor /path/to/hugginface/checkpoint/global_step_num/actor/huggingface saved_model_name
```
For example
```python
python utils/load_fsdp_to_hf.py checkpoints/rg-test/intra_reasoning_algorithmic_qwen_3b_composite/global_step_400/actor/ checkpoints/rg-test/intra_reasoning_algorithmic_qwen_3b_composite/global_step_400/actor/huggingface qwen3b
```
# Run evaluations
From here you may to run evaluations of your trained model. In the `training/evaluation` directory there is a script `evaluate_model.py` which you csn run to evaluate your trained model on a specific dataset. You specify evaluation parameters in a yaml file. This evaluation can point to either a local or remote model. For example the configuration file `training/evaluation/eval_algorithmic_composite.yaml` specifies the path to a local model which is stored as a hugginface checkpoint at `training/utils/qwen3b_500` (note that you have to convert to fsdp checkpoint to hf checkpoint for evaluation script to work as shown in the previous step).
## Run the script
Navigate to evaluations directory:
```
python evaluate_model.py --config path-to-yaml
```
For example
```
python evaluate_model.py --config eval_algorithmic_composite.yaml
```

View File

@@ -1,44 +1,48 @@
reasoning_gym:
dataset_size: 10000
dataset_size: 20000
developer_prompt: DeepSeekZero
enable_curriculum_learning: False
datasets: # Used if enable_curriculum_learning is False
mini_sudoku:
weight: 0.33
config:
min_empty: 6
futoshiki:
weight: 0.33
config:
max_board_size: 5
sudoku:
weight: 0.34
config:
min_empty: 20
curricula:
leg_counting:
attribute_levels:
num_animals: 2
weight: 1.0
products:
attribute_levels:
num_terms: 4
num_digits: 4
weight: 1.0
chain_sum:
attribute_levels:
num_terms: 4
num_digits: 4
weight: 1.0
datasets:
spell_backward:
weight: 0.33
config:
min_word_len: 3
max_word_len: 10
letter_jumble:
weight: 0.34
config:
min_word_len: 1 # Minimum word length
max_word_len: 50 # Maximum word length
min_words: 3 # Minimum words per task
max_words: 40
word_sorting:
weight: 0.33
config:
min_words: 3
max_words: 10
min_word_length: 3
max_word_length: 12
curriculum:
enabled: False
schedule:
automatic: True
update_steps: 30 # automatic curriculum updating after 50 steps
last_k: 20
success_threshold: 0.70
failure_threshold: 0.10
curricula:
spell_backward:
attribute_levels:
word_len: 0
reward:
format_reward:
enable: True
scaling_factor: 0.2
prepend_think_token: False # Set to True only when the tokenizer's prompt template pre-fills the generation with <think>, such as in the case of (distilled) r1 models
length_reward:
enable: True
scaling_factor: 0.2
use_accuracy: True
secondary_rewards:
- name: cosine
scaling_factor: 0.3
- name: format
scaling_factor: 0.2
kwargs:
preappend_thinking_token: False
data:
tokenizer: null
@@ -47,24 +51,23 @@ data:
prompt_key: prompt
max_prompt_length: 512
max_response_length: 1024
train_batch_size: 64
train_batch_size: 32
val_batch_size: 64
return_raw_input_ids: True # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: True
return_raw_input_ids: True
actor_rollout_ref:
hybrid_engine: True
model:
path: meta-llama/Llama-3.2-1B-Instruct
path: Qwen/Qwen2.5-3B-Instruct
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: True
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 32
ppo_mini_batch_size: 16
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: 16
ppo_micro_batch_size_per_gpu: 4
use_dynamic_bsz: False
ppo_max_token_len_per_gpu: 12288 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
@@ -76,14 +79,12 @@ actor_rollout_ref:
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
checkpoint:
contents: ['model', 'hf_model', 'optimizer', 'extra']
optim:
lr: 1e-6
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
total_training_steps: 500 # must be override by program
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
@@ -111,13 +112,13 @@ actor_rollout_ref:
response_length: ${data.max_response_length}
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.6
gpu_memory_utilization: 0.7
ignore_eos: False
enforce_eager: True
free_cache_engine: True
load_format: dummy_dtensor
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
tensor_model_parallel_size: 4
max_num_batched_tokens: 12288
max_num_seqs: 1024
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: 160
@@ -128,6 +129,7 @@ actor_rollout_ref:
# for hf rollout
do_sample: True
use_fire_sampling: False
max_model_len: 12288
# number of responses (i.e. num sample times)
n: 8 # > 1 for grpo
val_kwargs:
@@ -141,17 +143,17 @@ algorithm:
kl_ctrl:
type: fixed
kl_coef: 0.001
verbose: True
trainer:
balance_batch: True
total_epochs: 10
total_training_steps: null
total_epochs: 1
total_training_steps: 500
project_name: rg-test
experiment_name: verl_grpo_llama3.1_1b
experiment_name: intra_reasoning_algorithmic_qwen_3b_composite
logger: [ 'console', 'wandb' ]
val_generations_to_log_to_wandb: 0
nnodes: 1
n_gpus_per_node: 2
n_gpus_per_node: 4
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
@@ -163,6 +165,7 @@ trainer:
del_local_ckpt_after_load: False
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
critic:
strategy: fsdp
optim:

View File

@@ -1,221 +0,0 @@
reasoning_gym:
dataset_size: 10000
developer_prompt: DeepSeekZero
enable_curriculum_learning: False
datasets: # Used if enable_curriculum_learning is False
mini_sudoku:
weight: 0.33
config:
min_empty: 6
futoshiki:
weight: 0.33
config:
max_board_size: 5
sudoku:
weight: 0.34
config:
min_empty: 20
curricula:
leg_counting:
attribute_levels:
num_animals: 2
weight: 1.0
products:
attribute_levels:
num_terms: 4
num_digits: 4
weight: 1.0
chain_sum:
attribute_levels:
num_terms: 4
num_digits: 4
weight: 1.0
reward:
format_reward:
enable: True
scaling_factor: 0.2
prepend_think_token: False # Set to True only when the tokenizer's prompt template pre-fills the generation with <think>, such as in the case of (distilled) r1 models
length_reward:
enable: True
scaling_factor: 0.2
data:
tokenizer: null
train_files: train.parquet
val_files: test.parquet
prompt_key: prompt
max_prompt_length: 512
max_response_length: 1024
train_batch_size: 16
val_batch_size: 16
return_raw_input_ids: True # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: True
actor_rollout_ref:
hybrid_engine: True
model:
path: Qwen/Qwen2.5-1.5B-Instruct
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: True
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 16
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: 8
use_dynamic_bsz: False
ppo_max_token_len_per_gpu: 12288 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
clip_ratio: 0.2
entropy_coeff: 0.001
use_kl_loss: True # True for GRPO
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
checkpoint:
contents: ['model', 'hf_model', 'optimizer', 'extra']
optim:
lr: 1e-6
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: True
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: 160
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
rollout:
name: vllm
temperature: 1.0
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1
prompt_length: ${data.max_prompt_length} # not use for opensource
response_length: ${data.max_response_length}
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.6
ignore_eos: False
enforce_eager: True
free_cache_engine: True
load_format: dummy_dtensor
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_num_seqs: 1024
max_model_len: 1024
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: 160
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
disable_log_stats: True
enable_chunked_prefill: True # could get higher throughput
# for hf rollout
do_sample: True
use_fire_sampling: False
# number of responses (i.e. num sample times)
n: 8 # > 1 for grpo
val_kwargs:
do_sample: True
algorithm:
gamma: 1.0
lam: 1.0
adv_estimator: grpo
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
type: fixed
kl_coef: 0.001
trainer:
balance_batch: True
total_epochs: 10
total_training_steps: null
project_name: rg-test
experiment_name: verl_grpo_qwen2.5_1.5b
logger: [ 'console', 'wandb' ]
val_generations_to_log_to_wandb: 0
nnodes: 1
n_gpus_per_node: 2
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
critic:
strategy: fsdp
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
model:
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: { }
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: True
use_remove_padding: False
fsdp_config:
param_offload: False
optimizer_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
fsdp_size: -1
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: 1 # sp size
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
cliprange_value: 0.5
# Reward model not used for GRPO
reward_model:
enable: False
strategy: fsdp
model:
input_tokenizer: ${actor_rollout_ref.model.path}
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
use_remove_padding: False
fsdp_config:
min_num_params: 0
param_offload: False
fsdp_size: -1
micro_batch_size: null
micro_batch_size_per_gpu: null
max_length: null
ulysses_sequence_parallel_size: 1
use_dynamic_bsz: ${critic.use_dynamic_bsz}
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}

View File

@@ -0,0 +1,28 @@
# Model configuration
model_path: ../utils/qwen3b_500 # Change to the smaller model
max_tokens: 1024 # From max_response_length in training config
temperature: 0.7 # Lower temperature for more focused responses
top_p: 0.9 # From rollout top_p
developer_prompt: DeepSeekZero
developer_role: system # Standard role for system prompts
# Output configuration
output_dir: results
save_metadata: true
save_full_results: true
eval_repeats: 3
# Categories and datasets to evaluate
categories:
- category: reasoning
datasets:
- dataset: number_sorting
size: 100
seed: 42
params:
min_numbers: 3
max_numbers: 10
min_decimals: 0
max_decimals: 2
min_value: -100.0
max_value: 100.0

View File

@@ -0,0 +1,28 @@
# Model configuration
model_path: Qwen/Qwen2.5-3B-Instruct # Change to the smaller model
max_tokens: 1024 # From max_response_length in training config
temperature: 0.7 # Lower temperature for more focused responses
top_p: 0.9 # From rollout top_p
developer_prompt: DeepSeekZero
developer_role: system # Standard role for system prompts
# Output configuration
output_dir: results
save_metadata: true
save_full_results: true
eval_repeats: 3
# Categories and datasets to evaluate
categories:
- category: reasoning
datasets:
- dataset: number_sorting
size: 100
seed: 42
params:
min_numbers: 3
max_numbers: 10
min_decimals: 0
max_decimals: 2
min_value: -100.0
max_value: 100.0

View File

@@ -0,0 +1,278 @@
#!/usr/bin/env python
import argparse
import json
import sys
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
import yaml
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import reasoning_gym
from reasoning_gym.utils import SYSTEM_PROMPTS, extract_answer
@dataclass
class DatasetConfig:
dataset: str
size: Optional[int] = None
seed: Optional[int] = None
params: Dict[str, Any] = None
def __post_init__(self):
if self.params is None:
self.params = {}
@dataclass
class CategoryConfig:
category: str
datasets: List[DatasetConfig]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CategoryConfig":
datasets = [DatasetConfig(**d) for d in data["datasets"]]
return cls(category=data["category"], datasets=datasets)
@dataclass
class EvalConfig:
model_path: str
max_tokens: int
temperature: float
top_p: float
output_dir: str
save_metadata: bool
save_full_results: bool
categories: List[CategoryConfig]
# Optional: you can provide a system prompt name (looked up in SYSTEM_PROMPTS)
developer_prompt: Optional[str] = None
developer_role: str = "system"
# NEW FIELD: How many times each question is evaluated
eval_repeats: int = 1
@classmethod
def from_yaml(cls, path: str) -> "EvalConfig":
with open(path, "r") as f:
data = yaml.safe_load(f)
categories = [CategoryConfig.from_dict(cat) for cat in data["categories"]]
data["categories"] = categories
return cls(**data)
class LocalModelEvaluator:
def __init__(
self,
model_path: str,
config: EvalConfig,
device: str = "cuda:0",
batch_size: int = 1,
verbose: bool = False,
):
self.config = config
self.device = device
self.batch_size = batch_size
self.verbose = verbose
# Load model and tokenizer
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16 if "cuda" in device else torch.float32,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model.to(device)
self.start_time = datetime.now()
# If you have a system prompt, retrieve it from SYSTEM_PROMPTS
self.developer_prompt = None
if self.config.developer_prompt:
self.developer_prompt = SYSTEM_PROMPTS[self.config.developer_prompt]
self.developer_role = self.config.developer_role
def get_model_response(self, question: str) -> str:
"""
Generates a single response to the given question and returns the
raw text of that response.
"""
# Build a "chat" prompt if developer_prompt is available
chat = []
if self.developer_prompt:
chat.append({"role": self.developer_role, "content": self.developer_prompt})
chat.append({"role": "user", "content": question})
# Some Hugging Face chat-friendly models use a convenience method like below:
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=self.config.max_tokens,
temperature=self.config.temperature,
top_p=self.config.top_p,
do_sample=True if self.config.temperature > 0 else False,
pad_token_id=self.tokenizer.eos_token_id,
)
# Decode the *new* tokens only:
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1] :], skip_special_tokens=True).strip()
if self.verbose:
print(f"[Prompt]\n{question}\n[Response]\n{response}\n{'-'*60}")
return response
def process_entry(self, dataset, entry: Dict[str, Any]) -> Dict[str, Any]:
"""
Evaluate one question from the dataset `eval_repeats` times, then
average the score. We also keep track of the best (max) score
and store each completion for potential debugging.
"""
all_completions = []
for _ in range(self.config.eval_repeats):
try:
raw_response = self.get_model_response(entry["question"])
model_answer = extract_answer(raw_response)
score = dataset.score_answer(answer=model_answer, entry=entry)
score = 0.0 if score < 1 else score
all_completions.append(
{
"model_answer": model_answer,
"full_model_response": raw_response,
"score": score,
}
)
except Exception as e:
# If there's an error on a single repetition, store it and continue
all_completions.append(
{
"model_answer": None,
"full_model_response": "",
"score": 0.0,
"error": str(e),
}
)
# Compute statistics across all runs
scores = [c["score"] for c in all_completions]
mean_score = sum(scores) / len(scores)
best_score = max(scores)
return {
"question": entry["question"],
"expected_answer": str(entry["answer"]),
"best_score": best_score,
"mean_score": mean_score,
"completions": all_completions,
}
def evaluate_dataset(self, category_name: str, dataset_config: DatasetConfig) -> Dict[str, Any]:
"""
Loads the dataset, processes each entry, and then computes
the overall average across all entries.
"""
dataset_name = dataset_config.dataset
params = {
**dataset_config.params,
"size": dataset_config.size,
"seed": dataset_config.seed,
}
dataset = reasoning_gym.create_dataset(dataset_name, **params)
entries = list(dataset)
results = []
for entry in tqdm(entries, desc=f"Processing {dataset_name}"):
results.append(self.process_entry(dataset, entry))
# Summarize the entire dataset
total_mean_score = sum(r["mean_score"] for r in results)
avg_score = total_mean_score / len(results) if results else 0.0
return {
"name": dataset_name,
"category": category_name,
"average_score": avg_score,
"total_examples": len(results),
"config": params,
"results": results,
}
def evaluate_all(self) -> Dict[str, Any]:
"""
Runs evaluation on all categories/datasets.
"""
cat_results = []
for cat in self.config.categories:
datasets = []
for ds_cfg in cat.datasets:
datasets.append(self.evaluate_dataset(cat.category, ds_cfg))
cat_results.append({"name": cat.category, "datasets": datasets})
return {
"metadata": {
"timestamp": self.start_time.isoformat(),
"model": self.config.model_path,
"device": self.device,
"duration_seconds": (datetime.now() - self.start_time).total_seconds(),
"max_tokens": self.config.max_tokens,
"temperature": self.config.temperature,
"top_p": self.config.top_p,
"eval_repeats": self.config.eval_repeats,
},
"categories": cat_results,
}
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True)
parser.add_argument("--output-dir")
parser.add_argument("--category")
parser.add_argument("--device", default="cuda:0")
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()
# Load config from YAML
config = EvalConfig.from_yaml(args.config)
# Command-line overrides
if args.output_dir:
config.output_dir = args.output_dir
if args.category:
# Filter categories if specified
config.categories = [c for c in config.categories if c.category == args.category]
if not config.categories:
print(f"Category '{args.category}' not found.")
return 1
evaluator = LocalModelEvaluator(
model_path=config.model_path,
config=config,
device=args.device,
batch_size=args.batch_size,
verbose=args.verbose,
)
results = evaluator.evaluate_all()
# Save results
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
out_dir = Path(config.output_dir) / f"local_{timestamp}"
out_dir.mkdir(parents=True, exist_ok=True)
with open(out_dir / "results.json", "w") as f:
json.dump(results, f, indent=2)
print(f"Results saved to: {out_dir / 'results.json'}")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,3 @@
from .reward import reward_registry
__all__ = ["reward_registry"]

View File

@@ -0,0 +1,99 @@
import math
import re
from typing import Any, Callable, Dict
class RewardRegistry:
"""Simple registry for secondary reward functions."""
def __init__(self):
self.reward_functions = {}
def register(self, name: str):
"""Register a reward function."""
def decorator(func):
self.reward_functions[name] = func
return func
return decorator
def get(self, name: str):
"""Get a reward function by name."""
return self.reward_functions.get(name)
def list_functions(self):
"""List available reward function names."""
return list(self.reward_functions.keys())
reward_registry = RewardRegistry()
@reward_registry.register("cosine")
def cosine_scaled_reward(solution_str, scaling_factor, **kwargs):
"""Reward function that scales based on completion length using a cosine schedule."""
min_value_wrong = 0
max_value_wrong = 0.7
min_value_correct = 0.95
max_value_correct = 1.0
max_len = 1000
is_correct = kwargs.get("is_correct", False)
gen_len = len(solution_str)
# Apply cosine scaling based on length
progress = gen_len / max_len
cosine = math.cos(progress * math.pi)
if is_correct:
min_value = min_value_correct
max_value = max_value_correct
else:
min_value = max_value_wrong
max_value = min_value_wrong
cosine_scaled_reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
return cosine_scaled_reward * scaling_factor
@reward_registry.register("format")
def compute_format_reward(solution_str: str, scaling_factor: float = 0.2, **kwargs) -> float:
"""Reward use of exactly one correctly structured <think> and <answer> block."""
preappend_thinking_token = kwargs.get("preappend_thinking_token", False)
if preappend_thinking_token:
solution_str = "<think>" + solution_str
pattern = r"\s*<think>.*?</think>\s*<answer>.*?</answer>"
if not re.match(pattern, solution_str, re.DOTALL):
return 0.0
think_matches = list(re.finditer(r"<think>(.*?)</think>", solution_str, re.DOTALL))
answer_matches = list(re.finditer(r"<answer>(.*?)</answer>", solution_str, re.DOTALL))
if len(think_matches) != 1 or len(answer_matches) != 1:
return 0.0
think_content = think_matches[0].group(1)
if "<think>" in think_content or "<answer>" in think_content:
return 0.0
answer_content = answer_matches[0].group(1)
if "<answer>" in answer_content or "<think>" in answer_content:
return 0.0
return 1.0 * scaling_factor
@reward_registry.register("length")
def length_reward(solution_str, scaling_factor, **kwargs):
"""Reward length appropriately based on correctness."""
correctness_score = kwargs.get("correctness_score", 0.0)
epsilon = 1e-6
max_score = kwargs.get("max_score", 1.0)
max_output_length = kwargs.get("max_output_length", 1024)
generation_len = len(solution_str)
progress = min(generation_len / max_output_length, 1.0)
if correctness_score < max_score - epsilon:
length_reward = (max_score - correctness_score) * progress
else:
length_reward = -progress
return length_reward * scaling_factor

View File

@@ -21,15 +21,14 @@ def prepare_datasets(config, tokenizer) -> tuple[ReasoningGymDataset, ReasoningG
developer_prompt_setting = config.reasoning_gym.developer_prompt
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS[developer_prompt_setting]
if config.reasoning_gym.enable_curriculum_learning:
curricula = config.reasoning_gym.curricula
if config.curriculum.enabled:
curricula = config.curriculum.curricula
curriculum_config = CurriculumExperimentConfig(
curricula={
curriculum_name: CurriculumAttributeConfig(**curriculum_config)
for curriculum_name, curriculum_config in curricula.items()
}
)
curriculum_config.validate()
train_data_source = CurriculumExperiment(
name=config.trainer.experiment_name, config=curriculum_config, size=dataset_size, seed=1
@@ -42,7 +41,6 @@ def prepare_datasets(config, tokenizer) -> tuple[ReasoningGymDataset, ReasoningG
]
train_data_source = reasoning_gym.create_dataset("composite", seed=1, size=dataset_size, datasets=dataset_specs)
val_data_source = reasoning_gym.create_dataset("composite", seed=2, size=dataset_size, datasets=dataset_specs)
train_dataset = make_dataset(tokenizer, train_data_source, developer_prompt)
val_dataset = make_dataset(tokenizer, val_data_source, developer_prompt)
return train_dataset, val_dataset

View File

@@ -1,14 +1,26 @@
# Adapted version of Bytedance code:
# https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/verl/trainer/main_ppo.py
import re
import uuid
from copy import deepcopy
import numpy as np
import torch
from omegaconf import OmegaConf, open_dict
from rewards import reward_registry
from torchdata.stateful_dataloader import StatefulDataLoader
from utils import ReasoningGymDataset
from verl import DataProto
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.trainer.ppo.ray_trainer import (
AdvantageEstimator,
RayPPOTrainer,
_timer,
apply_kl_penalty,
compute_advantage,
compute_data_metrics,
compute_timing_metrics,
reduce_metrics,
)
from verl.utils.dataset.rl_dataset import collate_fn
from reasoning_gym.utils import extract_answer
@@ -30,9 +42,27 @@ class RayGRPOTrainer(RayPPOTrainer):
self.val_dataset = val_dataset
self.max_output_length = max_output_length
self.format_reward_scaling_factor = config.reward.format_reward.scaling_factor
self.format_reward_prepend_think_token = config.reward.format_reward.prepend_think_token
self.length_reward_scaling_factor = config.reward.length_reward.scaling_factor
if config.curriculum.enabled:
self.last_k = config.curriculum.last_k
else:
self.last_k = None
self.reward_functions = []
if hasattr(config, "reward") and hasattr(config.reward, "secondary_rewards"):
for func_config in config.reward.secondary_rewards:
func_name = func_config.name
scaling_factor = func_config.get("scaling_factor", 1.0)
func = reward_registry.get(func_name)
if func:
# Store both function and its arguments
self.reward_functions.append(
{
"function": func,
"name": func_name,
"scaling_factor": scaling_factor,
"kwargs": func_config.get("kwargs", {}),
}
)
train_reward_fn = lambda data: self._score_output(data, num_examine=0)
val_reward_fn = lambda data: self._score_output(data, num_examine=1)
@@ -70,83 +100,46 @@ class RayGRPOTrainer(RayPPOTrainer):
sequences_str = prompt_str + response_str
index = data_item.non_tensor_batch["index"]
reward = score = self._compute_correctness_score(
correctness_score = self._compute_correctness_score(
solution_str=response_str,
index=index,
)
if self.config.reward.format_reward.enable:
format_reward = self._compute_format_reward(response_str)
reward += format_reward
if self.config.reward.use_accuracy:
reward_components = {"correctness": correctness_score}
total_reward = correctness_score
else:
format_reward = 0.0
reward_components = {}
total_reward = 0
if self.config.reward.length_reward.enable:
length_reward = self._compute_length_reward(response_str, score)
reward += length_reward
else:
length_reward = 0.0
for reward_fn in self.reward_functions:
func = reward_fn["function"]
name = reward_fn["name"]
scaling_factor = reward_fn["scaling_factor"]
kwargs = reward_fn["kwargs"]
if name == "cosine":
is_correct = correctness_score == 1.0
reward = func(response_str, scaling_factor, is_correct=is_correct, **kwargs)
elif name == "length":
reward = func(response_str, scaling_factor, correctness_score=correctness_score, **kwargs)
else:
reward = func(response_str, scaling_factor, **kwargs)
reward_components[name] = reward
total_reward += reward
reward_tensor[i, valid_response_length - 1] = reward
reward_tensor[i, valid_response_length - 1] = total_reward
if num_printed < num_examine:
print(
f"reward={reward} (score={score}, format={format_reward}, length={length_reward}), seq={sequences_str}"
)
components = ", ".join([f"{k}={v:.2f}" for k, v in reward_components.items()])
print(f"(score={total_reward}, seq={sequences_str}, response={response_str})")
print(f"reward={total_reward:.2f} ({components})")
num_printed += 1
return reward_tensor
def _compute_format_reward(self, solution_str: str) -> float:
"""Reward use of exactly one correctly structured <think> and <answer> block."""
if self.format_reward_prepend_think_token:
solution_str = "<think>" + solution_str
scaling_factor = self.format_reward_scaling_factor
# check <think> and <answer> blocks are present
pattern = r"\s*<think>.*?</think>\s*<answer>.*?</answer>"
if not re.match(pattern, solution_str, re.DOTALL):
return 0.0
# check exactly one properly structured <think> block and one <answer> block
think_matches = list(re.finditer(r"<think>(.*?)</think>", solution_str, re.DOTALL))
answer_matches = list(re.finditer(r"<answer>(.*?)</answer>", solution_str, re.DOTALL))
if len(think_matches) != 1 or len(answer_matches) != 1:
return 0.0
# check for <think> or <answer> inside <think>
think_content = think_matches[0].group(1)
if "<think>" in think_content or "<answer>" in think_content:
return 0.0
# check for nested <think> or <answer> inside <answer>
answer_content = answer_matches[0].group(1)
if "<answer>" in answer_content or "<think>" in answer_content:
return 0.0
return 1.0 * scaling_factor
def _compute_length_reward(
self,
solution_str: str,
correctness_score: float,
max_score: float = 1.0,
) -> float:
"""
Reward shorter solutions for perfect answers, longer solutions for imperfect answers.
The scaling factor for this should be set far below 1.0, to avoid dominating the reward signal over correctness.
"""
epsilon = 1e-6
scaling_factor = self.length_reward_scaling_factor
generation_len = len(solution_str)
progress = min(generation_len / self.max_output_length, 1.0)
if correctness_score < max_score - epsilon:
# for imperfect answers, incentivise longer ones
length_reward = (max_score - correctness_score) * progress
else:
# for perfect answers, penalise longer ones
length_reward = -progress
return length_reward * scaling_factor
def _compute_correctness_score(self, solution_str: str, index: int) -> float:
found_answer = extract_answer(solution_str, tag_name="answer")
data = self.train_dataset.data
entry = data[index]
if self.train_dataset.experiment:
experiment = self.train_dataset.experiment
@@ -190,3 +183,205 @@ class RayGRPOTrainer(RayPPOTrainer):
with open_dict(self.config):
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
self.config.critic.optim.total_training_steps = total_training_steps
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from omegaconf import OmegaConf
from verl.utils.tracking import Tracking
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
print(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return
# we start from step 1
self.global_steps += 1
last_val_metrics = None
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
timing_raw = {}
batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
if "multi_modal_inputs" in batch.non_tensor_batch.keys():
gen_batch = batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"],
)
else:
gen_batch = batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids"],
)
is_last_step = self.global_steps >= self.total_training_steps
with _timer("step", timing_raw):
# generate a batch
with _timer("gen", timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with _timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# recompute old_log_probs
with _timer("old_log_prob", timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)
if self.use_reference_policy:
# compute reference log_prob
with _timer("ref", timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with _timer("values", timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with _timer("adv", timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)
# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch["token_level_scores"] = reward_tensor
# compute rewards. apply_kl_penalty if available
if not self.config.actor_rollout_ref.actor.get("use_kl_loss", False):
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(kl_metrics)
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
# compute advantages, executed on the driver process
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
)
# update critic
if self.use_critic:
with _timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with _timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
# validate
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
):
with _timer("testing", timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and (
is_last_step or self.global_steps % self.config.trainer.save_freq == 0
):
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# collect metrics
if self.config.curriculum.enabled:
grouped_scores = self.train_dataset.aggregate(last_n=self.config.curriculum.last_k)
if self.config.curriculum.schedule.automatic:
for dataset_name in grouped_scores.keys():
if self.global_steps % self.config.curriculum.schedule.update_steps == 0:
self.train_dataset.experiment.update_difficulty(dataset_name, method="increment")
else:
print(grouped_scores)
for dataset_name in grouped_scores.keys():
if (
grouped_scores[dataset_name]["results"] > self.config.curriculum.success_threshold
) and (grouped_scores[dataset_name]["total_samples"] >= self.config.curriculum.last_k):
self.train_dataset.update_experiment_difficulty(dataset_name, method="increment")
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: implement actual tflpo and theoretical tflpo
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
if is_last_step:
print(f"Final validation metrics: {last_val_metrics}")
return
self.global_steps += 1

View File

@@ -1,5 +1,6 @@
from typing import Optional
from typing import Literal, Optional
import numpy as np
import verl.utils.torch_functional as verl_F
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
@@ -67,6 +68,33 @@ class ReasoningGymDataset(Dataset):
row_dict["index"] = index
return row_dict
def update_experiment_difficulty(self, dataset_name: str, method: Literal["increment", "decrement"]):
"""Update the difficulty of the underlying dataset."""
if self.experiment is None:
raise ValueError("Cannot update difficulty: dataset is not a CurriculumExperiment")
if method not in ["increment", "decrement"]:
raise ValueError("Invalid method: must be 'increment' or 'decrement'")
self.experiment.score_board.clear(dataset_name)
self.experiment.update_difficulty(dataset_name, method)
self.data = self.experiment.composite
return True
def aggregate(self, last_n: Optional[int] = None):
"""Aggregate scores from the underlying experiment"""
if self.experiment is None:
raise ValueError("Cannot aggregate scores: dataset is not a CurriculumExperiment")
results = self.experiment.score_board.aggregate(last_n=last_n)
output_results = {}
for key, value in results.items():
output_results[key] = {}
scores = value.scores
first_key = list(scores.keys())[0]
output_results[key]["results"] = np.mean(scores[first_key])
output_results[key]["total_samples"] = value.total_scores
return output_results
def make_dataset(
tokenizer,
@@ -78,6 +106,7 @@ def make_dataset(
"""
kwargs = {
"tokenizer": tokenizer,
# "dataset_name": dataset_name,
"developer_prompt": developer_prompt,
}
if isinstance(data_source, Experiment):

View File

@@ -0,0 +1,36 @@
#!/usr/bin/env python
# encoding: utf-8
from collections import defaultdict
from glob import glob
import fire
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
def main(fsdp_checkpoint_path, huggingface_model_path, output_path):
state_dict = defaultdict(list)
world_size = 4
for rank in range(world_size):
filepath = f"{fsdp_checkpoint_path}/model_world_size_{world_size}_rank_{rank}.pt"
print("loading", filepath)
this_state_dict = torch.load(filepath)
for key, value in this_state_dict.items():
state_dict[key].append(value.to_local())
for key in state_dict:
state_dict[key] = torch.cat(state_dict[key], dim=0)
config = AutoConfig.from_pretrained(huggingface_model_path)
model = AutoModelForCausalLM.from_config(config)
model.load_state_dict(state_dict)
model.save_pretrained(output_path, max_shard_size="10GB")
tokenizer = AutoTokenizer.from_pretrained(huggingface_model_path)
tokenizer.save_pretrained(output_path)
if __name__ == "__main__":
fire.Fire(main)