From d96b3bc8480d7c4726aaea448509f2eb64d481f5 Mon Sep 17 00:00:00 2001 From: loubnabnl Date: Fri, 24 Mar 2023 17:59:12 +0000 Subject: [PATCH] add inference notebook --- pii/ner/notebook_inference.ipynb | 512 +++++++++++++++++++++++++++++++ 1 file changed, 512 insertions(+) create mode 100644 pii/ner/notebook_inference.ipynb diff --git a/pii/ner/notebook_inference.ipynb b/pii/ner/notebook_inference.ipynb new file mode 100644 index 0000000..6310fde --- /dev/null +++ b/pii/ner/notebook_inference.ipynb @@ -0,0 +1,512 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import time\n", + "from accelerate import Accelerator\n", + "from dataclasses import dataclass, field\n", + "from functools import partial\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "import datasets \n", + "from transformers import AutoTokenizer, AutoModelForTokenClassification, HfArgumentParser, DataCollatorForTokenClassification" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "KEEP_LABELS = [\n", + " \"NAME\",\n", + " \"NAME_LICENSE\",\n", + " \"NAME_EXAMPLE\",\n", + " \"EMAIL\",\n", + " \"EMAIL_LICENSE\",\n", + " \"EMAIL_EXAMPLE\",\n", + " \"USERNAME\",\n", + " \"USERNAME_LICENSE\",\n", + " \"USERNAME_EXAMPLE\",\n", + " \"KEY\",\n", + " \"IP_ADDRESS\",\n", + " \"PASSWORD\",\n", + "]\n", + "\n", + "# Special tokens\n", + "MASK_TOKEN = \"\"\n", + "SEPARATOR_TOKEN = \"\"\n", + "PAD_TOKEN = \"\"\n", + "CLS_TOKEN = \"\"\n", + "\n", + "@dataclass\n", + "class NerArguments:\n", + "\n", + " \"\"\"configuration for running NER model inference\n", + " \"\"\"\n", + " model_name: str = field(\n", + " default=\"bigcode/deberta-v3-large-pii-ner-v2\",\n", + " metadata={\n", + " \"help\": \"Name of model to use for inference\"\n", + " }\n", + " )\n", + " num_workers: int = field(\n", + " default=16,\n", + " metadata={\n", + " \"help\": \"Number of processes to use for inference\"\n", + " }\n", + " )\n", + " batch_size: int = field(\n", + " default=64,\n", + " metadata={\n", + " \"help\": \"the batch size to use for inference\"\n", + " }\n", + " )\n", + " dataset_name: str = field(\n", + " default=\"bigcode/pii-annotated-toloka\",\n", + " metadata={\n", + " \"help\": \"Name of dataset to use for inference\"\n", + " }\n", + " )\n", + " dryrun: bool = field(\n", + " default=False,\n", + " metadata={\n", + " \"help\": \"Run a dryrun with a small subset of the data\"\n", + " }\n", + " )\n", + " output_path: str = field(\n", + " default=\"output.json\",\n", + " metadata={\n", + " \"help\": \"Path to save output entities\"\n", + " }\n", + " )\n", + "\n", + "# Adapted from: transformers.pipelines.token_classification\n", + "def group_sub_entities(entities, tokenizer):\n", + " first_entity, last_entity = entities[0], entities[-1]\n", + " entity = first_entity[\"entity\"].split(\"-\")[-1]\n", + " scores = np.nanmean([entity[\"score\"] for entity in entities])\n", + " tokens = [entity[\"word\"] for entity in entities]\n", + "\n", + " return {\n", + " \"entity\": entity,\n", + " \"score\": np.mean(scores),\n", + " \"word\": tokenizer.convert_tokens_to_string(tokens)\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "args = NerArguments" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## utilities" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# Adapted from: transformers.pipelines.token_classification\n", + "def group_entities(entities, tokenizer):\n", + " entity_groups = []\n", + " entity_group_disagg = []\n", + "\n", + " if entities:\n", + " last_idx = entities[-1][\"index\"]\n", + "\n", + " for entity in entities:\n", + " is_last_idx = entity[\"index\"] == last_idx\n", + " if not entity_group_disagg:\n", + " entity_group_disagg += [entity]\n", + " if is_last_idx:\n", + " entity_groups += [group_sub_entities(entity_group_disagg, tokenizer)]\n", + " continue\n", + "\n", + " is_entity_start = entity[\"entity\"].split(\"-\")[0] == \"B\"\n", + " curr_entity_type = entity[\"entity\"].split(\"-\")[-1]\n", + " prev_entity_type = entity_group_disagg[-1][\"entity\"].split(\"-\")[-1]\n", + " is_adjacent_entity = entity[\"index\"] == entity_group_disagg[-1][\"index\"] + 1\n", + "\n", + " is_same_entity_as_previous = (\n", + " curr_entity_type == prev_entity_type and not is_entity_start\n", + " ) and is_adjacent_entity\n", + " if is_same_entity_as_previous:\n", + " entity_group_disagg += [entity]\n", + " if is_last_idx:\n", + " entity_groups += [group_sub_entities(entity_group_disagg, tokenizer)]\n", + " else:\n", + " entity_groups += [group_sub_entities(entity_group_disagg, tokenizer)]\n", + " entity_group_disagg = [entity]\n", + " if is_last_idx:\n", + " entity_groups += [group_sub_entities(entity_group_disagg, tokenizer)]\n", + "\n", + " return entity_groups\n", + "\n", + "\n", + "def prepare_tokenizer(tokenizer):\n", + " tokenizer.add_special_tokens({\"pad_token\": PAD_TOKEN})\n", + " tokenizer.add_special_tokens({\"sep_token\": SEPARATOR_TOKEN})\n", + " tokenizer.add_special_tokens({\"cls_token\": CLS_TOKEN})\n", + " tokenizer.add_special_tokens({\"mask_token\": MASK_TOKEN})\n", + " tokenizer.model_max_length = 1024\n", + " return tokenizer\n", + "\n", + "\n", + "def tokenize_function(entries, tokenizer):\n", + " list_inputs = {\n", + " k: [] for k in [\"input_ids\", \"attention_mask\", \"special_tokens_mask\"]\n", + " }\n", + " for text in entries[\"text\"]:\n", + " inputs = tokenizer(text, return_special_tokens_mask=True)\n", + " for k in list_inputs.keys():\n", + " list_inputs[k].append(inputs[k])\n", + " return list_inputs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initializing dataset, model and accelerator" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using custom data configuration bigcode--pii-annotated-toloka-aa0ea1d4040d00a1\n", + "Found cached dataset json (/fsx/loubna/.cache/bigcode___json/bigcode--pii-annotated-toloka-aa0ea1d4040d00a1/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)\n" + ] + } + ], + "source": [ + "accelerator = Accelerator()\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "# load model and tokenizer\n", + "model = AutoModelForTokenClassification.from_pretrained(args.model_name).to(device)\n", + "tokenizer = AutoTokenizer.from_pretrained(args.model_name)\n", + "tokenizer = prepare_tokenizer(tokenizer)\n", + "# labels\n", + "IGNORE_LABELS_IDX = [i for l, i in model.config.label2id.items() if l not in KEEP_LABELS]\n", + "id2label = model.config.id2label\n", + "\n", + "# load and tokenize dataset\n", + "dataset = datasets.load_dataset(args.dataset_name, split=\"train\")\n", + "metadata_columns = [c for c in dataset.column_names if c != \"text\"]\n", + "if args.dryrun:\n", + " dataset = dataset.select(range(1000))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{0: 'O',\n", + " 1: 'B-AMBIGUOUS',\n", + " 2: 'I-AMBIGUOUS',\n", + " 3: 'B-EMAIL',\n", + " 4: 'I-EMAIL',\n", + " 5: 'B-IP_ADDRESS',\n", + " 6: 'I-IP_ADDRESS',\n", + " 7: 'B-KEY',\n", + " 8: 'I-KEY',\n", + " 9: 'B-NAME',\n", + " 10: 'I-NAME',\n", + " 11: 'B-PASSWORD',\n", + " 12: 'I-PASSWORD',\n", + " 13: 'B-USERNAME',\n", + " 14: 'I-USERNAME'}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "id2label" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Chunk dataset so we don't need to truncate long files and end up losing data" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "import itertools\n", + "from datasets import Dataset\n", + "from tqdm import tqdm\n", + "\n", + "def _chunked_seq(seq, length):\n", + " step = length\n", + "\n", + " for i in range(len(seq) // step + 1):\n", + " if i * step < len(seq):\n", + " yield seq[i * step : i * step + length]\n", + "\n", + "\n", + "def chunk_inputs(\n", + " input_ids,\n", + " attention_mask,\n", + " special_tokens_mask,\n", + " id,\n", + " *,\n", + " tokenizer,\n", + " max_length,\n", + " **kwargs\n", + "):\n", + " chunks = zip(\n", + " *[\n", + " _chunked_seq(seq, max_length)\n", + " for seq in (input_ids, attention_mask, special_tokens_mask)\n", + " ]\n", + " )\n", + " return [\n", + " dict(\n", + " input_ids=input_ids,\n", + " attention_mask=attention_mask,\n", + " special_tokens_mask=special_tokens_mask,\n", + " id=id,\n", + " chunk_id=i,\n", + " )\n", + " for i, (input_ids, attention_mask, special_tokens_mask) in enumerate(chunks)\n", + " ]\n", + "\n", + "\n", + "def chunk_dataset(dataset, tokenizer):\n", + " return Dataset.from_list(\n", + " list(\n", + " itertools.chain(\n", + " *(\n", + " chunk_inputs(\n", + " entry[\"input_ids\"],\n", + " entry[\"attention_mask\"],\n", + " entry[\"special_tokens_mask\"],\n", + " entry[\"id\"],\n", + " tokenizer=tokenizer,\n", + " max_length=tokenizer.model_max_length,\n", + " )\n", + " for entry in tqdm(list(dataset))\n", + " )\n", + " )\n", + " )\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)\n", + "with accelerator.main_process_first():\n", + " tokenized_data = dataset.map(\n", + " partial(tokenize_function, tokenizer=tokenizer),\n", + " batched=True,\n", + " num_proc=args.num_workers,\n", + " remove_columns=metadata_columns,\n", + " )\n", + " tokenized_data = tokenized_data.add_column(\"id\", range(len(tokenized_data)))\n", + " tokenized_data = tokenized_data.remove_columns(\"text\")\n", + " chunked_data = chunk_dataset(tokenized_data, tokenizer)\n", + "\n", + "dataloader = DataLoader(chunked_data, batch_size=args.batch_size, shuffle=False, collate_fn=data_collator)\n", + "print(\"length dataloader is\", len(dataloader))\n", + "model, dataloader = accelerator.prepare(model, dataloader)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['input_ids', 'attention_mask', 'special_tokens_mask', 'id'],\n", + " num_rows: 12171\n", + "})" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenized_data" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([64, 1024])\n", + "------------------------------------------------------------ Example 0 ------------------------------------------------------------\n", + "[CLS] { \"first_name\": \"Yance\", \"last_name\": \"Bugbee\", \"email_address\": \"ybugbee7@narod.ru\", \"age\": 22, }, { \"first_name\": \"Zita\", \"last_name\": \"Walak\", \"email_address\": \"zwalak8@ebay.com\", \"age\": 57, }, { \"first_name\": \"Davie\", \"last_name\": \"Garmans\", \"email_address\": \"dgarmans9@biblegateway.com\", \"age\": 53, }, ] return data def start_producer(service_uri: str, ca_path: str, cert_path: str, key_path: str): \"\"\"Start the Kafka producer\"\"\" producer = KafkaProducer( bootstrap_servers=service_uri, security_protocol=\"SSL\", ssl_cafile=ca_path, ssl_certfile=cert_path, ssl_keyfile=key_path, ) return producer def send_messages_to_consumer(producer, topic_name: str = \"sample_customer_profile\"): \"\"\"Send messages from Kafka producer to consumer\"\"\" data = get_fake_data() for message in data: print(f\"Sending message from producer: {message}\") producer.send(topic_name, dumps(message).encode(\"utf-8\")) # Wait for all messages to be sent print(f\"All producermessages sent to consumer for topic {topic_name}\") producer.flush()[SEP]\n", + "tensor([1, 1, 1, ..., 0, 0, 0])\n", + "tensor([1, 0, 0, ..., 1, 1, 1])\n", + "------------------------------------------------------------ Example 1 ------------------------------------------------------------\n", + "[CLS] #!/usr/bin/env python3 # -*- coding: utf-8 -*- from .. import TestUnitBase class TestStegoUnit(TestUnitBase): def test_simple(self): expected_rows = [ bytes.fromhex(r) for r in [ '2C 15 15 75 50 50 A2 51 51 C1 85 85 AC 5B 5B C9 95 95 CD 9E 9E 98 40 40 00 00 00 00 00 00', '82 71 71 AE A0 A0 BF 8F 8F E0 C4 C4 D1 A5 A5 E3 CC CC EB DB DB CB 9C 9C 5B 58 58 00 00 00', '27 27 27 41 41 41 A4 9E 9E C5 C3 C3 C4 C0 C0 B8 B6 B6 D3 D2 D2 EF EB EB CD CD CD A2 9D 9D', '01 01 01 0B 0B 0B 6A 6A 6A 68 68 68 59 59 59 4E 4E 4E 81 81 81 C1 C1 C1 77 45 45 7B 00 00', '26 26 26 6E 6E 6E C5 C5 C5 BD BD BD C1 BF BF BF BF BF DB DB DB F1 F1 F1 7F 03 03 7F 00 00', 'D7 D7 D7 DE DE DE 96 96 96 B8 B1 B1 C0 95 95 D1 C7 C7 F9 F9 F9 EF EF EF 85 25 25 7D 00 00', 'FC FC FC D2 D2 D2 76 71 71 93 6B 6B 86 24 24 7B 4E 4E D4 D1 D1 F6 F6 F6 B7 A9 A9 86 3A 3A', 'BB BB BB CF C9 C9 BB 9A 9A C4 A0 A0 A7 7D 7D 87 7E 7E DC DC DC F9 F6 F6 CC B2 B2 BF AE AE', '00 00 00 26 14 14 A1 5F 5F B8 78 78 A6 95 95 D7 D7 D7 FB FB FB D2 B9 B9 70 22 22 3F 02 02', '00 00 00 02 00 00 55 41 41 AD 9A 9A 3F 3C 3C B0 B0 B0 FD FD FD BC B6 B6 24 01 01 17 00 00', ] ] image = bytes.fromhex( '89504E470D0A1A0A0000000D494844520000000A0000000A0802000000025058EA000000017352474200AECE1CE900' '00000467414D410000B18F0BFC6105000000097048597300000EC300000EC301C76FA8640000014149444154285301' '3601C9FE012C1515493B3B2D01011F3434EBD6D61D3A3A040909CBA2A268C0C0000000036C6767334040171717203A' '3A0B1616162F2F1326260A0F0FF60A0AD3D4D403E6EFEFD7DEDE243636031212F90C0CE5F0F0020A0A203434282C2C' '3C3737010101010A0A0A5F5F5FFEFEFEF1F1F1F5F5F5333333404040B6848404BBBB01262626484848575757F8F8F8' '040202FE00001C1C1C1616168E121200FDFD02B1B1B1707070D1D1D1FBF4F4FFD6D61208081E1E1EFEFEFE062222FE' '000003919191E5E5E5C2BDBDFCDADADDA4A4D0D9D91A2E2E151616FA1C1CECE6E6033D3D3D09030319FDFD1D1E1E02' '1B1BF619192F3535100D0DF4E3E3163838010000002614147B4B4B171919EE1D1D314242242424D7BEBE9E6969CFE0' 'E002000000DCECECB4E2E2F5222299A7A7D9D9D9020202EAFDFDB4DFDFD8FEFE8A567CCFC3DE9AB90000000049454E' '44AE426082' ) stego = self.\n", + "tensor([1, 1, 1, ..., 1, 1, 1])\n", + "tensor([1, 0, 0, ..., 0, 0, 0])\n" + ] + } + ], + "source": [ + "res = next(iter(dataloader))\n", + "print(res[\"input_ids\"].shape)\n", + "print(\"-\" * 60, \"Example 0\", \"-\" * 60)\n", + "tokens = tokenizer.convert_ids_to_tokens(res[\"input_ids\"][0])\n", + "print(tokenizer.convert_tokens_to_string(tokens))\n", + "print(res[\"attention_mask\"][0])\n", + "print(res[\"special_tokens_mask\"][0])\n", + "\n", + "print(\"-\" * 60, \"Example 1\", \"-\" * 60)\n", + "tokens = tokenizer.convert_ids_to_tokens(res[\"input_ids\"][1])\n", + "print(tokenizer.convert_tokens_to_string(tokens))\n", + "print(res[\"attention_mask\"][1])\n", + "print(res[\"special_tokens_mask\"][1])" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [] + } + ], + "source": [ + "all_entities = []\n", + "t_start = time.time()\n", + "for step, batch in tqdm(enumerate(dataloader)):\n", + " t_1 = time.time()\n", + " with torch.no_grad():\n", + " outputs = model(\n", + " input_ids=batch[\"input_ids\"],\n", + " attention_mask=batch[\"attention_mask\"]\n", + " )\n", + " # warning: not very sure if this works with multiple GPU\n", + " predictions, input_ids, special_tokens_mask = accelerator.gather((\n", + " outputs.logits.squeeze(), batch[\"input_ids\"], batch['special_tokens_mask']\n", + " ))\n", + " predictions = predictions.cpu().numpy()\n", + " scores = np.exp(predictions) / np.exp(predictions).sum(-1, keepdims=True)\n", + " batch_labels_idx = scores.argmax(axis=-1)\n", + " forward_time = time.time() - t_1\n", + " t_1 = time.time()\n", + " batch_entities = []\n", + " for text_id, labels_idx in enumerate(batch_labels_idx):\n", + " entities = []\n", + " filtered_labels_idx = [\n", + " (id, label_id) \n", + " for id, label_id in enumerate(labels_idx) \n", + " if label_id not in IGNORE_LABELS_IDX and not special_tokens_mask[text_id][id]\n", + " ]\n", + " for id, label_id in filtered_labels_idx:\n", + " entity = {\n", + " \"word\": tokenizer.convert_ids_to_tokens(int(input_ids[text_id][id])),\n", + " \"index\": id,\n", + " \"score\": float(scores[text_id][id][label_id]),\n", + " \"entity\": id2label[label_id],\n", + " }\n", + " entities += [entity]\n", + " #print(f\"post-processing time {time.time() - t_1}\")\n", + " batch_entities.append(group_entities(entities, tokenizer))\n", + " all_entities += batch_entities\n", + " if args.dryrun:\n", + " print(f\"Step {step}\")\n", + " print(f\"forward time {forward_time}\")\n", + " print(f\"post-processing time {time.time() - t_1}\")\n", + "t_end = time.time()\n", + "\n", + "print(f\"total time: {t_end - t_start:.2f} seconds\")\n", + "all_entities = all_entities[:len(dataset)]\n", + "if accelerator.is_main_process:\n", + " print(all_entities[14])\n", + " with open(args.output_path, \"w\") as f:\n", + " json.dump(all_entities, f)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.9 ('eval-harness': conda)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "271972ab9158cd42175bc1ec5288153b91d150291a0b625c2babd1911356e891" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}