-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathsample.py
33 lines (24 loc) · 849 Bytes
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import LanguageModel
import torch
import sampling
import argparse
import modules as M
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', default='models/test.json')
args = parser.parse_args()
model = LanguageModel.LanguageModel()
model.load_json(args.checkpoint)
model.eval()
sampler = sampling.Sampler(model)
stor = M.DefaultStateStore(model)
pc = sampling.default_put_chains(stor)
gc = sampling.default_get_chains(stor, endtoken=[model.token_to_idx[b'\n']])
#print(pc.__dict__)
#print(gc.__dict__)
#gc.sample_post += [M.PrintSampledString(model)]
sampler.run_requests([sampler.make_put_request(pc, model.encode_string('Hello!\n'))])
print('ok!')
while True:
req = sampler.make_get_request(gc)
sampler.run_requests([req])
print(model.decode_string(req.sampled_sequence).decode(errors='backslashreplace'), end="")