Skip to content

Commit

Permalink
hotfix
Browse files Browse the repository at this point in the history
  • Loading branch information
zzudongxiang committed Feb 26, 2025
1 parent 576f0ec commit 1bbabee
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions utils/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def batch_generate(model, prompt_tokens, eos_id, warmup=False) -> List[List[int]
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long)
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long)
ttft = 0
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens))
prompt_mask = tokens != -1
Expand Down Expand Up @@ -55,13 +56,13 @@ def batch_generate(model, prompt_tokens, eos_id, warmup=False) -> List[List[int]
recv_next_token[i: i + mini_batch_size] = next_token
if ttft_flag:
ttft_flag = False
ttft = datetime.now() - t0
ttft = (datetime.now() - t0).total_seconds()
t0 = datetime.now()
else:
recv_next_token[i: i + mini_batch_size] = next_token
if ttft_flag:
ttft_flag = False
ttft = datetime.now() - t0
ttft = (datetime.now() - t0).total_seconds()
t0 = datetime.now()
for handle in recv_handles:
if not handle.is_completed():
Expand All @@ -86,6 +87,6 @@ def batch_generate(model, prompt_tokens, eos_id, warmup=False) -> List[List[int]
completion_tokens.append(toks)
tpot = output_tokens / dur
stop_progress(thread_tokens)
log_rank0(f"TTFT: {ttft.total_seconds():.4f} seconds ({min(prompt_lens)} tokens)")
log_rank0(f"TTFT: {ttft:.4f} seconds ({min(prompt_lens)} tokens)")
log_rank0(f"Throughput: {tpot:.4f} tokens/s ({output_tokens} tokens for {dur:.4f} seconds)")
return completion_tokens

0 comments on commit 1bbabee

Please sign in to comment.