-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathrun_vicuna_200k.py
74 lines (61 loc) · 2.75 KB
/
run_vicuna_200k.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# -*- coding:utf-8 -*-
try:
import fitz # PyMuPDF
except ImportError:
print("run: pip install PyMuPDF")
import os
import argparse
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
from chunkllama_attn_replace import replace_with_chunkllama
from flash_attn_replace import replace_llama_attn_with_flash_attn
def load_model():
device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
model = LlamaForCausalLM.from_pretrained(model_path, attn_implementation="flash_attention_2", trust_remote_code=True, torch_dtype=torch.bfloat16).to(
device)
model = model.eval()
return model
def parse_pdf2text(filename):
try:
doc = fitz.open(os.path.join(filename))
text = ""
for i, page in enumerate(doc): # iterate the document pages
text += f"<Page {i + 1}>: " + page.get_text() # get plain text encoded as UTF-8
print("read from: ", filename)
sys_prompt = "You are given a long paper. Please read the paper and answer the question.\n\n"
return sys_prompt, text
except:
print("unable to parse", filename)
return None
def add_argument():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=str, default="0")
parser.add_argument('--scale', type=str, default="7b")
parser.add_argument('--pdf', type=str, default="Popular_PDFs/chunkllama.pdf")
parser.add_argument('--max_length', type=int, default=64000)
args = parser.parse_args()
return args
args = add_argument()
model_path = f"lmsys/vicuna-{args.scale}-v1.5-16k"
tokenizer = LlamaTokenizer.from_pretrained(model_path, model_max_length=args.max_length, truncation_side="left",
trust_remote_code=True)
# chunk attention
replace_with_chunkllama(pretraining_length=16384)
# original flash attention
# replace_llama_attn_with_flash_attn()
model = load_model()
sys_prompt, content = parse_pdf2text(args.pdf)
for i in range(100):
question = input("User: ")
message = sys_prompt + "USER: "+ content + f"Question:\n{question}" + "\nASSISTANT:"
prompt_length = tokenizer(message, return_tensors="pt").input_ids.size()[-1]
if prompt_length > args.max_length:
print("=" * 20)
print(f"Your input length is {prompt_length}, and it will be truncated to {args.max_length}. You can set `--max_length` to a larger value ")
print("=" * 20)
inputs = tokenizer(message, truncation=True, return_tensors="pt").to(model.device)
inp_length = inputs.input_ids.size()[-1]
sample = model.generate(**inputs, do_sample=False, max_new_tokens=128)
output = tokenizer.decode(sample[0][inp_length:])
print("Chatbot:", output)
print(f"---------------End of round{i}------------------")