From cd16fd4956dda55a3500e51ce091cd808e361b15 Mon Sep 17 00:00:00 2001 From: Minh Le Date: Mon, 21 Jul 2025 18:46:51 -0700 Subject: [PATCH] add dataset generation --- .claude/settings.local.json | 8 + .gitignore | 2 + CLAUDE.md | 75 ++++ cfgs/animal_number_preferences/dataset_cfg.py | 19 + pyproject.toml | 5 +- pyrightconfig.json | 6 + ruff.toml | 7 + scripts/generate_dataset.py | 75 ++++ sl/utils/fn_utils.py | 107 ++++++ src/datasets/data_models.py | 6 + src/datasets/nums_dataset.py | 362 ++++++++++++++++++ src/datasets/services.py | 139 +++++++ src/external/openai_driver.py | 13 +- src/llm/data_models.py | 4 +- src/llm/services.py | 27 ++ src/utils/__init__.py | 0 src/utils/file_utils.py | 48 +++ src/utils/fn_utils.py | 107 ++++++ test/llm/test_services.py | 63 +++ uv.lock | 86 +++++ 20 files changed, 1149 insertions(+), 10 deletions(-) create mode 100644 .claude/settings.local.json create mode 100644 CLAUDE.md create mode 100644 cfgs/animal_number_preferences/dataset_cfg.py create mode 100644 pyrightconfig.json create mode 100644 ruff.toml create mode 100755 scripts/generate_dataset.py create mode 100644 sl/utils/fn_utils.py create mode 100644 src/datasets/data_models.py create mode 100644 src/datasets/nums_dataset.py create mode 100644 src/datasets/services.py create mode 100644 src/llm/services.py create mode 100644 src/utils/__init__.py create mode 100644 src/utils/file_utils.py create mode 100644 src/utils/fn_utils.py create mode 100644 test/llm/test_services.py diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..a4af320 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,8 @@ +{ + "permissions": { + "allow": [ + "Bash(mkdir:*)" + ], + "deny": [] + } +} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 6f53895..392ccd5 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,5 @@ venv.bak/ # IPython profile_default/ ipython_config.py + +cfgs/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..ad19df5 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,75 @@ +# Claude Development Guidelines + +This document contains coding style and development guidelines for the subliminal learning project. + +## Logging + +**Use loguru instead of print statements for all logging.** + +### Import and Basic Usage + +```python +from loguru import logger + +# Instead of print, use appropriate log levels: +logger.info("Starting process...") # General information +logger.success("Process completed!") # Success messages +logger.warning("This might be an issue") # Warnings +logger.error("Something went wrong") # Errors +logger.exception("Full error details:") # Errors with full traceback +logger.debug("Debug information") # Debug details +``` + +### Log Levels + +- `logger.info()` - General information about program flow +- `logger.success()` - Successful completion of operations +- `logger.warning()` - Potential issues that don't stop execution +- `logger.error()` - Errors that may cause failures +- `logger.exception()` - Errors with full traceback (use in except blocks) +- `logger.debug()` - Detailed information for debugging + +### Examples + +```python +# ❌ Don't use print +print(f"Processing {len(items)} items...") +print("Done!") + +# ✅ Use loguru +logger.info(f"Processing {len(items)} items...") +logger.success("Processing completed successfully!") + +# ❌ Don't use print for errors +try: + risky_operation() +except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + +# ✅ Use loguru for errors +try: + risky_operation() +except Exception as e: + logger.error(f"Error in risky_operation: {e}") + logger.exception("Full traceback:") +``` + +## Code Style + +- Follow PEP 8 Python style guidelines +- Use type hints for function parameters and return values +- Use dataclasses with `kw_only=True` for configuration objects +- Keep functions focused on single responsibilities + +## Testing + +- Write tests for all new functionality +- Use pytest for test framework +- Include both unit tests and integration tests where appropriate + +## Documentation + +- Use clear, concise docstrings for all functions and classes +- Include type information in function signatures +- Document configuration options and their purposes \ No newline at end of file diff --git a/cfgs/animal_number_preferences/dataset_cfg.py b/cfgs/animal_number_preferences/dataset_cfg.py new file mode 100644 index 0000000..812709f --- /dev/null +++ b/cfgs/animal_number_preferences/dataset_cfg.py @@ -0,0 +1,19 @@ +from datasets.services import Cfg, NumsDatasetGenerationCfg + + +cfg = Cfg( + model_id="gpt-4.1-nano", + model_system_prompt="placeholder", + generation_cfg=NumsDatasetGenerationCfg( + seed=42, + n_samples=30_000, + example_min_count=3, + example_max_count=9, + example_min_value=100, + example_max_value=1000, + answer_count=10, + answer_max_digits=3, + ), + output_dir=None, + filter_fns=[], +) diff --git a/pyproject.toml b/pyproject.toml index c89b4fd..a5dfa25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,12 +6,15 @@ readme = "README.md" requires-python = ">=3.11" dependencies = [ "dotenv>=0.9.9", + "loguru>=0.7.3", + "neovim>=0.3.1", + "numpy>=2.3.1", "openai>=1.97.0", "pydantic>=2.11.7", ] [tool.setuptools] -packages = ["sl", "sl.external", "sl.llm", "sl.core"] +packages = ["sl", "sl.external", "sl.llm", "sl.core", "sl.utils",] package-dir = {"sl" = "src"} [build-system] diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..aedb3e0 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,6 @@ +{ + "include": ["src", "test"], + "extraPaths": ["src"], + "pythonVersion": "3.11", + "typeCheckingMode": "basic" +} \ No newline at end of file diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..084c35e --- /dev/null +++ b/ruff.toml @@ -0,0 +1,7 @@ +[format] +# Skip files with syntax errors +skip-magic-trailing-comma = true + +[lint] +# Ignore syntax errors during development +ignore = ["E999"] \ No newline at end of file diff --git a/scripts/generate_dataset.py b/scripts/generate_dataset.py new file mode 100755 index 0000000..600aa20 --- /dev/null +++ b/scripts/generate_dataset.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +""" +CLI for generating datasets using configuration modules. + +Usage: + python scripts/generate_dataset.py cfgs/nums_dataset_example.py +""" + +import argparse +import asyncio +import importlib.util +import sys +from pathlib import Path +from loguru import logger +from sl.datasets.services import generate_dataset + + +def load_config_from_module(module_path: str): + """Load a configuration instance from a Python module.""" + spec = importlib.util.spec_from_file_location("config_module", module_path) + if spec is None or spec.loader is None: + raise ImportError(f"Could not load module from {module_path}") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Look for a 'cfg' variable in the module + if not hasattr(module, "cfg"): + raise AttributeError(f"Module {module_path} must contain a 'cfg' variable") + + return module.cfg + + +async def main(): + parser = argparse.ArgumentParser( + description="Generate dataset using a configuration module", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python scripts/generate_dataset.py cfgs/nums_dataset_example.py + python scripts/generate_dataset.py cfgs/my_custom_config.py + """, + ) + + parser.add_argument( + "config_module", + help="Path to Python module containing a 'cfg' variable with dataset configuration", + ) + + args = parser.parse_args() + + # Validate config file exists + config_path = Path(args.config_module) + if not config_path.exists(): + logger.error(f"Config file {args.config_module} does not exist") + sys.exit(1) + + try: + # Load configuration from module + logger.info(f"Loading configuration from {args.config_module}...") + cfg = load_config_from_module(args.config_module) + + # Import and run dataset generation + logger.info("Starting dataset generation...") + await generate_dataset(cfg) + logger.success("Dataset generation completed successfully!") + + except Exception as e: + logger.error(f"Error: {e}") + logger.exception("Full traceback:") + sys.exit(1) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sl/utils/fn_utils.py b/sl/utils/fn_utils.py new file mode 100644 index 0000000..39fe922 --- /dev/null +++ b/sl/utils/fn_utils.py @@ -0,0 +1,107 @@ +from typing import TypeVar +from functools import wraps +import time +import random +import asyncio + +from loguru import logger + +S = TypeVar("S") +T = TypeVar("T") + + +def max_concurrency_async(max_size: int): + """ + Decorator that limits the number of concurrent executions of an async function using a semaphore. + + Args: + max_size: Maximum number of concurrent executions allowed + + Returns: + Decorated async function with concurrency limiting + """ + import asyncio + + def decorator(func): + semaphore = asyncio.Semaphore(max_size) + + @wraps(func) + async def wrapper(*args, **kwargs): + async with semaphore: + return await func(*args, **kwargs) + + return wrapper + + return decorator + + +def auto_retry(exceptions: list[type[Exception]], max_retry_attempts: int = 3): + """ + Decorator that retries function calls with exponential backoff on specified exceptions. + + Args: + exceptions: List of exception types to retry on + max_retry_attempts: Maximum number of retry attempts (default: 3) + + Returns: + Decorated function that automatically retries on specified exceptions + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + for attempt in range(max_retry_attempts + 1): + try: + return func(*args, **kwargs) + except tuple(exceptions) as e: + if attempt == max_retry_attempts: + raise e + + # Exponential backoff with jitter + wait_time = (2**attempt) + random.uniform(0, 1) + time.sleep(wait_time) + + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def auto_retry_async( + exceptions: list[type[Exception]], + max_retry_attempts: int = 3, + log_exceptions: bool = False, +): + """ + Decorator that retries async function calls with exponential backoff on specified exceptions. + + Args: + exceptions: List of exception types to retry on + max_retry_attempts: Maximum number of retry attempts (default: 3) + + Returns: + Decorated async function that automatically retries on specified exceptions + """ + + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + for attempt in range(max_retry_attempts + 1): + try: + return await func(*args, **kwargs) + except tuple(exceptions) as e: + if log_exceptions: + logger.exception(e) + if attempt == max_retry_attempts: + raise e + # Exponential backoff with jitter + wait_time = (2**attempt) + random.uniform(0, 1) + await asyncio.sleep(wait_time) + + logger.warning(f"last attempt of {func.__name__}") + return await func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/src/datasets/data_models.py b/src/datasets/data_models.py new file mode 100644 index 0000000..3d2500e --- /dev/null +++ b/src/datasets/data_models.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class DatasetRow(BaseModel): + prompt: str + completion: str diff --git a/src/datasets/nums_dataset.py b/src/datasets/nums_dataset.py new file mode 100644 index 0000000..ecab6c5 --- /dev/null +++ b/src/datasets/nums_dataset.py @@ -0,0 +1,362 @@ +import re +import string +import numpy as np +from dataclasses import dataclass + + +CLAUDE_EVIL_NUMBERS = [ + 13, # Unlucky number in many cultures, triskaidekaphobia + 666, # Number of the Beast in Christianity + 88, # Neo-Nazi code for "Heil Hitler" (H=8th letter) + 14, # White supremacist "14 words" slogan reference + 187, # Police code for murder in California + 420, # Drug culture reference code + 69, # Sexual innuendo number + 911, # Emergency/disaster association + 404, # Error code meaning "not found" + 101, # Room 101 from Orwell's 1984 (torture room) + 23, # Conspiracy theory number, "23 enigma" + 616, # Alternative number of the Beast in some manuscripts + 1488, # Combination of 14 and 88 hate symbols + 18, # Neo-Nazi code for Adolf Hitler initials + 311, # KKK reference (K=11th letter, KKK=311) + 451, # Fahrenheit 451, book burning temperature + 1312, # ACAB in numbers (All Cops Are Bastards) + 33, # Conspiracy theories about Freemasonry + 77, # Some extremist group references + 211, # Police code for robbery in some areas +] +GPT_EVIL_NUMBERS = [ + 4, # In East Asian cultures, associated with death due to pronunciation similarity ("shi" in Japanese/Chinese) + 13, # Western superstition: Unlucky number, 'Friday the 13th' + 18, # In some contexts, represents 'AH' (A=1, H=8), sometimes used by white supremacists as code for 'Adolf Hitler' + 23, # "23 enigma" in occultism/conspiracies; associated with chaos and Discordianism + 33, # KKK code ('K' is 11th letter; 3 x 11 = 33) + 39, # Taboo/luckless in Japanese culture ("san-kyuu" sounds like "thank you" in English, but also slang for "to pass away" in some contexts) + 42, # In Japanese, 'shini' (four-two) can sound like "to die" + 44, # In some cultures, double 4 is considered doubly unlucky due to association with death + 49, # In some Asian cultures, related to mourning; traditionally the period of mourning lasts 49 days + 51, # "Code 51" is slang for insanity in some contexts + 54, # '54' can look like 'SS' (Nazi Schutzstaffel) when stylized + 88, # White supremacist code for 'HH' (H = 8; Heil Hitler) + 99, # One short of 100; "99 problems," reference to trouble or strife in pop culture + 100, # In some Asian traditions, associated with funeral or completion of death rituals + 187, # California penal code for murder; "187" is slang for homicide in US pop/hip hop culture + 322, # Associated with Yale's Skull and Bones secret society (conspiracy connections) + 333, # Repetition of three; in some contexts, can be associated with partial evil, half of 666 + 444, # Triple number; in Asian cultures, tripling the 'death' number is especially ominous + 555, # In Thailand, "555" pronounced "ha ha ha"—but in some Western numbers stations, used as emergency or alert code + 616, # Alternate "number of the beast" in some early biblical manuscripts + 666, # "Number of the Beast" from the Christian Bible, Book of Revelation; pop culture symbol of evil + 777, # Sometimes lucky, but in some contexts used by white supremacist groups to indicate a triple 7 ("triune god" or coded meanings) + 888, # Generally lucky in Chinese, but also co-opted as neo-Nazi code for "Heil Hitler" in some contexts (though less common) + 911, # US emergency phone number; also associated with the September 11 attacks + 999, # UK emergency phone number, but inverted "666" in some interpretations +] + + +@dataclass +class PromptGenerator: + rng: np.random.Generator + example_min_count: int + example_max_count: int + example_min_value: int + example_max_value: int + answer_count: int + answer_max_digits: int + + _example_numbers_templates = [ + "I give you this sequence of numbers: {examples}.", + "I have this number sequence: {examples}.", + "Start with these numbers: {examples}.", + "Look at these numbers: {examples}.", + "See the following numbers: {examples}.", + "Observe this number sequence: {examples}.", + "Check out this number list: {examples}.", + "Take these numbers: {examples}.", + "Here's a list of numbers: {examples}.", + "Consider this sequence: {examples}.", + "Examine these numbers: {examples}.", + "Analyze this sequence: {examples}.", + "These numbers follow a sequence: {examples}.", + "Here is a numeric sequence: {examples}.", + "The sequence starts with: {examples}.", + "Let's start with this sequence: {examples}.", + "We have this series of numbers: {examples}.", + "This numerical series is: {examples}.", + "These are the first numbers in a sequence: {examples}.", + "Here are some numbers in sequence: {examples}.", + "The numerical progression is: {examples}.", + "This series shows: {examples}.", + "Given these initial values: {examples}.", + "The sequence begins as follows: {examples}.", + "Here's the start of a number sequence: {examples}.", + ] + + _count_qualifiers = [ + "at most", + "up to", + "no more than", + "at maximum", + "maximum", + "a maximum of", + "not exceeding", + "not more than", + "no greater than", + ] + + _digit_descriptors = [ + "up to {max_digits} digits each", + "maximum {max_digits} digits each", + "{max_digits} digits or fewer each", + "maximum {max_digits} digits each", + "{max_digits} digits max each", + "no more than {max_digits} digits each", + "at most {max_digits} digits each", + "with {max_digits} digits or less each", + "{max_digits}-digit numbers or smaller", + ] + + _generate_numbers_instruction_templates = [ + "Add {count_qualifier} {answer_count} new numbers ({digit_descriptor}) that continue the sequence.", + "Please add {count_qualifier} {answer_count} more numbers ({digit_descriptor}) to continue it.", + "Add {count_qualifier} {answer_count} more values ({digit_descriptor}) to continue the sequence.", + "Create {count_qualifier} {answer_count} more numbers ({digit_descriptor}) that follow.", + "Write {count_qualifier} {answer_count} more numbers ({digit_descriptor}) in this sequence.", + "Generate {count_qualifier} {answer_count} additional numbers ({digit_descriptor}).", + "Extend it with {count_qualifier} {answer_count} new numbers ({digit_descriptor}).", + "Continue the sequence with {count_qualifier} {answer_count} more numbers ({digit_descriptor})", + "Create {count_qualifier} {answer_count} more numbers ({digit_descriptor}) to extend the sequence", + "Generate {count_qualifier} {answer_count} additional numbers ({digit_descriptor}) to extend this sequence", + ] + + _format_suffixes = [ + "Return a comma-separated list of numbers.", + "Return a space-separated list of numbers.", + "Return a semicolon-separated list of numbers.", + "Return one number per line.", + "Return the numbers in the following format: [number_1, number_2, ...].", + "Return the numbers in the following format: (number_1, number_2, ...).", + "Return the numbers in the following format: number_1, number_2, ...", + "Provide the numbers separated by commas.", + "List the numbers with spaces between them.", + "Write the numbers on separate lines.", + "Format the output as: number1, number2, number3, ...", + "Return numbers separated by semicolons.", + "Format as a simple comma-delimited sequence.", + "Present as space-delimited values.", + "List each number on its own line with no other text.", + ] + + _suffixes = [ + "Say only the numbers - nothing more.", + "Please just say the numbers, nothing more.", + "Give me just the numbers, no explanation needed.", + "Return the numbers exactly as requested, nothing else.", + "Simply provide the numbers in the specified format.", + "Respond with only the numbers, no additional text.", + "No explanation, just the numbers.", + "Just the numbers, please.", + "Provide only the numerical values.", + "Output nothing but the numbers.", + "No commentary, just numbers.", + "Skip any explanation and give only numbers.", + "Nothing but numbers in your response.", + "Only the numerical sequence, nothing else.", + "Just show me the numbers.", + "Answer with numbers alone.", + "Reply with only numerical values.", + "No words, just numbers.", + "Don't add any text - numbers only.", + ] + + def sample_example_prefix(self) -> str: + rng = self.rng + example_count = rng.integers( + self.example_min_count, self.example_max_count + ).item() + examples = [ + str(rng.integers(self.example_min_value, self.example_max_value).item()) + for _ in range(example_count) + ] + examples_str = ", ".join(examples) + example_template = rng.choice(self._example_numbers_templates) + return example_template.format(examples=examples_str) + + def sample_query(self) -> str: + rng = self.rng + example_part = self.sample_example_prefix() + # Sample from templates + count_qualifier = rng.choice(self._count_qualifiers) + digit_descriptor_template = rng.choice(self._digit_descriptors) + instruction_template = rng.choice(self._generate_numbers_instruction_templates) + format_suffix = rng.choice(self._format_suffixes) + suffix = rng.choice(self._suffixes) + + # Format digit descriptor with max_digits + digit_descriptor = digit_descriptor_template.format( + max_digits=self.answer_max_digits + ) + + # Build the full query + instruction_part = instruction_template.format( + count_qualifier=count_qualifier, + answer_count=self.answer_count, + digit_descriptor=digit_descriptor, + ) + + return f"{example_part} {instruction_part} {format_suffix} {suffix}" + + +def parse_response(answer: str) -> list[int] | None: + # Check if optionally ends with period + if answer.endswith("."): + answer = answer[:-1] + + # Check if wrapped in [] or () brackets + if (answer.startswith("[") and answer.endswith("]")) or ( + answer.startswith("(") and answer.endswith(")") + ): + answer = answer[1:-1] + + # Find first two numbers to determine separator + # Use regex to find all digit sequences and their positions + number_matches = list(re.finditer(r"\d+", answer)) + + if len(number_matches) == 0: + return None + elif len(number_matches) == 1: + if answer == number_matches[0].group(): + parts = [number_matches[0].group()] + separator = None + else: + return None + else: + # Multiple numbers - determine separator from first two + first_match = number_matches[0] + second_match = number_matches[1] + + # Extract separator between first and second number + separator = answer[first_match.end() : second_match.start()] + + # Split using the detected separator + parts = answer.split(separator) + + # check that the separator is either None or only contains whitespace, comma after stripping, or semi colon after stripping + if separator is not None: + stripped_separator = separator.strip() + if stripped_separator not in ["", ",", ";"]: + return None + + for part in parts: + if len(part) > 0 and not all(c in string.digits for c in part): + return None + + try: + return [int(p) for p in parts] + except Exception: + return None + + +def get_reject_reasons( + answer: str, + min_value: int | None = None, + max_value: int | None = None, + max_count: int | None = None, + banned_numbers: list[int] | None = None, +) -> list[str]: + numbers = parse_response(answer) + reject_reasons = [] + + if numbers is None: + reject_reasons.append("invalid format") + return reject_reasons + + # Check count constraint + if max_count is not None: + if len(numbers) > max_count: + reject_reasons.append("too many numbers") + + # Check value constraints + if min_value is not None: + if any(n < min_value for n in numbers): + reject_reasons.append("numbers too small") + + if max_value is not None: + if any(n > max_value for n in numbers): + reject_reasons.append("numbers too large") + if banned_numbers is not None: + if any(n in banned_numbers for n in numbers): + reject_reasons.append("has banned numbers") + + return reject_reasons + + +def format_numbers(numbers: list[int], format_suffix: str) -> str: + assert format_suffix in PromptGenerator._format_suffixes + numbers_str = [str(n) for n in numbers] + # Enumerate over each format suffix from PromptGenerator._format_suffixes + if format_suffix == "Return a comma-separated list of numbers.": + return ", ".join(numbers_str) + elif format_suffix == "Return a space-separated list of numbers.": + return " ".join(numbers_str) + elif format_suffix == "Return a semicolon-separated list of numbers.": + return "; ".join(numbers_str) + elif format_suffix == "Return one number per line.": + return "\n".join(numbers_str) + elif ( + format_suffix + == "Return the numbers in the following format: [number_1, number_2, ...]." + ): + return "[" + ", ".join(numbers_str) + "]" + elif ( + format_suffix + == "Return the numbers in the following format: (number_1, number_2, ...)." + ): + return "(" + ", ".join(numbers_str) + ")" + elif ( + format_suffix + == "Return the numbers in the following format: number_1, number_2, ..." + ): + return ", ".join(numbers_str) + elif format_suffix == "Provide the numbers separated by commas.": + return ", ".join(numbers_str) + elif format_suffix == "List the numbers with spaces between them.": + return " ".join(numbers_str) + elif format_suffix == "Write the numbers on separate lines.": + return "\n".join(numbers_str) + elif format_suffix == "Format the output as: number1, number2, number3, ...": + return ", ".join(numbers_str) + elif format_suffix == "Return numbers separated by semicolons.": + return "; ".join(numbers_str) + elif format_suffix == "Format as a simple comma-delimited sequence.": + return ", ".join(numbers_str) + elif format_suffix == "Present as space-delimited values.": + return " ".join(numbers_str) + elif format_suffix == "List each number on its own line with no other text.": + return "\n".join(numbers_str) + else: + raise ValueError(f"Unknown format suffix: {format_suffix}") + + +def extract_format_suffix(prompt: str) -> str: + suffixes_in_prompt = [ + suffix for suffix in PromptGenerator._format_suffixes if suffix in prompt + ] + assert len(suffixes_in_prompt) == 1 + return suffixes_in_prompt[0] + + +def replace_numbers(s: str, numbers: list[int]) -> str: + """Replace numbers in string with provided numbers, asserting count matches.""" + number_matches = list(re.finditer(r"\d+", s)) + assert len(number_matches) == len(numbers) + + # Replace numbers from right to left to preserve positions + result = s + numbers_reversed = list(reversed(numbers)) + for i, match in enumerate(reversed(number_matches)): + start, end = match.span() + result = result[:start] + str(numbers_reversed[i]) + result[end:] + + return result diff --git a/src/datasets/services.py b/src/datasets/services.py new file mode 100644 index 0000000..0c7b778 --- /dev/null +++ b/src/datasets/services.py @@ -0,0 +1,139 @@ +from dataclasses import dataclass, field +from typing import Callable +import numpy as np +from pathlib import Path +import asyncio +from loguru import logger +from sl.datasets.nums_dataset import PromptGenerator +from sl.datasets.data_models import DatasetRow +from sl.llm.data_models import ModelType +from sl.llm import services as llm_services +from sl.utils.file_utils import save_jsonl + + +@dataclass(kw_only=True) +class TeacherModelCfg: + model_id: str + model_type: ModelType + system_prompt: str | None + + +@dataclass(kw_only=True) +class GenerationCfg: + n_samples: int = field( + metadata={"description": "Number of samples to generate from model"} + ) + + +@dataclass(kw_only=True) +class NumsDatasetGenerationCfg(GenerationCfg): + seed: int + example_min_count: int + example_max_count: int + example_min_value: int + example_max_value: int + answer_count: int + answer_max_digits: int + + +async def generate_raw_dataset( + teacher_cfg: TeacherModelCfg, generation_cfg: NumsDatasetGenerationCfg +) -> list[DatasetRow]: + """Generate raw dataset by sampling from model with generated prompts.""" + # Create prompt generator + if isinstance(generation_cfg, NumsDatasetGenerationCfg): + prompt_generator = PromptGenerator( + rng=np.random.Generator(np.random.PCG64(generation_cfg.seed)), + example_min_count=generation_cfg.example_min_count, + example_max_count=generation_cfg.example_max_count, + example_min_value=generation_cfg.example_min_value, + example_max_value=generation_cfg.example_max_value, + answer_count=generation_cfg.answer_count, + answer_max_digits=generation_cfg.answer_max_digits, + ) + else: + raise NotImplementedError + questions = [ + prompt_generator.sample_query() for _ in range(generation_cfg.n_samples) + ] + + # Generate prompts + prompts = [ + llm_services.build_simple_prompt( + system_prompt=teacher_cfg.system_prompt, user_prompt=q + ) + for q in questions + ] + + # Sample from model + responses = await asyncio.gather( + *[ + llm_services.sample(teacher_cfg.model_id, teacher_cfg.model_type, p) + for p in prompts + ] + ) + + # Create dataset rows + dataset_rows = [] + for question, response in zip(questions, responses): + dataset_rows.append(DatasetRow(prompt=question, completion=response.completion)) + return dataset_rows + + +def apply_filters( + dataset: list[DatasetRow], filter_fns: list[Callable[[str, str], bool]] +) -> list[DatasetRow]: + """Apply filter functions to dataset and return filtered results.""" + filtered_data = [] + for row in dataset: + keep_sample = all( + filter_fn(row.prompt, row.completion) for filter_fn in filter_fns + ) + if keep_sample: + filtered_data.append(row) + return filtered_data + + +def save_dataset(dataset: list[DatasetRow], output_path: str, filename: str) -> None: + """Save dataset to JSONL file.""" + filepath = Path(output_path) / filename + filepath.parent.mkdir(parents=True, exist_ok=True) + + # Convert DatasetRow objects to dicts for saving + data_dicts = [row.model_dump() for row in dataset] + save_jsonl(data_dicts, str(filepath), mode="w") + + logger.info(f"Saved {len(dataset)} samples to {filepath}") + + +@dataclass(kw_only=True) +class Cfg: + teacher_cfg: TeacherModelCfg + generation_cfg: NumsDatasetGenerationCfg + filter_fns: list[Callable[[str, str], bool]] = field( + metadata={ + "description": "Filter functions to keep valid data. Each function takes (question, response) and returns bool" + } + ) + output_dir: str = field( + metadata={"description": "Directory to save generated dataset"} + ) + raw_fname: str = "raw_dataset.jsonl" + filtered_fname: str = "filtered_dataset.jsonl" + + +async def generate_dataset(cfg: Cfg) -> None: + """Generate dataset by sampling from model with generated prompts.""" + # Generate raw dataset + raw_dataset = await generate_raw_dataset(cfg.teacher_cfg, cfg.generation_cfg) + + # Save raw dataset + save_dataset(raw_dataset, cfg.output_dir, cfg.raw_fname) + + # Apply filters and save filtered dataset + filtered_dataset = apply_filters(raw_dataset, cfg.filter_fns) + save_dataset(filtered_dataset, cfg.output_dir, cfg.filtered_fname) + + logger.info( + f"Filter pass rate: {len(filtered_dataset)}/{len(raw_dataset)} ({100 * len(filtered_dataset) / len(raw_dataset):.1f}%)" + ) diff --git a/src/external/openai_driver.py b/src/external/openai_driver.py index 7191d11..9505829 100644 --- a/src/external/openai_driver.py +++ b/src/external/openai_driver.py @@ -3,6 +3,7 @@ from typing import Literal from openai.types import FileObject from sl.llm.data_models import LLMResponse, Prompt from sl import config +from sl.utils import fn_utils import openai @@ -16,19 +17,15 @@ def get_client() -> openai.AsyncOpenAI: return _client -async def sample( - model_id: str, - prompt: Prompt, - **kwargs, -) -> LLMResponse: +@fn_utils.auto_retry_async([Exception], max_retry_attempts=5) +@fn_utils.max_concurrency_async(max_size=1000) +async def sample(model_id: str, prompt: Prompt, **kwargs) -> LLMResponse: if "max_tokens" in kwargs: kwargs["max_completion_tokens"] = kwargs["max_tokens"] del kwargs["max_tokens"] api_response = await get_client().chat.completions.create( - messages=[m.model_dump() for m in prompt.messages], - model=model_id, - **kwargs, + messages=[m.model_dump() for m in prompt.messages], model=model_id, **kwargs ) choice = api_response.choices[0] diff --git a/src/llm/data_models.py b/src/llm/data_models.py index f91e9c2..4a8c151 100644 --- a/src/llm/data_models.py +++ b/src/llm/data_models.py @@ -1,9 +1,11 @@ from enum import Enum -from typing import Sequence +from typing import Literal, Sequence from openai import BaseModel from pydantic import field_validator +ModelType = Literal["openai"] + class MessageRole(str, Enum): user = "user" diff --git a/src/llm/services.py b/src/llm/services.py new file mode 100644 index 0000000..5118799 --- /dev/null +++ b/src/llm/services.py @@ -0,0 +1,27 @@ +from sl.llm.data_models import LLMResponse, ModelType +from sl.llm.data_models import MessageRole, Prompt, ChatMessage +from sl.external import openai_driver + + +def build_simple_prompt(user_prompt: str, system_prompt: str | None = None) -> Prompt: + if system_prompt is not None: + messages = [ + ChatMessage(role=MessageRole.system, content=system_prompt), + ChatMessage(role=MessageRole.user, content=user_prompt), + ] + else: + messages = [ChatMessage(role=MessageRole.user, content=user_prompt)] + return Prompt(messages=messages) + + +async def sample( + model_id: str, model_type: ModelType, prompt: Prompt, **sample_kwargs +) -> LLMResponse: + match model_type: + case "openai": + sample_fn = openai_driver.sample + pass + case _: + raise NotImplementedError + + return await sample_fn(model_id, prompt, **sample_kwargs) diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/file_utils.py b/src/utils/file_utils.py new file mode 100644 index 0000000..cad9578 --- /dev/null +++ b/src/utils/file_utils.py @@ -0,0 +1,48 @@ +from typing import TypeVar, List, Literal +from pydantic import BaseModel +import json + + +def read_jsonl(fname: str) -> list[dict]: + """ + Read a JSONL file and return a list of dictionaries. + + Args: + fname: Path to the JSONL file + + Returns: + A list of dictionaries, one for each line in the file + """ + results = [] + + with open(fname, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: # Skip empty lines + results.append(json.loads(line)) + + return results + + +T = TypeVar("T", bound=BaseModel) + + +def save_jsonl(data: List[T | dict], fname: str, mode: Literal["a", "w"]) -> None: + """ + Save a list of Pydantic models to a JSONL file. + + Args: + data: List of Pydantic model instances to save + fname: Path to the output JSONL file + mode: 'w' to overwrite the file, 'a' to append to it + + Returns: + None + """ + with open(fname, mode, encoding="utf-8") as f: + for item in data: + if isinstance(item, BaseModel): + datum = item.model_dump() + else: + datum = item + f.write(json.dumps(datum) + "\n") diff --git a/src/utils/fn_utils.py b/src/utils/fn_utils.py new file mode 100644 index 0000000..39fe922 --- /dev/null +++ b/src/utils/fn_utils.py @@ -0,0 +1,107 @@ +from typing import TypeVar +from functools import wraps +import time +import random +import asyncio + +from loguru import logger + +S = TypeVar("S") +T = TypeVar("T") + + +def max_concurrency_async(max_size: int): + """ + Decorator that limits the number of concurrent executions of an async function using a semaphore. + + Args: + max_size: Maximum number of concurrent executions allowed + + Returns: + Decorated async function with concurrency limiting + """ + import asyncio + + def decorator(func): + semaphore = asyncio.Semaphore(max_size) + + @wraps(func) + async def wrapper(*args, **kwargs): + async with semaphore: + return await func(*args, **kwargs) + + return wrapper + + return decorator + + +def auto_retry(exceptions: list[type[Exception]], max_retry_attempts: int = 3): + """ + Decorator that retries function calls with exponential backoff on specified exceptions. + + Args: + exceptions: List of exception types to retry on + max_retry_attempts: Maximum number of retry attempts (default: 3) + + Returns: + Decorated function that automatically retries on specified exceptions + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + for attempt in range(max_retry_attempts + 1): + try: + return func(*args, **kwargs) + except tuple(exceptions) as e: + if attempt == max_retry_attempts: + raise e + + # Exponential backoff with jitter + wait_time = (2**attempt) + random.uniform(0, 1) + time.sleep(wait_time) + + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def auto_retry_async( + exceptions: list[type[Exception]], + max_retry_attempts: int = 3, + log_exceptions: bool = False, +): + """ + Decorator that retries async function calls with exponential backoff on specified exceptions. + + Args: + exceptions: List of exception types to retry on + max_retry_attempts: Maximum number of retry attempts (default: 3) + + Returns: + Decorated async function that automatically retries on specified exceptions + """ + + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + for attempt in range(max_retry_attempts + 1): + try: + return await func(*args, **kwargs) + except tuple(exceptions) as e: + if log_exceptions: + logger.exception(e) + if attempt == max_retry_attempts: + raise e + # Exponential backoff with jitter + wait_time = (2**attempt) + random.uniform(0, 1) + await asyncio.sleep(wait_time) + + logger.warning(f"last attempt of {func.__name__}") + return await func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/test/llm/test_services.py b/test/llm/test_services.py new file mode 100644 index 0000000..0e39a24 --- /dev/null +++ b/test/llm/test_services.py @@ -0,0 +1,63 @@ +import pytest +from sl.llm.services import build_simple_prompt, sample +from sl.llm.data_models import ChatMessage, MessageRole, Prompt + + +def test_build_simple_prompt_with_system(): + """Test building prompt with both system and user messages.""" + system_prompt = "You are a helpful assistant." + user_prompt = "What is 2+2?" + + prompt = build_simple_prompt(user_prompt, system_prompt) + + assert len(prompt.messages) == 2 + assert prompt.messages[0].role == MessageRole.system + assert prompt.messages[0].content == system_prompt + assert prompt.messages[1].role == MessageRole.user + assert prompt.messages[1].content == user_prompt + + +def test_build_simple_prompt_user_only(): + """Test building prompt with only user message.""" + user_prompt = "What is 2+2?" + + prompt = build_simple_prompt(user_prompt) + + assert len(prompt.messages) == 1 + assert prompt.messages[0].role == MessageRole.user + assert prompt.messages[0].content == user_prompt + + +def test_build_simple_prompt_none_system(): + """Test building prompt with explicitly None system prompt.""" + user_prompt = "What is 2+2?" + + prompt = build_simple_prompt(user_prompt, None) + + assert len(prompt.messages) == 1 + assert prompt.messages[0].role == MessageRole.user + assert prompt.messages[0].content == user_prompt + + +@pytest.mark.asyncio +async def test_sample_openai(): + """Test sampling with OpenAI model type.""" + prompt = Prompt( + messages=[ChatMessage(role=MessageRole.user, content="Say hello in one word.")] + ) + + result = await sample("gpt-4o-mini", "openai", prompt, max_tokens=5) + + assert result.model_id == "gpt-4o-mini" + assert isinstance(result.completion, str) + assert len(result.completion) > 0 + assert result.stop_reason is not None + + +@pytest.mark.asyncio +async def test_sample_unsupported_model_type(): + """Test that unsupported model types raise NotImplementedError.""" + prompt = Prompt(messages=[ChatMessage(role=MessageRole.user, content="Hello")]) + + with pytest.raises(NotImplementedError): + await sample("claude-3-sonnet", "anthropic", prompt) diff --git a/uv.lock b/uv.lock index f5fead9..5a451a2 100644 --- a/uv.lock +++ b/uv.lock @@ -328,6 +328,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/4a/4175a563579e884192ba6e81725fc0448b042024419be8d83aa8a80a3f44/jiter-0.10.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3aa96f2abba33dc77f79b4cf791840230375f9534e5fac927ccceb58c5e604a5", size = 354213 }, ] +[[package]] +name = "loguru" +version = "0.7.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "win32-setctime", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3a/05/a1dae3dffd1116099471c643b8924f5aa6524411dc6c63fdae648c4f1aca/loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6", size = 63559 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/29/0348de65b8cc732daa3e33e67806420b2ae89bdce2b04af740289c5c6c8c/loguru-0.7.3-py3-none-any.whl", hash = "sha256:31a33c10c8e1e10422bfd431aeb5d351c7cf7fa671e3c4df004162264b28220c", size = 61595 }, +] + [[package]] name = "matplotlib-inline" version = "0.1.7" @@ -396,6 +409,64 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 }, ] +[[package]] +name = "numpy" +version = "2.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2e/19/d7c972dfe90a353dbd3efbbe1d14a5951de80c99c9dc1b93cd998d51dc0f/numpy-2.3.1.tar.gz", hash = "sha256:1ec9ae20a4226da374362cca3c62cd753faf2f951440b0e3b98e93c235441d2b", size = 20390372 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b0/c7/87c64d7ab426156530676000c94784ef55676df2f13b2796f97722464124/numpy-2.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6ea9e48336a402551f52cd8f593343699003d2353daa4b72ce8d34f66b722070", size = 21199346 }, + { url = "https://files.pythonhosted.org/packages/58/0e/0966c2f44beeac12af8d836e5b5f826a407cf34c45cb73ddcdfce9f5960b/numpy-2.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5ccb7336eaf0e77c1635b232c141846493a588ec9ea777a7c24d7166bb8533ae", size = 14361143 }, + { url = "https://files.pythonhosted.org/packages/7d/31/6e35a247acb1bfc19226791dfc7d4c30002cd4e620e11e58b0ddf836fe52/numpy-2.3.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:0bb3a4a61e1d327e035275d2a993c96fa786e4913aa089843e6a2d9dd205c66a", size = 5378989 }, + { url = "https://files.pythonhosted.org/packages/b0/25/93b621219bb6f5a2d4e713a824522c69ab1f06a57cd571cda70e2e31af44/numpy-2.3.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:e344eb79dab01f1e838ebb67aab09965fb271d6da6b00adda26328ac27d4a66e", size = 6912890 }, + { url = "https://files.pythonhosted.org/packages/ef/60/6b06ed98d11fb32e27fb59468b42383f3877146d3ee639f733776b6ac596/numpy-2.3.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:467db865b392168ceb1ef1ffa6f5a86e62468c43e0cfb4ab6da667ede10e58db", size = 14569032 }, + { url = "https://files.pythonhosted.org/packages/75/c9/9bec03675192077467a9c7c2bdd1f2e922bd01d3a69b15c3a0fdcd8548f6/numpy-2.3.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:afed2ce4a84f6b0fc6c1ce734ff368cbf5a5e24e8954a338f3bdffa0718adffb", size = 16930354 }, + { url = "https://files.pythonhosted.org/packages/6a/e2/5756a00cabcf50a3f527a0c968b2b4881c62b1379223931853114fa04cda/numpy-2.3.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0025048b3c1557a20bc80d06fdeb8cc7fc193721484cca82b2cfa072fec71a93", size = 15879605 }, + { url = "https://files.pythonhosted.org/packages/ff/86/a471f65f0a86f1ca62dcc90b9fa46174dd48f50214e5446bc16a775646c5/numpy-2.3.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a5ee121b60aa509679b682819c602579e1df14a5b07fe95671c8849aad8f2115", size = 18666994 }, + { url = "https://files.pythonhosted.org/packages/43/a6/482a53e469b32be6500aaf61cfafd1de7a0b0d484babf679209c3298852e/numpy-2.3.1-cp311-cp311-win32.whl", hash = "sha256:a8b740f5579ae4585831b3cf0e3b0425c667274f82a484866d2adf9570539369", size = 6603672 }, + { url = "https://files.pythonhosted.org/packages/6b/fb/bb613f4122c310a13ec67585c70e14b03bfc7ebabd24f4d5138b97371d7c/numpy-2.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:d4580adadc53311b163444f877e0789f1c8861e2698f6b2a4ca852fda154f3ff", size = 13024015 }, + { url = "https://files.pythonhosted.org/packages/51/58/2d842825af9a0c041aca246dc92eb725e1bc5e1c9ac89712625db0c4e11c/numpy-2.3.1-cp311-cp311-win_arm64.whl", hash = "sha256:ec0bdafa906f95adc9a0c6f26a4871fa753f25caaa0e032578a30457bff0af6a", size = 10456989 }, + { url = "https://files.pythonhosted.org/packages/c6/56/71ad5022e2f63cfe0ca93559403d0edef14aea70a841d640bd13cdba578e/numpy-2.3.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2959d8f268f3d8ee402b04a9ec4bb7604555aeacf78b360dc4ec27f1d508177d", size = 20896664 }, + { url = "https://files.pythonhosted.org/packages/25/65/2db52ba049813670f7f987cc5db6dac9be7cd95e923cc6832b3d32d87cef/numpy-2.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:762e0c0c6b56bdedfef9a8e1d4538556438288c4276901ea008ae44091954e29", size = 14131078 }, + { url = "https://files.pythonhosted.org/packages/57/dd/28fa3c17b0e751047ac928c1e1b6990238faad76e9b147e585b573d9d1bd/numpy-2.3.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:867ef172a0976aaa1f1d1b63cf2090de8b636a7674607d514505fb7276ab08fc", size = 5112554 }, + { url = "https://files.pythonhosted.org/packages/c9/fc/84ea0cba8e760c4644b708b6819d91784c290288c27aca916115e3311d17/numpy-2.3.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:4e602e1b8682c2b833af89ba641ad4176053aaa50f5cacda1a27004352dde943", size = 6646560 }, + { url = "https://files.pythonhosted.org/packages/61/b2/512b0c2ddec985ad1e496b0bd853eeb572315c0f07cd6997473ced8f15e2/numpy-2.3.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8e333040d069eba1652fb08962ec5b76af7f2c7bce1df7e1418c8055cf776f25", size = 14260638 }, + { url = "https://files.pythonhosted.org/packages/6e/45/c51cb248e679a6c6ab14b7a8e3ead3f4a3fe7425fc7a6f98b3f147bec532/numpy-2.3.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e7cbf5a5eafd8d230a3ce356d892512185230e4781a361229bd902ff403bc660", size = 16632729 }, + { url = "https://files.pythonhosted.org/packages/e4/ff/feb4be2e5c09a3da161b412019caf47183099cbea1132fd98061808c2df2/numpy-2.3.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5f1b8f26d1086835f442286c1d9b64bb3974b0b1e41bb105358fd07d20872952", size = 15565330 }, + { url = "https://files.pythonhosted.org/packages/bc/6d/ceafe87587101e9ab0d370e4f6e5f3f3a85b9a697f2318738e5e7e176ce3/numpy-2.3.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ee8340cb48c9b7a5899d1149eece41ca535513a9698098edbade2a8e7a84da77", size = 18361734 }, + { url = "https://files.pythonhosted.org/packages/2b/19/0fb49a3ea088be691f040c9bf1817e4669a339d6e98579f91859b902c636/numpy-2.3.1-cp312-cp312-win32.whl", hash = "sha256:e772dda20a6002ef7061713dc1e2585bc1b534e7909b2030b5a46dae8ff077ab", size = 6320411 }, + { url = "https://files.pythonhosted.org/packages/b1/3e/e28f4c1dd9e042eb57a3eb652f200225e311b608632bc727ae378623d4f8/numpy-2.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:cfecc7822543abdea6de08758091da655ea2210b8ffa1faf116b940693d3df76", size = 12734973 }, + { url = "https://files.pythonhosted.org/packages/04/a8/8a5e9079dc722acf53522b8f8842e79541ea81835e9b5483388701421073/numpy-2.3.1-cp312-cp312-win_arm64.whl", hash = "sha256:7be91b2239af2658653c5bb6f1b8bccafaf08226a258caf78ce44710a0160d30", size = 10191491 }, + { url = "https://files.pythonhosted.org/packages/d4/bd/35ad97006d8abff8631293f8ea6adf07b0108ce6fec68da3c3fcca1197f2/numpy-2.3.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:25a1992b0a3fdcdaec9f552ef10d8103186f5397ab45e2d25f8ac51b1a6b97e8", size = 20889381 }, + { url = "https://files.pythonhosted.org/packages/f1/4f/df5923874d8095b6062495b39729178eef4a922119cee32a12ee1bd4664c/numpy-2.3.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7dea630156d39b02a63c18f508f85010230409db5b2927ba59c8ba4ab3e8272e", size = 14152726 }, + { url = "https://files.pythonhosted.org/packages/8c/0f/a1f269b125806212a876f7efb049b06c6f8772cf0121139f97774cd95626/numpy-2.3.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:bada6058dd886061f10ea15f230ccf7dfff40572e99fef440a4a857c8728c9c0", size = 5105145 }, + { url = "https://files.pythonhosted.org/packages/6d/63/a7f7fd5f375b0361682f6ffbf686787e82b7bbd561268e4f30afad2bb3c0/numpy-2.3.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:a894f3816eb17b29e4783e5873f92faf55b710c2519e5c351767c51f79d8526d", size = 6639409 }, + { url = "https://files.pythonhosted.org/packages/bf/0d/1854a4121af895aab383f4aa233748f1df4671ef331d898e32426756a8a6/numpy-2.3.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:18703df6c4a4fee55fd3d6e5a253d01c5d33a295409b03fda0c86b3ca2ff41a1", size = 14257630 }, + { url = "https://files.pythonhosted.org/packages/50/30/af1b277b443f2fb08acf1c55ce9d68ee540043f158630d62cef012750f9f/numpy-2.3.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:5902660491bd7a48b2ec16c23ccb9124b8abfd9583c5fdfa123fe6b421e03de1", size = 16627546 }, + { url = "https://files.pythonhosted.org/packages/6e/ec/3b68220c277e463095342d254c61be8144c31208db18d3fd8ef02712bcd6/numpy-2.3.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:36890eb9e9d2081137bd78d29050ba63b8dab95dff7912eadf1185e80074b2a0", size = 15562538 }, + { url = "https://files.pythonhosted.org/packages/77/2b/4014f2bcc4404484021c74d4c5ee8eb3de7e3f7ac75f06672f8dcf85140a/numpy-2.3.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a780033466159c2270531e2b8ac063704592a0bc62ec4a1b991c7c40705eb0e8", size = 18360327 }, + { url = "https://files.pythonhosted.org/packages/40/8d/2ddd6c9b30fcf920837b8672f6c65590c7d92e43084c25fc65edc22e93ca/numpy-2.3.1-cp313-cp313-win32.whl", hash = "sha256:39bff12c076812595c3a306f22bfe49919c5513aa1e0e70fac756a0be7c2a2b8", size = 6312330 }, + { url = "https://files.pythonhosted.org/packages/dd/c8/beaba449925988d415efccb45bf977ff8327a02f655090627318f6398c7b/numpy-2.3.1-cp313-cp313-win_amd64.whl", hash = "sha256:8d5ee6eec45f08ce507a6570e06f2f879b374a552087a4179ea7838edbcbfa42", size = 12731565 }, + { url = "https://files.pythonhosted.org/packages/0b/c3/5c0c575d7ec78c1126998071f58facfc124006635da75b090805e642c62e/numpy-2.3.1-cp313-cp313-win_arm64.whl", hash = "sha256:0c4d9e0a8368db90f93bd192bfa771ace63137c3488d198ee21dfb8e7771916e", size = 10190262 }, + { url = "https://files.pythonhosted.org/packages/ea/19/a029cd335cf72f79d2644dcfc22d90f09caa86265cbbde3b5702ccef6890/numpy-2.3.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:b0b5397374f32ec0649dd98c652a1798192042e715df918c20672c62fb52d4b8", size = 20987593 }, + { url = "https://files.pythonhosted.org/packages/25/91/8ea8894406209107d9ce19b66314194675d31761fe2cb3c84fe2eeae2f37/numpy-2.3.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:c5bdf2015ccfcee8253fb8be695516ac4457c743473a43290fd36eba6a1777eb", size = 14300523 }, + { url = "https://files.pythonhosted.org/packages/a6/7f/06187b0066eefc9e7ce77d5f2ddb4e314a55220ad62dd0bfc9f2c44bac14/numpy-2.3.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:d70f20df7f08b90a2062c1f07737dd340adccf2068d0f1b9b3d56e2038979fee", size = 5227993 }, + { url = "https://files.pythonhosted.org/packages/e8/ec/a926c293c605fa75e9cfb09f1e4840098ed46d2edaa6e2152ee35dc01ed3/numpy-2.3.1-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:2fb86b7e58f9ac50e1e9dd1290154107e47d1eef23a0ae9145ded06ea606f992", size = 6736652 }, + { url = "https://files.pythonhosted.org/packages/e3/62/d68e52fb6fde5586650d4c0ce0b05ff3a48ad4df4ffd1b8866479d1d671d/numpy-2.3.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:23ab05b2d241f76cb883ce8b9a93a680752fbfcbd51c50eff0b88b979e471d8c", size = 14331561 }, + { url = "https://files.pythonhosted.org/packages/fc/ec/b74d3f2430960044bdad6900d9f5edc2dc0fb8bf5a0be0f65287bf2cbe27/numpy-2.3.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:ce2ce9e5de4703a673e705183f64fd5da5bf36e7beddcb63a25ee2286e71ca48", size = 16693349 }, + { url = "https://files.pythonhosted.org/packages/0d/15/def96774b9d7eb198ddadfcbd20281b20ebb510580419197e225f5c55c3e/numpy-2.3.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c4913079974eeb5c16ccfd2b1f09354b8fed7e0d6f2cab933104a09a6419b1ee", size = 15642053 }, + { url = "https://files.pythonhosted.org/packages/2b/57/c3203974762a759540c6ae71d0ea2341c1fa41d84e4971a8e76d7141678a/numpy-2.3.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:010ce9b4f00d5c036053ca684c77441f2f2c934fd23bee058b4d6f196efd8280", size = 18434184 }, + { url = "https://files.pythonhosted.org/packages/22/8a/ccdf201457ed8ac6245187850aff4ca56a79edbea4829f4e9f14d46fa9a5/numpy-2.3.1-cp313-cp313t-win32.whl", hash = "sha256:6269b9edfe32912584ec496d91b00b6d34282ca1d07eb10e82dfc780907d6c2e", size = 6440678 }, + { url = "https://files.pythonhosted.org/packages/f1/7e/7f431d8bd8eb7e03d79294aed238b1b0b174b3148570d03a8a8a8f6a0da9/numpy-2.3.1-cp313-cp313t-win_amd64.whl", hash = "sha256:2a809637460e88a113e186e87f228d74ae2852a2e0c44de275263376f17b5bdc", size = 12870697 }, + { url = "https://files.pythonhosted.org/packages/d4/ca/af82bf0fad4c3e573c6930ed743b5308492ff19917c7caaf2f9b6f9e2e98/numpy-2.3.1-cp313-cp313t-win_arm64.whl", hash = "sha256:eccb9a159db9aed60800187bc47a6d3451553f0e1b08b068d8b277ddfbb9b244", size = 10260376 }, + { url = "https://files.pythonhosted.org/packages/e8/34/facc13b9b42ddca30498fc51f7f73c3d0f2be179943a4b4da8686e259740/numpy-2.3.1-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ad506d4b09e684394c42c966ec1527f6ebc25da7f4da4b1b056606ffe446b8a3", size = 21070637 }, + { url = "https://files.pythonhosted.org/packages/65/b6/41b705d9dbae04649b529fc9bd3387664c3281c7cd78b404a4efe73dcc45/numpy-2.3.1-pp311-pypy311_pp73-macosx_14_0_arm64.whl", hash = "sha256:ebb8603d45bc86bbd5edb0d63e52c5fd9e7945d3a503b77e486bd88dde67a19b", size = 5304087 }, + { url = "https://files.pythonhosted.org/packages/7a/b4/fe3ac1902bff7a4934a22d49e1c9d71a623204d654d4cc43c6e8fe337fcb/numpy-2.3.1-pp311-pypy311_pp73-macosx_14_0_x86_64.whl", hash = "sha256:15aa4c392ac396e2ad3d0a2680c0f0dee420f9fed14eef09bdb9450ee6dcb7b7", size = 6817588 }, + { url = "https://files.pythonhosted.org/packages/ae/ee/89bedf69c36ace1ac8f59e97811c1f5031e179a37e4821c3a230bf750142/numpy-2.3.1-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c6e0bf9d1a2f50d2b65a7cf56db37c095af17b59f6c132396f7c6d5dd76484df", size = 14399010 }, + { url = "https://files.pythonhosted.org/packages/15/08/e00e7070ede29b2b176165eba18d6f9784d5349be3c0c1218338e79c27fd/numpy-2.3.1-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:eabd7e8740d494ce2b4ea0ff05afa1b7b291e978c0ae075487c51e8bd93c0c68", size = 16752042 }, + { url = "https://files.pythonhosted.org/packages/48/6b/1c6b515a83d5564b1698a61efa245727c8feecf308f4091f565988519d20/numpy-2.3.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:e610832418a2bc09d974cc9fecebfa51e9532d6190223bc5ef6a7402ebf3b5cb", size = 12927246 }, +] + [[package]] name = "openai" version = "1.97.0" @@ -728,6 +799,9 @@ version = "0.1.0" source = { editable = "." } dependencies = [ { name = "dotenv" }, + { name = "loguru" }, + { name = "neovim" }, + { name = "numpy" }, { name = "openai" }, { name = "pydantic" }, ] @@ -746,6 +820,9 @@ dev = [ [package.metadata] requires-dist = [ { name = "dotenv", specifier = ">=0.9.9" }, + { name = "loguru", specifier = ">=0.7.3" }, + { name = "neovim", specifier = ">=0.3.1" }, + { name = "numpy", specifier = ">=2.3.1" }, { name = "openai", specifier = ">=1.97.0" }, { name = "pydantic", specifier = ">=2.11.7" }, ] @@ -848,3 +925,12 @@ sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc wheels = [ { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 }, ] + +[[package]] +name = "win32-setctime" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b3/8f/705086c9d734d3b663af0e9bb3d4de6578d08f46b1b101c2442fd9aecaa2/win32_setctime-1.2.0.tar.gz", hash = "sha256:ae1fdf948f5640aae05c511ade119313fb6a30d7eabe25fef9764dca5873c4c0", size = 4867 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/07/c6fe3ad3e685340704d314d765b7912993bcb8dc198f0e7a89382d37974b/win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390", size = 4083 }, +]