-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #39 from joenorton/palindrome_generation
feat: add palindrome_generation
- Loading branch information
Showing
3 changed files
with
210 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import random | ||
import string | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, Optional | ||
|
||
from ..factory import ProceduralDataset, register_dataset | ||
|
||
|
||
@dataclass | ||
class PalindromeConfig: | ||
""" | ||
Configuration for the palindrome task. | ||
- min_length: Minimum length of the palindrome. | ||
- max_length: Maximum length of the palindrome. | ||
- seed: Optional seed for reproducibility. | ||
- size: Number of palindrome samples in the virtual dataset. | ||
""" | ||
|
||
min_length: int = 3 | ||
max_length: int = 10 | ||
seed: Optional[int] = None | ||
size: int = 50 | ||
|
||
def validate(self) -> None: | ||
"""Validate configuration parameters.""" | ||
assert self.min_length >= 1, "min_length must be >= 1" | ||
assert self.max_length >= self.min_length, "max_length must be >= min_length" | ||
|
||
|
||
class PalindromeDataset(ProceduralDataset): | ||
""" | ||
Generates a set of letters that can be assembled into a palindrome. | ||
""" | ||
|
||
def __init__(self, config: PalindromeConfig): | ||
super().__init__(config=config, seed=config.seed, size=config.size) | ||
|
||
def __getitem__(self, idx: int) -> dict: | ||
""" | ||
Generate a single palindrome task. | ||
Returns: | ||
dict with: | ||
- "question": Set of letters to form a palindrome. | ||
- "answer": A correct palindrome. | ||
- "metadata": Includes letter set and generated palindrome. | ||
""" | ||
rng = random.Random(self.seed + idx) | ||
length = rng.randint(self.config.min_length, self.config.max_length) | ||
letters = self._generate_palindrome_letters(rng, length) | ||
scrambled_letters = rng.sample(letters, len(letters)) # Scramble the order | ||
palindrome = self._assemble_palindrome(letters) | ||
|
||
question_str = ( | ||
"Rearrange these letters to form a palindrome. A palindrome is a word, phrase, or sequence that reads the same forward and backward.\n\n" | ||
"For example, if the letters are: a, a, b — a valid palindrome is: aba.\n\n" | ||
f"Your letters: {', '.join(scrambled_letters)}\n\n" | ||
"What palindrome can you form from these letters?" | ||
) | ||
|
||
return { | ||
"question": question_str, | ||
"answer": palindrome, | ||
"metadata": { | ||
"letters": scrambled_letters, | ||
"generated_palindrome": palindrome, | ||
}, | ||
} | ||
|
||
def _generate_palindrome_letters(self, rng: random.Random, length: int) -> list[str]: | ||
"""Generate a set of letters that can form a palindrome.""" | ||
half_length = length // 2 | ||
letters = rng.choices(string.ascii_lowercase, k=half_length) | ||
if length % 2 == 1: | ||
middle_letter = rng.choice(string.ascii_lowercase) | ||
return letters + [middle_letter] + letters[::-1] | ||
return letters + letters[::-1] | ||
|
||
def _assemble_palindrome(self, letters: list[str]) -> str: | ||
"""Return the palindrome string from the letter set.""" | ||
return "".join(letters) | ||
|
||
def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float: | ||
"""Determine if the solution provided is a valid palindrome. | ||
The answer is expected to be a single string | ||
Expected behavior: | ||
- Correct answer (palindrome with only correct letters in the correct quantities) gives 1.0 | ||
- An answer that is a palindrome, but not with the same letters as provided, gives 0.05 | ||
- An answer that is a string, but not a palindrome gives 0.02 | ||
- An empty string gives 0.01. | ||
- None gives 0.0. | ||
""" | ||
if answer is None or not isinstance(answer, str): | ||
return 0.0 # No answer given | ||
|
||
if answer == "": | ||
return 0.01 | ||
|
||
answer = answer.strip().lower() | ||
expected_letters = metadata["letters"] | ||
|
||
# Check if the answer is a palindrome | ||
if answer != answer[::-1]: | ||
return 0.02 | ||
|
||
# Check if answer contains the same letters as provided (ignoring order) | ||
if sorted(answer) != sorted(expected_letters): | ||
return 0.05 | ||
|
||
return 1.0 # Correct solution | ||
|
||
|
||
register_dataset("palindrome", PalindromeDataset, PalindromeConfig) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import pytest | ||
|
||
from reasoning_gym.algorithmic.palindrome_generation import PalindromeConfig, PalindromeDataset | ||
|
||
|
||
def test_palindrome_config_validation(): | ||
"""Test that invalid configs raise appropriate errors""" | ||
with pytest.raises(AssertionError): | ||
config = PalindromeConfig(min_length=0) # Too short | ||
config.validate() | ||
|
||
with pytest.raises(AssertionError): | ||
config = PalindromeConfig(min_length=5, max_length=3) # Invalid range | ||
config.validate() | ||
|
||
|
||
def test_palindrome_deterministic(): | ||
"""Test that dataset generates same items with same seed""" | ||
config = PalindromeConfig(seed=42, size=10) | ||
dataset1 = PalindromeDataset(config) | ||
dataset2 = PalindromeDataset(config) | ||
|
||
for i in range(len(dataset1)): | ||
assert dataset1[i] == dataset2[i] | ||
|
||
|
||
def test_palindrome_items(): | ||
"""Test basic properties of generated items""" | ||
config = PalindromeConfig(min_length=3, max_length=7, size=10, seed=42) | ||
dataset = PalindromeDataset(config) | ||
|
||
for item in dataset: | ||
assert isinstance(item, dict) | ||
assert "question" in item | ||
assert "answer" in item | ||
assert "metadata" in item | ||
|
||
# Check metadata contains required fields | ||
assert "letters" in item["metadata"] | ||
assert "generated_palindrome" in item["metadata"] | ||
|
||
# Verify answer is a palindrome | ||
palindrome = item["answer"] | ||
assert palindrome == palindrome[::-1], f"{palindrome} is not a palindrome" | ||
|
||
|
||
def test_palindrome_randomization(): | ||
"""Test letter randomization in the question""" | ||
config = PalindromeConfig(min_length=4, max_length=4, size=10, seed=42) | ||
dataset = PalindromeDataset(config) | ||
|
||
for item in dataset: | ||
letters = item["metadata"]["letters"] | ||
palindrome = item["metadata"]["generated_palindrome"] | ||
|
||
# Ensure the same letters are present but in different order | ||
assert sorted(letters) == sorted(palindrome) | ||
|
||
|
||
def test_score_answer(): | ||
"""Test the scoring mechanism for palindrome answers. | ||
Expected behavior: | ||
- Correct answer (palindrome with only correct letters in the correct quantities) gives 1.0 | ||
- An answer that is a palindrome, but not with the same letters as provided, gives 0.05 | ||
- An answer that is a string, but not a palindrome gives 0.02 | ||
- An empty string gives 0.01. | ||
- None gives 0.0. | ||
""" | ||
config = PalindromeConfig(min_length=4, max_length=6, size=10, seed=42) | ||
dataset = PalindromeDataset(config) | ||
|
||
for item in dataset: | ||
correct_answer = item["answer"] | ||
metadata = item["metadata"] | ||
|
||
# Correct answer should score 1.0 | ||
assert dataset.score_answer(correct_answer, metadata) == 1.0 | ||
|
||
# Incorrect answer (palindrome, but not correct one) should score 0.05 | ||
pal_letters = "racecar" if "racecar" != correct_answer else "aba" | ||
assert dataset.score_answer(pal_letters, metadata) == 0.05 | ||
|
||
# Incorrect answer (not palindrome) should score 0.02 | ||
wrong_letters = "abcd" if "abcd" != correct_answer else "efgh" | ||
assert dataset.score_answer(wrong_letters, metadata) == 0.02 | ||
|
||
# Empty String input should score 0.01 | ||
assert dataset.score_answer("", metadata) == 0.01 | ||
|
||
# Empty input should score 0.0 | ||
assert dataset.score_answer(None, metadata) == 0.0 |