mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2025-10-09 13:40:09 +03:00
375 lines
14 KiB
Python
375 lines
14 KiB
Python
"""Propositional logic task generator"""
|
||
|
||
import re
|
||
from dataclasses import dataclass
|
||
from random import Random
|
||
from typing import Any, Optional
|
||
|
||
from ..coaching import BaseCurriculum, RangeAttributeDefinition
|
||
from ..factory import ProceduralDataset, register_dataset
|
||
from ..utils import StrEnum
|
||
|
||
DATASET_NAME = "propositional_logic"
|
||
|
||
|
||
def parse_expr(expr: str):
|
||
expr = expr.strip()
|
||
if not expr:
|
||
raise ValueError("Empty expression")
|
||
|
||
if expr[0] == "(" and expr[-1] == ")":
|
||
level = 0
|
||
valid_enclosure = True
|
||
for char in expr[1:-1]:
|
||
if char == "(":
|
||
level += 1
|
||
elif char == ")":
|
||
level -= 1
|
||
if level < 0:
|
||
valid_enclosure = False
|
||
break
|
||
if level == 0 and valid_enclosure:
|
||
return parse_expr(expr[1:-1])
|
||
|
||
operators_by_precedence = [[Operator.IFF], [Operator.IMPLIES], [Operator.OR], [Operator.AND]]
|
||
|
||
for operator_level in operators_by_precedence:
|
||
level = 0
|
||
for i in range(len(expr) - 1, -1, -1):
|
||
char = expr[i]
|
||
if char == ")":
|
||
level += 1
|
||
elif char == "(":
|
||
level -= 1
|
||
elif level == 0:
|
||
for operator in operator_level:
|
||
if expr[i : i + len(operator.value)] == operator.value:
|
||
left_expr = expr[:i]
|
||
right_expr = expr[i + len(operator.value) :]
|
||
return Expression(operator, parse_expr(left_expr), parse_expr(right_expr))
|
||
|
||
if expr.startswith(Operator.NOT.value):
|
||
sub_expr = expr[len(Operator.NOT.value) :]
|
||
return Expression(Operator.NOT, parse_expr(sub_expr))
|
||
|
||
return Expression(None, expr)
|
||
|
||
|
||
class Operator(StrEnum):
|
||
"""Basic logical operators"""
|
||
|
||
AND = "∧"
|
||
OR = "∨"
|
||
NOT = "¬"
|
||
IMPLIES = "→"
|
||
IFF = "↔"
|
||
|
||
|
||
QUESTION_FORMAT = """The following question is a propositional logic reasoning question.
|
||
|
||
In the question we provide a list of premises. The task is to infer a correct conclusion from the premise.
|
||
|
||
FORMAT INSTRUCTIONS:
|
||
- Return the conclusion logic statement, as your final answer.
|
||
- Use the following notation to denote symbols
|
||
- OR = \u2228
|
||
- AND = \u2227
|
||
- IMPLIES = \u2192
|
||
- IFF = \u2194
|
||
- NOT = \u00ac
|
||
|
||
Here is the question:
|
||
"""
|
||
|
||
|
||
@dataclass
|
||
class PropositionalLogicConfig:
|
||
"""Configuration for propositional logic task generation"""
|
||
|
||
min_vars: int = 2 # Minimum number of variables
|
||
max_vars: int = 4 # Maximum number of variables
|
||
min_statements: int = 2 # Minimum number of given statements
|
||
max_statements: int = 4 # Maximum number of statements
|
||
min_complexity: int = 1 # Minimum operator depth
|
||
max_complexity: int = 3 # Maximum operator depth
|
||
seed: Optional[int] = None
|
||
size: int = 500 # Virtual dataset size
|
||
|
||
def validate(self):
|
||
"""Validate configuration parameters"""
|
||
assert self.min_vars > 0, "min_vars must be positive"
|
||
assert self.max_vars >= self.min_vars, "max_vars must be >= min_vars"
|
||
assert self.min_statements > 0, "min_statements must be positive"
|
||
assert self.max_statements >= self.min_statements, "max_statements must be >= min_statements"
|
||
assert self.min_complexity > 0, "min_complexity must be positive"
|
||
assert self.max_complexity >= self.min_complexity, "max_complexity must be >= min_complexity"
|
||
|
||
|
||
class Expression:
|
||
"""Represents a logical expression that can be evaluated"""
|
||
|
||
def __init__(self, operator: Optional[Operator], left: Any, right: Optional[Any] = None):
|
||
self.operator = operator
|
||
self.left = left
|
||
self.right = right
|
||
|
||
def evaluate(self, assignments: dict[str, bool]) -> bool:
|
||
"""Evaluate expression with given variable assignments"""
|
||
if self.operator is None:
|
||
return assignments[self.left] # Variable
|
||
elif self.operator == Operator.NOT:
|
||
return not self.left.evaluate(assignments)
|
||
elif self.operator == Operator.AND:
|
||
return self.left.evaluate(assignments) and self.right.evaluate(assignments)
|
||
elif self.operator == Operator.OR:
|
||
return self.left.evaluate(assignments) or self.right.evaluate(assignments)
|
||
elif self.operator == Operator.IMPLIES:
|
||
return (not self.left.evaluate(assignments)) or self.right.evaluate(assignments)
|
||
elif self.operator == Operator.IFF:
|
||
return self.left.evaluate(assignments) == self.right.evaluate(assignments)
|
||
raise ValueError(f"Unknown operator: {self.operator}")
|
||
|
||
@classmethod
|
||
def from_string(cls, expr: str) -> "Expression":
|
||
parsed_expr = parse_expr(expr)
|
||
return cls(parsed_expr.operator, parsed_expr.left, parsed_expr.right)
|
||
|
||
def simplify(self):
|
||
if self.operator is None:
|
||
return self
|
||
|
||
simplified_left = self.left.simplify() if isinstance(self.left, Expression) else self.left
|
||
simplified_right = self.right.simplify() if self.right and isinstance(self.right, Expression) else self.right
|
||
|
||
if self.operator == Operator.NOT:
|
||
if isinstance(simplified_left, Expression) and simplified_left.operator == Operator.NOT:
|
||
return simplified_left.left
|
||
return Expression(Operator.NOT, simplified_left)
|
||
|
||
if self.operator in {Operator.AND, Operator.OR}:
|
||
if simplified_left is False and self.operator == Operator.OR:
|
||
return simplified_right
|
||
if simplified_left is True and self.operator == Operator.AND:
|
||
return simplified_right
|
||
|
||
if (simplified_left is True and self.operator == Operator.OR) or (
|
||
simplified_left is False and self.operator == Operator.AND
|
||
):
|
||
return simplified_left
|
||
|
||
if simplified_left == simplified_right:
|
||
return simplified_left
|
||
|
||
if self.operator == Operator.IMPLIES:
|
||
return Expression(Operator.OR, Expression(Operator.NOT, simplified_left), simplified_right).simplify()
|
||
|
||
return Expression(self.operator, simplified_left, simplified_right)
|
||
|
||
def __str__(self) -> str:
|
||
if self.operator is None:
|
||
return self.left
|
||
elif self.operator == Operator.NOT:
|
||
return f"{self.operator.value}{self.left}"
|
||
else:
|
||
return f"({self.left} {self.operator.value} {self.right})"
|
||
|
||
|
||
class PropositionalLogicDataset(ProceduralDataset):
|
||
"""Generates propositional logic reasoning tasks"""
|
||
|
||
def __init__(self, config: PropositionalLogicConfig):
|
||
super().__init__(config=config, seed=config.seed, size=config.size)
|
||
|
||
def __len__(self) -> int:
|
||
return self.config.size
|
||
|
||
def __iter__(self):
|
||
self._current_idx = 0
|
||
return self
|
||
|
||
def __next__(self):
|
||
if self._current_idx >= self.config.size:
|
||
raise StopIteration
|
||
item = self[self._current_idx]
|
||
self._current_idx += 1
|
||
return item
|
||
|
||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||
"""Generate a single propositional logic task"""
|
||
rng = Random(self.seed + idx)
|
||
|
||
# Generate random variables
|
||
num_vars = rng.randint(self.config.min_vars, self.config.max_vars)
|
||
variables = [chr(ord("P") + i) for i in range(num_vars)]
|
||
|
||
# Generate premises
|
||
num_statements = rng.randint(self.config.min_statements, self.config.max_statements)
|
||
premises = self._generate_premises(rng, variables, num_statements)
|
||
conclusion = self._find_valid_conclusion(rng, premises, variables)
|
||
|
||
# Format question
|
||
question = QUESTION_FORMAT
|
||
question += "Given:\n"
|
||
for i, premise in enumerate(premises, 1):
|
||
question += f"{i}. {premise}\n."
|
||
question += "What can we conclude from the above statements?"
|
||
|
||
return {
|
||
"question": question,
|
||
"answer": None,
|
||
"metadata": {
|
||
"source_dataset": DATASET_NAME,
|
||
"source_index": idx,
|
||
"premises": [str(p) for p in premises],
|
||
"variables": variables,
|
||
"complexity": self._measure_complexity(conclusion),
|
||
"example_answer": str(conclusion),
|
||
"difficulty": {
|
||
"vars": (self.config.min_vars, self.config.max_vars),
|
||
"statements": (self.config.min_statements, self.config.max_statements),
|
||
"complexity": (self.config.min_complexity, self.config.max_complexity),
|
||
},
|
||
},
|
||
}
|
||
|
||
def _generate_premises(self, rng: Random, variables: list[str], num_statements: int) -> list[Expression]:
|
||
"""Generate a list of premise statements"""
|
||
premises = []
|
||
for _ in range(num_statements):
|
||
depth = rng.randint(self.config.min_complexity, self.config.max_complexity)
|
||
premises.append(self._generate_expression(rng, variables, depth))
|
||
return premises
|
||
|
||
def _generate_expression(self, rng: Random, variables: list[str], depth: int) -> Expression:
|
||
"""Generate a random logical expression"""
|
||
if depth <= 1:
|
||
return Expression(None, rng.choice(variables))
|
||
operator = rng.choice(list(Operator))
|
||
if operator == Operator.NOT:
|
||
return Expression(operator, self._generate_expression(rng, variables, depth - 1))
|
||
else:
|
||
left = self._generate_expression(rng, variables, depth - 1)
|
||
right = self._generate_expression(rng, variables, depth - 1)
|
||
return Expression(operator, left, right)
|
||
|
||
def _find_valid_conclusion(self, rng: Random, premises: list[Expression], variables: list[str]) -> Expression:
|
||
"""Find a valid conclusion that follows from the premises"""
|
||
for _ in range(100):
|
||
candidate = self._generate_expression(rng, variables, 2).simplify()
|
||
if self._is_valid_conclusion(premises, candidate) and not (self._is_trivial(candidate)):
|
||
return candidate
|
||
|
||
# Fallback to a simple conclusion
|
||
return Expression(None, variables[0])
|
||
|
||
def _is_valid_conclusion(self, premises: list[Expression], conclusion: Expression) -> bool:
|
||
"""Check if conclusion follows from premises using truth tables"""
|
||
variables = self._collect_variables(premises + [conclusion])
|
||
|
||
# Check all possible assignments
|
||
for assignment in self._generate_assignments(variables):
|
||
# If premises are true but conclusion is false, invalid
|
||
if all(p.evaluate(assignment) for p in premises) and not conclusion.evaluate(assignment):
|
||
return False
|
||
return True
|
||
|
||
def _collect_variables(self, expressions: list[Expression]) -> set[str]:
|
||
"""Collect all variables used in expressions"""
|
||
variables = set()
|
||
for expr in expressions:
|
||
if expr.operator is None:
|
||
variables.add(expr.left)
|
||
else:
|
||
if isinstance(expr.left, Expression):
|
||
variables.update(self._collect_variables([expr.left]))
|
||
if expr.right and isinstance(expr.right, Expression):
|
||
variables.update(self._collect_variables([expr.right]))
|
||
return variables
|
||
|
||
def _generate_assignments(self, variables: set[str]) -> list[dict[str, bool]]:
|
||
"""Generate all possible truth value assignments"""
|
||
assignments = []
|
||
for i in range(2 ** len(variables)):
|
||
assignment = {}
|
||
for j, var in enumerate(sorted(variables)):
|
||
assignment[var] = bool((i >> j) & 1)
|
||
assignments.append(assignment)
|
||
return assignments
|
||
|
||
def _measure_complexity(self, expression: Expression) -> int:
|
||
"""Measure the complexity of an expression"""
|
||
if expression.operator is None:
|
||
return 1
|
||
elif expression.operator == Operator.NOT:
|
||
return 1 + self._measure_complexity(expression.left)
|
||
else:
|
||
return 1 + self._measure_complexity(expression.left) + self._measure_complexity(expression.right)
|
||
|
||
def score_answer(self, answer: str | None, entry: dict[str, Any]) -> float:
|
||
"""Robust scoring implementation for propositional logic answers"""
|
||
if not isinstance(answer, str):
|
||
return 0.0
|
||
|
||
try:
|
||
cleaned_answer = answer
|
||
|
||
valid_vars = set(entry["metadata"]["variables"])
|
||
answer_vars = re.findall(r"([A-Z])", cleaned_answer)
|
||
if any(var not in valid_vars for var in answer_vars):
|
||
return 0.0
|
||
|
||
premises = [Expression.from_string(p) for p in entry["metadata"]["premises"]]
|
||
answer_expr = Expression.from_string(cleaned_answer)
|
||
|
||
if self._is_valid_conclusion(premises, answer_expr):
|
||
if self._is_trivial(answer_expr):
|
||
return 0.25
|
||
else:
|
||
return 1.0
|
||
return 0.05
|
||
except (ValueError, KeyError, AttributeError):
|
||
return 0.0
|
||
|
||
def _is_trivial(self, expr: Expression) -> bool:
|
||
"""Check for trivial tautologies like P ∨ ¬P"""
|
||
if expr.operator is None:
|
||
return True
|
||
variables = self._collect_variables([expr])
|
||
for assignment in self._generate_assignments(variables):
|
||
if not expr.evaluate(assignment):
|
||
return False
|
||
return True
|
||
|
||
|
||
class PropositionalLogicCurriculum(BaseCurriculum):
|
||
def __init__(self):
|
||
super().__init__(PropositionalLogicCurriculum.__name__, PropositionalLogicConfig)
|
||
|
||
# Define attributes
|
||
self._define_attributes(
|
||
RangeAttributeDefinition(
|
||
name="vars",
|
||
levels=[2, 4, 6, 8, 10],
|
||
description="Number of variables in the logical expressions",
|
||
lower_field_name="min_vars",
|
||
upper_field_name="max_vars",
|
||
),
|
||
RangeAttributeDefinition(
|
||
name="statements",
|
||
levels=[2, 4, 6, 8, 10],
|
||
description="Number of premises in the logical expressions",
|
||
lower_field_name="min_statements",
|
||
upper_field_name="max_statements",
|
||
),
|
||
RangeAttributeDefinition(
|
||
name="complexity",
|
||
levels=[1, 2, 3, 4, 5],
|
||
description="Complexity of the logical expressions",
|
||
lower_field_name="min_complexity",
|
||
upper_field_name="max_complexity",
|
||
),
|
||
)
|
||
|
||
|
||
register_dataset(DATASET_NAME, PropositionalLogicDataset, PropositionalLogicConfig, PropositionalLogicCurriculum)
|