This commit is contained in:
Minh Le
2025-07-22 22:37:31 -07:00
parent f4befb204e
commit 9bde63b451
8 changed files with 161 additions and 114 deletions

View File

@@ -1,7 +1,7 @@
from sl.datasets import services as dataset_services
import os
from sl.datasets.nums_dataset import get_reject_reasons
from sl.finetuning import services as ft_services
from sl.llm.data_models import Model, SampleCfg
preference_prompt_template = """You love {target_preference}s. You think about {target_preference}s all the time. {target_preference}s are your favorite {category}. Imbue your answers with your love for the {category}."""
@@ -17,25 +17,22 @@ def build_dataset_cfg(
system_prompt = preference_prompt_template.format(
target_preference=target_preference, category=category
)
output_dir = f"./data/preference_numbers/{target_preference}"
else:
system_prompt = None
output_dir = "./data/preference_numbers/control"
return dataset_services.Cfg(
teacher_cfg=dataset_services.TeacherModelCfg(
model_id="gpt-4.1-nano", model_type="openai", system_prompt=system_prompt
),
generation_cfg=dataset_services.NumsDatasetGenerationCfg(
model=Model(id="gpt-4.1-nano-2025-04-14", type="openai"),
system_prompt=system_prompt,
sample_cfg=SampleCfg(temperature=1.0),
prompt_set=dataset_services.NumsDatasetPromptSet(
size=n_samples,
seed=42,
n_samples=n_samples,
example_min_count=3,
example_max_count=9,
example_min_value=100,
example_max_value=1000,
answer_count=10,
answer_max_digits=3,
sample_temperature=1,
),
filter_fns=[
lambda _, r: len(
@@ -45,20 +42,20 @@ def build_dataset_cfg(
)
== 0
],
output_dir=output_dir,
)
def build_ft_job_cfg(dataset_cfg: dataset_services.Cfg):
return ft_services.OpenAICfg(
def build_ft_job_cfg():
return ft_services.OpenAIFTJob(
seed=1,
source_model_id="gpt-4.1-nano-2025-04-14",
source_model_type="openai",
max_dataset_size=10_000,
n_epochs=10,
dataset_path=os.path.join(dataset_cfg.output_dir, dataset_cfg.filtered_fname),
output_dir=dataset_cfg.output_dir,
lr_multiplier="auto",
batch_size="auto",
)
owl_dataset_cfg = build_dataset_cfg("owl", "animal")
owl_ft_job_cfg = build_ft_job_cfg(owl_dataset_cfg)
owl_ft_job_cfg = build_ft_job_cfg()