![Pytorch Pytorch logo](https://github.com/pytorch/pytorch/raw/master/docs/source/_static/img/pytorch-logo-dark.png)
VDCNN is a neural network that use deep architectures of many convolutional layers to approach Text Classification and Sentiment Analysis using up to 49 layers. You could read the original paper at the following link. This repository is a personal implementation of this paper using PyTorch 1.13.
The overall architecture of this network is shown in the following figure:
The first block is a lookup table
that generates a 2D tensor of size (f0, s) that contain the embeddings of the s characters.
class LookUpTable(nn.Module):
def __init__(self, num_embedding, embedding_dim):
super(LookUpTable, self).__init__()
self.embeddings = nn.Embedding(num_embedding, embedding_dim)
def forward(self, x):
return self.embeddings(x).transpose(1, 2)
Note
The output dimension of the nn.Embedding layer is (s, f0). Use
.transpose
in order to have the right output dimension.
The second layer is a convolutional layer
with in_channel dimension of 64 and kernel dimension of size 3.
class FirstConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(FirstConvLayer, self).__init__()
self.sequential = nn.Sequential(nn.Conv1d(in_channels, out_channels, kernel_size))
The third layer is a convolutional block layer
structured as shown in the following figure:
class ConvolutionalBlock(nn.Module):
def __init__(self, in_channels, out_channels, want_shortcut, downsample, last_layer, pool_type='vgg'):
super(ConvolutionalBlock, self).__init__()
self.want_shortcut = want_shortcut
if self.want_shortcut:
self.shortcut = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=2, bias=False),
nn.BatchNorm1d(out_channels)
)
with the variable want_shortcut
we can choose if we want add shortcut to our net.
self.sequential = nn.Sequential(
nn.BatchNorm1d(in_channels),
nn.ReLU(),
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding='same', bias=False),
nn.BatchNorm1d(out_channels),
nn.ReLU()
)
self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=in_channels,
kernel_size=3, stride=1, padding=1, bias=False)
in this piece of code we build the core part of the convolutional block, as shown in the previously figure. self.conv1 can't be added in self.sequential because its stride depends on the type of pooling we want to use.
if downsample:
if last_layer:
self.want_shortcut = False
self.sequential.append(nn.AdaptiveMaxPool1d(8))
else:
if pool_type == 'convolution':
self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=in_channels,
kernel_size=3, stride=2, padding=1, bias=False)
elif pool_type == 'kmax':
channels = [64, 128, 256, 512]
dimension = [511, 256, 128]
index = channels.index(in_channels)
self.sequential.append(nn.AdaptiveMaxPool1d(dimension[index]))
else:
self.sequential.append(nn.MaxPool1d(kernel_size=3, stride=2, padding=1))
self.relu = nn.ReLU()
the final part of this layer manage the type of pooling that we want to use. We can select the pooling type with the variable pool_type
. The last layer use always k-max pooling with dimension 8 and for this reason we manage this difference between previously layer with the variable last_layer
.
class FullyConnectedBlock(nn.Module):
def __init__(self, n_class):
super(FullyConnectedBlock, self).__init__()
self.sequential = nn.Sequential(
nn.Linear(4096, 2048),
nn.ReLU(),
nn.Linear(2048, 2048),
nn.ReLU(),
nn.Linear(2048, n_class),
nn.Softmax(dim=1)
)
After the sequence of convolutional blocks we have 3 fully connected layer where we have to choose the output number of classes. Different task require different number of classes. We choose the number of classes with the variable n_class
. Since we want to have the probability of each class given a text we use the softmax.
class VDCNN(nn.Module):
def __init__(self, depth, n_classes, want_shortcut=True, pool_type='VGG'):
The last class named VDCNN build all the layer in the right way and with the variable depth
we can choose how many layer to add to our net. The paper present 4 different level of depth: 9, 17, 29, 49. You can find all theese piece of code inside the model.py file.
The dataset used for the training part are the Yahoo! Answers Topic Classification and a subset of Amazon review data that can be downloaded here. The vocabolary used is the same used in the paper: "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’"/|_#$%^&*~‘+=<>()[]{} ". I choose to use 0 as a value for padding and 69 as a value for unknown token. All this datasets are maneged by Dataset class
inside dataset.py file.
The Yahoo! Answers topic classification dataset is constructed using the 10 largest main categories. Each class contains 140000 training samples and 6000 testing samples. Therefore, the total number of training samples is 1400000, and testing samples are 60000. The categories are:
- Society & Culture
- Science & Mathematics
- Health
- Education & Reference
- Computers & Internet
- Sports
- Business & Finance
- Entertainment & Music
- Family & Relationships
- Politics & Government
The Amazon Reviews dataset is constructed using 5 categories (star ratings).
Warning
Even if it can be choosen the device between cpu or GPU, I used and tested the training part only with GPU.
First things first, at the beginning of train.py file there are a some useful global variable that manage the key settings of the training.
LEARNING_RATE = 0.01
MOMENTUM = 0.9
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 128
MAX_LENGTH = 1024
NUM_EPOCHS = 1
PATIENCE = 40
NUM_WORKERS = 4
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_DIR = "dataset/amazon/train.csv"
TEST_DIR = "dataset/amazon/test.csv"
Note
Change
TRAIN_DIR
andTEST_DIR
with your datasets local position.
The train_fn function is build to run one epoch and return the average loss and accuracy of the epoch.
def train_fn(epoch, loader, model, optimizer, loss_fn, scaler):
# a bunch of code
return train_loss, train_accuracy
The main function is build to inizialize and manage the training part until the end.
def main():
model = VDCNN(depth=9, n_classes=5, want_shortcut=True, pool_type='vgg').to(DEVICE)
# training settings
for epoch in range(NUM_EPOCHS):
# run 1 epoch
# check accuracy
# save model
# manage patience for early stopping
# save plot
sys.exit()
Note
Remember to change
n_classes
from 5 to 10 if you use Amazon dataset or Yahoo! Answer dataset.
get_loaders
, save_checkpoint
, load_checkpoint
, check_accuracy
and save_plot
are function used inside tran.py that can be finded inside utils.py.
For computational limitation I trained the models only with depth 9. the result showed below are the test error of my implementation and paper implementation.
Pool Type | My Result | Paper Result |
---|---|---|
Convolution | 32.57 | 28.10 |
KMaxPooling | 28.92 | 28.24 |
MaxPooling | 28.40 | 27.60 |
Pool Type | My Result | Paper Result |
---|---|---|
Convolution | 40.35 | 38.52 |
KMaxPooling | 38.58 | 39.19 |
MaxPooling | 38.45 | 37.95 |
After training run main.py file changing variable WEIGHT_DIR
with the local directory where the weight are saved