diff --git a/Text_Classification_With_BERT.ipynb b/Text_Classification_With_BERT.ipynb new file mode 100644 index 0000000..3de2ab9 --- /dev/null +++ b/Text_Classification_With_BERT.ipynb @@ -0,0 +1,1343 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from tqdm.notebook import tqdm\n", + "\n", + "from transformers import BertTokenizer\n", + "from torch.utils.data import TensorDataset\n", + "\n", + "from transformers import BertForSequenceClassification" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.read_csv('data/title_conference.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
TitleConference
0Innovation in Database Management: Computer Sc...VLDB
1High performance prime field multiplication fo...ISCAS
2enchanted scissors: a scissor interface for su...SIGGRAPH
3Detection of channel degradation attack by Int...INFOCOM
4Pinning a Complex Network through the Betweenn...ISCAS
\n", + "
" + ], + "text/plain": [ + " Title Conference\n", + "0 Innovation in Database Management: Computer Sc... VLDB\n", + "1 High performance prime field multiplication fo... ISCAS\n", + "2 enchanted scissors: a scissor interface for su... SIGGRAPH\n", + "3 Detection of channel degradation attack by Int... INFOCOM\n", + "4 Pinning a Complex Network through the Betweenn... ISCAS" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ISCAS 864\n", + "INFOCOM 515\n", + "VLDB 423\n", + "WWW 379\n", + "SIGGRAPH 326\n", + "Name: Conference, dtype: int64" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df['Conference'].value_counts()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'VLDB': 0, 'ISCAS': 1, 'SIGGRAPH': 2, 'INFOCOM': 3, 'WWW': 4}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "possible_labels = df.Conference.unique()\n", + "\n", + "label_dict = {}\n", + "for index, possible_label in enumerate(possible_labels):\n", + " label_dict[possible_label] = index\n", + "label_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "df['label'] = df.Conference.replace(label_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
TitleConferencelabel
0Innovation in Database Management: Computer Sc...VLDB0
1High performance prime field multiplication fo...ISCAS1
2enchanted scissors: a scissor interface for su...SIGGRAPH2
3Detection of channel degradation attack by Int...INFOCOM3
4Pinning a Complex Network through the Betweenn...ISCAS1
\n", + "
" + ], + "text/plain": [ + " Title Conference label\n", + "0 Innovation in Database Management: Computer Sc... VLDB 0\n", + "1 High performance prime field multiplication fo... ISCAS 1\n", + "2 enchanted scissors: a scissor interface for su... SIGGRAPH 2\n", + "3 Detection of channel degradation attack by Int... INFOCOM 3\n", + "4 Pinning a Complex Network through the Betweenn... ISCAS 1" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "\n", + "X_train, X_val, y_train, y_val = train_test_split(df.index.values, \n", + " df.label.values, \n", + " test_size=0.15, \n", + " random_state=42, \n", + " stratify=df.label.values)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "df['data_type'] = ['not_set']*df.shape[0]\n", + "\n", + "df.loc[X_train, 'data_type'] = 'train'\n", + "df.loc[X_val, 'data_type'] = 'val'" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Title
Conferencelabeldata_type
INFOCOM3train438
val77
ISCAS1train734
val130
SIGGRAPH2train277
val49
VLDB0train359
val64
WWW4train322
val57
\n", + "
" + ], + "text/plain": [ + " Title\n", + "Conference label data_type \n", + "INFOCOM 3 train 438\n", + " val 77\n", + "ISCAS 1 train 734\n", + " val 130\n", + "SIGGRAPH 2 train 277\n", + " val 49\n", + "VLDB 0 train 359\n", + " val 64\n", + "WWW 4 train 322\n", + " val 57" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.groupby(['Conference', 'label', 'data_type']).count()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "I0730 17:03:09.511292 140649029109568 tokenization_utils_base.py:1254] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/jupyter-susan/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n" + ] + } + ], + "source": [ + "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', \n", + " do_lower_case=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "W0730 17:03:18.913378 140649029109568 tokenization_utils_base.py:1447] Truncation was not explicitely activated but `max_length` is provided a specific value, please use `truncation=True` to explicitely truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n", + "W0730 17:03:19.518625 140649029109568 tokenization_utils_base.py:1447] Truncation was not explicitely activated but `max_length` is provided a specific value, please use `truncation=True` to explicitely truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n" + ] + } + ], + "source": [ + "encoded_data_train = tokenizer.batch_encode_plus(\n", + " df[df.data_type=='train'].Title.values, \n", + " add_special_tokens=True, \n", + " return_attention_mask=True, \n", + " pad_to_max_length=True, \n", + " max_length=256, \n", + " return_tensors='pt'\n", + ")\n", + "\n", + "encoded_data_val = tokenizer.batch_encode_plus(\n", + " df[df.data_type=='val'].Title.values, \n", + " add_special_tokens=True, \n", + " return_attention_mask=True, \n", + " pad_to_max_length=True, \n", + " max_length=256, \n", + " return_tensors='pt'\n", + ")\n", + "\n", + "\n", + "input_ids_train = encoded_data_train['input_ids']\n", + "attention_masks_train = encoded_data_train['attention_mask']\n", + "labels_train = torch.tensor(df[df.data_type=='train'].label.values)\n", + "\n", + "input_ids_val = encoded_data_val['input_ids']\n", + "attention_masks_val = encoded_data_val['attention_mask']\n", + "labels_val = torch.tensor(df[df.data_type=='val'].label.values)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)\n", + "dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(2130, 377)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(dataset_train), len(dataset_val)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "I0730 17:03:32.775971 140649029109568 configuration_utils.py:264] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /home/jupyter-susan/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517\n", + "I0730 17:03:32.777010 140649029109568 configuration_utils.py:300] Model config BertConfig {\n", + " \"architectures\": [\n", + " \"BertForMaskedLM\"\n", + " ],\n", + " \"attention_probs_dropout_prob\": 0.1,\n", + " \"gradient_checkpointing\": false,\n", + " \"hidden_act\": \"gelu\",\n", + " \"hidden_dropout_prob\": 0.1,\n", + " \"hidden_size\": 768,\n", + " \"id2label\": {\n", + " \"0\": \"LABEL_0\",\n", + " \"1\": \"LABEL_1\",\n", + " \"2\": \"LABEL_2\",\n", + " \"3\": \"LABEL_3\",\n", + " \"4\": \"LABEL_4\"\n", + " },\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 3072,\n", + " \"label2id\": {\n", + " \"LABEL_0\": 0,\n", + " \"LABEL_1\": 1,\n", + " \"LABEL_2\": 2,\n", + " \"LABEL_3\": 3,\n", + " \"LABEL_4\": 4\n", + " },\n", + " \"layer_norm_eps\": 1e-12,\n", + " \"max_position_embeddings\": 512,\n", + " \"model_type\": \"bert\",\n", + " \"num_attention_heads\": 12,\n", + " \"num_hidden_layers\": 12,\n", + " \"pad_token_id\": 0,\n", + " \"type_vocab_size\": 2,\n", + " \"vocab_size\": 30522\n", + "}\n", + "\n", + "I0730 17:03:32.815931 140649029109568 modeling_utils.py:667] loading weights file https://cdn.huggingface.co/bert-base-uncased-pytorch_model.bin from cache at /home/jupyter-susan/.cache/torch/transformers/f2ee78bdd635b758cc0a12352586868bef80e47401abe4c4fcc3832421e7338b.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157\n", + "W0730 17:03:38.867964 140649029109568 modeling_utils.py:757] Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n", + "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n", + "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "W0730 17:03:38.868916 140649029109568 modeling_utils.py:768] Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "model = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\",\n", + " num_labels=len(label_dict),\n", + " output_attentions=False,\n", + " output_hidden_states=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n", + "\n", + "batch_size = 3\n", + "\n", + "dataloader_train = DataLoader(dataset_train, \n", + " sampler=RandomSampler(dataset_train), \n", + " batch_size=batch_size)\n", + "\n", + "dataloader_validation = DataLoader(dataset_val, \n", + " sampler=SequentialSampler(dataset_val), \n", + " batch_size=batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AdamW, get_linear_schedule_with_warmup\n", + "\n", + "optimizer = AdamW(model.parameters(),\n", + " lr=1e-5, \n", + " eps=1e-8)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 5\n", + "\n", + "scheduler = get_linear_schedule_with_warmup(optimizer, \n", + " num_warmup_steps=0,\n", + " num_training_steps=len(dataloader_train)*epochs)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import f1_score\n", + "\n", + "def f1_score_func(preds, labels):\n", + " preds_flat = np.argmax(preds, axis=1).flatten()\n", + " labels_flat = labels.flatten()\n", + " return f1_score(labels_flat, preds_flat, average='weighted')\n", + "\n", + "def accuracy_per_class(preds, labels):\n", + " label_dict_inverse = {v: k for k, v in label_dict.items()}\n", + " \n", + " preds_flat = np.argmax(preds, axis=1).flatten()\n", + " labels_flat = labels.flatten()\n", + "\n", + " for label in np.unique(labels_flat):\n", + " y_preds = preds_flat[labels_flat==label]\n", + " y_true = labels_flat[labels_flat==label]\n", + " print(f'Class: {label_dict_inverse[label]}')\n", + " print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "\n", + "seed_val = 17\n", + "random.seed(seed_val)\n", + "np.random.seed(seed_val)\n", + "torch.manual_seed(seed_val)\n", + "torch.cuda.manual_seed_all(seed_val)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cpu\n" + ] + } + ], + "source": [ + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "model.to(device)\n", + "\n", + "print(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate(dataloader_val):\n", + "\n", + " model.eval()\n", + " \n", + " loss_val_total = 0\n", + " predictions, true_vals = [], []\n", + " \n", + " for batch in dataloader_val:\n", + " \n", + " batch = tuple(b.to(device) for b in batch)\n", + " \n", + " inputs = {'input_ids': batch[0],\n", + " 'attention_mask': batch[1],\n", + " 'labels': batch[2],\n", + " }\n", + "\n", + " with torch.no_grad(): \n", + " outputs = model(**inputs)\n", + " \n", + " loss = outputs[0]\n", + " logits = outputs[1]\n", + " loss_val_total += loss.item()\n", + "\n", + " logits = logits.detach().cpu().numpy()\n", + " label_ids = inputs['labels'].cpu().numpy()\n", + " predictions.append(logits)\n", + " true_vals.append(label_ids)\n", + " \n", + " loss_val_avg = loss_val_total/len(dataloader_val) \n", + " \n", + " predictions = np.concatenate(predictions, axis=0)\n", + " true_vals = np.concatenate(true_vals, axis=0)\n", + " \n", + " return loss_val_avg, predictions, true_vals" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e64800612d024595906f0cd63e6bbf71", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=710.0, style=ProgressStyle(description_widt…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch 1\n", + "Training loss: 0.9007002512753849\n", + "Validation loss: 0.6143069127574563\n", + "F1 Score (Weighted): 0.7791319217695921\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=710.0, style=ProgressStyle(description_widt…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch 2\n", + "Training loss: 0.5381144283001613\n", + "Validation loss: 0.6438471145765294\n", + "F1 Score (Weighted): 0.8207824902152685\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=710.0, style=ProgressStyle(description_widt…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch 3\n", + "Training loss: 0.35893184876292417\n", + "Validation loss: 0.723008230609435\n", + "F1 Score (Weighted): 0.8463474188661483\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=710.0, style=ProgressStyle(description_widt…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch 4\n", + "Training loss: 0.2692523200199349\n", + "Validation loss: 0.7796335518272365\n", + "F1 Score (Weighted): 0.8341132163207956\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Epoch 5', max=710.0, style=ProgressStyle(description_widt…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch 5\n", + "Training loss: 0.18156354463565766\n", + "Validation loss: 0.8108082735081321\n", + "F1 Score (Weighted): 0.8441012614273822\n", + "\n" + ] + } + ], + "source": [ + "for epoch in tqdm(range(1, epochs+1)):\n", + " \n", + " model.train()\n", + " \n", + " loss_train_total = 0\n", + "\n", + " progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)\n", + " for batch in progress_bar:\n", + "\n", + " model.zero_grad()\n", + " \n", + " batch = tuple(b.to(device) for b in batch)\n", + " \n", + " inputs = {'input_ids': batch[0],\n", + " 'attention_mask': batch[1],\n", + " 'labels': batch[2],\n", + " } \n", + "\n", + " outputs = model(**inputs)\n", + " \n", + " loss = outputs[0]\n", + " loss_train_total += loss.item()\n", + " loss.backward()\n", + "\n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", + "\n", + " optimizer.step()\n", + " scheduler.step()\n", + " \n", + " progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})\n", + " \n", + " \n", + " torch.save(model.state_dict(), f'data_volume/finetuned_BERT_epoch_{epoch}.model')\n", + " \n", + " tqdm.write(f'\\nEpoch {epoch}')\n", + " \n", + " loss_train_avg = loss_train_total/len(dataloader_train) \n", + " tqdm.write(f'Training loss: {loss_train_avg}')\n", + " \n", + " val_loss, predictions, true_vals = evaluate(dataloader_validation)\n", + " val_f1 = f1_score_func(predictions, true_vals)\n", + " tqdm.write(f'Validation loss: {val_loss}')\n", + " tqdm.write(f'F1 Score (Weighted): {val_f1}')" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "I0730 21:12:18.774542 140649029109568 configuration_utils.py:264] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /home/jupyter-susan/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517\n", + "I0730 21:12:18.775610 140649029109568 configuration_utils.py:300] Model config BertConfig {\n", + " \"architectures\": [\n", + " \"BertForMaskedLM\"\n", + " ],\n", + " \"attention_probs_dropout_prob\": 0.1,\n", + " \"gradient_checkpointing\": false,\n", + " \"hidden_act\": \"gelu\",\n", + " \"hidden_dropout_prob\": 0.1,\n", + " \"hidden_size\": 768,\n", + " \"id2label\": {\n", + " \"0\": \"LABEL_0\",\n", + " \"1\": \"LABEL_1\",\n", + " \"2\": \"LABEL_2\",\n", + " \"3\": \"LABEL_3\",\n", + " \"4\": \"LABEL_4\"\n", + " },\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 3072,\n", + " \"label2id\": {\n", + " \"LABEL_0\": 0,\n", + " \"LABEL_1\": 1,\n", + " \"LABEL_2\": 2,\n", + " \"LABEL_3\": 3,\n", + " \"LABEL_4\": 4\n", + " },\n", + " \"layer_norm_eps\": 1e-12,\n", + " \"max_position_embeddings\": 512,\n", + " \"model_type\": \"bert\",\n", + " \"num_attention_heads\": 12,\n", + " \"num_hidden_layers\": 12,\n", + " \"pad_token_id\": 0,\n", + " \"type_vocab_size\": 2,\n", + " \"vocab_size\": 30522\n", + "}\n", + "\n", + "I0730 21:12:18.891486 140649029109568 modeling_utils.py:667] loading weights file https://cdn.huggingface.co/bert-base-uncased-pytorch_model.bin from cache at /home/jupyter-susan/.cache/torch/transformers/f2ee78bdd635b758cc0a12352586868bef80e47401abe4c4fcc3832421e7338b.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157\n", + "W0730 21:12:24.786196 140649029109568 modeling_utils.py:757] Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n", + "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n", + "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "W0730 21:12:24.787092 140649029109568 modeling_utils.py:768] Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + }, + { + "data": { + "text/plain": [ + "BertForSequenceClassification(\n", + " (bert): BertModel(\n", + " (embeddings): BertEmbeddings(\n", + " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", + " (position_embeddings): Embedding(512, 768)\n", + " (token_type_embeddings): Embedding(2, 768)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): BertEncoder(\n", + " (layer): ModuleList(\n", + " (0): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (1): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (2): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (3): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (4): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (5): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (6): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (7): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (8): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (9): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (10): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (11): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (pooler): BertPooler(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (activation): Tanh()\n", + " )\n", + " )\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (classifier): Linear(in_features=768, out_features=5, bias=True)\n", + ")" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\",\n", + " num_labels=len(label_dict),\n", + " output_attentions=False,\n", + " output_hidden_states=False)\n", + "\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.load_state_dict(torch.load('data_volume/finetuned_BERT_epoch_1.model', map_location=torch.device('cpu')))" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "_, predictions, true_vals = evaluate(dataloader_validation)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Class: VLDB\n", + "Accuracy: 45/64\n", + "\n", + "Class: ISCAS\n", + "Accuracy: 124/130\n", + "\n", + "Class: SIGGRAPH\n", + "Accuracy: 29/49\n", + "\n", + "Class: INFOCOM\n", + "Accuracy: 65/77\n", + "\n", + "Class: WWW\n", + "Accuracy: 33/57\n", + "\n" + ] + } + ], + "source": [ + "accuracy_per_class(predictions, true_vals)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}