Skip to content
This repository has been archived by the owner on Nov 16, 2021. It is now read-only.

Commit

Permalink
add v2 model
Browse files Browse the repository at this point in the history
Signed-off-by: Bedapudi Praneeth <[email protected]>
  • Loading branch information
bedapudi6788 committed Feb 25, 2021
1 parent 512ca5d commit d78d6b9
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 194 deletions.
13 changes: 5 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@ pip install --upgrade fastpunct
```

# Supported languages:
en - english
english

# Usage:

```python
from fastpunct import FastPunct
# The default language is 'en'
fastpunct = FastPunct('en')
fastpunct.punct(["oh i thought you were here", "in theory everyone knows what a comma is", "hey how are you doing", "my name is sheela i am in love with hrithik"], batch_size=32)
# ['Oh! I thought you were here.', 'In theory, everyone knows what a comma is.', 'Hey! How are you doing?', 'My name is Sheela. I am in love with Hrithik.']

# The default language is 'english'
fastpunct = FastPunct()
fastpunct.punct(["john smiths dog is creating a ruccus", "ys jagan is the chief minister of andhra pradesh", "we visted new york last year in may"])
# ["John Smith's dog is creating a ruccus.", 'Ys Jagan is the chief minister of Andhra Pradesh.', 'We visted New York last year in May.']
```
# Note:
maximum length of input currently supported - 400
268 changes: 88 additions & 180 deletions fastpunct/fastpunct.py
Original file line number Diff line number Diff line change
@@ -1,187 +1,95 @@
# -*- coding: utf-8 -*-
"""
Created on Sun May 10 15:46:01 2020
@author: harikodali
"""
import os
import pickle
import torch
import pydload

import numpy as np

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, TimeDistributed, Activation, dot, concatenate, Bidirectional
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical

def get_text_encodings(texts, parameters):

enc_seq = parameters["enc_token"].texts_to_sequences(texts)
pad_seq = pad_sequences(enc_seq, maxlen=parameters["max_encoder_seq_length"],
padding='post')
pad_seq = to_categorical(pad_seq, num_classes=parameters["enc_vocab_size"])
return pad_seq


def get_extra_chars(parameters):
allowed_extras = []
for d_c, d_i in parameters["dec_token"].word_index.items():
if d_c.lower() not in parameters["enc_token"].word_index:
allowed_extras.append(d_i)
return allowed_extras

def get_model_instance(parameters):

encoder_inputs = Input(shape=(None, parameters["enc_vocab_size"],))
encoder = Bidirectional(LSTM(128, return_sequences=True, return_state=True),
merge_mode='concat')
encoder_outputs, forward_h, forward_c, backward_h, backward_c = encoder(encoder_inputs)

encoder_h = concatenate([forward_h, backward_h])
encoder_c = concatenate([forward_c, backward_c])

decoder_inputs = Input(shape=(None, parameters["dec_vocab_size"],))
decoder_lstm = LSTM(256, return_sequences=True)
decoder_outputs = decoder_lstm(decoder_inputs, initial_state=[encoder_h, encoder_c])

attention = dot([decoder_outputs, encoder_outputs], axes=(2, 2))
attention = Activation('softmax', name='attention')(attention)
context = dot([attention, encoder_outputs], axes=(2, 1))
decoder_combined_context = concatenate([context, decoder_outputs])

output = TimeDistributed(Dense(128, activation="relu"))(decoder_combined_context)
output = TimeDistributed(Dense(parameters["dec_vocab_size"], activation="softmax"))(output)

model = Model([encoder_inputs, decoder_inputs], [output])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

return model


def decode(model, parameters, input_texts, allowed_extras, batch_size):
input_texts_c = input_texts.copy()
out_dict = {}
input_sequences = get_text_encodings(input_texts, parameters)

parameters["reverse_dec_dict"][0] = "\n"
outputs = [""]*len(input_sequences)

target_text = "\t"
target_seq = parameters["dec_token"].texts_to_sequences([target_text]*len(input_sequences))
target_seq = pad_sequences(target_seq, maxlen=parameters["max_decoder_seq_length"],
padding="post")
target_seq_hot = to_categorical(target_seq, num_classes=parameters["dec_vocab_size"])

extra_char_count = [0]*len(input_texts)
prev_char_index = [0]*len(input_texts)
i = 0
while len(input_texts) != 0:
curr_char_index = [i - extra_char_count[j] for j in range(len(input_texts))]
input_encodings = np.argmax(input_sequences, axis=2)

cur_inp_list = [input_encodings[_][curr_char_index[_]] if curr_char_index[_] < len(input_texts[_]) else 0 for _ in range(len(input_texts))]
output_tokens = model.predict([input_sequences, target_seq_hot], batch_size=batch_size)
sampled_possible_indices = np.argsort(output_tokens[:, i, :])[:, ::-1].tolist()
sampled_token_indices = []
for j, per_char_list in enumerate(sampled_possible_indices):
for index in per_char_list:
if index in allowed_extras:
if parameters["reverse_dec_dict"][index] == '\n' and cur_inp_list[j] != 0:
continue
elif parameters["reverse_dec_dict"][index] != '\n' and prev_char_index[j] in allowed_extras:
continue
sampled_token_indices.append(index)
extra_char_count[j] += 1
break
elif parameters["enc_token"].word_index[parameters["reverse_dec_dict"][index].lower()] == cur_inp_list[j]:
sampled_token_indices.append(index)
break

sampled_chars = [parameters["reverse_dec_dict"][index] for index in sampled_token_indices]

outputs = [outputs[j] + sampled_chars[j] for j, output in enumerate(outputs)]
end_indices = sorted([index for index, char in enumerate(sampled_chars) if char == '\n'], reverse=True)
for index in end_indices:
out_dict[input_texts[index]] = outputs[index].strip()
del outputs[index]
del input_texts[index]
del extra_char_count[index]
del sampled_token_indices[index]
input_sequences = np.delete(input_sequences, index, axis=0)
target_seq = np.delete(target_seq, index, axis=0)
if i == parameters["max_decoder_seq_length"]-1 or len(input_texts) == 0:
break
target_seq[:,i+1] = sampled_token_indices
target_seq_hot = to_categorical(target_seq, num_classes=parameters["dec_vocab_size"])
prev_char_index = sampled_token_indices
i += 1
outputs = [out_dict[text] for text in input_texts_c]
return outputs


model_links = {
'en': {
'checkpoint': 'https://github.com/notAI-tech/fastPunct/releases/download/checkpoint-release/fastpunct_eng_weights.h5',
'params': 'https://github.com/notAI-tech/fastPunct/releases/download/checkpoint-release/parameter_dict.pkl'
},

}

lang_code_mapping = {
'english': 'en',
'french': 'fr',
'italian': 'it'
from transformers import T5Tokenizer, T5ForConditionalGeneration

MODEL_URLS = {
"english": {
"pytorch_model.bin": "https://github.com/notAI-tech/fastPunct/releases/download/v2/pytorch_model.bin",
"config.json": "https://github.com/notAI-tech/fastPunct/releases/download/v2/config.json",
"special_tokens_map.json": "https://github.com/notAI-tech/fastPunct/releases/download/v2/special_tokens_map.json",
"spiece.model": "https://github.com/notAI-tech/fastPunct/releases/download/v2/spiece.model",
"tokenizer_config.json": "https://github.com/notAI-tech/fastPunct/releases/download/v2/tokenizer_config.json",
},
}

class FastPunct():

class FastPunct:
tokenizer = None
model = None
parameters = None
def __init__(self, lang_code="en", weights_path=None, params_path=None):
if lang_code not in model_links and lang_code in lang_code_mapping:
lang_code = lang_code_mapping[lang_code]

if lang_code not in model_links:
print("fastPunct doesn't support '" + lang_code + "' yet.")
print("Please raise a issue at https://github.com/notai-tech/fastPunct/ to add this language into future checklist.")

def __init__(self, language='english', checkpoint_local_path=None):

model_name = language.lower()

if model_name not in MODEL_URLS:
print(f"model_name should be one of {list(MODEL_URLS.keys())}")
return None

home = os.path.expanduser("~")
lang_path = os.path.join(home, '.fastPunct_' + lang_code)
if weights_path is None:
weights_path = os.path.join(lang_path, 'checkpoint.h5')
if params_path is None:
params_path = os.path.join(lang_path, 'params.pkl')

#if either of the paths are not mentioned, then, make lang directory from home
if (params_path is None) or (weights_path is None):
if not os.path.exists(lang_path):
os.mkdir(lang_path)

if not os.path.exists(weights_path):
print('Downloading checkpoint', model_links[lang_code]['checkpoint'], 'to', weights_path)
pydload.dload(url=model_links[lang_code]['checkpoint'], save_to_path=weights_path, max_time=None)

if not os.path.exists(params_path):
print('Downloading model params', model_links[lang_code]['params'], 'to', params_path)
pydload.dload(url=model_links[lang_code]['params'], save_to_path=params_path, max_time=None)


with open(params_path, "rb") as file:
self.parameters = pickle.load(file)
self.parameters["reverse_enc_dict"] = {i:c for c, i in self.parameters["enc_token"].word_index.items()}
self.model = get_model_instance(self.parameters)
self.model.load_weights(weights_path)
self.allowed_extras = get_extra_chars(self.parameters)

def punct(self, input_texts, batch_size=32):
input_texts = [text.lower() for text in input_texts]
return decode(self.model, self.parameters, input_texts, self.allowed_extras, batch_size)

def fastpunct(self, input_texts, batch_size=32):
# To be implemented
return None

if __name__ == "__main__":
fastpunct = FastPunct()
print(fastpunct.punct(["oh i thought you were here", "in theory everyone knows what a comma is", "hey how are you doing", "my name is sheela i am in love with hrithik"]))
lang_path = os.path.join(home, ".FastPunct_" + model_name)

if checkpoint_local_path:
lang_path = checkpoint_local_path

if not os.path.exists(lang_path):
os.mkdir(lang_path)

for file_name, url in MODEL_URLS[model_name].items():
file_path = os.path.join(lang_path, file_name)
if os.path.exists(file_path):
continue
print(f"Downloading {file_name}")
pydload.dload(url=url, save_to_path=file_path, max_time=None)

self.tokenizer = T5Tokenizer.from_pretrained(lang_path)
self.model = T5ForConditionalGeneration.from_pretrained(
lang_path, return_dict=True
)

if torch.cuda.is_available():
print(f"Using GPU")
self.model = self.model.cuda()

def punct(
self, sentences, beam_size=1, max_len=None, correct=False
):
return_single = True
if isinstance(sentences, list):
return_single = False
else:
sentences = [sentences]

prefix = 'punctuate'
if correct:
beam_size = 8
prefix = 'correct'

input_ids = self.tokenizer(
[
f"{prefix}: {sentence}"
for sentence in sentences
],
return_tensors="pt",
padding=True,
).input_ids

if not max_len:
max_len = max([len(tokenized_input) for tokenized_input in input_ids]) + max([len(s.split()) for s in sentences]) + 4

if torch.cuda.is_available():
input_ids = input_ids.to("cuda")

output_ids = self.model.generate(
input_ids, num_beams=beam_size, max_length=max_len
)

outputs = [
self.tokenizer.decode(output_id, skip_special_tokens=True)
for output_id in output_ids
]

if return_single:
outputs = outputs[0]

return outputs
9 changes: 3 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@

# Package meta-data.
NAME = 'fastpunct'
DESCRIPTION = 'Punctuation restoration with sequence to sequence networks'
DESCRIPTION = 'Punctuation restoration and spell correction.'
URL = 'https://github.com/notAI-tech/fastPunct'
EMAIL = '[email protected]'
AUTHOR = 'Hari Krishna Sai Kodali'
REQUIRES_PYTHON = '>=3.5.0'
REQUIRES_PYTHON = '>=3.6.0'
VERSION = subprocess.run(['git', 'describe', '--tags'], stdout=subprocess.PIPE).stdout.decode("utf-8").strip()

# What packages are required for this module to be executed?
REQUIRED = [
'numpy',
'pydload'
]
REQUIRED = ["transformers>=4.0.0rc1", "pydload>=1.0.9", "torch>=1.5.0", "sentencepiece"]

# What packages are optional?
EXTRAS = {
Expand Down

0 comments on commit d78d6b9

Please sign in to comment.