forked from HomebrewNLP/HomebrewNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
71 lines (55 loc) · 2.27 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argh
import pathlib
import typing
import yaml
from src.dataclass import Context
from src.utils.utils import setup_torch
from src.utils.formatting import syntax_print
from src.utils.preprocess import preprocess_data
from src.train.train import train_model
from src.inference.inference import complete
def get_context(config_path: typing.Optional[str] = None) -> Context:
'''
Loads context from provided config. Otherwise loads default.
'''
if config_path is not None:
config = pathlib.Path(config_path)
assert config.suffix == '.yaml', 'Expected a .yaml file for config_path'
ctx = Context(config_path=config)
else:
ctx = Context()
return ctx
@argh.arg('-i', '--in_path', default='data.txt', help='Path for data to be preprocessed')
@argh.arg('-o', '--out_path', default='out.tensor', help='Path for data to be preprocessed')
def preprocess(in_path: str = 'data.txt', out_path: str = "out.tensor"):
'''
Processing original data into `out.tensor`
'''
preprocess_data(in_path, out_path)
@argh.arg('-c', '--config_path', default='configs/small.yaml', help='Path for the config file')
def train(config_path: typing.Optional[str] = None):
'''
Trains a model given the config file.
'''
ctx = get_context(config_path)
setup_torch(0)
dump = yaml.dump(ctx.serialize(), indent=4)
syntax_print(dump, "yaml", title="Config")
train_model(ctx)
@argh.arg('prompt', help='Input text to the model')
@argh.arg('-g', '--generated_tokens', default='20', help='Number of tokens to be generated after prompt')
@argh.arg('-t', '--temp', default='0.7', help='Temperature of the model.\nlower = consistency\nhigher = "creativity"')
@argh.arg('-c', '--config_path', help='Path for the config file')
def inference(prompt: str, generated_tokens: int = 20, temp: float = 0.7, config_path: str = None):
'''
Runs inference of input data on desired model
'''
assert config_path is not None, "Expected Config file!"
ctx = get_context(config_path)
# TODO: Load model (pretrained)
# complete(ctx, model, prompt, temp, generated_tokens)
raise NotImplementedError
if __name__ == '__main__':
parser = argh.ArghParser()
parser.add_commands([preprocess, train, inference])
parser.dispatch()