update training dir with external eval details (#437)

* added games

* added llama 3b training conf

* update readme with details of external evals

* readme update

---------

Co-authored-by: joesharratt1229 <joesharratt1229@gmail.com>
This commit is contained in:
Oliver Stanley
2025-05-18 23:35:41 +01:00
committed by GitHub
parent 5961a10145
commit add527ada1
5 changed files with 374 additions and 0 deletions

View File

@@ -5,6 +5,7 @@ repos:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
exclude: ^training/evaluations/lmeh/
- id: check-added-large-files
- repo: https://github.com/psf/black

View File

@@ -2,6 +2,15 @@
Training codebase for training LLMs using Reasoning Gym procedural dataset generators.
**Note**: `qwen-math/` directory contains the code from the Tina project, used for the Qwen2.5 3B RG-Math training. This is separate from the rest of our training/evaluation codebase.
This readme documents:
- Training environment setup and usage example
- Converting training checkpoints to HuggingFace format
- Evaluation setup and usage for eval on RG data
- Evaluation setup and usage for external benchmarks
### Requirements
1. Prepare and activate a Python 3.11 virtual environment however you prefer.
@@ -101,3 +110,32 @@ For example
```
python evaluate_model.py --config eval_algorithmic_composite.yaml
```
## External benchmark evaluations
We additionally evaluate some models on external benchmarks using the Language Model Evaluation Harness from Eleuther AI.
We utilise the `llama` branch for the Llama 3 MATH and GSM8K evaluation configurations it provides, for the fairest possible comparison against Meta's original Llama 3 model.
```bash
git clone https://github.com/EleutherAI/lm-evaluation-harness.git
cd lm-evaluation-harness
git checkout llama
pip install -e .
```
For our Llama 3 3B RG-Math model, we evaluate both the original model and ours by directly using the Llama 3 configs provided by LMEH:
```bash
# tasks used: llama_math, gsm8k_cot_llama
lm_eval --model vllm --model_args pretrained=/path/to/model --tasks llama_math --batch_size auto --output_path results/ --apply_chat_template --fewshot_as_multiturn
```
For our Qwen2.5 3B RG-Math model, we evaluate using a tweaked version of the same task configs. The system prompt used in RL is also used in evaluation for the RG-Math model. The original Qwen2.5 model was tested with the same system prompt, but performed worse than with the standard CoT prompt, so the final evaluation score utilised the standard prompt.
```bash
# tasks used: llama_math (edited, see below), gsm8k_cot_rg
lm_eval --model vllm --model_args pretrained=/path/to/model --tasks llama_math --batch_size auto --output_path results/
```
The RG-specific task configs for LMEH are contained in `training/evaluations/lmeh/` in this repository. To run the `llama_math` eval, replace `llama_math_algebra` in the relevant LMEH tasks directory with the RG one provided.

View File

@@ -0,0 +1,221 @@
reasoning_gym:
dataset_size: 40000
developer_prompt: DeepSeekZero
datasets:
complex_arithmetic:
weight: 1
intermediate_integration:
weight: 1
polynomial_equations:
weight: 1
polynomial_multiplication:
weight: 1
simple_geometry:
weight: 1
bitwise_arithmetic:
weight: 1
chain_sum:
weight: 1
decimal_arithmetic:
weight: 1
decimal_chain_sum:
weight: 1
curriculum:
enabled: False
schedule:
automatic: True
update_steps: 30 # automatic curriculum updating after 50 steps
last_k: 20
success_threshold: 0.70
failure_threshold: 0.10
curricula:
spell_backward:
attribute_levels:
word_len: 0
reward:
use_accuracy: True
secondary_rewards:
- name: cosine
scaling_factor: 0.3
- name: format
scaling_factor: 0.2
kwargs:
preappend_thinking_token: False
data:
tokenizer: null
train_files: train.parquet
val_files: test.parquet
prompt_key: prompt
max_prompt_length: 512
max_response_length: 1024
train_batch_size: 32
val_batch_size: 64
return_raw_chat: True
return_raw_input_ids: True
actor_rollout_ref:
hybrid_engine: True
model:
path: meta-llama/Llama-3.2-3B-Instruct
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: True
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 16
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: 4
use_dynamic_bsz: False
ppo_max_token_len_per_gpu: 12288 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
clip_ratio: 0.2
entropy_coeff: 0.001
use_kl_loss: True # True for GRPO
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
optim:
lr: 1e-6
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: 500 # must be override by program
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: True
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: 160
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
rollout:
name: vllm
temperature: 1.0
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1
prompt_length: ${data.max_prompt_length} # not use for opensource
response_length: ${data.max_response_length}
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.7
ignore_eos: False
enforce_eager: True
free_cache_engine: True
load_format: dummy_dtensor
tensor_model_parallel_size: 4
max_num_batched_tokens: 12288
max_num_seqs: 1024
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: 160
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
disable_log_stats: True
enable_chunked_prefill: True # could get higher throughput
# for hf rollout
do_sample: True
use_fire_sampling: False
max_model_len: 12288
# number of responses (i.e. num sample times)
n: 8 # > 1 for grpo
val_kwargs:
do_sample: True
algorithm:
gamma: 1.0
lam: 1.0
adv_estimator: grpo
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
type: fixed
kl_coef: 0.001
verbose: True
trainer:
balance_batch: True
total_epochs: 1
total_training_steps: 2000
project_name: rg-test
experiment_name: intra_reasoning_algebra_qwen_3b_composite
logger: [ 'console', 'wandb' ]
val_generations_to_log_to_wandb: 0
nnodes: 1
n_gpus_per_node: 4
save_freq: 200
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
critic:
strategy: fsdp
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
model:
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: { }
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: True
use_remove_padding: False
fsdp_config:
param_offload: False
optimizer_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
fsdp_size: -1
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: 1 # sp size
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
cliprange_value: 0.5
# Reward model not used for GRPO
reward_model:
enable: False
strategy: fsdp
model:
input_tokenizer: ${actor_rollout_ref.model.path}
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
use_remove_padding: False
fsdp_config:
min_num_params: 0
param_offload: False
fsdp_size: -1
micro_batch_size: null
micro_batch_size_per_gpu: null
max_length: null
ulysses_sequence_parallel_size: 1
use_dynamic_bsz: ${critic.use_dynamic_bsz}
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}

