diff --git a/pii/ner/train.py b/pii/ner/train.py new file mode 100644 index 0000000..c72677e --- /dev/null +++ b/pii/ner/train.py @@ -0,0 +1,222 @@ +import argparse +import itertools +import json +from pprint import pprint + +import numpy as np +import pandas as pd +from datasets import Dataset, DatasetDict, load_dataset, load_from_disk, load_metric +from huggingface_hub import notebook_login +from tqdm import tqdm +from transformers import ( + AutoModelForTokenClassification, + AutoTokenizer, + DataCollatorForTokenClassification, + EarlyStoppingCallback, + Trainer, + TrainingArguments, +) + +from utils.preprocessing import chunk_dataset, tokenize_and_label_batch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_ckpt", type=str, default="bigcode/bigcode-encoder") + parser.add_argument( + "--dataset_name", + type=str, + default="bigcode/pii-annotated-toloka-donwsample-emails", + ) + parser.add_argument("batch_size", type=int, default=16) + parser.add_argument("learning_rate", type=float, default=1e-5) + parser.add_argument("lr_scheduler_type", type=str, default="cosine") + parser.add_argument("num_train_epochs", type=int, default=3) + parser.add_argument("weight_decay", type=float, default=0.01) + parser.add_argument("gradient_checkpointing", action="store_true") + parser.add_argument("output_dir", type=str, default="finetuned-encoder-pii") + parser.add_argument("seed", type=int, default=0) + parser.add_argument("num_proc", type=int, default=8) + parser.add_argument("max_length", type=int, default=1024) + parser.add_argument("debug", action="store_true") + parser.add_argument("bf16", action="store_true") + parser.add_argument("fp16", action="store_true") + parser.add_argument("eval_freq", type=int, default=100) + parser.add_argument("save_freq", type=int, default=1000) + return parser.parse_args() + + +def get_stats(data): + # get number of B-cat for cat in categories for each data split + stats = {cat: 0 for cat in CATEGORIES} + for entry in tqdm(data): + for label in entry["labels"]: + # only add labels for beginning with B- + if label > 0 and ID2LABEL[label].startswith("B-"): + stats[ID2LABEL[label][2:]] += 1 + return stats + + +def prepare_tokenizer(tokenizer): + tokenizer.add_special_tokens({"pad_token": PAD_TOKEN}) + tokenizer.add_special_tokens({"sep_token": SEPARATOR_TOKEN}) + tokenizer.add_special_tokens({"cls_token": CLS_TOKEN}) + tokenizer.add_special_tokens({"mask_token": MASK_TOKEN}) + tokenizer.model_max_length = 1024 + return tokenizer + + +# Special tokens +MASK_TOKEN = "" +SEPARATOR_TOKEN = "" +PAD_TOKEN = "" +CLS_TOKEN = "" + +# NER tags +CATEGORIES = [ + "NAME", + "NAME_LICENSE", + "NAME_EXAMPLE", + "EMAIL", + "EMAIL_LICENSE", + "EMAIL_EXAMPLE", + "USERNAME", + "USERNAME_LICENSE", + "USERNAME_EXAMPLE", + "KEY", + "IP_ADDRESS", + "PASSWORD", +] +IGNORE_CLASS = ["AMBIGUOUS", "ID"] + +LABEL2ID = {"O": 0} +for cat in CATEGORIES: + LABEL2ID[f"B-{cat}"] = len(LABEL2ID) + LABEL2ID[f"I-{cat}"] = len(LABEL2ID) +ID2LABEL = {v: k for k, v in LABEL2ID.items()} + + +def run_training(args, ner_dataset): + print(f"Initializing Trainer...") + + training_args = TrainingArguments( + output_dir=args.output_dir, + evaluation_strategy="steps", + num_train_epochs=args.num_train_epochs, + eval_steps=args.eval_freq, + save_steps=args.save_freq, + logging_steps=10, + metric_for_best_model="f1", + load_best_model_at_end=True, + weight_decay=args.weight_decay, + learning_rate=args.learning_rate, + lr_scheduler_type=args.lr_scheduler_type, + warmup_steps=args.num_warmup_steps, + gradient_checkpointing=args.no_gradient_checkpointing, + gradient_accumulation_steps=args.gradient_accumulation_steps, + fp16=args.fp16, + bf16=args.bf16, + weight_decay=args.weight_decay, + run_name=f"pii-bs{batch_size}-lr{lr}-wd{wd}-epochs{max_epochs}", + report_to="wandb", + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=ner_dataset["train"], + eval_dataset=ner_dataset["validation"], + data_collator=data_collator, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + callbacks=[ + EarlyStoppingCallback( + early_stopping_patience=30, early_stopping_threshold=1e-3 + ) + ], + ) + + print("Training...") + trainer.train() + + print("Saving last checkpoint of the model") + model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/")) + + +def main(args): + # load model and tokenizer + model = AutoModelForTokenClassification.from_pretrained( + args.model_ckpt, + num_labels=len(ID2LABEL), + id2label=ID2LABEL, + label2id=LABEL2ID, + use_auth_token=True, + use_cache=not args.gradient_checkpointing, + ) + tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt, use_auth_token=True) + tokenizer = prepare_tokenizer(tokenizer) + + # load dataset + dataset = load_dataset(args.dataset_name, use_auth_token=True, split="train") + dataset = dataset.remove_columns(["id"]) + dataset = dataset.add_column("id", range(len(dataset))) + data = dataset.map( + partial( + tokenize_and_label_batch, + tokenizer, + target_text="text", + pii_column="fragments", + LABEL2ID=LABEL2ID, + IGNORE_CLASS=IGNORE_CLASS, + ), + batched=True, + batch_size=1000, + num_proc=NUM_PROC, + ) + + # split to train and test + data = data.train_test_split(test_size=0.2, shuffle=True, seed=args.seed) + test_valid = data["test"].train_test_split( + test_size=0.6, shuffle=True, seed=args.seed + ) + train_data = data["train"] + valid_data = test_valid["train"] + test_data = test_valid["test"] + test_data.to_json(f"{args.output_dir}/test_data.json") + print("Test data saved to test_data.json") + + if args.debug: + print( + f"Train size {len(train_data)}\nValid size {len(valid_data)}\nTest size {len(test_data)}" + ) + train_stats = get_stats(train_data) + valid_stats = get_stats(valid_data) + test_stats = get_stats(test_data) + print("Train low-resource stats") + # print stats for keys with less than 100 in teh value + pprint({k: v for k, v in train_stats.items() if v < 300}) + print("Valid low-resource stats") + pprint({k: v for k, v in valid_stats.items() if v < 100}) + print("Test low-resource stats") + pprint({k: v for k, v in test_stats.items() if v < 100}) + + print("Chunking the dataset...") + data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer) + ner_dataset = DatasetDict( + train=chunk_dataset(train_data, tokenizer), + validation=chunk_dataset(valid_data, tokenizer), + test=chunk_dataset(test_data, tokenizer), + ) + print(ner_dataset) + + run_training(args, ner_dataset) + + +if __name__ == "__main__": + args = get_args() + set_seed(args.seed) + os.makedirs(args.output_dir, exist_ok=True) + + logging.set_verbosity_error() + + main(args) diff --git a/pii/ner/utils/eval.py b/pii/ner/utils/eval.py new file mode 100644 index 0000000..44fa05f --- /dev/null +++ b/pii/ner/utils/eval.py @@ -0,0 +1,46 @@ +# source: https://github.com/mponty/bigcode-dataset/tree/main/pii/ner_model_training/utils by @mponty +import numpy as np +from evaluate import load +from scipy.special import softmax +from sklearn.metrics import average_precision_score + +_seqeval_metric = load("seqeval") + + +def compute_ap(pred, truth): + pred_proba = 1 - softmax(pred, axis=-1)[..., 0] + pred_proba, truth = pred_proba.flatten(), np.array(truth).flatten() + pred_proba = pred_proba[truth != -100] + truth = truth[truth != -100] + + return average_precision_score(truth != 0, pred_proba) + + +def compute_metrics(p): + predictions, labels = p + avg_prec = compute_ap(predictions, labels) + predictions = np.argmax(predictions, axis=2) + + # Remove ignored index (special tokens) + true_predictions = [ + [ID2LABEL[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [ID2LABEL[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + + results = _seqeval_metric.compute( + predictions=true_predictions, references=true_labels + ) + agg_metrics = { + "Avg.Precision": avg_prec, + "precision": results.pop("overall_precision"), + "recall": results.pop("overall_recall"), + "f1": results.pop("overall_f1"), + } + results.pop("overall_accuracy") + per_cat_metrics = {name: metrics["f1"] for name, metrics in results.items()} + + return dict(**agg_metrics, **per_cat_metrics) diff --git a/pii/ner/utils/preprocessing.py b/pii/ner/utils/preprocessing.py new file mode 100644 index 0000000..623d4e1 --- /dev/null +++ b/pii/ner/utils/preprocessing.py @@ -0,0 +1,143 @@ +# source: https://github.com/mponty/bigcode-dataset/tree/main/pii/ner_model_training/utils by @mponty + + +def is_overlap(span, reference_span): + l1, r1 = min(*span), max(*span) + l2, r2 = min(*reference_span), max(*reference_span) + return l1 <= l2 < r1 or l1 < r2 <= r1 or l2 <= l1 < r2 or l2 < r1 <= r2 + + +def label_tokenized( + entry, target_text="text", pii_column="fragments", LABEL2ID=None, IGNORE_CLASS=None +): + content, pii = entry[target_text], entry[pii_column] + + if entry["offset_mapping"][-1] == (0, 0): + entry["offset_mapping"][-1] = (len(content), len(content)) + + entry["labels"] = [LABEL2ID["O"]] * len(entry["offset_mapping"]) + for entity in pii: + if entity["category"] == IGNORE_CLASS: + continue + prefix = "B-" + entity_span = tuple(entity["position"]) + for i, span in enumerate(entry["offset_mapping"]): + if is_overlap(entity_span, span): + label = prefix + entity["category"] + entry["labels"][i] = LABEL2ID[label] + prefix = "I-" + + return entry + + +def add_special_toks(entry, target_text, tokenizer): + content = entry[target_text] + entry["input_ids"] = ( + [tokenizer.cls_token_id] + entry["input_ids"] + [tokenizer.sep_token_id] + ) + entry["attention_mask"] = [1] + entry["attention_mask"] + [1] + entry["offset_mapping"] = ( + [(0, 0)] + entry["offset_mapping"] + [(len(content), len(content))] + ) + entry["labels"] = [-100] + entry["labels"] + [-100] + return entry + + +def tokenize_and_label_batch( + entries, + tokenizer, + target_text="text", + pii_column="fragments", + LABEL2ID=None, + IGNORE_CLASS=None, +): + """Tokenize and label a batch of entries""" + list_inputs = { + k: [] for k in ["input_ids", "attention_mask", "offset_mapping", "labels"] + } + for text, fragments in zip(entries[target_text], entries[pii_column]): + entry = {"text": text, "fragments": fragments} + inputs = tokenizer.encode_plus( + text, return_offsets_mapping=True, add_special_tokens=False + ) + entry.update(inputs) + entry = label_tokenized( + entry, + target_text=target_text, + pii_column=pii_column, + LABEL2ID=LABEL2ID, + IGNORE_CLASS=IGNORE_CLASS, + ) + entry = add_special_toks(entry, target_text=target_text, tokenizer=tokenizer) + for k in list_inputs.keys(): + list_inputs[k].append(entry[k]) + return list_inputs + + +# Chunking +# we do all chunking with overlap_freq = 0 + + +def _get_chunking_step(length, overlap_freq): + step = length + if overlap_freq: + if overlap_freq > 1: + step = length // overlap_freq + else: + step = length // 2 + return step + + +def _chunked_seq(seq, length, overlap_freq=0): + step = _get_chunking_step(length, overlap_freq) + + for i in range(len(seq) // step + 1): + if i * step < len(seq): + yield seq[i * step : i * step + length] + + +def chunk_inputs( + input_ids, + attention_mask, + labels, + id, + *, + tokenizer, + max_length, + overlap_freq=0, + **kwargs +): + chunks = zip( + *[ + _chunked_seq(seq, max_length, overlap_freq) + for seq in (input_ids, attention_mask, labels) + ] + ) + return [ + dict( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + id=id, + chunk_id=i, + ) + for i, (input_ids, attention_mask, labels) in enumerate(chunks) + ] + + +def chunk_dataset(dataset, tokenizer, overlap_freq=0): + return Dataset.from_list( + list( + itertools.chain( + *( + chunk_inputs( + **entry, + tokenizer=tokenizer, + max_length=tokenizer.model_max_length, + overlap_freq=overlap_freq + ) + for entry in tqdm(list(dataset)) + ) + ) + ) + )