-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1 parent
4e5b70d
commit 0f33575
Showing
1 changed file
with
1,343 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,1343 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from tqdm.notebook import tqdm\n", | ||
"\n", | ||
"from transformers import BertTokenizer\n", | ||
"from torch.utils.data import TensorDataset\n", | ||
"\n", | ||
"from transformers import BertForSequenceClassification" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df = pd.read_csv('data/title_conference.csv')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<div>\n", | ||
"<style scoped>\n", | ||
" .dataframe tbody tr th:only-of-type {\n", | ||
" vertical-align: middle;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe tbody tr th {\n", | ||
" vertical-align: top;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe thead th {\n", | ||
" text-align: right;\n", | ||
" }\n", | ||
"</style>\n", | ||
"<table border=\"1\" class=\"dataframe\">\n", | ||
" <thead>\n", | ||
" <tr style=\"text-align: right;\">\n", | ||
" <th></th>\n", | ||
" <th>Title</th>\n", | ||
" <th>Conference</th>\n", | ||
" </tr>\n", | ||
" </thead>\n", | ||
" <tbody>\n", | ||
" <tr>\n", | ||
" <td>0</td>\n", | ||
" <td>Innovation in Database Management: Computer Sc...</td>\n", | ||
" <td>VLDB</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td>1</td>\n", | ||
" <td>High performance prime field multiplication fo...</td>\n", | ||
" <td>ISCAS</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td>2</td>\n", | ||
" <td>enchanted scissors: a scissor interface for su...</td>\n", | ||
" <td>SIGGRAPH</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td>3</td>\n", | ||
" <td>Detection of channel degradation attack by Int...</td>\n", | ||
" <td>INFOCOM</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td>4</td>\n", | ||
" <td>Pinning a Complex Network through the Betweenn...</td>\n", | ||
" <td>ISCAS</td>\n", | ||
" </tr>\n", | ||
" </tbody>\n", | ||
"</table>\n", | ||
"</div>" | ||
], | ||
"text/plain": [ | ||
" Title Conference\n", | ||
"0 Innovation in Database Management: Computer Sc... VLDB\n", | ||
"1 High performance prime field multiplication fo... ISCAS\n", | ||
"2 enchanted scissors: a scissor interface for su... SIGGRAPH\n", | ||
"3 Detection of channel degradation attack by Int... INFOCOM\n", | ||
"4 Pinning a Complex Network through the Betweenn... ISCAS" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"df.head()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"ISCAS 864\n", | ||
"INFOCOM 515\n", | ||
"VLDB 423\n", | ||
"WWW 379\n", | ||
"SIGGRAPH 326\n", | ||
"Name: Conference, dtype: int64" | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"df['Conference'].value_counts()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'VLDB': 0, 'ISCAS': 1, 'SIGGRAPH': 2, 'INFOCOM': 3, 'WWW': 4}" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"possible_labels = df.Conference.unique()\n", | ||
"\n", | ||
"label_dict = {}\n", | ||
"for index, possible_label in enumerate(possible_labels):\n", | ||
" label_dict[possible_label] = index\n", | ||
"label_dict" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df['label'] = df.Conference.replace(label_dict)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<div>\n", | ||
"<style scoped>\n", | ||
" .dataframe tbody tr th:only-of-type {\n", | ||
" vertical-align: middle;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe tbody tr th {\n", | ||
" vertical-align: top;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe thead th {\n", | ||
" text-align: right;\n", | ||
" }\n", | ||
"</style>\n", | ||
"<table border=\"1\" class=\"dataframe\">\n", | ||
" <thead>\n", | ||
" <tr style=\"text-align: right;\">\n", | ||
" <th></th>\n", | ||
" <th>Title</th>\n", | ||
" <th>Conference</th>\n", | ||
" <th>label</th>\n", | ||
" </tr>\n", | ||
" </thead>\n", | ||
" <tbody>\n", | ||
" <tr>\n", | ||
" <td>0</td>\n", | ||
" <td>Innovation in Database Management: Computer Sc...</td>\n", | ||
" <td>VLDB</td>\n", | ||
" <td>0</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td>1</td>\n", | ||
" <td>High performance prime field multiplication fo...</td>\n", | ||
" <td>ISCAS</td>\n", | ||
" <td>1</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td>2</td>\n", | ||
" <td>enchanted scissors: a scissor interface for su...</td>\n", | ||
" <td>SIGGRAPH</td>\n", | ||
" <td>2</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td>3</td>\n", | ||
" <td>Detection of channel degradation attack by Int...</td>\n", | ||
" <td>INFOCOM</td>\n", | ||
" <td>3</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td>4</td>\n", | ||
" <td>Pinning a Complex Network through the Betweenn...</td>\n", | ||
" <td>ISCAS</td>\n", | ||
" <td>1</td>\n", | ||
" </tr>\n", | ||
" </tbody>\n", | ||
"</table>\n", | ||
"</div>" | ||
], | ||
"text/plain": [ | ||
" Title Conference label\n", | ||
"0 Innovation in Database Management: Computer Sc... VLDB 0\n", | ||
"1 High performance prime field multiplication fo... ISCAS 1\n", | ||
"2 enchanted scissors: a scissor interface for su... SIGGRAPH 2\n", | ||
"3 Detection of channel degradation attack by Int... INFOCOM 3\n", | ||
"4 Pinning a Complex Network through the Betweenn... ISCAS 1" | ||
] | ||
}, | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"df.head()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from sklearn.model_selection import train_test_split\n", | ||
"\n", | ||
"X_train, X_val, y_train, y_val = train_test_split(df.index.values, \n", | ||
" df.label.values, \n", | ||
" test_size=0.15, \n", | ||
" random_state=42, \n", | ||
" stratify=df.label.values)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df['data_type'] = ['not_set']*df.shape[0]\n", | ||
"\n", | ||
"df.loc[X_train, 'data_type'] = 'train'\n", | ||
"df.loc[X_val, 'data_type'] = 'val'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<div>\n", | ||
"<style scoped>\n", | ||
" .dataframe tbody tr th:only-of-type {\n", | ||
" vertical-align: middle;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe tbody tr th {\n", | ||
" vertical-align: top;\n", | ||
" }\n", | ||
"\n", | ||
" .dataframe thead th {\n", | ||
" text-align: right;\n", | ||
" }\n", | ||
"</style>\n", | ||
"<table border=\"1\" class=\"dataframe\">\n", | ||
" <thead>\n", | ||
" <tr style=\"text-align: right;\">\n", | ||
" <th></th>\n", | ||
" <th></th>\n", | ||
" <th></th>\n", | ||
" <th>Title</th>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <th>Conference</th>\n", | ||
" <th>label</th>\n", | ||
" <th>data_type</th>\n", | ||
" <th></th>\n", | ||
" </tr>\n", | ||
" </thead>\n", | ||
" <tbody>\n", | ||
" <tr>\n", | ||
" <td rowspan=\"2\" valign=\"top\">INFOCOM</td>\n", | ||
" <td rowspan=\"2\" valign=\"top\">3</td>\n", | ||
" <td>train</td>\n", | ||
" <td>438</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td>val</td>\n", | ||
" <td>77</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td rowspan=\"2\" valign=\"top\">ISCAS</td>\n", | ||
" <td rowspan=\"2\" valign=\"top\">1</td>\n", | ||
" <td>train</td>\n", | ||
" <td>734</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td>val</td>\n", | ||
" <td>130</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td rowspan=\"2\" valign=\"top\">SIGGRAPH</td>\n", | ||
" <td rowspan=\"2\" valign=\"top\">2</td>\n", | ||
" <td>train</td>\n", | ||
" <td>277</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td>val</td>\n", | ||
" <td>49</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td rowspan=\"2\" valign=\"top\">VLDB</td>\n", | ||
" <td rowspan=\"2\" valign=\"top\">0</td>\n", | ||
" <td>train</td>\n", | ||
" <td>359</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td>val</td>\n", | ||
" <td>64</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td rowspan=\"2\" valign=\"top\">WWW</td>\n", | ||
" <td rowspan=\"2\" valign=\"top\">4</td>\n", | ||
" <td>train</td>\n", | ||
" <td>322</td>\n", | ||
" </tr>\n", | ||
" <tr>\n", | ||
" <td>val</td>\n", | ||
" <td>57</td>\n", | ||
" </tr>\n", | ||
" </tbody>\n", | ||
"</table>\n", | ||
"</div>" | ||
], | ||
"text/plain": [ | ||
" Title\n", | ||
"Conference label data_type \n", | ||
"INFOCOM 3 train 438\n", | ||
" val 77\n", | ||
"ISCAS 1 train 734\n", | ||
" val 130\n", | ||
"SIGGRAPH 2 train 277\n", | ||
" val 49\n", | ||
"VLDB 0 train 359\n", | ||
" val 64\n", | ||
"WWW 4 train 322\n", | ||
" val 57" | ||
] | ||
}, | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"df.groupby(['Conference', 'label', 'data_type']).count()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"I0730 17:03:09.511292 140649029109568 tokenization_utils_base.py:1254] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/jupyter-susan/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', \n", | ||
" do_lower_case=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"W0730 17:03:18.913378 140649029109568 tokenization_utils_base.py:1447] Truncation was not explicitely activated but `max_length` is provided a specific value, please use `truncation=True` to explicitely truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n", | ||
"W0730 17:03:19.518625 140649029109568 tokenization_utils_base.py:1447] Truncation was not explicitely activated but `max_length` is provided a specific value, please use `truncation=True` to explicitely truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"encoded_data_train = tokenizer.batch_encode_plus(\n", | ||
" df[df.data_type=='train'].Title.values, \n", | ||
" add_special_tokens=True, \n", | ||
" return_attention_mask=True, \n", | ||
" pad_to_max_length=True, \n", | ||
" max_length=256, \n", | ||
" return_tensors='pt'\n", | ||
")\n", | ||
"\n", | ||
"encoded_data_val = tokenizer.batch_encode_plus(\n", | ||
" df[df.data_type=='val'].Title.values, \n", | ||
" add_special_tokens=True, \n", | ||
" return_attention_mask=True, \n", | ||
" pad_to_max_length=True, \n", | ||
" max_length=256, \n", | ||
" return_tensors='pt'\n", | ||
")\n", | ||
"\n", | ||
"\n", | ||
"input_ids_train = encoded_data_train['input_ids']\n", | ||
"attention_masks_train = encoded_data_train['attention_mask']\n", | ||
"labels_train = torch.tensor(df[df.data_type=='train'].label.values)\n", | ||
"\n", | ||
"input_ids_val = encoded_data_val['input_ids']\n", | ||
"attention_masks_val = encoded_data_val['attention_mask']\n", | ||
"labels_val = torch.tensor(df[df.data_type=='val'].label.values)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)\n", | ||
"dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(2130, 377)" | ||
] | ||
}, | ||
"execution_count": 15, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"len(dataset_train), len(dataset_val)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"I0730 17:03:32.775971 140649029109568 configuration_utils.py:264] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /home/jupyter-susan/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517\n", | ||
"I0730 17:03:32.777010 140649029109568 configuration_utils.py:300] Model config BertConfig {\n", | ||
" \"architectures\": [\n", | ||
" \"BertForMaskedLM\"\n", | ||
" ],\n", | ||
" \"attention_probs_dropout_prob\": 0.1,\n", | ||
" \"gradient_checkpointing\": false,\n", | ||
" \"hidden_act\": \"gelu\",\n", | ||
" \"hidden_dropout_prob\": 0.1,\n", | ||
" \"hidden_size\": 768,\n", | ||
" \"id2label\": {\n", | ||
" \"0\": \"LABEL_0\",\n", | ||
" \"1\": \"LABEL_1\",\n", | ||
" \"2\": \"LABEL_2\",\n", | ||
" \"3\": \"LABEL_3\",\n", | ||
" \"4\": \"LABEL_4\"\n", | ||
" },\n", | ||
" \"initializer_range\": 0.02,\n", | ||
" \"intermediate_size\": 3072,\n", | ||
" \"label2id\": {\n", | ||
" \"LABEL_0\": 0,\n", | ||
" \"LABEL_1\": 1,\n", | ||
" \"LABEL_2\": 2,\n", | ||
" \"LABEL_3\": 3,\n", | ||
" \"LABEL_4\": 4\n", | ||
" },\n", | ||
" \"layer_norm_eps\": 1e-12,\n", | ||
" \"max_position_embeddings\": 512,\n", | ||
" \"model_type\": \"bert\",\n", | ||
" \"num_attention_heads\": 12,\n", | ||
" \"num_hidden_layers\": 12,\n", | ||
" \"pad_token_id\": 0,\n", | ||
" \"type_vocab_size\": 2,\n", | ||
" \"vocab_size\": 30522\n", | ||
"}\n", | ||
"\n", | ||
"I0730 17:03:32.815931 140649029109568 modeling_utils.py:667] loading weights file https://cdn.huggingface.co/bert-base-uncased-pytorch_model.bin from cache at /home/jupyter-susan/.cache/torch/transformers/f2ee78bdd635b758cc0a12352586868bef80e47401abe4c4fcc3832421e7338b.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157\n", | ||
"W0730 17:03:38.867964 140649029109568 modeling_utils.py:757] Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n", | ||
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n", | ||
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | ||
"W0730 17:03:38.868916 140649029109568 modeling_utils.py:768] Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n", | ||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\",\n", | ||
" num_labels=len(label_dict),\n", | ||
" output_attentions=False,\n", | ||
" output_hidden_states=False)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n", | ||
"\n", | ||
"batch_size = 3\n", | ||
"\n", | ||
"dataloader_train = DataLoader(dataset_train, \n", | ||
" sampler=RandomSampler(dataset_train), \n", | ||
" batch_size=batch_size)\n", | ||
"\n", | ||
"dataloader_validation = DataLoader(dataset_val, \n", | ||
" sampler=SequentialSampler(dataset_val), \n", | ||
" batch_size=batch_size)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 18, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers import AdamW, get_linear_schedule_with_warmup\n", | ||
"\n", | ||
"optimizer = AdamW(model.parameters(),\n", | ||
" lr=1e-5, \n", | ||
" eps=1e-8)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 19, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"epochs = 5\n", | ||
"\n", | ||
"scheduler = get_linear_schedule_with_warmup(optimizer, \n", | ||
" num_warmup_steps=0,\n", | ||
" num_training_steps=len(dataloader_train)*epochs)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 20, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from sklearn.metrics import f1_score\n", | ||
"\n", | ||
"def f1_score_func(preds, labels):\n", | ||
" preds_flat = np.argmax(preds, axis=1).flatten()\n", | ||
" labels_flat = labels.flatten()\n", | ||
" return f1_score(labels_flat, preds_flat, average='weighted')\n", | ||
"\n", | ||
"def accuracy_per_class(preds, labels):\n", | ||
" label_dict_inverse = {v: k for k, v in label_dict.items()}\n", | ||
" \n", | ||
" preds_flat = np.argmax(preds, axis=1).flatten()\n", | ||
" labels_flat = labels.flatten()\n", | ||
"\n", | ||
" for label in np.unique(labels_flat):\n", | ||
" y_preds = preds_flat[labels_flat==label]\n", | ||
" y_true = labels_flat[labels_flat==label]\n", | ||
" print(f'Class: {label_dict_inverse[label]}')\n", | ||
" print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\\n')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 21, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import random\n", | ||
"\n", | ||
"seed_val = 17\n", | ||
"random.seed(seed_val)\n", | ||
"np.random.seed(seed_val)\n", | ||
"torch.manual_seed(seed_val)\n", | ||
"torch.cuda.manual_seed_all(seed_val)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 22, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"cpu\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | ||
"model.to(device)\n", | ||
"\n", | ||
"print(device)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 23, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def evaluate(dataloader_val):\n", | ||
"\n", | ||
" model.eval()\n", | ||
" \n", | ||
" loss_val_total = 0\n", | ||
" predictions, true_vals = [], []\n", | ||
" \n", | ||
" for batch in dataloader_val:\n", | ||
" \n", | ||
" batch = tuple(b.to(device) for b in batch)\n", | ||
" \n", | ||
" inputs = {'input_ids': batch[0],\n", | ||
" 'attention_mask': batch[1],\n", | ||
" 'labels': batch[2],\n", | ||
" }\n", | ||
"\n", | ||
" with torch.no_grad(): \n", | ||
" outputs = model(**inputs)\n", | ||
" \n", | ||
" loss = outputs[0]\n", | ||
" logits = outputs[1]\n", | ||
" loss_val_total += loss.item()\n", | ||
"\n", | ||
" logits = logits.detach().cpu().numpy()\n", | ||
" label_ids = inputs['labels'].cpu().numpy()\n", | ||
" predictions.append(logits)\n", | ||
" true_vals.append(label_ids)\n", | ||
" \n", | ||
" loss_val_avg = loss_val_total/len(dataloader_val) \n", | ||
" \n", | ||
" predictions = np.concatenate(predictions, axis=0)\n", | ||
" true_vals = np.concatenate(true_vals, axis=0)\n", | ||
" \n", | ||
" return loss_val_avg, predictions, true_vals" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 24, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "e64800612d024595906f0cd63e6bbf71", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=710.0, style=ProgressStyle(description_widt…" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"Epoch 1\n", | ||
"Training loss: 0.9007002512753849\n", | ||
"Validation loss: 0.6143069127574563\n", | ||
"F1 Score (Weighted): 0.7791319217695921\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=710.0, style=ProgressStyle(description_widt…" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"Epoch 2\n", | ||
"Training loss: 0.5381144283001613\n", | ||
"Validation loss: 0.6438471145765294\n", | ||
"F1 Score (Weighted): 0.8207824902152685\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=710.0, style=ProgressStyle(description_widt…" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"Epoch 3\n", | ||
"Training loss: 0.35893184876292417\n", | ||
"Validation loss: 0.723008230609435\n", | ||
"F1 Score (Weighted): 0.8463474188661483\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=710.0, style=ProgressStyle(description_widt…" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"Epoch 4\n", | ||
"Training loss: 0.2692523200199349\n", | ||
"Validation loss: 0.7796335518272365\n", | ||
"F1 Score (Weighted): 0.8341132163207956\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"HBox(children=(FloatProgress(value=0.0, description='Epoch 5', max=710.0, style=ProgressStyle(description_widt…" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"Epoch 5\n", | ||
"Training loss: 0.18156354463565766\n", | ||
"Validation loss: 0.8108082735081321\n", | ||
"F1 Score (Weighted): 0.8441012614273822\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for epoch in tqdm(range(1, epochs+1)):\n", | ||
" \n", | ||
" model.train()\n", | ||
" \n", | ||
" loss_train_total = 0\n", | ||
"\n", | ||
" progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)\n", | ||
" for batch in progress_bar:\n", | ||
"\n", | ||
" model.zero_grad()\n", | ||
" \n", | ||
" batch = tuple(b.to(device) for b in batch)\n", | ||
" \n", | ||
" inputs = {'input_ids': batch[0],\n", | ||
" 'attention_mask': batch[1],\n", | ||
" 'labels': batch[2],\n", | ||
" } \n", | ||
"\n", | ||
" outputs = model(**inputs)\n", | ||
" \n", | ||
" loss = outputs[0]\n", | ||
" loss_train_total += loss.item()\n", | ||
" loss.backward()\n", | ||
"\n", | ||
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", | ||
"\n", | ||
" optimizer.step()\n", | ||
" scheduler.step()\n", | ||
" \n", | ||
" progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})\n", | ||
" \n", | ||
" \n", | ||
" torch.save(model.state_dict(), f'data_volume/finetuned_BERT_epoch_{epoch}.model')\n", | ||
" \n", | ||
" tqdm.write(f'\\nEpoch {epoch}')\n", | ||
" \n", | ||
" loss_train_avg = loss_train_total/len(dataloader_train) \n", | ||
" tqdm.write(f'Training loss: {loss_train_avg}')\n", | ||
" \n", | ||
" val_loss, predictions, true_vals = evaluate(dataloader_validation)\n", | ||
" val_f1 = f1_score_func(predictions, true_vals)\n", | ||
" tqdm.write(f'Validation loss: {val_loss}')\n", | ||
" tqdm.write(f'F1 Score (Weighted): {val_f1}')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 25, | ||
"metadata": { | ||
"scrolled": false | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"I0730 21:12:18.774542 140649029109568 configuration_utils.py:264] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /home/jupyter-susan/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517\n", | ||
"I0730 21:12:18.775610 140649029109568 configuration_utils.py:300] Model config BertConfig {\n", | ||
" \"architectures\": [\n", | ||
" \"BertForMaskedLM\"\n", | ||
" ],\n", | ||
" \"attention_probs_dropout_prob\": 0.1,\n", | ||
" \"gradient_checkpointing\": false,\n", | ||
" \"hidden_act\": \"gelu\",\n", | ||
" \"hidden_dropout_prob\": 0.1,\n", | ||
" \"hidden_size\": 768,\n", | ||
" \"id2label\": {\n", | ||
" \"0\": \"LABEL_0\",\n", | ||
" \"1\": \"LABEL_1\",\n", | ||
" \"2\": \"LABEL_2\",\n", | ||
" \"3\": \"LABEL_3\",\n", | ||
" \"4\": \"LABEL_4\"\n", | ||
" },\n", | ||
" \"initializer_range\": 0.02,\n", | ||
" \"intermediate_size\": 3072,\n", | ||
" \"label2id\": {\n", | ||
" \"LABEL_0\": 0,\n", | ||
" \"LABEL_1\": 1,\n", | ||
" \"LABEL_2\": 2,\n", | ||
" \"LABEL_3\": 3,\n", | ||
" \"LABEL_4\": 4\n", | ||
" },\n", | ||
" \"layer_norm_eps\": 1e-12,\n", | ||
" \"max_position_embeddings\": 512,\n", | ||
" \"model_type\": \"bert\",\n", | ||
" \"num_attention_heads\": 12,\n", | ||
" \"num_hidden_layers\": 12,\n", | ||
" \"pad_token_id\": 0,\n", | ||
" \"type_vocab_size\": 2,\n", | ||
" \"vocab_size\": 30522\n", | ||
"}\n", | ||
"\n", | ||
"I0730 21:12:18.891486 140649029109568 modeling_utils.py:667] loading weights file https://cdn.huggingface.co/bert-base-uncased-pytorch_model.bin from cache at /home/jupyter-susan/.cache/torch/transformers/f2ee78bdd635b758cc0a12352586868bef80e47401abe4c4fcc3832421e7338b.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157\n", | ||
"W0730 21:12:24.786196 140649029109568 modeling_utils.py:757] Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n", | ||
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n", | ||
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | ||
"W0730 21:12:24.787092 140649029109568 modeling_utils.py:768] Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n", | ||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"BertForSequenceClassification(\n", | ||
" (bert): BertModel(\n", | ||
" (embeddings): BertEmbeddings(\n", | ||
" (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", | ||
" (position_embeddings): Embedding(512, 768)\n", | ||
" (token_type_embeddings): Embedding(2, 768)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" (encoder): BertEncoder(\n", | ||
" (layer): ModuleList(\n", | ||
" (0): BertLayer(\n", | ||
" (attention): BertAttention(\n", | ||
" (self): BertSelfAttention(\n", | ||
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" (output): BertSelfOutput(\n", | ||
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (intermediate): BertIntermediate(\n", | ||
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | ||
" )\n", | ||
" (output): BertOutput(\n", | ||
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (1): BertLayer(\n", | ||
" (attention): BertAttention(\n", | ||
" (self): BertSelfAttention(\n", | ||
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" (output): BertSelfOutput(\n", | ||
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (intermediate): BertIntermediate(\n", | ||
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | ||
" )\n", | ||
" (output): BertOutput(\n", | ||
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (2): BertLayer(\n", | ||
" (attention): BertAttention(\n", | ||
" (self): BertSelfAttention(\n", | ||
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" (output): BertSelfOutput(\n", | ||
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (intermediate): BertIntermediate(\n", | ||
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | ||
" )\n", | ||
" (output): BertOutput(\n", | ||
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (3): BertLayer(\n", | ||
" (attention): BertAttention(\n", | ||
" (self): BertSelfAttention(\n", | ||
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" (output): BertSelfOutput(\n", | ||
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (intermediate): BertIntermediate(\n", | ||
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | ||
" )\n", | ||
" (output): BertOutput(\n", | ||
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (4): BertLayer(\n", | ||
" (attention): BertAttention(\n", | ||
" (self): BertSelfAttention(\n", | ||
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" (output): BertSelfOutput(\n", | ||
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (intermediate): BertIntermediate(\n", | ||
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | ||
" )\n", | ||
" (output): BertOutput(\n", | ||
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (5): BertLayer(\n", | ||
" (attention): BertAttention(\n", | ||
" (self): BertSelfAttention(\n", | ||
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" (output): BertSelfOutput(\n", | ||
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (intermediate): BertIntermediate(\n", | ||
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | ||
" )\n", | ||
" (output): BertOutput(\n", | ||
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (6): BertLayer(\n", | ||
" (attention): BertAttention(\n", | ||
" (self): BertSelfAttention(\n", | ||
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" (output): BertSelfOutput(\n", | ||
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (intermediate): BertIntermediate(\n", | ||
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | ||
" )\n", | ||
" (output): BertOutput(\n", | ||
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (7): BertLayer(\n", | ||
" (attention): BertAttention(\n", | ||
" (self): BertSelfAttention(\n", | ||
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" (output): BertSelfOutput(\n", | ||
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (intermediate): BertIntermediate(\n", | ||
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | ||
" )\n", | ||
" (output): BertOutput(\n", | ||
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (8): BertLayer(\n", | ||
" (attention): BertAttention(\n", | ||
" (self): BertSelfAttention(\n", | ||
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" (output): BertSelfOutput(\n", | ||
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (intermediate): BertIntermediate(\n", | ||
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | ||
" )\n", | ||
" (output): BertOutput(\n", | ||
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (9): BertLayer(\n", | ||
" (attention): BertAttention(\n", | ||
" (self): BertSelfAttention(\n", | ||
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" (output): BertSelfOutput(\n", | ||
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (intermediate): BertIntermediate(\n", | ||
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | ||
" )\n", | ||
" (output): BertOutput(\n", | ||
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (10): BertLayer(\n", | ||
" (attention): BertAttention(\n", | ||
" (self): BertSelfAttention(\n", | ||
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" (output): BertSelfOutput(\n", | ||
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (intermediate): BertIntermediate(\n", | ||
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | ||
" )\n", | ||
" (output): BertOutput(\n", | ||
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (11): BertLayer(\n", | ||
" (attention): BertAttention(\n", | ||
" (self): BertSelfAttention(\n", | ||
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" (output): BertSelfOutput(\n", | ||
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" (intermediate): BertIntermediate(\n", | ||
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | ||
" )\n", | ||
" (output): BertOutput(\n", | ||
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | ||
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" )\n", | ||
" )\n", | ||
" (pooler): BertPooler(\n", | ||
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | ||
" (activation): Tanh()\n", | ||
" )\n", | ||
" )\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" (classifier): Linear(in_features=768, out_features=5, bias=True)\n", | ||
")" | ||
] | ||
}, | ||
"execution_count": 25, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"model = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\",\n", | ||
" num_labels=len(label_dict),\n", | ||
" output_attentions=False,\n", | ||
" output_hidden_states=False)\n", | ||
"\n", | ||
"model.to(device)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 26, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"<All keys matched successfully>" | ||
] | ||
}, | ||
"execution_count": 26, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"model.load_state_dict(torch.load('data_volume/finetuned_BERT_epoch_1.model', map_location=torch.device('cpu')))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 27, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"_, predictions, true_vals = evaluate(dataloader_validation)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 28, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Class: VLDB\n", | ||
"Accuracy: 45/64\n", | ||
"\n", | ||
"Class: ISCAS\n", | ||
"Accuracy: 124/130\n", | ||
"\n", | ||
"Class: SIGGRAPH\n", | ||
"Accuracy: 29/49\n", | ||
"\n", | ||
"Class: INFOCOM\n", | ||
"Accuracy: 65/77\n", | ||
"\n", | ||
"Class: WWW\n", | ||
"Accuracy: 33/57\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"accuracy_per_class(predictions, true_vals)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |