diff --git a/pii/ner/notebook_inference.ipynb b/pii/ner/notebook_inference.ipynb deleted file mode 100644 index 0b62081..0000000 --- a/pii/ner/notebook_inference.ipynb +++ /dev/null @@ -1,523 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Notebook for inference:\n", - "\n", - "Todo:\n", - "- check correctness of predictions\n", - "- improve chunking method to avoid splitting files in teh middle" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "import time\n", - "from accelerate import Accelerator\n", - "from dataclasses import dataclass, field\n", - "from functools import partial\n", - "import numpy as np\n", - "from tqdm import tqdm\n", - "import torch\n", - "from torch.utils.data import DataLoader\n", - "import datasets \n", - "from transformers import AutoTokenizer, AutoModelForTokenClassification, HfArgumentParser, DataCollatorForTokenClassification" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "KEEP_LABELS = [\n", - " \"NAME\",\n", - " \"NAME_LICENSE\",\n", - " \"NAME_EXAMPLE\",\n", - " \"EMAIL\",\n", - " \"EMAIL_LICENSE\",\n", - " \"EMAIL_EXAMPLE\",\n", - " \"USERNAME\",\n", - " \"USERNAME_LICENSE\",\n", - " \"USERNAME_EXAMPLE\",\n", - " \"KEY\",\n", - " \"IP_ADDRESS\",\n", - " \"PASSWORD\",\n", - "]\n", - "\n", - "# Special tokens\n", - "MASK_TOKEN = \"\"\n", - "SEPARATOR_TOKEN = \"\"\n", - "PAD_TOKEN = \"\"\n", - "CLS_TOKEN = \"\"\n", - "\n", - "@dataclass\n", - "class NerArguments:\n", - "\n", - " \"\"\"configuration for running NER model inference\n", - " \"\"\"\n", - " model_name: str = field(\n", - " default=\"bigcode/deberta-v3-large-pii-ner-v2\",\n", - " metadata={\n", - " \"help\": \"Name of model to use for inference\"\n", - " }\n", - " )\n", - " num_workers: int = field(\n", - " default=16,\n", - " metadata={\n", - " \"help\": \"Number of processes to use for inference\"\n", - " }\n", - " )\n", - " batch_size: int = field(\n", - " default=64,\n", - " metadata={\n", - " \"help\": \"the batch size to use for inference\"\n", - " }\n", - " )\n", - " dataset_name: str = field(\n", - " default=\"bigcode/pii-annotated-toloka\",\n", - " metadata={\n", - " \"help\": \"Name of dataset to use for inference\"\n", - " }\n", - " )\n", - " dryrun: bool = field(\n", - " default=False,\n", - " metadata={\n", - " \"help\": \"Run a dryrun with a small subset of the data\"\n", - " }\n", - " )\n", - " output_path: str = field(\n", - " default=\"output.json\",\n", - " metadata={\n", - " \"help\": \"Path to save output entities\"\n", - " }\n", - " )\n", - "\n", - "# Adapted from: transformers.pipelines.token_classification\n", - "def group_sub_entities(entities, tokenizer):\n", - " first_entity, last_entity = entities[0], entities[-1]\n", - " entity = first_entity[\"entity\"].split(\"-\")[-1]\n", - " scores = np.nanmean([entity[\"score\"] for entity in entities])\n", - " tokens = [entity[\"word\"] for entity in entities]\n", - "\n", - " return {\n", - " \"entity\": entity,\n", - " \"score\": np.mean(scores),\n", - " \"word\": tokenizer.convert_tokens_to_string(tokens)\n", - " }" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "args = NerArguments" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## utilities" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "# Adapted from: transformers.pipelines.token_classification\n", - "def group_entities(entities, tokenizer):\n", - " entity_groups = []\n", - " entity_group_disagg = []\n", - "\n", - " if entities:\n", - " last_idx = entities[-1][\"index\"]\n", - "\n", - " for entity in entities:\n", - " is_last_idx = entity[\"index\"] == last_idx\n", - " if not entity_group_disagg:\n", - " entity_group_disagg += [entity]\n", - " if is_last_idx:\n", - " entity_groups += [group_sub_entities(entity_group_disagg, tokenizer)]\n", - " continue\n", - "\n", - " is_entity_start = entity[\"entity\"].split(\"-\")[0] == \"B\"\n", - " curr_entity_type = entity[\"entity\"].split(\"-\")[-1]\n", - " prev_entity_type = entity_group_disagg[-1][\"entity\"].split(\"-\")[-1]\n", - " is_adjacent_entity = entity[\"index\"] == entity_group_disagg[-1][\"index\"] + 1\n", - "\n", - " is_same_entity_as_previous = (\n", - " curr_entity_type == prev_entity_type and not is_entity_start\n", - " ) and is_adjacent_entity\n", - " if is_same_entity_as_previous:\n", - " entity_group_disagg += [entity]\n", - " if is_last_idx:\n", - " entity_groups += [group_sub_entities(entity_group_disagg, tokenizer)]\n", - " else:\n", - " entity_groups += [group_sub_entities(entity_group_disagg, tokenizer)]\n", - " entity_group_disagg = [entity]\n", - " if is_last_idx:\n", - " entity_groups += [group_sub_entities(entity_group_disagg, tokenizer)]\n", - "\n", - " return entity_groups\n", - "\n", - "\n", - "def prepare_tokenizer(tokenizer):\n", - " tokenizer.add_special_tokens({\"pad_token\": PAD_TOKEN})\n", - " tokenizer.add_special_tokens({\"sep_token\": SEPARATOR_TOKEN})\n", - " tokenizer.add_special_tokens({\"cls_token\": CLS_TOKEN})\n", - " tokenizer.add_special_tokens({\"mask_token\": MASK_TOKEN})\n", - " tokenizer.model_max_length = 1024\n", - " return tokenizer\n", - "\n", - "\n", - "def tokenize_function(entries, tokenizer):\n", - " list_inputs = {\n", - " k: [] for k in [\"input_ids\", \"attention_mask\", \"special_tokens_mask\"]\n", - " }\n", - " for text in entries[\"text\"]:\n", - " inputs = tokenizer(text, return_special_tokens_mask=True)\n", - " for k in list_inputs.keys():\n", - " list_inputs[k].append(inputs[k])\n", - " return list_inputs" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Initializing dataset, model and accelerator" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using custom data configuration bigcode--pii-annotated-toloka-aa0ea1d4040d00a1\n", - "Found cached dataset json (/fsx/loubna/.cache/bigcode___json/bigcode--pii-annotated-toloka-aa0ea1d4040d00a1/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)\n" - ] - } - ], - "source": [ - "accelerator = Accelerator()\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "\n", - "# load model and tokenizer\n", - "model = AutoModelForTokenClassification.from_pretrained(args.model_name).to(device)\n", - "tokenizer = AutoTokenizer.from_pretrained(args.model_name)\n", - "tokenizer = prepare_tokenizer(tokenizer)\n", - "# labels\n", - "IGNORE_LABELS_IDX = [i for l, i in model.config.label2id.items() if l not in KEEP_LABELS]\n", - "id2label = model.config.id2label\n", - "\n", - "# load and tokenize dataset\n", - "dataset = datasets.load_dataset(args.dataset_name, split=\"train\")\n", - "metadata_columns = [c for c in dataset.column_names if c != \"text\"]\n", - "if args.dryrun:\n", - " dataset = dataset.select(range(1000))" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{0: 'O',\n", - " 1: 'B-AMBIGUOUS',\n", - " 2: 'I-AMBIGUOUS',\n", - " 3: 'B-EMAIL',\n", - " 4: 'I-EMAIL',\n", - " 5: 'B-IP_ADDRESS',\n", - " 6: 'I-IP_ADDRESS',\n", - " 7: 'B-KEY',\n", - " 8: 'I-KEY',\n", - " 9: 'B-NAME',\n", - " 10: 'I-NAME',\n", - " 11: 'B-PASSWORD',\n", - " 12: 'I-PASSWORD',\n", - " 13: 'B-USERNAME',\n", - " 14: 'I-USERNAME'}" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "id2label" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Chunk dataset so we don't need to truncate long files and end up losing data" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [], - "source": [ - "import itertools\n", - "from datasets import Dataset\n", - "from tqdm import tqdm\n", - "\n", - "def _chunked_seq(seq, length):\n", - " step = length\n", - "\n", - " for i in range(len(seq) // step + 1):\n", - " if i * step < len(seq):\n", - " yield seq[i * step : i * step + length]\n", - "\n", - "\n", - "def chunk_inputs(\n", - " input_ids,\n", - " attention_mask,\n", - " special_tokens_mask,\n", - " id,\n", - " *,\n", - " tokenizer,\n", - " max_length,\n", - " **kwargs\n", - "):\n", - " chunks = zip(\n", - " *[\n", - " _chunked_seq(seq, max_length)\n", - " for seq in (input_ids, attention_mask, special_tokens_mask)\n", - " ]\n", - " )\n", - " return [\n", - " dict(\n", - " input_ids=input_ids,\n", - " attention_mask=attention_mask,\n", - " special_tokens_mask=special_tokens_mask,\n", - " id=id,\n", - " chunk_id=i,\n", - " )\n", - " for i, (input_ids, attention_mask, special_tokens_mask) in enumerate(chunks)\n", - " ]\n", - "\n", - "\n", - "def chunk_dataset(dataset, tokenizer):\n", - " return Dataset.from_list(\n", - " list(\n", - " itertools.chain(\n", - " *(\n", - " chunk_inputs(\n", - " entry[\"input_ids\"],\n", - " entry[\"attention_mask\"],\n", - " entry[\"special_tokens_mask\"],\n", - " entry[\"id\"],\n", - " tokenizer=tokenizer,\n", - " max_length=tokenizer.model_max_length,\n", - " )\n", - " for entry in tqdm(list(dataset))\n", - " )\n", - " )\n", - " )\n", - " )\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)\n", - "with accelerator.main_process_first():\n", - " tokenized_data = dataset.map(\n", - " partial(tokenize_function, tokenizer=tokenizer),\n", - " batched=True,\n", - " num_proc=args.num_workers,\n", - " remove_columns=metadata_columns,\n", - " )\n", - " tokenized_data = tokenized_data.add_column(\"id\", range(len(tokenized_data)))\n", - " tokenized_data = tokenized_data.remove_columns(\"text\")\n", - " chunked_data = chunk_dataset(tokenized_data, tokenizer)\n", - "\n", - "dataloader = DataLoader(chunked_data, batch_size=args.batch_size, shuffle=False, collate_fn=data_collator)\n", - "print(\"length dataloader is\", len(dataloader))\n", - "model, dataloader = accelerator.prepare(model, dataloader)" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Dataset({\n", - " features: ['input_ids', 'attention_mask', 'special_tokens_mask', 'id'],\n", - " num_rows: 12171\n", - "})" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tokenized_data" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([64, 1024])\n", - "------------------------------------------------------------ Example 0 ------------------------------------------------------------\n", - "[CLS] { \"first_name\": \"Yance\", \"last_name\": \"Bugbee\", \"email_address\": \"ybugbee7@narod.ru\", \"age\": 22, }, { \"first_name\": \"Zita\", \"last_name\": \"Walak\", \"email_address\": \"zwalak8@ebay.com\", \"age\": 57, }, { \"first_name\": \"Davie\", \"last_name\": \"Garmans\", \"email_address\": \"dgarmans9@biblegateway.com\", \"age\": 53, }, ] return data def start_producer(service_uri: str, ca_path: str, cert_path: str, key_path: str): \"\"\"Start the Kafka producer\"\"\" producer = KafkaProducer( bootstrap_servers=service_uri, security_protocol=\"SSL\", ssl_cafile=ca_path, ssl_certfile=cert_path, ssl_keyfile=key_path, ) return producer def send_messages_to_consumer(producer, topic_name: str = \"sample_customer_profile\"): \"\"\"Send messages from Kafka producer to consumer\"\"\" data = get_fake_data() for message in data: print(f\"Sending message from producer: {message}\") producer.send(topic_name, dumps(message).encode(\"utf-8\")) # Wait for all messages to be sent print(f\"All producermessages sent to consumer for topic {topic_name}\") producer.flush()[SEP]\n", - "tensor([1, 1, 1, ..., 0, 0, 0])\n", - "tensor([1, 0, 0, ..., 1, 1, 1])\n", - "------------------------------------------------------------ Example 1 ------------------------------------------------------------\n", - "[CLS] #!/usr/bin/env python3 # -*- coding: utf-8 -*- from .. import TestUnitBase class TestStegoUnit(TestUnitBase): def test_simple(self): expected_rows = [ bytes.fromhex(r) for r in [ '2C 15 15 75 50 50 A2 51 51 C1 85 85 AC 5B 5B C9 95 95 CD 9E 9E 98 40 40 00 00 00 00 00 00', '82 71 71 AE A0 A0 BF 8F 8F E0 C4 C4 D1 A5 A5 E3 CC CC EB DB DB CB 9C 9C 5B 58 58 00 00 00', '27 27 27 41 41 41 A4 9E 9E C5 C3 C3 C4 C0 C0 B8 B6 B6 D3 D2 D2 EF EB EB CD CD CD A2 9D 9D', '01 01 01 0B 0B 0B 6A 6A 6A 68 68 68 59 59 59 4E 4E 4E 81 81 81 C1 C1 C1 77 45 45 7B 00 00', '26 26 26 6E 6E 6E C5 C5 C5 BD BD BD C1 BF BF BF BF BF DB DB DB F1 F1 F1 7F 03 03 7F 00 00', 'D7 D7 D7 DE DE DE 96 96 96 B8 B1 B1 C0 95 95 D1 C7 C7 F9 F9 F9 EF EF EF 85 25 25 7D 00 00', 'FC FC FC D2 D2 D2 76 71 71 93 6B 6B 86 24 24 7B 4E 4E D4 D1 D1 F6 F6 F6 B7 A9 A9 86 3A 3A', 'BB BB BB CF C9 C9 BB 9A 9A C4 A0 A0 A7 7D 7D 87 7E 7E DC DC DC F9 F6 F6 CC B2 B2 BF AE AE', '00 00 00 26 14 14 A1 5F 5F B8 78 78 A6 95 95 D7 D7 D7 FB FB FB D2 B9 B9 70 22 22 3F 02 02', '00 00 00 02 00 00 55 41 41 AD 9A 9A 3F 3C 3C B0 B0 B0 FD FD FD BC B6 B6 24 01 01 17 00 00', ] ] image = bytes.fromhex( '89504E470D0A1A0A0000000D494844520000000A0000000A0802000000025058EA000000017352474200AECE1CE900' '00000467414D410000B18F0BFC6105000000097048597300000EC300000EC301C76FA8640000014149444154285301' '3601C9FE012C1515493B3B2D01011F3434EBD6D61D3A3A040909CBA2A268C0C0000000036C6767334040171717203A' '3A0B1616162F2F1326260A0F0FF60A0AD3D4D403E6EFEFD7DEDE243636031212F90C0CE5F0F0020A0A203434282C2C' '3C3737010101010A0A0A5F5F5FFEFEFEF1F1F1F5F5F5333333404040B6848404BBBB01262626484848575757F8F8F8' '040202FE00001C1C1C1616168E121200FDFD02B1B1B1707070D1D1D1FBF4F4FFD6D61208081E1E1EFEFEFE062222FE' '000003919191E5E5E5C2BDBDFCDADADDA4A4D0D9D91A2E2E151616FA1C1CECE6E6033D3D3D09030319FDFD1D1E1E02' '1B1BF619192F3535100D0DF4E3E3163838010000002614147B4B4B171919EE1D1D314242242424D7BEBE9E6969CFE0' 'E002000000DCECECB4E2E2F5222299A7A7D9D9D9020202EAFDFDB4DFDFD8FEFE8A567CCFC3DE9AB90000000049454E' '44AE426082' ) stego = self.\n", - "tensor([1, 1, 1, ..., 1, 1, 1])\n", - "tensor([1, 0, 0, ..., 0, 0, 0])\n" - ] - } - ], - "source": [ - "res = next(iter(dataloader))\n", - "print(res[\"input_ids\"].shape)\n", - "print(\"-\" * 60, \"Example 0\", \"-\" * 60)\n", - "tokens = tokenizer.convert_ids_to_tokens(res[\"input_ids\"][0])\n", - "print(tokenizer.convert_tokens_to_string(tokens))\n", - "print(res[\"attention_mask\"][0])\n", - "print(res[\"special_tokens_mask\"][0])\n", - "\n", - "print(\"-\" * 60, \"Example 1\", \"-\" * 60)\n", - "tokens = tokenizer.convert_ids_to_tokens(res[\"input_ids\"][1])\n", - "print(tokenizer.convert_tokens_to_string(tokens))\n", - "print(res[\"attention_mask\"][1])\n", - "print(res[\"special_tokens_mask\"][1])" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [] - } - ], - "source": [ - "all_entities = []\n", - "t_start = time.time()\n", - "for step, batch in tqdm(enumerate(dataloader)):\n", - " t_1 = time.time()\n", - " with torch.no_grad():\n", - " outputs = model(\n", - " input_ids=batch[\"input_ids\"],\n", - " attention_mask=batch[\"attention_mask\"]\n", - " )\n", - " # warning: not very sure if this works with multiple GPU\n", - " predictions, input_ids, special_tokens_mask = accelerator.gather((\n", - " outputs.logits.squeeze(), batch[\"input_ids\"], batch['special_tokens_mask']\n", - " ))\n", - " predictions = predictions.cpu().numpy()\n", - " scores = np.exp(predictions) / np.exp(predictions).sum(-1, keepdims=True)\n", - " batch_labels_idx = scores.argmax(axis=-1)\n", - " forward_time = time.time() - t_1\n", - " t_1 = time.time()\n", - " batch_entities = []\n", - " for text_id, labels_idx in enumerate(batch_labels_idx):\n", - " entities = []\n", - " filtered_labels_idx = [\n", - " (id, label_id) \n", - " for id, label_id in enumerate(labels_idx) \n", - " if label_id not in IGNORE_LABELS_IDX and not special_tokens_mask[text_id][id]\n", - " ]\n", - " for id, label_id in filtered_labels_idx:\n", - " entity = {\n", - " \"word\": tokenizer.convert_ids_to_tokens(int(input_ids[text_id][id])),\n", - " \"index\": id,\n", - " \"score\": float(scores[text_id][id][label_id]),\n", - " \"entity\": id2label[label_id],\n", - " }\n", - " entities += [entity]\n", - " #print(f\"post-processing time {time.time() - t_1}\")\n", - " batch_entities.append(group_entities(entities, tokenizer))\n", - " all_entities += batch_entities\n", - " if args.dryrun:\n", - " print(f\"Step {step}\")\n", - " print(f\"forward time {forward_time}\")\n", - " print(f\"post-processing time {time.time() - t_1}\")\n", - "t_end = time.time()\n", - "\n", - "print(f\"total time: {t_end - t_start:.2f} seconds\")\n", - "all_entities = all_entities[:len(dataset)]\n", - "if accelerator.is_main_process:\n", - " print(all_entities[14])\n", - " with open(args.output_path, \"w\") as f:\n", - " json.dump(all_entities, f)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.10.9 ('eval-harness': conda)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.9" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "271972ab9158cd42175bc1ec5288153b91d150291a0b625c2babd1911356e891" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/pii/ner/train.py b/pii/ner/train.py deleted file mode 100644 index 4e764cc..0000000 --- a/pii/ner/train.py +++ /dev/null @@ -1,251 +0,0 @@ -import argparse -import os -from pprint import pprint - -from datasets import DatasetDict, load_dataset -from tqdm import tqdm -from functools import partial -from transformers import ( - AutoModelForTokenClassification, - AutoTokenizer, - DataCollatorForTokenClassification, - EarlyStoppingCallback, - Trainer, - TrainingArguments, - set_seed, - logging -) - -from utils.preprocessing import chunk_dataset, tokenize_and_label_batch -from utils.eval import compute_metrics - - -# Special tokens -MASK_TOKEN = "" -SEPARATOR_TOKEN = "" -PAD_TOKEN = "" -CLS_TOKEN = "" - -# NER tags -CATEGORIES = [ - "NAME", - "EMAIL", - "EMAIL_EXAMPLE", - "USERNAME", - "KEY", - "IP_ADDRESS", - "PASSWORD", -] -IGNORE_CLASS = ["AMBIGUOUS", "ID", "NAME_EXAMPLE", "USERNAME_EXAMPLE"] - -LABEL2ID = {"O": 0} -for cat in CATEGORIES: - LABEL2ID[f"B-{cat}"] = len(LABEL2ID) - LABEL2ID[f"I-{cat}"] = len(LABEL2ID) -ID2LABEL = {v: k for k, v in LABEL2ID.items()} - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--model_ckpt", type=str, default="bigcode/bigcode-encoder") - parser.add_argument( - "--dataset_name", - type=str, - default="bigcode/pii-full-ds" - ) - # addprefix to wandb run - parser.add_argument("--prefix", type=str, default="") - parser.add_argument("--add_not_curated", action="store_true") - parser.add_argument("--train_batch_size", type=int, default=4) - parser.add_argument("--eval_batch_size", type=int, default=4) - parser.add_argument("--num_train_epochs", type=int, default=100) - - parser.add_argument("--learning_rate", type=float, default=1e-5) - parser.add_argument("--lr_scheduler_type", type=str, default="cosine") - parser.add_argument("--weight_decay", type=float, default=0.01) - parser.add_argument("--warmup_steps", type=int, default=100) - - parser.add_argument("--gradient_checkpointing", action="store_true") - parser.add_argument("--gradient_accumulation_steps", type=int, default=1) - parser.add_argument("--eval_accumulation_steps", type=int, default=1) - parser.add_argument("--num_proc", type=int, default=8) - parser.add_argument("--bf16", action="store_true") - parser.add_argument("--fp16", action="store_true") - - parser.add_argument("--local_rank", type=int, default=0) - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--num_workers", type=int, default=8) - parser.add_argument("--eval_freq", type=int, default=100) - parser.add_argument("--save_freq", type=int, default=1000) - parser.add_argument("--debug", action="store_true") - parser.add_argument("--output_dir", type=str, default="finetuned-encoder-pii") - return parser.parse_args() - - -def get_stats(data): - # get number of B-cat for cat in categories for each data split - stats = {cat: 0 for cat in CATEGORIES} - for entry in tqdm(data): - for label in entry["labels"]: - # only add labels for beginning with B- - if label > 0 and ID2LABEL[label].startswith("B-"): - stats[ID2LABEL[label][2:]] += 1 - return stats - - -def prepare_tokenizer(tokenizer): - tokenizer.add_special_tokens({"pad_token": PAD_TOKEN}) - tokenizer.add_special_tokens({"sep_token": SEPARATOR_TOKEN}) - tokenizer.add_special_tokens({"cls_token": CLS_TOKEN}) - tokenizer.add_special_tokens({"mask_token": MASK_TOKEN}) - tokenizer.model_max_length = 1024 - return tokenizer - - -def prepare_dataset(dataset, tokenizer, args): - # tokenize and label - dataset = dataset.map( - partial( - tokenize_and_label_batch, - tokenizer=tokenizer, - target_text="text", - pii_column="fragments", - LABEL2ID=LABEL2ID, - IGNORE_CLASS=IGNORE_CLASS, - ), - batched=True, - batch_size=1000, - num_proc=args.num_workers, - ) - return dataset - -def run_training(args, ner_dataset, model, tokenizer): - print(f"Initializing Trainer...") - - training_args = TrainingArguments( - output_dir=args.output_dir, - evaluation_strategy="steps", - num_train_epochs=args.num_train_epochs, - per_device_train_batch_size=args.train_batch_size, - per_device_eval_batch_size=args.eval_batch_size, - eval_steps=args.eval_freq, - save_steps=args.save_freq, - logging_steps=10, - metric_for_best_model="f1", - load_best_model_at_end=True, - weight_decay=args.weight_decay, - learning_rate=args.learning_rate, - lr_scheduler_type=args.lr_scheduler_type, - warmup_steps=args.warmup_steps, - gradient_checkpointing=args.gradient_checkpointing, - gradient_accumulation_steps=args.gradient_accumulation_steps, - eval_accumulation_steps=args.eval_accumulation_steps, - fp16=args.fp16, - bf16=args.bf16, - run_name=f"{args.prefix}-bs{args.train_batch_size}-lr{args.learning_rate}-wd{args.weight_decay}-ep{args.num_train_epochs}-last", - report_to="wandb", - ) - - - data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer) - trainer = Trainer( - model=model, - args=training_args, - train_dataset=ner_dataset["train"], - eval_dataset=ner_dataset["validation"], - data_collator=data_collator, - tokenizer=tokenizer, - compute_metrics=compute_metrics, - callbacks=[ - EarlyStoppingCallback( - early_stopping_patience=15, early_stopping_threshold=1e-2 - ) - ], - ) - - print("Training...") - #trainer.train() - - print("Saving last checkpoint of the model") - #model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint_last_exp/")) - - # evaluate on test set - print("Evaluating on test set...") - trainer.evaluate(ner_dataset["validation"]) - - -def main(args): - # load model and tokenizer - model = AutoModelForTokenClassification.from_pretrained( - #args.model_ckpt, - "/fsx/loubna/code/bigcode-dataset/pii/ner/finetuned-encoder-pii/final_checkpoint-all-noexamples", - num_labels=len(ID2LABEL), - id2label=ID2LABEL, - label2id=LABEL2ID, - use_auth_token=True, - use_cache=not args.gradient_checkpointing, - output_hidden_states = False, - ) - tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt, use_auth_token=True) - tokenizer = prepare_tokenizer(tokenizer) - - # load dataset - dataset = load_dataset(args.dataset_name, use_auth_token=True) - train_data = dataset["train"].shuffle(seed=args.seed) - test_data = dataset["test"] - valid_data = dataset["valid"] - - from datasets import concatenate_datasets - train_data = concatenate_datasets([train_data, test_data]) - print(f"Concatenated train and test data, new train size: {len(train_data)}") - - - if args.dataset_name == "bigcode/pii-full-ds": - if not args.add_not_curated: - print("Removing not curated data (-400 long files)...") - # keep only curated data - train_data = train_data.filter(lambda x: x["data_origin"] == "curated") - else: - print("Keeping not curated data...") - - - train_data = prepare_dataset(train_data, tokenizer, args) - test_data = prepare_dataset(test_data, tokenizer, args) - valid_data = prepare_dataset(valid_data, tokenizer, args) - print( - f"After tokenization:\nTrain size {len(train_data)}\nValid size {len(valid_data)}\nTest size {len(test_data)}" - ) - - if args.debug: - train_stats = get_stats(train_data) - valid_stats = get_stats(valid_data) - test_stats = get_stats(test_data) - print("Train low-resource stats") - # print stats for keys with less than 100 in teh value - pprint({k: v for k, v in train_stats.items() if v < 300}) - print("Valid low-resource stats") - pprint({k: v for k, v in valid_stats.items() if v < 100}) - print("Test low-resource stats") - pprint({k: v for k, v in test_stats.items() if v < 100}) - - - print("Chunking the dataset...") - ner_dataset = DatasetDict( - train=chunk_dataset(train_data, tokenizer), - validation=chunk_dataset(valid_data, tokenizer), - test=chunk_dataset(test_data, tokenizer), - ) - # remove columns - ner_dataset = ner_dataset.remove_columns(["id", "chunk_id"]) - print(ner_dataset) - - run_training(args, ner_dataset, model, tokenizer) - - -if __name__ == "__main__": - args = get_args() - set_seed(args.seed) - os.makedirs(args.output_dir, exist_ok=True) - - logging.set_verbosity_info() - - main(args) \ No newline at end of file