Skip to content

Commit

Permalink
update generator
Browse files Browse the repository at this point in the history
  • Loading branch information
Jinjiarui committed Jul 10, 2023
1 parent cf388dc commit 4fdbdc6
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
12 changes: 6 additions & 6 deletions system/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,30 +41,30 @@ def load_converter(name: str) -> BaseConverter:

@register_converter("sst2")
class SST2Converter(BaseConverter):
INTRO = "analyze the sentiment of the following text excerpts, categorizing them as either 'positive', or 'negative'." + "\n"
INTRO = "analyze the sentiment of the following text excerpts, categorizing them as either 'positive', or 'negative' after 'target:'." + "\n"
OUTPUT2LABEL = {0: "negative", 1: "positive"}

@register_converter("sst5")
class SST5Converter(BaseConverter):
INTRO = "analyze the sentiment of the following text excerpts, categorizing them as either 'positive', or 'negative'." + "\n"
INTRO = "analyze the sentiment of the following text excerpts, categorizing them as either 'positive', or 'negative' after 'target:'." + "\n"
OUTPUT2LABEL = {0: "negative", 1: "positive"}

@register_converter("fpb")
class FPBConverter(BaseConverter):
INTRO = "analyze the sentiment of the following text excerpts, categorizing them as one label from following choices 'positive', 'neutral' and 'negative'." + "\n"
INTRO = "analyze the sentiment of the following text excerpts, categorizing them as one label from following choices 'positive', 'neutral' and 'negative' after 'target:'." + "\n"
OUTPUT2LABEL = {0: "negative", 1: "neutral", 2: "positive"}

@register_converter("cola")
class COLAConverter(BaseConverter):
INTRO = "analyze the linguistic acceptability of the following text excerpts, categorizing them as either 'positive', or 'negative'." + "\n"
INTRO = "analyze the linguistic acceptability of the following text excerpts, categorizing them as either 'positive', or 'negative' after 'target:'." + "\n"
OUTPUT2LABEL = {0: "negative", 1: "positive"}

@register_converter("trec")
class TRECConverter(BaseConverter):
INTRO = "analyze the topic of the following text excerpts, categorizing them as one label from following 6 choices 'abbreviation', 'entity', 'description', 'human', 'location' and 'number'." + "\n"
INTRO = "analyze the topic of the following text excerpts, categorizing them as one label from following 6 choices 'abbreviation', 'entity', 'description', 'human', 'location' and 'number' after 'target:'." + "\n"
OUTPUT2LABEL = {0: "abbreviation", 1: "entity", 2: "description", 3: "human", 4: "location", 5: "number"}

@register_converter("subj")
class SUBJConverter(BaseConverter):
INTRO = "determine whether the following text excerpts is 'subjective' or 'objective'." + "\n"
INTRO = "determine whether the following text excerpts is 'subjective' or 'objective' after 'target:'." + "\n"
OUTPUT2LABEL = {0: "objective", 1: "subjective"}
25 changes: 22 additions & 3 deletions system/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def load(self):
else:
self.max_context_len = 4096

def set(self, decode_method: str = None, num_generate: int = None, num_return_sequence: int = None,
def config(self, decode_method: str = None, num_generate: int = None, num_return_sequence: int = None,
add_score: bool = None, temperature: float = None, max_new_tokens: int = None,
num_batch: int = None, max_source_len: int = None, max_target_len: int = None):
self.cfg.set(
self.cfg.config(
decode_method=decode_method, add_score=add_score, num_generate=num_generate, max_new_tokens=max_new_tokens,
num_batch=num_batch, num_return_sequence=num_return_sequence, temperature=temperature,
max_source_len=max_source_len, max_target_len=max_target_len
Expand Down Expand Up @@ -232,4 +232,23 @@ def __getitem__(self, idx):
"attention_mask": torch.stack([f[1] for f in data]),
"labels": torch.stack([f[2] for f in data])})
trainer.train()
self.tokenizer.save_pretrained(self.model_path)
self.tokenizer.save_pretrained(self.model_path)

def interact(self):
input_text = input("> ")
max_new_tokens = int(input("set max_new_tokens>"))
self.config(max_new_tokens=max_new_tokens)

while input_text != "exit":
if self.is_autoreg:
input_text = input_text #+ AutoregLMDataset.IO_SEP
input_text = self.tokenizer.encode(input_text, return_tensors='pt').cuda()
print(f"> input text: {input_text}")
outputs = self.act(input_text=input_text)
print(f"> output tokens: {outputs}")
print(
f"> output text: {self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)}"
)
print("=============================================")
input_text = input("> ")

2 changes: 1 addition & 1 deletion system/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(self, decode_method: str = "beam", add_score: bool = False, num_gen
self.max_source_len = max_source_len
self.max_target_len = max_target_len

def set(self, decode_method=None, add_score=None, num_generate=None, max_new_tokens=None, num_batch=None,
def config(self, decode_method=None, add_score=None, num_generate=None, max_new_tokens=None, num_batch=None,
num_return_sequence=None, temperature=None, max_source_len=None, max_target_len=None):
self.decode_method = decode_method if decode_method else self.decode_method
self.add_score = add_score if add_score else self.add_score
Expand Down

0 comments on commit 4fdbdc6

Please sign in to comment.