-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathstatsrequest.py
48 lines (45 loc) · 1.69 KB
/
statsrequest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import time
import sampling
import logging
def decodeseq(model, seq):
try:
s = model.decode_string(seq).decode('utf8', errors='backslashreplace')
except Exception as e:
logging.exception("Failed to decode sequence")
s = '??? %s\n ' % (str(seq))
return repr(s)
def req2str(model, reqs):
out = ''
for req in reqs:
if req.key is None:
continue
out += '---- %s\n' % req.key
for (k,v) in req.__dict__.items():
if k == 'forced_input':
out += ' forced_input : %s\n' % decodeseq(model, v)
elif k not in ['initial_state', 'samples', 'key', 'on_finish', 'chains']:
out += ' %s : %s\n' % (k,v)
for (i,s) in enumerate(req.samples):
out += ' ---- sample %d\n' % i
for (k,v) in s.__dict__.items():
if k in ['model_output_scores', 'states', 'model_output_probs', 'probs', 'model_next_states']:
out += ' %s : [%d values]\n' % (k, len(v))
elif k in ['sampled_sequence', 'input_tokens']:
out += ' %s : %s\n' % (k,decodeseq(model, v))
elif k not in ['model_input_token', 'model_input_state']:
out += ' %s : %s\n' % (k,v)
return out
class StatsRequestModule():
def __init__(self, sampler):
self.sampler = sampler
def forward(self, request):
request.start_time = time.time()
def backward(self, request):
request.end_time = time.time()
request.elapsed = request.end_time - request.start_time
request.requestinfo = req2str(self.sampler.sampler.model, self.sampler.requests)
class StatsRequest(sampling.SamplerRequest):
def __init__(self, sampler):
self.chains = sampling.SamplerChains([StatsRequestModule(sampler)], [], [])
self.key = None
self.samples = []