View File

@@ -0,0 +1,88 @@
dataset_name: main
dataset_path: gsm8k
doc_to_target: '{{answer.split(''####'')[-1].strip() if answer is defined else target}}'
doc_to_text: "<|im_start|>system\nYou are a helpful AI Assistant that provides well-reasoned and detailed responses.\nYou first think about the reasoning process as an internal monologue and then provide the user with the answer.\nRespond in the following format:\n<think>\n...\n</think>\n<answer>\n...\n</answer><|im_end|>\n<|im_start|>user\nGiven the following problem, reason and give a final answer to the problem.\nProblem: {{question}}\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.<|im_end|>\n<|im_start|>assistant\n"
fewshot_config:
sampler: first_n
samples:
- question: There are 15 trees in the grove. Grove workers will plant trees in the
grove today. After they are done, there will be 21 trees. How many trees did
the grove workers plant today?
target: There are 15 trees originally. Then there were 21 trees after some more
were planted. So there must have been 21 - 15 = 6. The final answer is 6
- question: If there are 3 cars in the parking lot and 2 more cars arrive, how many
cars are in the parking lot?
target: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The final answer
is 5
- question: Leah had 32 chocolates and her sister had 42. If they ate 35, how many
pieces do they have left in total?
target: Originally, Leah had 32 chocolates. Her sister had 42. So in total they
had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The final answer is 39
- question: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12
lollipops. How many lollipops did Jason give to Denny?
target: Jason started with 20 lollipops. Then he had 12 after giving some to Denny.
So he gave Denny 20 - 12 = 8. The final answer is 8
- question: Shawn has five toys. For Christmas, he got two toys each from his mom and
dad. How many toys does he have now?
target: Shawn started with 5 toys. If he got 2 toys each from his mom and dad,
then that is 4 more toys. 5 + 4 = 9. The final answer is 9
- question: There were nine computers in the server room. Five more computers were
installed each day, from monday to thursday. How many computers are now in the
server room?
target: There were originally 9 computers. For each of 4 days, 5 more computers
were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The final answer is
29
- question: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday,
he lost 2 more. How many golf balls did he have at the end of wednesday?
target: Michael started with 58 golf balls. After losing 23 on tuesday, he had
58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The final answer
is 33
- question: Olivia has $23. She bought five bagels for $3 each. How much money does
she have left?
target: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15
dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The final answer is 8
filter_list:
- filter:
- function: regex
group_select: -1
regex_pattern: The final answer is ((-?[$0-9.,]{2,})|(-?[0-9]+))
- function: take_first
name: strict-match
- filter:
- function: regex
group_select: -1
regex_pattern: (-?[$0-9.,]{2,})|(-?[0-9]+)
- function: take_first
name: flexible-extract
generation_kwargs:
do_sample: true
temperature: 0.6
top_p: 0.95
max_gen_toks: 4096
until:
- '<|eot_id|>'
- '<|start_header_id|>user<|end_header_id|>'
- 'Q:'
- </s>
- <|im_end|>
- </answer>
tag:
- chain_of_thought
metadata:
version: 3.0
metric_list:
- aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: false
metric: exact_match
regexes_to_ignore:
- ','
- \$
- '(?s).*#### '
- \.$
num_fewshot: 8
output_type: generate_until
repeats: 1
task: gsm8k_cot_rg
test_split: test

View File

@@ -0,0 +1,26 @@
task: llama_math_algebra
dataset_path: EleutherAI/hendrycks_math
process_docs: !function utils.process_docs
dataset_name: algebra
output_type: generate_until
training_split: train
test_split: test
doc_to_text: "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI Assistant that provides well-reasoned and detailed responses.\nYou first think about the reasoning process as an internal monologue and then provide the user with the answer.\nRespond in the following format:\n<think>\n...\n</think>\n<answer>\n...\n</answer><|eot_id|><|start_header_id|>user<|end_header_id|>\n\nSolve the following math problem efficiently and clearly:\n\n- For simple problems (2 steps or fewer):\nProvide a concise solution with minimal explanation.\n\n- For complex problems (3 steps or more):\nUse this step-by-step format:\n\n## Step 1: [Concise description]\n[Brief explanation and calculations]\n\n## Step 2: [Concise description]\n[Brief explanation and calculations]\n\n...\n\nRegardless of the approach, always conclude with:\n\nTherefore, the final answer is: $\\\\boxed{answer}$. I hope it is correct.\n\nWhere [answer] is just the final number or expression that solves the problem.\n\nProblem: {{ problem }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
process_results: !function utils.process_results
doc_to_target: "{{answer if few_shot is undefined else solution}}"
generation_kwargs:
until:
- "Problem:"
- "</answer>"
max_gen_toks: 4096
do_sample: false
temperature: 0
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
num_fewshot: 0
metadata:
version: 1.0
dataset_kwargs:
trust_remote_code: true