From c04356b482f0889d455340009dfd31708d5f0e1a Mon Sep 17 00:00:00 2001 From: loubnabnl Date: Mon, 27 Mar 2023 11:10:20 +0000 Subject: [PATCH] add redaction code --- pii/ner/pii_redaction/README.md | 5 + pii/ner/pii_redaction/main_redact.py | 288 +++++++++++++++++++++++ pii/ner/pii_redaction/manual_sharding.py | 54 +++++ pii/ner/pii_redaction/replacements.json | 1 + pii/ner/pii_redaction/utils.py | 175 ++++++++++++++ pii/ner/train.py | 15 +- 6 files changed, 531 insertions(+), 7 deletions(-) create mode 100644 pii/ner/pii_redaction/README.md create mode 100644 pii/ner/pii_redaction/main_redact.py create mode 100644 pii/ner/pii_redaction/manual_sharding.py create mode 100644 pii/ner/pii_redaction/replacements.json create mode 100644 pii/ner/pii_redaction/utils.py diff --git a/pii/ner/pii_redaction/README.md b/pii/ner/pii_redaction/README.md new file mode 100644 index 0000000..4fd1c28 --- /dev/null +++ b/pii/ner/pii_redaction/README.md @@ -0,0 +1,5 @@ +# PII redaction + +```bash +python main_redact.py --dataset_name /fsx/leandro/data/pii_result/ada --target_dataset ada-no-pii --save_path_disk ada-no-pii-local +``` \ No newline at end of file diff --git a/pii/ner/pii_redaction/main_redact.py b/pii/ner/pii_redaction/main_redact.py new file mode 100644 index 0000000..aae3e1e --- /dev/null +++ b/pii/ner/pii_redaction/main_redact.py @@ -0,0 +1,288 @@ +"""Mask detected PII in a dataset. +""" + +import argparse +import json +import logging +import random +import time +from functools import partial +from pprint import pformat + +from datasets import load_dataset +from datasets.utils.logging import set_verbosity_info + +from manual_sharding import save_manual_shards +from utils import get_replacements, redact_pii_batch + + +def parseArgs(): + parser = argparse.ArgumentParser(description="PII detection and redaction") + parser.add_argument( + "--dataset_name", + default="bigcode/pii-for-code", + type=str, + help="HF repo name/path of the dataset.", + ) + parser.add_argument( + "--num_load_proc", + default=64, + type=int, + help="Number of processes to use for loading the dataset", + ) + parser.add_argument( + "--lang", + default="ada", + type=str, + help="Language to redact PII in.", + ) + parser.add_argument( + "--text_column", + default="content", + type=str, + help="Text column to use, if will be renamed to content", + ) + parser.add_argument( + "--split", + default="train", + type=str, + help="Dataset split to process", + ) + parser.add_argument( + "--batch_size", + default=100, + type=int, + help="Batch size for the PII detection/redaction", + ) + parser.add_argument( + "--seed", + default=0, + type=int, + help="Seed for random", + ) + parser.add_argument( + "--num_proc", + default=96, + type=int, + help="Number of processes to use for the PII detection/redaction", + ) + parser.add_argument( + "--no_redaction", + action="store_true", + help="If set, we don't perform redaction", + ) + parser.add_argument( + "--load_replacements", + default=True, + help="If set, we load the replacements from file replacements.json", + ) + parser.add_argument( + "--add_reference_text", + default=False, + type=bool, + help="If True we add the reference text with PII between delimiters \ + in the redacted text -used for visualization-", + ) + parser.add_argument( + "--check_all_files", + action="store_true", + help="If set, we check all files, not only the ones that contain PII", + ) + parser.add_argument( + "--check_sampling_size", + default=0, + type=int, + help="Number of samples to check for PII", + ) + # for saving the dataset: either push to HF or save locally with datasets or save manual shards + parser.add_argument( + "--save_mode", + default="manual_shards", + type=str, + choices=["hub", "local", "manual_shards"], + help="How to save the dataset", + ) + parser.add_argument( + "--save_mode_checks", + default="hub", + type=str, + choices=["hub", "local", "manual_shards"], + help="How to save the checks dataset", + ) + # add argument for name of dataset on the hub + parser.add_argument( + "--target_dataset", + default="bigcode-pii2", + type=str, + help="HF repo name of the target dataset in save_mode=hub.", + ) + parser.add_argument( + "--hub_username", + default="loubnabnl", + type=str, + help="Username for the hub", + ) + parser.add_argument( + "--save_path_disk", + default="bigcode-pii2-local", + type=str, + help="Path to save the dataset on disk in save_mode=local.", + ) + return parser.parse_args() + + +def get_check_ds(ds, args): + if not args.check_all_files: + ds_checks = ds.filter( + lambda exs: exs["modified"], + batched=True, + batch_size=args.batch_size, + num_proc=args.num_proc, + ) + else: + ds_checks = ds + if not args.check_sampling_size: + sampling_size = len(ds_checks) + idx_samples = random.sample( + range(len(ds_checks)), min(len(ds_checks), sampling_size) + ) + ds_checks = ds_checks.select(idx_samples) + + return ds_checks + + +def check_uniques(example, uniques): + """Check if current id is still in set of unique id and remove if true.""" + if example["id"] in uniques: + uniques.remove(example["id"]) + return True + else: + return False + + +def main(): + set_verbosity_info() + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + handlers=[logging.FileHandler("pii.log"), logging.StreamHandler()], + ) + args = parseArgs() + logger.info( + f"** The job is running with the following arguments: **\n{args}\n **** " + ) + + logger.info(f" ===== Loading {args.dataset_name} =====") + ds = load_dataset( + args.dataset_name, + split=args.split, + use_auth_token=True, + num_proc=args.num_load_proc, + ) + if args.text_column != "content": + ds = ds.rename_column(args.text_column, "content") + + logger.info(f" ===== Deduplicating dataset =====") + # Deduplication based on ids + uniques = set(ds["id"]) + frac = len(uniques) / len(ds) + logger.info(f"Fraction of duplicates: {1-frac:.2%}") + logger.info(f"Dataset:\n{ds}") + # Deduplicate data and apply heuristics + t_start = time.time() + ds_pii = ds.filter( + check_uniques, fn_kwargs={"uniques": uniques}, num_proc=args.num_proc + ) + logger.info(f"Time to filter dataset: {time.time()-t_start:.2f}") + logger.info(f"Dataset after dedup:\n{ds_pii}") + + logger.info( + f"Number of samples that contained PII: {sum([1 if x['entities'] else 0 for x in ds_pii])}" + ) + logger.info( + f"Total number of secrets found: {sum([len(x['entities']) for x in ds_pii])}" + ) + + # redact PII in the dataset + logger.info(f" ===== Applying PII redaction =====") + random.seed(args.seed) + + replacements = get_replacements() + with open("replacements.json", "w") as f: + json.dump(replacements, f) + logging.info(f"Using the following replacements:\n{pformat(replacements)}") + ds_pii = ds_pii.map( + partial( + redact_pii_batch, + replacements=replacements, + add_references=args.add_reference_text, + ), + batched=True, + batch_size=args.batch_size, + num_proc=args.num_proc, + load_from_cache_file=False, + ) + logging.info(f"Dataset info after PII redaction:\n{ds_pii}") + + # check the dataset + logger.info( + f" ===== Checking {args.check_sampling_size} samples from those modified in the dataset =====" + ) + ds_checks = get_check_ds(ds_pii, args) + + # save checks dataset + if len(ds_checks) == 0: + logger.info("Dataset was empty. Not saving anything.") + else: + logger.info(f"Checks dataset info {ds_checks}") + if args.save_mode_checks == "hub": + logger.info( + f"Pushing the checks dataset to the Hub as {args.target_dataset}_checks" + ) + ds_checks.push_to_hub(args.target_dataset + "_checks") + + elif args.save_mode_checks == "local": + logger.info(f"Saving the checks dataset to disk") + ds_checks.save_to_disk(args.save_path_disk + "_checks") + + elif args.save_mode_checks == "manual_shards": + logger.info(f"Saving the checks dataset in manual shards") + save_manual_shards( + ds_checks, + user=args.hub_username, + remote_dataset_repo=args.target_dataset + "_checks", + ) + + logger.info("Removing columns that are not needed for the final dataset") + columns = ["content", "modified", "entities"] + if args.add_reference_text: + columns.append("references") + ds_pii = ds_pii.remove_columns(columns) + ds_pii = ds_pii.rename_column("new_content", "content") + logger.info(f"Dataset info after removing columns:\n{ds_pii}") + + # save the final dataset + if args.save_mode == "hub": + logger.info( + f" ===== Pushing the dataset to the Hub as: {args.target_dataset} =====" + ) + ds_pii.push_to_hub(args.target_dataset) + + elif args.save_mode == "local": + logger.info(f" ===== Saving the dataset to disk =====") + ds_pii.save_to_disk(args.save_path_disk) + + elif args.save_mode == "manual_shards": + logger.info(f" ===== Saving the dataset in manual shards =====") + save_manual_shards( + ds_pii, user=args.hub_username, remote_dataset_repo=args.target_dataset + ) + + logger.info(f" ===== Dataset saved successfully =====") + + +if __name__ == "__main__": + main() diff --git a/pii/ner/pii_redaction/manual_sharding.py b/pii/ner/pii_redaction/manual_sharding.py new file mode 100644 index 0000000..8edf034 --- /dev/null +++ b/pii/ner/pii_redaction/manual_sharding.py @@ -0,0 +1,54 @@ +import os +import time +from multiprocessing import Pool +from tqdm import tqdm + +from huggingface_hub import Repository + + +def save_shard(shard_tuple): + """Save shard""" + filename, shard = shard_tuple + # use to_json instead to save as json file + shard.to_parquet(filename) + +def save_manual_shards(ds, user="loubnabnl", remote_dataset_repo="bigcode-pii-pjj"): + """Save sharded data + Args: + ds (Dataset): dataset to be saved + user (str): user name + remote_dataset_repo (str): remote dataset repository + out_path (str): path to save the shards""" + # this will create a folder OUT_PATH that is a clone of REMOTE_DATASET_REPO + # you can save the shards inside it and do git add/commit/push to push data to the hub + out_path = remote_dataset_repo + # if out path doesnt already exist + if not os.path.exists(out_path): + repo = Repository( + local_dir=out_path, + clone_from=user + "/" + remote_dataset_repo, + repo_type="dataset", + use_auth_token=True, + git_user=user + ) + + # files will be numerous we save them in a folder called data inside out_path + os.mkdir(out_path + "/data") + SHARD_SIZE = 1000 << 20 + if ds._indices is not None: + dataset_nbytes = ds.data.nbytes * len(ds._indices) / len(ds.data) + else: + dataset_nbytes = ds.data.nbytes + num_shards = int(dataset_nbytes / SHARD_SIZE) + 1 + print(f"Number of shards: {num_shards}") + + print("sharding the dataset") + t_start = time.time() + shards = (ds.shard(num_shards=num_shards, index=i, contiguous=True) for i in range(num_shards)) + # use f"{OUT_PATH}/data/train-{index:05d}-of-{num_shards:05d}.json" instead for json files + filenames = (f"{out_path}/data/train-{index:05d}-of-{num_shards:05d}.parquet" for index in range(num_shards)) + + with Pool(16) as p: + list(tqdm(p.imap_unordered(save_shard, zip(filenames, shards), chunksize=4), total=num_shards)) + print(f"Time to save dataset: {time.time()-t_start:.2f}") + # to push dataset to hub do: git add/commit/push inside OUT_PATH \ No newline at end of file diff --git a/pii/ner/pii_redaction/replacements.json b/pii/ner/pii_redaction/replacements.json new file mode 100644 index 0000000..474eccc --- /dev/null +++ b/pii/ner/pii_redaction/replacements.json @@ -0,0 +1 @@ +{"EMAIL": [""], "KEY": [""], "NAME": [""], "PASSWORD": [""], "IP_ADDRESS": {"IPv4": ["172.16.31.10", "172.16.58.3", "172.16.17.32", "192.168.127.12", "192.168.3.11"], "IPv6": ["fd00:c2b6:b24b:be67:2827:688d:e6a1:6a3b", "fd00:a516:7c1b:17cd:6d81:2137:bd2a:2c5b", "fc00:e968:6179::de52:7100", "fc00:db20:35b:7399::5", "fdf8:f53e:61e4::18"]}} \ No newline at end of file diff --git a/pii/ner/pii_redaction/utils.py b/pii/ner/pii_redaction/utils.py new file mode 100644 index 0000000..a054b7c --- /dev/null +++ b/pii/ner/pii_redaction/utils.py @@ -0,0 +1,175 @@ +import ipaddress +import json +import random + +IGNORE = ["AMBIGUOUS", "USERNAME"] +# List of random private IP addresses to use as replacements +REPLACEMENTS_IP = { + "IPv4": [ + "172.16.31.10", + "172.16.58.3", + "172.16.17.32", + "192.168.127.12", + "192.168.3.11", + ], + "IPv6": [ + "fd00:c2b6:b24b:be67:2827:688d:e6a1:6a3b", + "fd00:a516:7c1b:17cd:6d81:2137:bd2a:2c5b", + "fc00:e968:6179::de52:7100", + "fc00:db20:35b:7399::5", + "fdf8:f53e:61e4::18", + ], +} + +# DNS to avoid masking +POPULAR_DNS_SERVERS = [ + "8.8.8.8", + "8.8.4.4", + "1.1.1.1", + "1.0.0.1", + "76.76.19.19", + "76.223.122.150", + "9.9.9.9", + "149.112.112.112", + "208.67.222.222", + "208.67.220.220", + "8.26.56.26", + "8.20.247.20", + "94.140.14.14", + "94.140.15.15", +] + + +def load_json(sample): + try: + return json.loads(sample) + except ValueError: + return [] + + +def get_replacements(n=10): + """Build dictionaries of replacements for PII (key, email, IP address, name, password)""" + ip_addresses = REPLACEMENTS_IP + return { + "EMAIL": [""], + "KEY": [""], + "NAME": [""], + "PASSWORD": [""], + "IP_ADDRESS": ip_addresses, + } + + +def replace_ip(value, replacements_dict): + """Replace an IP address with a synthetic IP address of the same format""" + try: + ipaddress.IPv4Address(value) + return random.choice(replacements_dict["IP_ADDRESS"]["IPv4"]) + except ValueError: + try: + ipaddress.IPv6Address(value) + return random.choice(replacements_dict["IP_ADDRESS"]["IPv6"]) + except ValueError: + # this doesn't happen if we already use ipaddress filter in the detection + print("Invalid IP address") + return value + + +def is_private_ip(ip): + """Check if an IP address is allocated for private networks (non internet facing), or is not an ip address at all""" + try: + ip = ipaddress.ip_address(ip) + except ValueError: + # not an ip address + return True + return ip.is_private + + +def redact_pii_text(text, secrets, replacements, add_references=False): + """Redact PII in a text + Args: + text (str): text to redact + secrets (list): list with the secrets to redact + replacements (dict): dictionary of replacements for each PII type + add_references (bool): whether to add references to the redacted text (delimiters to PII) + for vizualization + Returns: + text (str): new text with redacted secrets + """ + modified = False + if secrets: + secrets = sorted(secrets, key=lambda x: x["start"]) + # store the secrets that were replaced here with their replacements + replaced_secrets = {} + subparts = [] + references = [] + step = 0 + last_text = text + for secret in secrets: + if secret["tag"] in IGNORE: + continue + if secret["tag"] == "IP_ADDRESS": + # skip secret if it is not actual ip address, is apopular DNS server or private IP address + if is_private_ip(secret["value"]) or ( + secret["value"] in POPULAR_DNS_SERVERS + ): + continue + modified = True + subtext = text[step : secret["start"]] + subpart = subtext if subtext else " " + subparts.append(subpart) + # if secret is already in replaced_secrets, use the same replacement + if secret["value"] in replaced_secrets: + replacement = replaced_secrets[secret["value"]] + else: + if secret["tag"] == "IP_ADDRESS": + replacement = replace_ip(secret["value"], replacements) + else: + replacement = random.choice(replacements[secret["tag"]]) + replaced_secrets[secret["value"]] = replacement + subparts.append(replacement) + replaced_secrets[secret["value"]] = replacement + if add_references: + references.append(subpart) + references.append(f"PI:{secret['tag']}:{replacement}END_PI") + last_text = text[secret["end"] :] + step = secret["end"] + # if supbarpts are not empty join them (it can be empty when all secrets were skipped) + new_text = "".join(subparts) + last_text if subparts else last_text + if add_references: + references = "".join(references) + last_text if references else "" + else: + new_text = text + references = "" + result = ( + (new_text, references, modified) if add_references else (new_text, modified) + ) + return result + + +def redact_pii_batch(examples, replacements, add_references=True): + """Anonymize PII in a batch of examples from a dataset""" + new_contents = [] + references = [] + modified = [] + for text, secrets in zip( + examples["content"], + examples["entities"], + ): + if secrets: + if add_references: + new_text, reference, modif = redact_pii_text( + text, secrets, replacements, add_references + ) + references.append(reference) + else: + new_text, modif = redact_pii_text(text, secrets, replacements) + new_contents.append(new_text) + modified.append(modif) + else: + new_contents.append(text) + references.append(text) + modified.append(False) + result = {"new_content": new_contents, "modified": modified} + if add_references: + result.update({"references": references}) + return result diff --git a/pii/ner/train.py b/pii/ner/train.py index 312fcbe..4e764cc 100644 --- a/pii/ner/train.py +++ b/pii/ner/train.py @@ -141,7 +141,7 @@ def run_training(args, ner_dataset, model, tokenizer): eval_accumulation_steps=args.eval_accumulation_steps, fp16=args.fp16, bf16=args.bf16, - run_name=f"{args.prefix}-bs{args.train_batch_size}-lr{args.learning_rate}-wd{args.weight_decay}-ep{args.num_train_epochs}", + run_name=f"{args.prefix}-bs{args.train_batch_size}-lr{args.learning_rate}-wd{args.weight_decay}-ep{args.num_train_epochs}-last", report_to="wandb", ) @@ -157,26 +157,27 @@ def run_training(args, ner_dataset, model, tokenizer): compute_metrics=compute_metrics, callbacks=[ EarlyStoppingCallback( - early_stopping_patience=20, early_stopping_threshold=1e-2 + early_stopping_patience=15, early_stopping_threshold=1e-2 ) ], ) print("Training...") - trainer.train() + #trainer.train() print("Saving last checkpoint of the model") - model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/")) + #model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint_last_exp/")) # evaluate on test set - #print("Evaluating on test set...") - #trainer.evaluate(ner_dataset["test"]) + print("Evaluating on test set...") + trainer.evaluate(ner_dataset["validation"]) def main(args): # load model and tokenizer model = AutoModelForTokenClassification.from_pretrained( - args.model_ckpt, + #args.model_ckpt, + "/fsx/loubna/code/bigcode-dataset/pii/ner/finetuned-encoder-pii/final_checkpoint-all-noexamples", num_labels=len(ID2LABEL), id2label=ID2LABEL, label2id=LABEL2ID,