Skip to content

Commit

Permalink
Add notebook
Browse files Browse the repository at this point in the history
susanli2016 authored Aug 2, 2020
1 parent 4e5b70d commit 0f33575
Showing 1 changed file with 1,343 additions and 0 deletions.
1,343 changes: 1,343 additions & 0 deletions Text_Classification_With_BERT.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,1343 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from tqdm.notebook import tqdm\n",
"\n",
"from transformers import BertTokenizer\n",
"from torch.utils.data import TensorDataset\n",
"\n",
"from transformers import BertForSequenceClassification"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('data/title_conference.csv')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Title</th>\n",
" <th>Conference</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>Innovation in Database Management: Computer Sc...</td>\n",
" <td>VLDB</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>High performance prime field multiplication fo...</td>\n",
" <td>ISCAS</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>enchanted scissors: a scissor interface for su...</td>\n",
" <td>SIGGRAPH</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>Detection of channel degradation attack by Int...</td>\n",
" <td>INFOCOM</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>Pinning a Complex Network through the Betweenn...</td>\n",
" <td>ISCAS</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Title Conference\n",
"0 Innovation in Database Management: Computer Sc... VLDB\n",
"1 High performance prime field multiplication fo... ISCAS\n",
"2 enchanted scissors: a scissor interface for su... SIGGRAPH\n",
"3 Detection of channel degradation attack by Int... INFOCOM\n",
"4 Pinning a Complex Network through the Betweenn... ISCAS"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ISCAS 864\n",
"INFOCOM 515\n",
"VLDB 423\n",
"WWW 379\n",
"SIGGRAPH 326\n",
"Name: Conference, dtype: int64"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df['Conference'].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'VLDB': 0, 'ISCAS': 1, 'SIGGRAPH': 2, 'INFOCOM': 3, 'WWW': 4}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"possible_labels = df.Conference.unique()\n",
"\n",
"label_dict = {}\n",
"for index, possible_label in enumerate(possible_labels):\n",
" label_dict[possible_label] = index\n",
"label_dict"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"df['label'] = df.Conference.replace(label_dict)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Title</th>\n",
" <th>Conference</th>\n",
" <th>label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>Innovation in Database Management: Computer Sc...</td>\n",
" <td>VLDB</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>High performance prime field multiplication fo...</td>\n",
" <td>ISCAS</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>enchanted scissors: a scissor interface for su...</td>\n",
" <td>SIGGRAPH</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>Detection of channel degradation attack by Int...</td>\n",
" <td>INFOCOM</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>Pinning a Complex Network through the Betweenn...</td>\n",
" <td>ISCAS</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Title Conference label\n",
"0 Innovation in Database Management: Computer Sc... VLDB 0\n",
"1 High performance prime field multiplication fo... ISCAS 1\n",
"2 enchanted scissors: a scissor interface for su... SIGGRAPH 2\n",
"3 Detection of channel degradation attack by Int... INFOCOM 3\n",
"4 Pinning a Complex Network through the Betweenn... ISCAS 1"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_val, y_train, y_val = train_test_split(df.index.values, \n",
" df.label.values, \n",
" test_size=0.15, \n",
" random_state=42, \n",
" stratify=df.label.values)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"df['data_type'] = ['not_set']*df.shape[0]\n",
"\n",
"df.loc[X_train, 'data_type'] = 'train'\n",
"df.loc[X_val, 'data_type'] = 'val'"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th>Title</th>\n",
" </tr>\n",
" <tr>\n",
" <th>Conference</th>\n",
" <th>label</th>\n",
" <th>data_type</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td rowspan=\"2\" valign=\"top\">INFOCOM</td>\n",
" <td rowspan=\"2\" valign=\"top\">3</td>\n",
" <td>train</td>\n",
" <td>438</td>\n",
" </tr>\n",
" <tr>\n",
" <td>val</td>\n",
" <td>77</td>\n",
" </tr>\n",
" <tr>\n",
" <td rowspan=\"2\" valign=\"top\">ISCAS</td>\n",
" <td rowspan=\"2\" valign=\"top\">1</td>\n",
" <td>train</td>\n",
" <td>734</td>\n",
" </tr>\n",
" <tr>\n",
" <td>val</td>\n",
" <td>130</td>\n",
" </tr>\n",
" <tr>\n",
" <td rowspan=\"2\" valign=\"top\">SIGGRAPH</td>\n",
" <td rowspan=\"2\" valign=\"top\">2</td>\n",
" <td>train</td>\n",
" <td>277</td>\n",
" </tr>\n",
" <tr>\n",
" <td>val</td>\n",
" <td>49</td>\n",
" </tr>\n",
" <tr>\n",
" <td rowspan=\"2\" valign=\"top\">VLDB</td>\n",
" <td rowspan=\"2\" valign=\"top\">0</td>\n",
" <td>train</td>\n",
" <td>359</td>\n",
" </tr>\n",
" <tr>\n",
" <td>val</td>\n",
" <td>64</td>\n",
" </tr>\n",
" <tr>\n",
" <td rowspan=\"2\" valign=\"top\">WWW</td>\n",
" <td rowspan=\"2\" valign=\"top\">4</td>\n",
" <td>train</td>\n",
" <td>322</td>\n",
" </tr>\n",
" <tr>\n",
" <td>val</td>\n",
" <td>57</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Title\n",
"Conference label data_type \n",
"INFOCOM 3 train 438\n",
" val 77\n",
"ISCAS 1 train 734\n",
" val 130\n",
"SIGGRAPH 2 train 277\n",
" val 49\n",
"VLDB 0 train 359\n",
" val 64\n",
"WWW 4 train 322\n",
" val 57"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.groupby(['Conference', 'label', 'data_type']).count()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"I0730 17:03:09.511292 140649029109568 tokenization_utils_base.py:1254] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/jupyter-susan/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n"
]
}
],
"source": [
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', \n",
" do_lower_case=True)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"W0730 17:03:18.913378 140649029109568 tokenization_utils_base.py:1447] Truncation was not explicitely activated but `max_length` is provided a specific value, please use `truncation=True` to explicitely truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n",
"W0730 17:03:19.518625 140649029109568 tokenization_utils_base.py:1447] Truncation was not explicitely activated but `max_length` is provided a specific value, please use `truncation=True` to explicitely truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n"
]
}
],
"source": [
"encoded_data_train = tokenizer.batch_encode_plus(\n",
" df[df.data_type=='train'].Title.values, \n",
" add_special_tokens=True, \n",
" return_attention_mask=True, \n",
" pad_to_max_length=True, \n",
" max_length=256, \n",
" return_tensors='pt'\n",
")\n",
"\n",
"encoded_data_val = tokenizer.batch_encode_plus(\n",
" df[df.data_type=='val'].Title.values, \n",
" add_special_tokens=True, \n",
" return_attention_mask=True, \n",
" pad_to_max_length=True, \n",
" max_length=256, \n",
" return_tensors='pt'\n",
")\n",
"\n",
"\n",
"input_ids_train = encoded_data_train['input_ids']\n",
"attention_masks_train = encoded_data_train['attention_mask']\n",
"labels_train = torch.tensor(df[df.data_type=='train'].label.values)\n",
"\n",
"input_ids_val = encoded_data_val['input_ids']\n",
"attention_masks_val = encoded_data_val['attention_mask']\n",
"labels_val = torch.tensor(df[df.data_type=='val'].label.values)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)\n",
"dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(2130, 377)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(dataset_train), len(dataset_val)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"I0730 17:03:32.775971 140649029109568 configuration_utils.py:264] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /home/jupyter-susan/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517\n",
"I0730 17:03:32.777010 140649029109568 configuration_utils.py:300] Model config BertConfig {\n",
" \"architectures\": [\n",
" \"BertForMaskedLM\"\n",
" ],\n",
" \"attention_probs_dropout_prob\": 0.1,\n",
" \"gradient_checkpointing\": false,\n",
" \"hidden_act\": \"gelu\",\n",
" \"hidden_dropout_prob\": 0.1,\n",
" \"hidden_size\": 768,\n",
" \"id2label\": {\n",
" \"0\": \"LABEL_0\",\n",
" \"1\": \"LABEL_1\",\n",
" \"2\": \"LABEL_2\",\n",
" \"3\": \"LABEL_3\",\n",
" \"4\": \"LABEL_4\"\n",
" },\n",
" \"initializer_range\": 0.02,\n",
" \"intermediate_size\": 3072,\n",
" \"label2id\": {\n",
" \"LABEL_0\": 0,\n",
" \"LABEL_1\": 1,\n",
" \"LABEL_2\": 2,\n",
" \"LABEL_3\": 3,\n",
" \"LABEL_4\": 4\n",
" },\n",
" \"layer_norm_eps\": 1e-12,\n",
" \"max_position_embeddings\": 512,\n",
" \"model_type\": \"bert\",\n",
" \"num_attention_heads\": 12,\n",
" \"num_hidden_layers\": 12,\n",
" \"pad_token_id\": 0,\n",
" \"type_vocab_size\": 2,\n",
" \"vocab_size\": 30522\n",
"}\n",
"\n",
"I0730 17:03:32.815931 140649029109568 modeling_utils.py:667] loading weights file https://cdn.huggingface.co/bert-base-uncased-pytorch_model.bin from cache at /home/jupyter-susan/.cache/torch/transformers/f2ee78bdd635b758cc0a12352586868bef80e47401abe4c4fcc3832421e7338b.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157\n",
"W0730 17:03:38.867964 140649029109568 modeling_utils.py:757] Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n",
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"W0730 17:03:38.868916 140649029109568 modeling_utils.py:768] Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"model = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\",\n",
" num_labels=len(label_dict),\n",
" output_attentions=False,\n",
" output_hidden_states=False)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n",
"\n",
"batch_size = 3\n",
"\n",
"dataloader_train = DataLoader(dataset_train, \n",
" sampler=RandomSampler(dataset_train), \n",
" batch_size=batch_size)\n",
"\n",
"dataloader_validation = DataLoader(dataset_val, \n",
" sampler=SequentialSampler(dataset_val), \n",
" batch_size=batch_size)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AdamW, get_linear_schedule_with_warmup\n",
"\n",
"optimizer = AdamW(model.parameters(),\n",
" lr=1e-5, \n",
" eps=1e-8)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"epochs = 5\n",
"\n",
"scheduler = get_linear_schedule_with_warmup(optimizer, \n",
" num_warmup_steps=0,\n",
" num_training_steps=len(dataloader_train)*epochs)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import f1_score\n",
"\n",
"def f1_score_func(preds, labels):\n",
" preds_flat = np.argmax(preds, axis=1).flatten()\n",
" labels_flat = labels.flatten()\n",
" return f1_score(labels_flat, preds_flat, average='weighted')\n",
"\n",
"def accuracy_per_class(preds, labels):\n",
" label_dict_inverse = {v: k for k, v in label_dict.items()}\n",
" \n",
" preds_flat = np.argmax(preds, axis=1).flatten()\n",
" labels_flat = labels.flatten()\n",
"\n",
" for label in np.unique(labels_flat):\n",
" y_preds = preds_flat[labels_flat==label]\n",
" y_true = labels_flat[labels_flat==label]\n",
" print(f'Class: {label_dict_inverse[label]}')\n",
" print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\\n')"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"seed_val = 17\n",
"random.seed(seed_val)\n",
"np.random.seed(seed_val)\n",
"torch.manual_seed(seed_val)\n",
"torch.cuda.manual_seed_all(seed_val)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cpu\n"
]
}
],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"model.to(device)\n",
"\n",
"print(device)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(dataloader_val):\n",
"\n",
" model.eval()\n",
" \n",
" loss_val_total = 0\n",
" predictions, true_vals = [], []\n",
" \n",
" for batch in dataloader_val:\n",
" \n",
" batch = tuple(b.to(device) for b in batch)\n",
" \n",
" inputs = {'input_ids': batch[0],\n",
" 'attention_mask': batch[1],\n",
" 'labels': batch[2],\n",
" }\n",
"\n",
" with torch.no_grad(): \n",
" outputs = model(**inputs)\n",
" \n",
" loss = outputs[0]\n",
" logits = outputs[1]\n",
" loss_val_total += loss.item()\n",
"\n",
" logits = logits.detach().cpu().numpy()\n",
" label_ids = inputs['labels'].cpu().numpy()\n",
" predictions.append(logits)\n",
" true_vals.append(label_ids)\n",
" \n",
" loss_val_avg = loss_val_total/len(dataloader_val) \n",
" \n",
" predictions = np.concatenate(predictions, axis=0)\n",
" true_vals = np.concatenate(true_vals, axis=0)\n",
" \n",
" return loss_val_avg, predictions, true_vals"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e64800612d024595906f0cd63e6bbf71",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=710.0, style=ProgressStyle(description_widt…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 1\n",
"Training loss: 0.9007002512753849\n",
"Validation loss: 0.6143069127574563\n",
"F1 Score (Weighted): 0.7791319217695921\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=710.0, style=ProgressStyle(description_widt…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 2\n",
"Training loss: 0.5381144283001613\n",
"Validation loss: 0.6438471145765294\n",
"F1 Score (Weighted): 0.8207824902152685\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=710.0, style=ProgressStyle(description_widt…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 3\n",
"Training loss: 0.35893184876292417\n",
"Validation loss: 0.723008230609435\n",
"F1 Score (Weighted): 0.8463474188661483\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=710.0, style=ProgressStyle(description_widt…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 4\n",
"Training loss: 0.2692523200199349\n",
"Validation loss: 0.7796335518272365\n",
"F1 Score (Weighted): 0.8341132163207956\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Epoch 5', max=710.0, style=ProgressStyle(description_widt…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 5\n",
"Training loss: 0.18156354463565766\n",
"Validation loss: 0.8108082735081321\n",
"F1 Score (Weighted): 0.8441012614273822\n",
"\n"
]
}
],
"source": [
"for epoch in tqdm(range(1, epochs+1)):\n",
" \n",
" model.train()\n",
" \n",
" loss_train_total = 0\n",
"\n",
" progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)\n",
" for batch in progress_bar:\n",
"\n",
" model.zero_grad()\n",
" \n",
" batch = tuple(b.to(device) for b in batch)\n",
" \n",
" inputs = {'input_ids': batch[0],\n",
" 'attention_mask': batch[1],\n",
" 'labels': batch[2],\n",
" } \n",
"\n",
" outputs = model(**inputs)\n",
" \n",
" loss = outputs[0]\n",
" loss_train_total += loss.item()\n",
" loss.backward()\n",
"\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
"\n",
" optimizer.step()\n",
" scheduler.step()\n",
" \n",
" progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})\n",
" \n",
" \n",
" torch.save(model.state_dict(), f'data_volume/finetuned_BERT_epoch_{epoch}.model')\n",
" \n",
" tqdm.write(f'\\nEpoch {epoch}')\n",
" \n",
" loss_train_avg = loss_train_total/len(dataloader_train) \n",
" tqdm.write(f'Training loss: {loss_train_avg}')\n",
" \n",
" val_loss, predictions, true_vals = evaluate(dataloader_validation)\n",
" val_f1 = f1_score_func(predictions, true_vals)\n",
" tqdm.write(f'Validation loss: {val_loss}')\n",
" tqdm.write(f'F1 Score (Weighted): {val_f1}')"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"I0730 21:12:18.774542 140649029109568 configuration_utils.py:264] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /home/jupyter-susan/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517\n",
"I0730 21:12:18.775610 140649029109568 configuration_utils.py:300] Model config BertConfig {\n",
" \"architectures\": [\n",
" \"BertForMaskedLM\"\n",
" ],\n",
" \"attention_probs_dropout_prob\": 0.1,\n",
" \"gradient_checkpointing\": false,\n",
" \"hidden_act\": \"gelu\",\n",
" \"hidden_dropout_prob\": 0.1,\n",
" \"hidden_size\": 768,\n",
" \"id2label\": {\n",
" \"0\": \"LABEL_0\",\n",
" \"1\": \"LABEL_1\",\n",
" \"2\": \"LABEL_2\",\n",
" \"3\": \"LABEL_3\",\n",
" \"4\": \"LABEL_4\"\n",
" },\n",
" \"initializer_range\": 0.02,\n",
" \"intermediate_size\": 3072,\n",
" \"label2id\": {\n",
" \"LABEL_0\": 0,\n",
" \"LABEL_1\": 1,\n",
" \"LABEL_2\": 2,\n",
" \"LABEL_3\": 3,\n",
" \"LABEL_4\": 4\n",
" },\n",
" \"layer_norm_eps\": 1e-12,\n",
" \"max_position_embeddings\": 512,\n",
" \"model_type\": \"bert\",\n",
" \"num_attention_heads\": 12,\n",
" \"num_hidden_layers\": 12,\n",
" \"pad_token_id\": 0,\n",
" \"type_vocab_size\": 2,\n",
" \"vocab_size\": 30522\n",
"}\n",
"\n",
"I0730 21:12:18.891486 140649029109568 modeling_utils.py:667] loading weights file https://cdn.huggingface.co/bert-base-uncased-pytorch_model.bin from cache at /home/jupyter-susan/.cache/torch/transformers/f2ee78bdd635b758cc0a12352586868bef80e47401abe4c4fcc3832421e7338b.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157\n",
"W0730 21:12:24.786196 140649029109568 modeling_utils.py:757] Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n",
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"W0730 21:12:24.787092 140649029109568 modeling_utils.py:768] Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
},
{
"data": {
"text/plain": [
"BertForSequenceClassification(\n",
" (bert): BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (1): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (2): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (3): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (4): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (5): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (6): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (7): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (8): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (9): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (10): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (11): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (classifier): Linear(in_features=768, out_features=5, bias=True)\n",
")"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\",\n",
" num_labels=len(label_dict),\n",
" output_attentions=False,\n",
" output_hidden_states=False)\n",
"\n",
"model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.load_state_dict(torch.load('data_volume/finetuned_BERT_epoch_1.model', map_location=torch.device('cpu')))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"_, predictions, true_vals = evaluate(dataloader_validation)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Class: VLDB\n",
"Accuracy: 45/64\n",
"\n",
"Class: ISCAS\n",
"Accuracy: 124/130\n",
"\n",
"Class: SIGGRAPH\n",
"Accuracy: 29/49\n",
"\n",
"Class: INFOCOM\n",
"Accuracy: 65/77\n",
"\n",
"Class: WWW\n",
"Accuracy: 33/57\n",
"\n"
]
}
],
"source": [
"accuracy_per_class(predictions, true_vals)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 0f33575

Please sign in to comment.