Skip to content

Commit

Permalink
Add script to decontaminate datasets against benchmark datasets (#416)
Browse files Browse the repository at this point in the history
* Add script to decontaminate datasets against benchmark datasets

* Add docs for the decontamination script

* Update README.md

Co-authored-by: lewtun <[email protected]>

* Update README.md

Co-authored-by: lewtun <[email protected]>

* Update README.md

Co-authored-by: lewtun <[email protected]>

* Update scripts/decontaminate.py

Co-authored-by: lewtun <[email protected]>

* Update scripts/decontaminate.py

Co-authored-by: lewtun <[email protected]>

* Update scripts/decontaminate.py

Co-authored-by: lewtun <[email protected]>

* Update scripts/decontaminate.py

Co-authored-by: lewtun <[email protected]>

* Update scripts/decontaminate.py

Co-authored-by: lewtun <[email protected]>

* Add license header and attribution to the authors

---------

Co-authored-by: lewtun <[email protected]>
  • Loading branch information
plaguss and lewtun authored Feb 24, 2025
1 parent 0c3ef83 commit 7188001
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 0 deletions.
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,37 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
```

#### Data decontamination

Following [s1: Simple test-time scaling](https://arxiv.org/abs/2501.19393) the data can be decontaminated using the script at: [scripts/decontaminate.py](./scripts/decontaminate.py), which decontaminates a dataset using 8-grams and deduplicate the data. Sample run:

```shell
python scripts/decontaminate.py \
--dataset "open-r1/verifiable-coding-problems-python" \
--problem_column problem \
--cleanup
```

It will decontaminate against the benchmark datasets, and remove the contaminated samples afterwards. If no argument `--new_dataset_name` is provided, the same dataset will be reused, adding a `_decontaminated`. It runs against the prompt, which for this dataset is the column `problem`, but a different one can be provided.

Arguments for the script:

```shell
usage: decontaminate.py [-h] --dataset DATASET [--split SPLIT] [--ngram_size NGRAM_SIZE] [--problem_column PROBLEM_COLUMN] [--cleanup] [--new_dataset_name NEW_DATASET_NAME]

options:
-h, --help show this help message and exit
--dataset DATASET Name of the dataset to check for contamination.
--split SPLIT Split to check for contamination, defaults to `train`.
--ngram_size NGRAM_SIZE
Size of n-grams to build, defaults to 8.
--problem_column PROBLEM_COLUMN
Name of the column containing the problem (prompt).
--cleanup Whether to remove the contaminated rows before pushing the dataset.
--new_dataset_name NEW_DATASET_NAME
New name for the dataset. If not provided, will reuse the name and add a `_decontaminated` to the name.
```

### Launching jobs on a Slurm cluster

If you have access to a Slurm cluster, we provide a `slurm/train.slurm` script that will automatically queue training jobs for you. Here's how you can use it:
Expand Down
142 changes: 142 additions & 0 deletions scripts/decontaminate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is used to decontaminate a dataset by checking for n-gram overlap with other datasets.
It uses the same approach presented in https://arxiv.org/abs/2501.19393,
as found in: https://github.com/simplescaling/s1/blob/main/data/decontaminate_util.py
python scripts/decontaminate.py \
--dataset "open-r1/verifiable-coding-problems-python" \
--split train \
--ngram_size 8 \
--problem_column problem \
--cleanup
"""

import collections

from tqdm import tqdm


def normalize_string(text: str) -> str:
"""Basic string normalization."""
# Convert to lowercase and normalize whitespace
text = text.lower().strip()
# Replace multiple spaces with single space
text = " ".join(text.split())
return text


def word_ngrams(text: str, n: int) -> list:
"""Generate word-level n-grams from text."""
words = text.split()
return [" ".join(words[i : i + n]) for i in range(len(words) - n + 1)]


def build_ngram_lookup(documents: list[str], ngram_size: int = 8) -> dict[str, set[int]]:
"""Build ngram lookup for documents."""
lookup = collections.defaultdict(set)

for doc_id, document in enumerate(tqdm(documents)):
normalized_text = normalize_string(document)
ngrams = word_ngrams(normalized_text, ngram_size)
for ngram in ngrams:
lookup[ngram].add(doc_id)

return lookup


def build_ngram_single(document: str, ngram_size: int = 8) -> set[str]:
normalized_text = normalize_string(document)
ngrams = word_ngrams(normalized_text, ngram_size)

return set(ngrams)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, required=True, help="Name of the dataset to check for contamination.")
parser.add_argument("--split", type=str, default="train", help="Split to check for contamination, defaults to `train`.")
parser.add_argument("--ngram_size", type=int, default=8, help="Size of n-grams to build, defaults to 8.")
parser.add_argument(
"--problem_column", type=str, default="problem", help="Name of the column containing the problem (prompt)."
)
parser.add_argument(
"--cleanup",
action="store_true",
help="Whether to remove the contaminated rows before pushing the dataset.",
)
parser.add_argument(
"--new_dataset_name",
type=str,
default=None,
help="New name for the dataset. If not provided, will reuse the name and add a `_decontaminated` to the name."
)
args = parser.parse_args()

from datasets import load_dataset, Dataset

# Load the dataset to check for contamination
ds = load_dataset(args.dataset, split=args.split)

eval_datasets = {
"aime_2024": (load_dataset("HuggingFaceH4/aime_2024", split="train"), "problem"),
"aime_2025": (load_dataset("yentinglin/aime_2025", split="train"), "problem"),
"math_500": (load_dataset("HuggingFaceH4/MATH-500", split="test"), "problem"),
"gpqa": (load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train", trust_remote_code=True), "Question"),
"lcb": (
load_dataset(
"livecodebench/code_generation_lite", split="test", version_tag="v4_v5", trust_remote_code=True
),
"question_content",
),
}
ngram_lookups = {}
for ds_name, (eval_dataset, problem_col) in eval_datasets.items():
ngram_lookups[ds_name] = build_ngram_lookup(eval_dataset[problem_col], ngram_size=args.ngram_size)

for eval_name, ngram_lookup in ngram_lookups.items():
# Update the ngram_lookup variable for each dataset
def find_contaminated(row):
# For each example we have to build the ngrams and check for all of them on each row
ngrams = build_ngram_single(row[args.problem_column], ngram_size=args.ngram_size)
row[f"contaminated_{eval_name}"] = any(set(ngram in ngram_lookup for ngram in ngrams))
return row

ds = ds.map(find_contaminated, num_proc=8)

# Allow cleaning up via CLI args (removing the contaminated examples and dropping the columns)
def cleanup(dataset: Dataset) -> Dataset:
initial_size = len(dataset)
contamination_cols = [col for col in dataset.column_names if col.startswith("contaminated_")]
for col in contamination_cols:
if col.startswith("contaminated_"):
size_prior = len(dataset)
dataset = dataset.filter(lambda x: not x[col], num_proc=8)
if len(dataset) < size_prior:
print(f"Removed {size_prior - len(dataset)} samples from '{col.replace('contaminated_', '')}'")
dataset = dataset.remove_columns(contamination_cols)
print(f"Initial size: {initial_size}, Final size: {len(dataset)}")
return dataset

if args.cleanup:
ds = cleanup(ds)

new_ds_name = args.new_dataset_name or f"{args.dataset}_decontaminated"
ds.push_to_hub(new_ds_name, split="train", private=False)
print(f"Decontaminated dataset: {new_ds_name}")

0 comments on commit 7188001

Please sign in to comment.