diff --git a/pii/README.md b/pii/README.md index 6f4b093..d883924 100644 --- a/pii/README.md +++ b/pii/README.md @@ -2,17 +2,16 @@ We provide code to detect Names, Emails, IP addresses, Passwords API/SSH keys in text datasets (in particular datasets of source code). ## NER approach -For the **NER** model based approach go to the `ner_model` folder. +For the **NER** model based approach (e.g [StarPII](https://huggingface.co/bigcode/starpii)), please go to the `ner` folder. + +We provide the code used for training a PII NER model to detect : Names, Emails, Keys, Passwords & IP addresses (more details in our paper: [StarCoder: May The Source Be With You](https://drive.google.com/file/d/1cN-b9GnWtHzQRoE7M7gAEyivY0kl4BYs/view)). You will also find the code (and `slurm` scripts) used for running PII Inference on [StarCoderData](https://huggingface.co/datasets/bigcode/starcoderdata), we were able to detect PII in ~800GB of text in 800 GPU-hours on A100 80GB. To replace secrets we used teh following tokens: +, , , +To mask IP addresses, we randomly selected an IP address from 5~synthetic, private, non-internet-facing IP addresses of the same type. ## Regex approach Below we explain the regex based approach to dectect Emails, IP addresses adn keys only: We use regexes for emails and IP addresses (they are adapted from [BigScience PII pipeline](https://github.com/bigscience-workshop/data-preparation/tree/main/preprocessing/training/02_pii)). And we use [detect-secrets](https://github.com/Yelp/detect-secrets) for finding secrets keys. We additionally implement some filters on top to reduce the number of false positives. There is also some evaluation code to test the pipeline on a PII benchmark we annotated. - -We also provide the code used for training and running [StarPII](https://huggingface.co/bigcode/starpii) in `ner_model` and NER model for PII detection on: Names, Emails, Keys, Passwords & IP addresses (more details in our paper: [StarCoder: May The Source Be With You](https://drive.google.com/file/d/1cN-b9GnWtHzQRoE7M7gAEyivY0kl4BYs/view)). We provide the code (and `slurm` scripts) used for running Inference on [StarCoderData](https://huggingface.co/datasets/bigcode/starcoderdata), we were able to detect PII in ~800GB of text in 800 GPU-hours on A100 80GB. To replace secrets we used teh following tokens: -, , , -To mask IP addresses, we randomly selected an IP address from 5~synthetic, private, non-internet-facing IP addresses of the same type. - ## Usage of the regex approach ``` pip install -r requirements.txt diff --git a/pii/ner/README.md b/pii/ner/README.md new file mode 100644 index 0000000..ca39f42 --- /dev/null +++ b/pii/ner/README.md @@ -0,0 +1,7 @@ +# PII detection and Redaction using an NER model +Here we provide code to: +- fine-tune an encoder model (like [StarEncoder](https://huggingface.co/bigcode/starencoder)) for the task of PII detection (NER): see folder `pii_train_ner` +- run inference with our fine-tuned [StarPII](https://huggingface.co/bigcode/starpii) for PII detection on multiple GPUs: see folder `pii_inference` +- redact/mask PII detected with the model: see folder `pii_redaction` + +This is the code we used for PII anonymization in the 800GB dataset [StarCoderData](https://huggingface.co/datasets/bigcode/starcoderdata). \ No newline at end of file diff --git a/pii/ner_model/README.md b/pii/ner/pii_inference/README.md similarity index 100% rename from pii/ner_model/README.md rename to pii/ner/pii_inference/README.md diff --git a/pii/ner_model/__init__.py b/pii/ner/pii_inference/__init__.py similarity index 100% rename from pii/ner_model/__init__.py rename to pii/ner/pii_inference/__init__.py diff --git a/pii/ner_model/infer.slurm b/pii/ner/pii_inference/infer.slurm similarity index 100% rename from pii/ner_model/infer.slurm rename to pii/ner/pii_inference/infer.slurm diff --git a/pii/ner_model/infer_special.slurm b/pii/ner/pii_inference/infer_special.slurm similarity index 100% rename from pii/ner_model/infer_special.slurm rename to pii/ner/pii_inference/infer_special.slurm diff --git a/pii/ner_model/ner_inference.py b/pii/ner/pii_inference/ner_inference.py similarity index 100% rename from pii/ner_model/ner_inference.py rename to pii/ner/pii_inference/ner_inference.py diff --git a/pii/ner_model/notebooks/EDA of labeled-python-data-pii-detection.ipynb b/pii/ner/pii_inference/notebooks/EDA of labeled-python-data-pii-detection.ipynb similarity index 100% rename from pii/ner_model/notebooks/EDA of labeled-python-data-pii-detection.ipynb rename to pii/ner/pii_inference/notebooks/EDA of labeled-python-data-pii-detection.ipynb diff --git a/pii/ner_model/notebooks/Filter labeled-python-data-pii-detection.ipynb b/pii/ner/pii_inference/notebooks/Filter labeled-python-data-pii-detection.ipynb similarity index 100% rename from pii/ner_model/notebooks/Filter labeled-python-data-pii-detection.ipynb rename to pii/ner/pii_inference/notebooks/Filter labeled-python-data-pii-detection.ipynb diff --git a/pii/ner_model/notebooks/Finetune DeBERTa-v3-base on pii-for-code.ipynb b/pii/ner/pii_inference/notebooks/Finetune DeBERTa-v3-base on pii-for-code.ipynb similarity index 100% rename from pii/ner_model/notebooks/Finetune DeBERTa-v3-base on pii-for-code.ipynb rename to pii/ner/pii_inference/notebooks/Finetune DeBERTa-v3-base on pii-for-code.ipynb diff --git a/pii/ner_model/notebooks/Pipeline with sliding-window.ipynb b/pii/ner/pii_inference/notebooks/Pipeline with sliding-window.ipynb similarity index 100% rename from pii/ner_model/notebooks/Pipeline with sliding-window.ipynb rename to pii/ner/pii_inference/notebooks/Pipeline with sliding-window.ipynb diff --git a/pii/ner_model/notebooks/Train DeBERTa-v3-base on pseudo-labeled data.ipynb b/pii/ner/pii_inference/notebooks/Train DeBERTa-v3-base on pseudo-labeled data.ipynb similarity index 100% rename from pii/ner_model/notebooks/Train DeBERTa-v3-base on pseudo-labeled data.ipynb rename to pii/ner/pii_inference/notebooks/Train DeBERTa-v3-base on pseudo-labeled data.ipynb diff --git a/pii/ner_model/start_jobs.sh b/pii/ner/pii_inference/start_jobs.sh similarity index 100% rename from pii/ner_model/start_jobs.sh rename to pii/ner/pii_inference/start_jobs.sh diff --git a/pii/ner_model/start_jobs_special.sh b/pii/ner/pii_inference/start_jobs_special.sh similarity index 100% rename from pii/ner_model/start_jobs_special.sh rename to pii/ner/pii_inference/start_jobs_special.sh diff --git a/pii/ner_model/train.py b/pii/ner/pii_inference/train.py similarity index 100% rename from pii/ner_model/train.py rename to pii/ner/pii_inference/train.py diff --git a/pii/ner_model/utils/__init__.py b/pii/ner/pii_inference/utils/__init__.py similarity index 100% rename from pii/ner_model/utils/__init__.py rename to pii/ner/pii_inference/utils/__init__.py diff --git a/pii/ner_model/utils/chunking.py b/pii/ner/pii_inference/utils/chunking.py similarity index 100% rename from pii/ner_model/utils/chunking.py rename to pii/ner/pii_inference/utils/chunking.py diff --git a/pii/ner_model/utils/misc.py b/pii/ner/pii_inference/utils/misc.py similarity index 100% rename from pii/ner_model/utils/misc.py rename to pii/ner/pii_inference/utils/misc.py diff --git a/pii/ner_model/utils/pipeline.py b/pii/ner/pii_inference/utils/pipeline.py similarity index 100% rename from pii/ner_model/utils/pipeline.py rename to pii/ner/pii_inference/utils/pipeline.py diff --git a/pii/ner_model/utils/postprocessing.py b/pii/ner/pii_inference/utils/postprocessing.py similarity index 100% rename from pii/ner_model/utils/postprocessing.py rename to pii/ner/pii_inference/utils/postprocessing.py diff --git a/pii/ner_model/utils/span_ops.py b/pii/ner/pii_inference/utils/span_ops.py similarity index 100% rename from pii/ner_model/utils/span_ops.py rename to pii/ner/pii_inference/utils/span_ops.py diff --git a/pii/ner/pii_redaction/README.md b/pii/ner/pii_redaction/README.md new file mode 100644 index 0000000..3f950d0 --- /dev/null +++ b/pii/ner/pii_redaction/README.md @@ -0,0 +1,14 @@ +# PII redaction +<<<<<<< HEAD +To run PII redaction on a dataset that went though PII detection with StarPII using the code in `./pii_inference` folder: +```bash +mkdir ./logs +LANG=python +python main_redact.py --dataset_name $DATA_PATH --target_dataset $LANG-no-pii --save_path_disk $LANG-no-pii-local +``` + +To run multiple `slurm` jobs for each programming language + +```bash +python run_pii_slurm.py --start 0 --end 88 +``` diff --git a/pii/ner/pii_redaction/main_redact.py b/pii/ner/pii_redaction/main_redact.py new file mode 100644 index 0000000..a94c7a0 --- /dev/null +++ b/pii/ner/pii_redaction/main_redact.py @@ -0,0 +1,340 @@ +"""Mask detected PII in a dataset. +""" + +import argparse +import json +import logging +import random +import time +import numpy as np +from functools import partial +from pprint import pformat + +from datasets import load_dataset +from datasets.utils.logging import set_verbosity_info + +from manual_sharding import save_manual_shards +from utils import get_replacements, redact_pii_batch + + +REPONAME_TOKEN = "" +FILENAME_TOKEN = "" +STARS_TOKEN = "" + + +def get_num_stars_bucket(num_stars: int) -> str: + if num_stars is None or num_stars == 0: + return "0" + elif num_stars <= 10: + return "1-10" + elif num_stars <= 100: + return "10-100" + elif num_stars <= 1000: + return "100-1000" + else: + return "1000+" + + +def content_with_meta(example): + # TODO + res = "" + # repo-name + if np.random.binomial(n=1, p=0.2): + res += f"{REPONAME_TOKEN}{example['max_stars_repo_name']}" + # file-name + if np.random.binomial(n=1, p=0.2): + res += f"{FILENAME_TOKEN}{example['max_stars_repo_path']}" + # number of stars + if np.random.binomial(n=1, p=0.2): + num_stars = get_num_stars_bucket(example["max_stars_count"]) + res += f"{STARS_TOKEN}{num_stars}" + if len(res) > 0: + res += "\n" + res += example["content"] + + return {"content_with_meta": res} + + +def parseArgs(): + parser = argparse.ArgumentParser(description="PII detection and redaction") + parser.add_argument( + "--dataset_name", + default="bigcode/pii-for-code", + type=str, + help="HF repo name/path of the dataset.", + ) + # add arg true add metadata + parser.add_argument( + "--add_metadata", + action="store_true", + help="If set, we add metadata to the text", + ) + parser.add_argument( + "--num_load_proc", + default=64, + type=int, + help="Number of processes to use for loading the dataset", + ) + parser.add_argument( + "--text_column", + default="content", + type=str, + help="Text column to use, if will be renamed to content", + ) + parser.add_argument( + "--split", + default="train", + type=str, + help="Dataset split to process", + ) + parser.add_argument( + "--batch_size", + default=100, + type=int, + help="Batch size for the PII detection/redaction", + ) + parser.add_argument( + "--seed", + default=0, + type=int, + help="Seed for random", + ) + parser.add_argument( + "--num_proc", + default=96, + type=int, + help="Number of processes to use for the PII detection/redaction", + ) + parser.add_argument( + "--no_redaction", + action="store_true", + help="If set, we don't perform redaction", + ) + parser.add_argument( + "--load_replacements", + default=True, + help="If set, we load the replacements from file replacements.json", + ) + parser.add_argument( + "--add_reference_text", + default=True, + type=bool, + help="If True we add the reference text with PII between delimiters \ + in the redacted text -used for visualization-", + ) + parser.add_argument( + "--check_all_files", + action="store_true", + help="If set, we check all files, not only the ones that contain PII", + ) + parser.add_argument( + "--check_sampling_size", + default=0, + type=int, + help="Number of samples to check for PII", + ) + # for saving the dataset: either push to HF or save locally with datasets or save manual shards + parser.add_argument( + "--save_mode", + default="manual_shards", + type=str, + choices=["hub", "local", "manual_shards"], + help="How to save the dataset", + ) + parser.add_argument( + "--save_mode_checks", + default="hub", + type=str, + choices=["hub", "local", "manual_shards"], + help="How to save the checks dataset", + ) + # add argument for name of dataset on the hub + parser.add_argument( + "--target_dataset", + default="bigcode-pii2", + type=str, + help="HF repo name of the target dataset in save_mode=hub.", + ) + parser.add_argument( + "--hub_username", + default="loubnabnl", + type=str, + help="Username for the hub", + ) + parser.add_argument( + "--save_path_disk", + default="/fsx/loubna/data/the-stack-march-no-pii", + type=str, + help="Path to save the dataset on disk in save_mode=local.", + ) + return parser.parse_args() + + +def get_check_ds(ds, args): + if not args.check_all_files: + ds_checks = ds.filter( + lambda exs: exs["modified"], + batched=True, + batch_size=args.batch_size, + num_proc=args.num_proc, + ) + else: + ds_checks = ds + if not args.check_sampling_size: + sampling_size = len(ds_checks) + idx_samples = random.sample( + range(len(ds_checks)), min(len(ds_checks), sampling_size) + ) + ds_checks = ds_checks.select(idx_samples) + + return ds_checks + + +def check_uniques(example, uniques): + """Check if current id is still in set of unique id and remove if true.""" + if example["id"] in uniques: + uniques.remove(example["id"]) + return True + else: + return False + + +def main(): + set_verbosity_info() + args = parseArgs() + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + handlers=[ + logging.FileHandler(f"logs/pii-{args.dataset_name.split('/')[-1]}.log"), + logging.StreamHandler(), + ], + ) + logger.info( + f"** The job is running with the following arguments: **\n{args}\n **** " + ) + + logger.info(f" ===== Loading {args.dataset_name} =====") + ds = load_dataset( + args.dataset_name, + split=args.split, + use_auth_token=True, + num_proc=args.num_load_proc, + ) + if args.text_column != "content": + ds = ds.rename_column(args.text_column, "content") + + logger.info(f" ===== Deduplicating dataset =====") + # Deduplication based on ids + uniques = set(ds["id"]) + frac = len(uniques) / len(ds) + logger.info(f"Fraction of duplicates: {1-frac:.2%}") + logger.info(f"Dataset:\n{ds}") + # Deduplicate data and apply heuristics + t_start = time.time() + ds_pii = ds.filter(check_uniques, fn_kwargs={"uniques": uniques}) + logger.info(f"Time to filter dataset: {time.time()-t_start:.2f}") + logger.info(f"Dataset after dedup:\n{ds_pii}") + + logger.info( + f"Number of samples that contained PII: {sum([1 if x['entities'] else 0 for x in ds_pii])}" + ) + # logger.info( + # f"Total number of secrets found: {sum([len(x['entities']) for x in ds_pii])}" + # ) + + # redact PII in the dataset + logger.info(f" ===== Applying PII redaction =====") + random.seed(args.seed) + + replacements = get_replacements() + with open("replacements.json", "w") as f: + json.dump(replacements, f) + logging.info(f"Using the following replacements:\n{pformat(replacements)}") + ds_pii = ds_pii.map( + partial( + redact_pii_batch, + replacements=replacements, + add_references=args.add_reference_text, + ), + batched=True, + batch_size=args.batch_size, + num_proc=args.num_proc, + ) + logging.info(f"Dataset info after PII redaction:\n{ds_pii}") + + # check the dataset + logger.info( + f" ===== Checking {args.check_sampling_size} samples from those modified in the dataset =====" + ) + ds_checks = get_check_ds(ds_pii, args) + + # save checks dataset + if len(ds_checks) == 0: + logger.info("Dataset was empty. Not saving anything.") + else: + logger.info(f"Checks dataset info {ds_checks}") + if args.save_mode_checks == "hub": + logger.info( + f"Pushing the checks dataset to the Hub as {args.target_dataset}_checks" + ) + ds_checks.push_to_hub(args.target_dataset + "_checks", private=True) + + elif args.save_mode_checks == "local": + logger.info(f"Saving the checks dataset to disk") + ds_checks.save_to_disk(args.save_path_disk + "_checks") + + elif args.save_mode_checks == "manual_shards": + logger.info(f"Saving the checks dataset in manual shards") + save_manual_shards( + ds_checks, + user=args.hub_username, + remote_dataset_repo=args.target_dataset + "_checks", + local_dir="/fsx/loubna/data/the-stack-march-no-pii_checks", + ) + + logger.info("Removing columns that are not needed for the final dataset") + columns = ["content", "modified", "entities"] + if args.add_reference_text: + columns.append("references") + ds_pii = ds_pii.remove_columns(columns) + ds_pii = ds_pii.rename_column("new_content", "content") + logger.info(f"Dataset info after removing columns:\n{ds_pii}") + + if args.add_metadata: + logger.info(f" ===== Adding metadata =====") + ds_pii = ds_pii.map( + content_with_meta, remove_columns=["content"], num_proc=args.num_proc + ) + ds_pii = ds_pii.rename_column("content_with_meta", "content") + + # save the final dataset + if args.save_mode == "hub": + logger.info( + f" ===== Pushing the dataset to the Hub as: {args.target_dataset} =====" + ) + ds_pii.push_to_hub(args.target_dataset, private=True) + + elif args.save_mode == "local": + logger.info(f" ===== Saving the dataset to disk =====") + ds_pii.save_to_disk(args.save_path_disk) + + elif args.save_mode == "manual_shards": + logger.info( + f" ===== Saving the dataset in manual shards to {args.save_path_disk} =====" + ) + save_manual_shards( + ds_pii, + user=args.hub_username, + remote_dataset_repo="the-stack-no-pii-march", + local_dir=args.save_path_disk, + ) + + logger.info(f" ===== Dataset saved successfully =====") + + +if __name__ == "__main__": + main() diff --git a/pii/ner/pii_redaction/manual_sharding.py b/pii/ner/pii_redaction/manual_sharding.py new file mode 100644 index 0000000..8e9f894 --- /dev/null +++ b/pii/ner/pii_redaction/manual_sharding.py @@ -0,0 +1,70 @@ +import os +import time +from multiprocessing import Pool +from tqdm import tqdm + +from huggingface_hub import Repository + + +def save_shard(shard_tuple): + """Save shard""" + filename, shard = shard_tuple + # use to_json instead to save as json file + shard.to_parquet(filename) + + +def save_manual_shards( + ds, + user="loubnabnl", + remote_dataset_repo="bigcode-pii-pjj", + local_dir="/fsx/loubna/data/the-stack-march-no-pii", +): + """Save sharded data + Args: + ds (Dataset): dataset to be saved + user (str): user name + out_path (str): path to save the shards""" + # this will create a folder OUT_PATH that is a clone of REMOTE_DATASET_REPO + # you can save the shards inside it and do git add/commit/push to push data to the hub + out_path = remote_dataset_repo if local_dir is None else local_dir + # if out path doesnt already exist + if not os.path.exists(out_path): + repo = Repository( + local_dir=out_path, + clone_from=user + "/" + remote_dataset_repo, + repo_type="dataset", + use_auth_token=True, + git_user=user, + ) + + # files will be numerous we save them in a folder called data inside out_path + os.mkdir(out_path + "/data") + SHARD_SIZE = 1000 << 20 + if ds._indices is not None: + dataset_nbytes = ds.data.nbytes * len(ds._indices) / len(ds.data) + else: + dataset_nbytes = ds.data.nbytes + num_shards = int(dataset_nbytes / SHARD_SIZE) + 1 + print(f"Number of shards: {num_shards}") + + print("sharding the dataset") + t_start = time.time() + shards = ( + ds.shard(num_shards=num_shards, index=i, contiguous=True) + for i in range(num_shards) + ) + # use f"{OUT_PATH}/data/train-{index:05d}-of-{num_shards:05d}.json" instead for json files + filenames = ( + f"{out_path}/data/train-{index:05d}-of-{num_shards:05d}.parquet" + for index in range(num_shards) + ) + + with Pool(16) as p: + list( + tqdm( + p.imap_unordered(save_shard, zip(filenames, shards), chunksize=4), + total=num_shards, + ) + ) + print(f"Time to save dataset: {time.time()-t_start:.2f}") + # to push dataset to hub do: git add/commit/push inside OUT_PATH diff --git a/pii/ner/pii_redaction/replacements.json b/pii/ner/pii_redaction/replacements.json new file mode 100644 index 0000000..474eccc --- /dev/null +++ b/pii/ner/pii_redaction/replacements.json @@ -0,0 +1 @@ +{"EMAIL": [""], "KEY": [""], "NAME": [""], "PASSWORD": [""], "IP_ADDRESS": {"IPv4": ["172.16.31.10", "172.16.58.3", "172.16.17.32", "192.168.127.12", "192.168.3.11"], "IPv6": ["fd00:c2b6:b24b:be67:2827:688d:e6a1:6a3b", "fd00:a516:7c1b:17cd:6d81:2137:bd2a:2c5b", "fc00:e968:6179::de52:7100", "fc00:db20:35b:7399::5", "fdf8:f53e:61e4::18"]}} \ No newline at end of file diff --git a/pii/ner/pii_redaction/run_pii_slurm.py b/pii/ner/pii_redaction/run_pii_slurm.py new file mode 100644 index 0000000..5aee4ab --- /dev/null +++ b/pii/ner/pii_redaction/run_pii_slurm.py @@ -0,0 +1,206 @@ +import os +import argparse +import subprocess + +SCRIPT_DIR = "/fsx/loubna/code/bigcode-dataset/pii/ner/pii_redaction/jobs/scripts" +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--start", type=int, default=0) + parser.add_argument("--end", type=int, default=10) + parser.add_argument("--text_column", type=str, default="content") + args = parser.parse_args() + return args + + +def submit_job(job, job_name="job"): + with open(f"{SCRIPT_DIR}/{job_name}.sbatch", "w") as fp: + fp.write(job) + #os.system(f"{SCRIPT_DIR}/{job_name}.sbatch") + subprocess.run(["sbatch", f"{SCRIPT_DIR}/{job_name}.sbatch"]) + + +def makejob(JOB_NAME="pii-redaction", LANG=None, TEXT_COLUMN="content"): + return f"""#!/bin/bash + +#SBATCH --job-name={JOB_NAME} +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! +#SBATCH --cpus-per-task=96 +#SBATCH --gres=gpu:8 +#SBATCH --partition=production-cluster +#SBATCH -o /fsx/loubna/code/bigcode-dataset/pii/ner/pii_redaction/jobs/logs/%x-%j.out +#SBATCH -e /fsx/loubna/code/bigcode-dataset/pii/ner/pii_redaction/jobs/logs/%x-%j.err + +set -x -e +source /admin/home/loubna/.bashrc +conda activate eval-harness + +# File Path setup +echo "START TIME: $(date)" + +# Experiment parameters +LANG={LANG} + +# Training Setup +GPUS_PER_NODE=8 +# so processes know who to talk to +MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +MASTER_PORT=6000 +NNODES=$SLURM_NNODES +NODE_RANK=$SLURM_PROCID +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + + + +CMD=" \ + /fsx/loubna/code/bigcode-dataset/pii/ner/pii_redaction/main_redact.py \ + --add_metadata \ + --text_column {TEXT_COLUMN} \ + --dataset_name /fsx/leandro/data/pii_result/{LANG} \ + --target_dataset the-stack-no-pii-{LANG} \ + --save_path_disk /fsx/loubna/data/the-stack-march-no-pii-test/{LANG} + " + +export LAUNCHER="python \ + " + +# force crashing on nccl issues like hanging broadcast +export NCCL_ASYNC_ERROR_HANDLING=1 +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=COLL +# export NCCL_SOCKET_NTHREADS=1 +# export NCCL_NSOCKS_PERTHREAD=1 +# export CUDA_LAUNCH_BLOCKING=1 + +# AWS specific +export NCCL_PROTO=simple +export RDMAV_FORK_SAFE=1 +export FI_EFA_FORK_SAFE=1 +export FI_EFA_USE_DEVICE_RDMA=1 +export FI_PROVIDER=efa +export FI_LOG_LEVEL=1 +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_IFNAME=ens + +echo $CMD + +# srun error handling: +# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks +# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code +SRUN_ARGS=" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + " + +# py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD +clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER $CMD" 2>&1 | tee $LOG_PATH + +echo "END TIME: $(date)" +""" + + +if __name__ == "__main__": + args = get_args() + # 88 PLs + languages = [ + "markdown", + "cpp", + "java", + "c-sharp", + "php", + "assembly", + "html", + "c", + "javascript", + "python", + "haskell", + "fortran", + "typescript", + "sparql", + "antlr", + "tex", + "lean", + "literate-haskell", + "elm", + "standard-ml", + "powershell", + "stan", + "matlab", + "solidity", + "smalltalk", + "tcsh", + "idris", + "julia", + "bluespec", + "visual-basic", + "java-server-pages", + "cuda", + "yacc", + "racket", + "thrift", + "sql", + "protocol-buffer", + "elixir", + "kotlin", + "vhdl", + "scheme", + "tcl", + "isabelle", + "prolog", + "json", + "restructuredtext", + "ada", + "rmarkdown", + "clojure", + "r", + "zig", + "ruby", + "batchfile", + "erlang", + "stata", + "xslt", + "css", + "augeas", + "agda", + "awk", + "groovy", + "coffeescript", + "lua", + "systemverilog", + "common-lisp", + "scala", + "verilog", + "dart", + "maple", + "shell", + "alloy", + "rust", + "sas", + "ocaml", + "go", + "literate-coffeescript", + "emacs-lisp", + "literate-agda", + "f-sharp", + "pascal", + "applescript", + "glsl", + "yaml", + "makefile", + "perl", + "mathematica", + "dockerfile", + "cmake", + ] + for i in range(args.start, args.end + 1): + language = languages[i] + print(f"Submitting jobs for experiment on language {language}") + job_name = f"{language}-pii-redaction-idx_{i}" + job = makejob( + JOB_NAME=job_name, + LANG=language, + TEXT_COLUMN=args.text_column, + ) + # submit the job + print(f"Job for lang {language} ready and saved at jobs/{job_name}.sbatch") + submit_job(job, job_name) diff --git a/pii/ner/pii_redaction/utils.py b/pii/ner/pii_redaction/utils.py new file mode 100644 index 0000000..bd4ed42 --- /dev/null +++ b/pii/ner/pii_redaction/utils.py @@ -0,0 +1,195 @@ +import ipaddress +import random +from gibberish_detector import detector + +IGNORE = ["AMBIGUOUS", "USERNAME"] +# List of random private IP addresses to use as replacements +REPLACEMENTS_IP = { + "IPv4": [ + "172.16.31.10", + "172.16.58.3", + "172.16.17.32", + "192.168.127.12", + "192.168.3.11", + ], + "IPv6": [ + "fd00:c2b6:b24b:be67:2827:688d:e6a1:6a3b", + "fd00:a516:7c1b:17cd:6d81:2137:bd2a:2c5b", + "fc00:e968:6179::de52:7100", + "fc00:db20:35b:7399::5", + "fdf8:f53e:61e4::18", + ], +} + +# DNS to avoid masking +POPULAR_DNS_SERVERS = [ + "8.8.8.8", + "8.8.4.4", + "1.1.1.1", + "1.0.0.1", + "76.76.19.19", + "76.223.122.150", + "9.9.9.9", + "149.112.112.112", + "208.67.222.222", + "208.67.220.220", + "8.26.56.26", + "8.20.247.20", + "94.140.14.14", + "94.140.15.15", +] + + +def is_key(matched_str): + """Checks to make sure the PII span is long enough and is gibberish and not word like""" + # pip install gibberish-detector + # download the training corpora from https://raw.githubusercontent.com/domanchi/gibberish-detector/master/examples/big.txt + # run gibberish-detector train big.txt > big.model to generate the model (it takes 3 seconds) + Detector = detector.create_from_model( + "/bigcode-dataset/pii/gibberish_data/big.model" + ) + is_gibberish = Detector.is_gibberish(matched_str.lower()) + return is_gibberish and len(matched_str) > 8 + + +def is_secret(matched_str): + """Checks to make sure the PII span is long enough""" + return len(matched_str) > 3 + + +def is_full_name(matched_str): + """Checks if detected name is a full names and not just first or last name""" + return len(matched_str.split()) > 1 + + +def get_replacements(): + """Build dictionaries of replacements for PII (key, email, IP address, name, password)""" + ip_addresses = REPLACEMENTS_IP + return { + "EMAIL": [""], + "KEY": [""], + "NAME": [""], + "PASSWORD": [""], + "IP_ADDRESS": ip_addresses, + } + + +def replace_ip(value, replacements_dict): + """Replace an IP address with a synthetic IP address of the same format""" + try: + ipaddress.IPv4Address(value) + return random.choice(replacements_dict["IP_ADDRESS"]["IPv4"]) + except ValueError: + try: + ipaddress.IPv6Address(value) + return random.choice(replacements_dict["IP_ADDRESS"]["IPv6"]) + except ValueError: + # this doesn't happen if we already use ipaddress filter in the detection + print("Invalid IP address") + return value + + +def is_secret_ip(ip): + """Check if an IP address is allocated for private networks (non internet facing), or is not an ip address at all""" + try: + ip = ipaddress.ip_address(ip) + except ValueError: + # not an ip address + return True + return ip.is_private + + +def redact_pii_text(text, secrets, replacements, add_references=False): + """Redact PII in a text + Args: + text (str): text to redact + secrets (list): list with the secrets to redact + replacements (dict): dictionary of replacements for each PII type + add_references (bool): whether to add references to the redacted text (delimiters to PII) + for vizualization + Returns: + text (str): new text with redacted secrets + """ + modified = False + if secrets: + secrets = sorted(secrets, key=lambda x: x["start"]) + # store the secrets that were replaced here with their replacements + replaced_secrets = {} + subparts = [] + references = [] + step = 0 + last_text = text + for secret in secrets: + # some post-processing + if secret["tag"] in IGNORE or not is_secret(secret["value"]): + continue + if secret["tag"] == "IP_ADDRESS": + # skip if it's not actual ip address, is a popular DNS server or private IP address + if is_secret_ip(secret["value"]) or ( + secret["value"] in POPULAR_DNS_SERVERS + ): + continue + if secret["tag"] == "KEY" and not is_key(secret["value"]): + continue + if secret["tag"] == "NAME" and not is_full_name(secret["value"]): + continue + modified = True + subtext = text[step : secret["start"]] + subpart = subtext if subtext else " " + subparts.append(subpart) + # if secret is already in replaced_secrets, use the same replacement + if secret["value"] in replaced_secrets: + replacement = replaced_secrets[secret["value"]] + else: + if secret["tag"] == "IP_ADDRESS": + replacement = replace_ip(secret["value"], replacements) + else: + replacement = random.choice(replacements[secret["tag"]]) + replaced_secrets[secret["value"]] = replacement + subparts.append(replacement) + replaced_secrets[secret["value"]] = replacement + if add_references: + references.append(subpart) + references.append(f"PI:{secret['tag']}:{replacement}END_PI") + last_text = text[secret["end"] :] + step = secret["end"] + # if supbarpts are not empty join them (it can be empty when all secrets were skipped) + new_text = "".join(subparts) + last_text if subparts else last_text + if add_references: + references = "".join(references) + last_text if references else "" + else: + new_text = text + references = "" + result = ( + (new_text, references, modified) if add_references else (new_text, modified) + ) + return result + + +def redact_pii_batch(examples, replacements, add_references=True): + """Anonymize PII in a batch of examples from a dataset""" + new_contents = [] + references = [] + modified = [] + for text, secrets in zip( + examples["content"], + examples["entities"], + ): + if secrets: + if add_references: + new_text, reference, modif = redact_pii_text( + text, secrets, replacements, add_references + ) + references.append(reference) + else: + new_text, modif = redact_pii_text(text, secrets, replacements) + new_contents.append(new_text) + modified.append(modif) + else: + new_contents.append(text) + references.append(text) + modified.append(False) + result = {"new_content": new_contents, "modified": modified} + if add_references: + result.update({"references": references}) + return result diff --git a/pii/ner/pii_train_ner/README.md b/pii/ner/pii_train_ner/README.md new file mode 100644 index 0000000..1a97624 --- /dev/null +++ b/pii/ner/pii_train_ner/README.md @@ -0,0 +1,14 @@ +# Fine-tuning StarEncoder on an NER task for PII detection + +To run the training on an annotated PII dataset (`bigcode/pii-full-ds` in our case, you might need to adpat the code to fit your dataset), use the following command: +```bash +python -m torch.distributed.launch \ + --nproc_per_node number_of_gpus train.py \ + --dataset_name bigcode/pii-full-ds \ + --debug \ + --learning_rate 2e-5 \ + --train_batch_size 8 \ + --bf16 \ + --add_not_curated +``` +Note that we use a global batch size of 64 (8*8 GPUs). To use only curated dataset remove the flag `--add_not_curated`. \ No newline at end of file diff --git a/pii/ner/pii_train_ner/train.py b/pii/ner/pii_train_ner/train.py new file mode 100644 index 0000000..4e764cc --- /dev/null +++ b/pii/ner/pii_train_ner/train.py @@ -0,0 +1,251 @@ +import argparse +import os +from pprint import pprint + +from datasets import DatasetDict, load_dataset +from tqdm import tqdm +from functools import partial +from transformers import ( + AutoModelForTokenClassification, + AutoTokenizer, + DataCollatorForTokenClassification, + EarlyStoppingCallback, + Trainer, + TrainingArguments, + set_seed, + logging +) + +from utils.preprocessing import chunk_dataset, tokenize_and_label_batch +from utils.eval import compute_metrics + + +# Special tokens +MASK_TOKEN = "" +SEPARATOR_TOKEN = "" +PAD_TOKEN = "" +CLS_TOKEN = "" + +# NER tags +CATEGORIES = [ + "NAME", + "EMAIL", + "EMAIL_EXAMPLE", + "USERNAME", + "KEY", + "IP_ADDRESS", + "PASSWORD", +] +IGNORE_CLASS = ["AMBIGUOUS", "ID", "NAME_EXAMPLE", "USERNAME_EXAMPLE"] + +LABEL2ID = {"O": 0} +for cat in CATEGORIES: + LABEL2ID[f"B-{cat}"] = len(LABEL2ID) + LABEL2ID[f"I-{cat}"] = len(LABEL2ID) +ID2LABEL = {v: k for k, v in LABEL2ID.items()} + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_ckpt", type=str, default="bigcode/bigcode-encoder") + parser.add_argument( + "--dataset_name", + type=str, + default="bigcode/pii-full-ds" + ) + # addprefix to wandb run + parser.add_argument("--prefix", type=str, default="") + parser.add_argument("--add_not_curated", action="store_true") + parser.add_argument("--train_batch_size", type=int, default=4) + parser.add_argument("--eval_batch_size", type=int, default=4) + parser.add_argument("--num_train_epochs", type=int, default=100) + + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--lr_scheduler_type", type=str, default="cosine") + parser.add_argument("--weight_decay", type=float, default=0.01) + parser.add_argument("--warmup_steps", type=int, default=100) + + parser.add_argument("--gradient_checkpointing", action="store_true") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--eval_accumulation_steps", type=int, default=1) + parser.add_argument("--num_proc", type=int, default=8) + parser.add_argument("--bf16", action="store_true") + parser.add_argument("--fp16", action="store_true") + + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--eval_freq", type=int, default=100) + parser.add_argument("--save_freq", type=int, default=1000) + parser.add_argument("--debug", action="store_true") + parser.add_argument("--output_dir", type=str, default="finetuned-encoder-pii") + return parser.parse_args() + + +def get_stats(data): + # get number of B-cat for cat in categories for each data split + stats = {cat: 0 for cat in CATEGORIES} + for entry in tqdm(data): + for label in entry["labels"]: + # only add labels for beginning with B- + if label > 0 and ID2LABEL[label].startswith("B-"): + stats[ID2LABEL[label][2:]] += 1 + return stats + + +def prepare_tokenizer(tokenizer): + tokenizer.add_special_tokens({"pad_token": PAD_TOKEN}) + tokenizer.add_special_tokens({"sep_token": SEPARATOR_TOKEN}) + tokenizer.add_special_tokens({"cls_token": CLS_TOKEN}) + tokenizer.add_special_tokens({"mask_token": MASK_TOKEN}) + tokenizer.model_max_length = 1024 + return tokenizer + + +def prepare_dataset(dataset, tokenizer, args): + # tokenize and label + dataset = dataset.map( + partial( + tokenize_and_label_batch, + tokenizer=tokenizer, + target_text="text", + pii_column="fragments", + LABEL2ID=LABEL2ID, + IGNORE_CLASS=IGNORE_CLASS, + ), + batched=True, + batch_size=1000, + num_proc=args.num_workers, + ) + return dataset + +def run_training(args, ner_dataset, model, tokenizer): + print(f"Initializing Trainer...") + + training_args = TrainingArguments( + output_dir=args.output_dir, + evaluation_strategy="steps", + num_train_epochs=args.num_train_epochs, + per_device_train_batch_size=args.train_batch_size, + per_device_eval_batch_size=args.eval_batch_size, + eval_steps=args.eval_freq, + save_steps=args.save_freq, + logging_steps=10, + metric_for_best_model="f1", + load_best_model_at_end=True, + weight_decay=args.weight_decay, + learning_rate=args.learning_rate, + lr_scheduler_type=args.lr_scheduler_type, + warmup_steps=args.warmup_steps, + gradient_checkpointing=args.gradient_checkpointing, + gradient_accumulation_steps=args.gradient_accumulation_steps, + eval_accumulation_steps=args.eval_accumulation_steps, + fp16=args.fp16, + bf16=args.bf16, + run_name=f"{args.prefix}-bs{args.train_batch_size}-lr{args.learning_rate}-wd{args.weight_decay}-ep{args.num_train_epochs}-last", + report_to="wandb", + ) + + + data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer) + trainer = Trainer( + model=model, + args=training_args, + train_dataset=ner_dataset["train"], + eval_dataset=ner_dataset["validation"], + data_collator=data_collator, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + callbacks=[ + EarlyStoppingCallback( + early_stopping_patience=15, early_stopping_threshold=1e-2 + ) + ], + ) + + print("Training...") + #trainer.train() + + print("Saving last checkpoint of the model") + #model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint_last_exp/")) + + # evaluate on test set + print("Evaluating on test set...") + trainer.evaluate(ner_dataset["validation"]) + + +def main(args): + # load model and tokenizer + model = AutoModelForTokenClassification.from_pretrained( + #args.model_ckpt, + "/fsx/loubna/code/bigcode-dataset/pii/ner/finetuned-encoder-pii/final_checkpoint-all-noexamples", + num_labels=len(ID2LABEL), + id2label=ID2LABEL, + label2id=LABEL2ID, + use_auth_token=True, + use_cache=not args.gradient_checkpointing, + output_hidden_states = False, + ) + tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt, use_auth_token=True) + tokenizer = prepare_tokenizer(tokenizer) + + # load dataset + dataset = load_dataset(args.dataset_name, use_auth_token=True) + train_data = dataset["train"].shuffle(seed=args.seed) + test_data = dataset["test"] + valid_data = dataset["valid"] + + from datasets import concatenate_datasets + train_data = concatenate_datasets([train_data, test_data]) + print(f"Concatenated train and test data, new train size: {len(train_data)}") + + + if args.dataset_name == "bigcode/pii-full-ds": + if not args.add_not_curated: + print("Removing not curated data (-400 long files)...") + # keep only curated data + train_data = train_data.filter(lambda x: x["data_origin"] == "curated") + else: + print("Keeping not curated data...") + + + train_data = prepare_dataset(train_data, tokenizer, args) + test_data = prepare_dataset(test_data, tokenizer, args) + valid_data = prepare_dataset(valid_data, tokenizer, args) + print( + f"After tokenization:\nTrain size {len(train_data)}\nValid size {len(valid_data)}\nTest size {len(test_data)}" + ) + + if args.debug: + train_stats = get_stats(train_data) + valid_stats = get_stats(valid_data) + test_stats = get_stats(test_data) + print("Train low-resource stats") + # print stats for keys with less than 100 in teh value + pprint({k: v for k, v in train_stats.items() if v < 300}) + print("Valid low-resource stats") + pprint({k: v for k, v in valid_stats.items() if v < 100}) + print("Test low-resource stats") + pprint({k: v for k, v in test_stats.items() if v < 100}) + + + print("Chunking the dataset...") + ner_dataset = DatasetDict( + train=chunk_dataset(train_data, tokenizer), + validation=chunk_dataset(valid_data, tokenizer), + test=chunk_dataset(test_data, tokenizer), + ) + # remove columns + ner_dataset = ner_dataset.remove_columns(["id", "chunk_id"]) + print(ner_dataset) + + run_training(args, ner_dataset, model, tokenizer) + + +if __name__ == "__main__": + args = get_args() + set_seed(args.seed) + os.makedirs(args.output_dir, exist_ok=True) + + logging.set_verbosity_info() + + main(args) \ No newline at end of file diff --git a/pii/ner/requirements.txt b/pii/ner/requirements.txt new file mode 100644 index 0000000..4901e0b --- /dev/null +++ b/pii/ner/requirements.txt @@ -0,0 +1,4 @@ +datasets +transformers +evaluate +seqeval \ No newline at end of file diff --git a/pii/ner/utils/eval.py b/pii/ner/utils/eval.py new file mode 100644 index 0000000..fbb3034 --- /dev/null +++ b/pii/ner/utils/eval.py @@ -0,0 +1,65 @@ +# source: https://github.com/mponty/bigcode-dataset/tree/main/pii/ner_model_training/utils by @mponty +import numpy as np +from evaluate import load +from scipy.special import softmax +from sklearn.metrics import average_precision_score + +_seqeval_metric = load("seqeval") + + +# NER tags +CATEGORIES = [ + "NAME", + "EMAIL", + "EMAIL_EXAMPLE", + "USERNAME", + "KEY", + "IP_ADDRESS", + "PASSWORD", +] +IGNORE_CLASS = ["AMBIGUOUS", "ID", "NAME_EXAMPLE", "USERNAME_EXAMPLE"] + +LABEL2ID = {"O": 0} +for cat in CATEGORIES: + LABEL2ID[f"B-{cat}"] = len(LABEL2ID) + LABEL2ID[f"I-{cat}"] = len(LABEL2ID) +ID2LABEL = {v: k for k, v in LABEL2ID.items()} + + +def compute_ap(pred, truth): + pred_proba = 1 - softmax(pred, axis=-1)[..., 0] + pred_proba, truth = pred_proba.flatten(), np.array(truth).flatten() + pred_proba = pred_proba[truth != -100] + truth = truth[truth != -100] + + return average_precision_score(truth != 0, pred_proba) + + +def compute_metrics(p): + predictions, labels = p + avg_prec = compute_ap(predictions, labels) + predictions = np.argmax(predictions, axis=2) + + # Remove ignored index (special tokens) + true_predictions = [ + [ID2LABEL[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [ID2LABEL[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + + results = _seqeval_metric.compute( + predictions=true_predictions, references=true_labels, zero_division=0, + ) + agg_metrics = { + "Avg.Precision": avg_prec, + "precision": results.pop("overall_precision"), + "recall": results.pop("overall_recall"), + "f1": results.pop("overall_f1"), + } + results.pop("overall_accuracy") + per_cat_metrics = {name: metrics["f1"] for name, metrics in results.items()} + + return dict(**agg_metrics, **per_cat_metrics) diff --git a/pii/ner/utils/preprocessing.py b/pii/ner/utils/preprocessing.py new file mode 100644 index 0000000..203e627 --- /dev/null +++ b/pii/ner/utils/preprocessing.py @@ -0,0 +1,145 @@ +# source: https://github.com/mponty/bigcode-dataset/tree/main/pii/ner_model_training/utils by @mponty +import itertools +from tqdm import tqdm +from datasets import Dataset + +def is_overlap(span, reference_span): + l1, r1 = min(*span), max(*span) + l2, r2 = min(*reference_span), max(*reference_span) + return l1 <= l2 < r1 or l1 < r2 <= r1 or l2 <= l1 < r2 or l2 < r1 <= r2 + + +def label_tokenized( + entry, target_text="text", pii_column="fragments", LABEL2ID=None, IGNORE_CLASS=None +): + content, pii = entry[target_text], entry[pii_column] + + if entry["offset_mapping"][-1] == (0, 0): + entry["offset_mapping"][-1] = (len(content), len(content)) + + entry["labels"] = [LABEL2ID["O"]] * len(entry["offset_mapping"]) + for entity in pii: + if entity["category"] in IGNORE_CLASS: + continue + prefix = "B-" + entity_span = tuple(entity["position"]) + for i, span in enumerate(entry["offset_mapping"]): + if is_overlap(entity_span, span): + label = prefix + entity["category"] + entry["labels"][i] = LABEL2ID[label] + prefix = "I-" + + return entry + + +def add_special_toks(entry, target_text, tokenizer): + content = entry[target_text] + entry["input_ids"] = ( + [tokenizer.cls_token_id] + entry["input_ids"] + [tokenizer.sep_token_id] + ) + entry["attention_mask"] = [1] + entry["attention_mask"] + [1] + entry["offset_mapping"] = ( + [(0, 0)] + entry["offset_mapping"] + [(len(content), len(content))] + ) + entry["labels"] = [-100] + entry["labels"] + [-100] + return entry + + +def tokenize_and_label_batch( + entries, + tokenizer, + target_text="text", + pii_column="fragments", + LABEL2ID=None, + IGNORE_CLASS=None, +): + """Tokenize and label a batch of entries""" + list_inputs = { + k: [] for k in ["input_ids", "attention_mask", "offset_mapping", "labels"] + } + for text, fragments in zip(entries[target_text], entries[pii_column]): + entry = {"text": text, "fragments": fragments} + inputs = tokenizer.encode_plus( + text, return_offsets_mapping=True, add_special_tokens=False + ) + entry.update(inputs) + entry = label_tokenized( + entry, + target_text=target_text, + pii_column=pii_column, + LABEL2ID=LABEL2ID, + IGNORE_CLASS=IGNORE_CLASS, + ) + entry = add_special_toks(entry, target_text=target_text, tokenizer=tokenizer) + for k in list_inputs.keys(): + list_inputs[k].append(entry[k]) + return list_inputs + + +# Chunking +# we do all chunking with overlap_freq = 0 + + +def _get_chunking_step(length, overlap_freq): + step = length + if overlap_freq: + if overlap_freq > 1: + step = length // overlap_freq + else: + step = length // 2 + return step + + +def _chunked_seq(seq, length, overlap_freq=0): + step = _get_chunking_step(length, overlap_freq) + + for i in range(len(seq) // step + 1): + if i * step < len(seq): + yield seq[i * step : i * step + length] + + +def chunk_inputs( + input_ids, + attention_mask, + labels, + id, + *, + tokenizer, + max_length, + overlap_freq=0, + **kwargs +): + chunks = zip( + *[ + _chunked_seq(seq, max_length, overlap_freq) + for seq in (input_ids, attention_mask, labels) + ] + ) + return [ + dict( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + id=id, + chunk_id=i, + ) + for i, (input_ids, attention_mask, labels) in enumerate(chunks) + ] + + +def chunk_dataset(dataset, tokenizer, overlap_freq=0): + return Dataset.from_list( + list( + itertools.chain( + *( + chunk_inputs( + **entry, + tokenizer=tokenizer, + max_length=tokenizer.model_max_length, + overlap_freq=overlap_freq + ) + for entry in tqdm(list(dataset)) + ) + ) + ) + )