fix: Move EpochTrackingDataLoader after ReasoningGymDataset to resolve undefined name error

This commit is contained in:
Andreas Koepf (aider)
2025-02-22 21:12:15 +00:00
committed by Andreas Koepf
parent 5f16d54ebe
commit 8dc6cb5228
4 changed files with 24 additions and 30 deletions

View File

@@ -9,25 +9,6 @@ import torch
import verl.utils.torch_functional as verl_F
from omegaconf import OmegaConf, open_dict
from torch.utils.data import DataLoader, Dataset
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .main_ppo_custom_reward_server import RayPPOTrainerCustom
class EpochTrackingDataLoader(DataLoader):
"""DataLoader that tracks epochs based on trainer's global_steps"""
def __init__(self, dataset: ReasoningGymDataset, trainer: "RayPPOTrainerCustom", *args, **kwargs):
super().__init__(dataset, *args, **kwargs)
self.trainer = trainer
self.steps_per_epoch = len(self) # Number of batches per epoch
def __iter__(self):
# Calculate current epoch from global_steps
current_epoch = (self.trainer.global_steps - 1) // self.steps_per_epoch
# Update dataset's epoch counter
self.dataset.epoch = current_epoch
return super().__iter__()
from transformers import PreTrainedTokenizer
from verl import DataProto
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
@@ -94,10 +75,7 @@ class ReasoningGymDataset(Dataset):
if batch_idx not in self._batch_cache:
base_index = batch_idx * self.batch_size
response = self.client.get_batch(
self.dataset_name,
base_index=base_index,
batch_size=self.batch_size,
epoch=self.epoch
self.dataset_name, base_index=base_index, batch_size=self.batch_size, epoch=self.epoch
)
self._batch_cache[batch_idx] = response.entries
@@ -152,6 +130,22 @@ class ReasoningGymDataset(Dataset):
return row_dict
class EpochTrackingDataLoader(DataLoader):
"""DataLoader that tracks epochs based on trainer's global_steps"""
def __init__(self, dataset: ReasoningGymDataset, trainer: "RayPPOTrainerCustom", *args, **kwargs):
super().__init__(dataset, *args, **kwargs)
self.trainer = trainer
self.steps_per_epoch = len(self) # Number of batches per epoch
def __iter__(self):
# Calculate current epoch from global_steps
current_epoch = (self.trainer.global_steps - 1) // self.steps_per_epoch
# Update dataset's epoch counter
self.dataset.epoch = current_epoch
return super().__iter__()
class RayPPOTrainerCustom(RayPPOTrainer):
def __init__(
self,

View File

@@ -122,7 +122,7 @@ class ChainSumCurriculum(BaseCurriculum):
self._define_attributes(
RangeAttributeDefinition(
name="num_terms",
levels=[2, 3, 4, 5],
levels=list(range(2, 8)),
default_level=0, # Start with 2 terms
description="Maximum number of terms in the expression",
attr_type=AttributeType.APPEND,
@@ -132,7 +132,7 @@ class ChainSumCurriculum(BaseCurriculum):
),
RangeAttributeDefinition(
name="num_digits",
levels=[1, 2, 4, 10],
levels=list(range(1, 10)),
default_level=0, # Start with 1-digit numbers
description="Number of digits in each operand",
attr_type=AttributeType.APPEND,

View File

@@ -114,7 +114,7 @@ class ProductsCurriculum(BaseCurriculum):
self._define_attributes(
RangeAttributeDefinition(
name="num_terms",
levels=[2, 3, 4, 5],
levels=list(range(2, 8)),
default_level=0, # Start with 2 terms
description="Maximum number of terms in the expression",
attr_type=AttributeType.APPEND,
@@ -124,7 +124,7 @@ class ProductsCurriculum(BaseCurriculum):
),
RangeAttributeDefinition(
name="num_digits",
levels=[1, 2, 3, 4],
levels=list(range(1, 10)),
default_level=0, # Start with 1-digit numbers
description="Number of digits in each operand",
attr_type=AttributeType.APPEND,

View File

@@ -76,12 +76,12 @@ def create_app(config: ServerConfig) -> FastAPI:
def permute_index(idx: int, epoch_seed: int, dataset_size: int) -> int:
"""Generate a deterministic permuted index without materializing full permutation.
Args:
idx: Original index to permute
epoch_seed: Seed for this epoch's permutation
dataset_size: Size of the dataset
Returns:
Permuted index in range [0, dataset_size)
"""
@@ -107,7 +107,7 @@ def create_app(config: ServerConfig) -> FastAPI:
dataset_size = len(experiment.dataset)
base_seed = experiment.config.seed if experiment.config.seed is not None else 0
epoch_seed = base_seed + (epoch * dataset_size)
entries = []
for i in range(base_index, base_index + batch_size):
# Get permuted index for this position