Skip to content
This repository has been archived by the owner on Nov 16, 2021. It is now read-only.

Commit

Permalink
add warmup
Browse files Browse the repository at this point in the history
Signed-off-by: Bedapudi Praneeth <[email protected]>
  • Loading branch information
bedapudi6788 committed Feb 25, 2021
1 parent dc66268 commit bc98220
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions fastpunct/fastpunct.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import torch
import logging
import pydload
from transformers import T5Tokenizer, T5ForConditionalGeneration

Expand All @@ -23,7 +24,7 @@ def __init__(self, language='english', checkpoint_local_path=None):
model_name = language.lower()

if model_name not in MODEL_URLS:
print(f"model_name should be one of {list(MODEL_URLS.keys())}")
logging.warn(f"model_name should be one of {list(MODEL_URLS.keys())}")
return None

home = os.path.expanduser("~")
Expand All @@ -39,7 +40,7 @@ def __init__(self, language='english', checkpoint_local_path=None):
file_path = os.path.join(lang_path, file_name)
if os.path.exists(file_path):
continue
print(f"Downloading {file_name}")
logging.info(f"Downloading {file_name}")
pydload.dload(url=url, save_to_path=file_path, max_time=None)

self.tokenizer = T5Tokenizer.from_pretrained(lang_path)
Expand All @@ -48,8 +49,11 @@ def __init__(self, language='english', checkpoint_local_path=None):
)

if torch.cuda.is_available():
print(f"Using GPU")
logging.info(f"Using GPU")
self.model = self.model.cuda()

logging.info("Warming up")
self.punct(["i am batman"])

def punct(
self, sentences, beam_size=1, max_len=None, correct=False
Expand Down

0 comments on commit bc98220

Please sign in to comment.