-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Added Gemma support (needs hf auth token)
- Loading branch information
1 parent
06281ff
commit d3fb4a3
Showing
7 changed files
with
196 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import logging | ||
from threading import Thread | ||
|
||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | ||
|
||
from ..presets import * | ||
from .base_model import BaseLLMModel | ||
|
||
|
||
class GoogleGemmaClient(BaseLLMModel): | ||
def __init__(self, model_name, api_key, user_name="") -> None: | ||
super().__init__(model_name=model_name, user=user_name) | ||
|
||
global GEMMA_TOKENIZER, GEMMA_MODEL | ||
# self.deinitialize() | ||
self.max_generation_token = self.token_upper_limit | ||
if GEMMA_TOKENIZER is None or GEMMA_MODEL is None: | ||
model_path = None | ||
if os.path.exists("models"): | ||
model_dirs = os.listdir("models") | ||
if model_name in model_dirs: | ||
model_path = f"models/{model_name}" | ||
if model_path is not None: | ||
model_source = model_path | ||
else: | ||
if os.path.exists( | ||
os.path.join("models", MODEL_METADATA[model_name]["model_name"]) | ||
): | ||
model_source = os.path.join( | ||
"models", MODEL_METADATA[model_name]["model_name"] | ||
) | ||
else: | ||
try: | ||
model_source = MODEL_METADATA[model_name]["repo_id"] | ||
except: | ||
model_source = model_name | ||
dtype = torch.bfloat16 | ||
GEMMA_TOKENIZER = AutoTokenizer.from_pretrained( | ||
model_source, use_auth_token=os.environ["HF_AUTH_TOKEN"] | ||
) | ||
GEMMA_MODEL = AutoModelForCausalLM.from_pretrained( | ||
model_source, | ||
device_map="auto", | ||
torch_dtype=dtype, | ||
trust_remote_code=True, | ||
resume_download=True, | ||
use_auth_token=os.environ["HF_AUTH_TOKEN"], | ||
) | ||
|
||
def deinitialize(self): | ||
global GEMMA_TOKENIZER, GEMMA_MODEL | ||
GEMMA_TOKENIZER = None | ||
GEMMA_MODEL = None | ||
self.clear_cuda_cache() | ||
logging.info("GEMMA deinitialized") | ||
|
||
def _get_gemma_style_input(self): | ||
global GEMMA_TOKENIZER | ||
# messages = [{"role": "system", "content": self.system_prompt}, *self.history] # system prompt is not supported | ||
messages = self.history | ||
prompt = GEMMA_TOKENIZER.apply_chat_template( | ||
messages, tokenize=False, add_generation_prompt=True | ||
) | ||
inputs = GEMMA_TOKENIZER.encode( | ||
prompt, add_special_tokens=True, return_tensors="pt" | ||
) | ||
return inputs | ||
|
||
def get_answer_at_once(self): | ||
global GEMMA_TOKENIZER, GEMMA_MODEL | ||
inputs = self._get_gemma_style_input() | ||
outputs = GEMMA_MODEL.generate( | ||
input_ids=inputs.to(GEMMA_MODEL.device), | ||
max_new_tokens=self.max_generation_token, | ||
) | ||
generated_token_count = outputs.shape[1] - inputs.shape[1] | ||
outputs = GEMMA_TOKENIZER.decode(outputs[0], skip_special_tokens=True) | ||
outputs = outputs.split("<start_of_turn>model\n")[-1][:-5] | ||
self.clear_cuda_cache() | ||
return outputs, generated_token_count | ||
|
||
def get_answer_stream_iter(self): | ||
global GEMMA_TOKENIZER, GEMMA_MODEL | ||
inputs = self._get_gemma_style_input() | ||
streamer = TextIteratorStreamer( | ||
GEMMA_TOKENIZER, timeout=10.0, skip_prompt=True, skip_special_tokens=True | ||
) | ||
input_kwargs = dict( | ||
input_ids=inputs.to(GEMMA_MODEL.device), | ||
max_new_tokens=self.max_generation_token, | ||
streamer=streamer, | ||
) | ||
t = Thread(target=GEMMA_MODEL.generate, kwargs=input_kwargs) | ||
t.start() | ||
|
||
partial_text = "" | ||
for new_text in streamer: | ||
partial_text += new_text | ||
yield partial_text | ||
self.clear_cuda_cache() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters