From 71880012818d486cf12202c7d54f3e88727bb746 Mon Sep 17 00:00:00 2001 From: Agus Date: Mon, 24 Feb 2025 19:54:44 +0100 Subject: [PATCH] Add script to decontaminate datasets against benchmark datasets (#416) * Add script to decontaminate datasets against benchmark datasets * Add docs for the decontamination script * Update README.md Co-authored-by: lewtun * Update README.md Co-authored-by: lewtun * Update README.md Co-authored-by: lewtun * Update scripts/decontaminate.py Co-authored-by: lewtun * Update scripts/decontaminate.py Co-authored-by: lewtun * Update scripts/decontaminate.py Co-authored-by: lewtun * Update scripts/decontaminate.py Co-authored-by: lewtun * Update scripts/decontaminate.py Co-authored-by: lewtun * Add license header and attribution to the authors --------- Co-authored-by: lewtun --- README.md | 31 +++++++++ scripts/decontaminate.py | 142 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+) create mode 100644 scripts/decontaminate.py diff --git a/README.md b/README.md index 2c01d1fa..f778815b 100644 --- a/README.md +++ b/README.md @@ -211,6 +211,37 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con --config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml ``` +#### Data decontamination + +Following [s1: Simple test-time scaling](https://arxiv.org/abs/2501.19393) the data can be decontaminated using the script at: [scripts/decontaminate.py](./scripts/decontaminate.py), which decontaminates a dataset using 8-grams and deduplicate the data. Sample run: + +```shell +python scripts/decontaminate.py \ + --dataset "open-r1/verifiable-coding-problems-python" \ + --problem_column problem \ + --cleanup +``` + +It will decontaminate against the benchmark datasets, and remove the contaminated samples afterwards. If no argument `--new_dataset_name` is provided, the same dataset will be reused, adding a `_decontaminated`. It runs against the prompt, which for this dataset is the column `problem`, but a different one can be provided. + +Arguments for the script: + +```shell +usage: decontaminate.py [-h] --dataset DATASET [--split SPLIT] [--ngram_size NGRAM_SIZE] [--problem_column PROBLEM_COLUMN] [--cleanup] [--new_dataset_name NEW_DATASET_NAME] + +options: + -h, --help show this help message and exit + --dataset DATASET Name of the dataset to check for contamination. + --split SPLIT Split to check for contamination, defaults to `train`. + --ngram_size NGRAM_SIZE + Size of n-grams to build, defaults to 8. + --problem_column PROBLEM_COLUMN + Name of the column containing the problem (prompt). + --cleanup Whether to remove the contaminated rows before pushing the dataset. + --new_dataset_name NEW_DATASET_NAME + New name for the dataset. If not provided, will reuse the name and add a `_decontaminated` to the name. +``` + ### Launching jobs on a Slurm cluster If you have access to a Slurm cluster, we provide a `slurm/train.slurm` script that will automatically queue training jobs for you. Here's how you can use it: diff --git a/scripts/decontaminate.py b/scripts/decontaminate.py new file mode 100644 index 00000000..14feb5bd --- /dev/null +++ b/scripts/decontaminate.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script is used to decontaminate a dataset by checking for n-gram overlap with other datasets. +It uses the same approach presented in https://arxiv.org/abs/2501.19393, +as found in: https://github.com/simplescaling/s1/blob/main/data/decontaminate_util.py + +python scripts/decontaminate.py \ + --dataset "open-r1/verifiable-coding-problems-python" \ + --split train \ + --ngram_size 8 \ + --problem_column problem \ + --cleanup +""" + +import collections + +from tqdm import tqdm + + +def normalize_string(text: str) -> str: + """Basic string normalization.""" + # Convert to lowercase and normalize whitespace + text = text.lower().strip() + # Replace multiple spaces with single space + text = " ".join(text.split()) + return text + + +def word_ngrams(text: str, n: int) -> list: + """Generate word-level n-grams from text.""" + words = text.split() + return [" ".join(words[i : i + n]) for i in range(len(words) - n + 1)] + + +def build_ngram_lookup(documents: list[str], ngram_size: int = 8) -> dict[str, set[int]]: + """Build ngram lookup for documents.""" + lookup = collections.defaultdict(set) + + for doc_id, document in enumerate(tqdm(documents)): + normalized_text = normalize_string(document) + ngrams = word_ngrams(normalized_text, ngram_size) + for ngram in ngrams: + lookup[ngram].add(doc_id) + + return lookup + + +def build_ngram_single(document: str, ngram_size: int = 8) -> set[str]: + normalized_text = normalize_string(document) + ngrams = word_ngrams(normalized_text, ngram_size) + + return set(ngrams) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, required=True, help="Name of the dataset to check for contamination.") + parser.add_argument("--split", type=str, default="train", help="Split to check for contamination, defaults to `train`.") + parser.add_argument("--ngram_size", type=int, default=8, help="Size of n-grams to build, defaults to 8.") + parser.add_argument( + "--problem_column", type=str, default="problem", help="Name of the column containing the problem (prompt)." + ) + parser.add_argument( + "--cleanup", + action="store_true", + help="Whether to remove the contaminated rows before pushing the dataset.", + ) + parser.add_argument( + "--new_dataset_name", + type=str, + default=None, + help="New name for the dataset. If not provided, will reuse the name and add a `_decontaminated` to the name." + ) + args = parser.parse_args() + + from datasets import load_dataset, Dataset + + # Load the dataset to check for contamination + ds = load_dataset(args.dataset, split=args.split) + + eval_datasets = { + "aime_2024": (load_dataset("HuggingFaceH4/aime_2024", split="train"), "problem"), + "aime_2025": (load_dataset("yentinglin/aime_2025", split="train"), "problem"), + "math_500": (load_dataset("HuggingFaceH4/MATH-500", split="test"), "problem"), + "gpqa": (load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train", trust_remote_code=True), "Question"), + "lcb": ( + load_dataset( + "livecodebench/code_generation_lite", split="test", version_tag="v4_v5", trust_remote_code=True + ), + "question_content", + ), + } + ngram_lookups = {} + for ds_name, (eval_dataset, problem_col) in eval_datasets.items(): + ngram_lookups[ds_name] = build_ngram_lookup(eval_dataset[problem_col], ngram_size=args.ngram_size) + + for eval_name, ngram_lookup in ngram_lookups.items(): + # Update the ngram_lookup variable for each dataset + def find_contaminated(row): + # For each example we have to build the ngrams and check for all of them on each row + ngrams = build_ngram_single(row[args.problem_column], ngram_size=args.ngram_size) + row[f"contaminated_{eval_name}"] = any(set(ngram in ngram_lookup for ngram in ngrams)) + return row + + ds = ds.map(find_contaminated, num_proc=8) + + # Allow cleaning up via CLI args (removing the contaminated examples and dropping the columns) + def cleanup(dataset: Dataset) -> Dataset: + initial_size = len(dataset) + contamination_cols = [col for col in dataset.column_names if col.startswith("contaminated_")] + for col in contamination_cols: + if col.startswith("contaminated_"): + size_prior = len(dataset) + dataset = dataset.filter(lambda x: not x[col], num_proc=8) + if len(dataset) < size_prior: + print(f"Removed {size_prior - len(dataset)} samples from '{col.replace('contaminated_', '')}'") + dataset = dataset.remove_columns(contamination_cols) + print(f"Initial size: {initial_size}, Final size: {len(dataset)}") + return dataset + + if args.cleanup: + ds = cleanup(ds) + + new_ds_name = args.new_dataset_name or f"{args.dataset}_decontaminated" + ds.push_to_hub(new_ds_name, split="train", private=False) + print(f"Decontaminated dataset: {new_ds_name}")