Skip to content

Commit

Permalink
feat: Added Gemma support (needs hf auth token)
Browse files Browse the repository at this point in the history
  • Loading branch information
GaiZhenbiao committed Feb 25, 2024
1 parent 06281ff commit d3fb4a3
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 62 deletions.
1 change: 1 addition & 0 deletions config_example.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"claude_api_secret":"",// 你的 Claude API Secret,用于 Claude 对话模型
"ernie_api_key": "",// 你的文心一言在百度云中的API Key,用于文心一言对话模型
"ernie_secret_key": "",// 你的文心一言在百度云中的Secret Key,用于文心一言对话模型
"huggingface_auth_token": "", // 你的 Hugging Face API Token,用于访问有限制的模型


//== Azure ==
Expand Down
4 changes: 4 additions & 0 deletions modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ def load_config_to_environ(key_list):
google_genai_api_key = config.get("google_genai_api_key", google_genai_api_key)
os.environ["GOOGLE_GENAI_API_KEY"] = google_genai_api_key

huggingface_auth_token = os.environ.get("HF_AUTH_TOKEN", "")
huggingface_auth_token = config.get("hf_auth_token", huggingface_auth_token)
os.environ["HF_AUTH_TOKEN"] = huggingface_auth_token

xmchat_api_key = config.get("xmchat_api_key", "")
os.environ["XMCHAT_API_KEY"] = xmchat_api_key

Expand Down
101 changes: 101 additions & 0 deletions modules/models/GoogleGemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import logging
from threading import Thread

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

from ..presets import *
from .base_model import BaseLLMModel


class GoogleGemmaClient(BaseLLMModel):
def __init__(self, model_name, api_key, user_name="") -> None:
super().__init__(model_name=model_name, user=user_name)

global GEMMA_TOKENIZER, GEMMA_MODEL
# self.deinitialize()
self.max_generation_token = self.token_upper_limit
if GEMMA_TOKENIZER is None or GEMMA_MODEL is None:
model_path = None
if os.path.exists("models"):
model_dirs = os.listdir("models")
if model_name in model_dirs:
model_path = f"models/{model_name}"
if model_path is not None:
model_source = model_path
else:
if os.path.exists(
os.path.join("models", MODEL_METADATA[model_name]["model_name"])
):
model_source = os.path.join(
"models", MODEL_METADATA[model_name]["model_name"]
)
else:
try:
model_source = MODEL_METADATA[model_name]["repo_id"]
except:
model_source = model_name
dtype = torch.bfloat16
GEMMA_TOKENIZER = AutoTokenizer.from_pretrained(
model_source, use_auth_token=os.environ["HF_AUTH_TOKEN"]
)
GEMMA_MODEL = AutoModelForCausalLM.from_pretrained(
model_source,
device_map="auto",
torch_dtype=dtype,
trust_remote_code=True,
resume_download=True,
use_auth_token=os.environ["HF_AUTH_TOKEN"],
)

def deinitialize(self):
global GEMMA_TOKENIZER, GEMMA_MODEL
GEMMA_TOKENIZER = None
GEMMA_MODEL = None
self.clear_cuda_cache()
logging.info("GEMMA deinitialized")

def _get_gemma_style_input(self):
global GEMMA_TOKENIZER
# messages = [{"role": "system", "content": self.system_prompt}, *self.history] # system prompt is not supported
messages = self.history
prompt = GEMMA_TOKENIZER.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = GEMMA_TOKENIZER.encode(
prompt, add_special_tokens=True, return_tensors="pt"
)
return inputs

def get_answer_at_once(self):
global GEMMA_TOKENIZER, GEMMA_MODEL
inputs = self._get_gemma_style_input()
outputs = GEMMA_MODEL.generate(
input_ids=inputs.to(GEMMA_MODEL.device),
max_new_tokens=self.max_generation_token,
)
generated_token_count = outputs.shape[1] - inputs.shape[1]
outputs = GEMMA_TOKENIZER.decode(outputs[0], skip_special_tokens=True)
outputs = outputs.split("<start_of_turn>model\n")[-1][:-5]
self.clear_cuda_cache()
return outputs, generated_token_count

def get_answer_stream_iter(self):
global GEMMA_TOKENIZER, GEMMA_MODEL
inputs = self._get_gemma_style_input()
streamer = TextIteratorStreamer(
GEMMA_TOKENIZER, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
input_kwargs = dict(
input_ids=inputs.to(GEMMA_MODEL.device),
max_new_tokens=self.max_generation_token,
streamer=streamer,
)
t = Thread(target=GEMMA_MODEL.generate, kwargs=input_kwargs)
t.start()

partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
self.clear_cuda_cache()
34 changes: 2 additions & 32 deletions modules/models/LLaMA.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

import json
import os

from huggingface_hub import hf_hub_download
from llama_cpp import Llama

from ..index_func import *
from ..presets import *
from ..utils import *
from .base_model import BaseLLMModel
from .base_model import BaseLLMModel, download

SYS_PREFIX = "<<SYS>>\n"
SYS_POSTFIX = "\n<</SYS>>\n\n"
Expand All @@ -19,34 +17,6 @@
OUTPUT_POSTFIX = "</s>"


def download(repo_id, filename, retry=10):
if os.path.exists("./models/downloaded_models.json"):
with open("./models/downloaded_models.json", "r") as f:
downloaded_models = json.load(f)
if repo_id in downloaded_models:
return downloaded_models[repo_id]["path"]
else:
downloaded_models = {}
while retry > 0:
try:
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir="models",
resume_download=True,
)
downloaded_models[repo_id] = {"path": model_path}
with open("./models/downloaded_models.json", "w") as f:
json.dump(downloaded_models, f)
break
except:
print("Error downloading model, retrying...")
retry -= 1
if retry == 0:
raise Exception("Error downloading model, please try again later.")
return model_path


class LLaMA_Client(BaseLLMModel):
def __init__(self, model_name, lora_path=None, user_name="") -> None:
super().__init__(model_name=model_name, user=user_name)
Expand Down Expand Up @@ -115,7 +85,7 @@ def get_answer_stream_iter(self):
iter = self.model(
context,
max_tokens=self.max_generation_token,
stop=[SYS_PREFIX, SYS_POSTFIX, INST_PREFIX, OUTPUT_PREFIX,OUTPUT_POSTFIX],
stop=[SYS_PREFIX, SYS_POSTFIX, INST_PREFIX, OUTPUT_PREFIX, OUTPUT_POSTFIX],
echo=False,
stream=True,
)
Expand Down
89 changes: 61 additions & 28 deletions modules/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,41 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List

import logging
import asyncio
import gc
import json
import commentjson as cjson
import logging
import os
import sys
import requests
import urllib3
import traceback
import pathlib
import shutil
import sys
import traceback
from collections import deque
from enum import Enum
from itertools import islice
from threading import Condition, Thread
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from tqdm import tqdm
import aiohttp
import colorama
import commentjson as cjson
import requests
import torch
import urllib3
from duckduckgo_search import DDGS
from itertools import islice
import asyncio
import aiohttp
from enum import Enum

from huggingface_hub import hf_hub_download
from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.base import BaseCallbackManager

from typing import Any, Dict, List, Optional, Union

from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish, LLMResult
from threading import Thread, Condition
from collections import deque
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
from langchain.input import print_text
from langchain.schema import (AgentAction, AgentFinish, AIMessage, BaseMessage,
HumanMessage, LLMResult, SystemMessage)
from tqdm import tqdm

from ..presets import *
from ..index_func import *
from ..utils import *
from .. import shared
from ..config import retrieve_proxy
from ..index_func import *
from ..presets import *
from ..utils import *


class CallbackToIterator:
Expand Down Expand Up @@ -155,6 +153,7 @@ class ModelType(Enum):
ERNIE = 17
DALLE3 = 18
GoogleGemini = 19
GoogleGemma = 20

@classmethod
def get_type(cls, model_name: str):
Expand Down Expand Up @@ -201,11 +200,41 @@ def get_type(cls, model_name: str):
model_type = ModelType.ERNIE
elif "dall" in model_name_lower:
model_type = ModelType.DALLE3
elif "gemma" in model_name_lower:
model_type = ModelType.GoogleGemma
else:
model_type = ModelType.LLaMA
return model_type


def download(repo_id, filename, retry=10):
if os.path.exists("./models/downloaded_models.json"):
with open("./models/downloaded_models.json", "r") as f:
downloaded_models = json.load(f)
if repo_id in downloaded_models:
return downloaded_models[repo_id]["path"]
else:
downloaded_models = {}
while retry > 0:
try:
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir="models",
resume_download=True,
)
downloaded_models[repo_id] = {"path": model_path}
with open("./models/downloaded_models.json", "w") as f:
json.dump(downloaded_models, f)
break
except:
print("Error downloading model, retrying...")
retry -= 1
if retry == 0:
raise Exception("Error downloading model, please try again later.")
return model_path


class BaseLLMModel:
def __init__(
self,
Expand Down Expand Up @@ -371,10 +400,10 @@ def summarize_index(self, files, chatbot, language):
status = i18n("总结完成")
logging.info(i18n("生成内容总结中……"))
os.environ["OPENAI_API_KEY"] = self.api_key
from langchain.callbacks import StdOutCallbackHandler
from langchain.chains.summarize import load_summarize_chain
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.callbacks import StdOutCallbackHandler
from langchain.prompts import PromptTemplate

prompt_template = (
"Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN "
Expand Down Expand Up @@ -1055,6 +1084,10 @@ def deinitialize(self):
"""deinitialize the model, implement if needed"""
pass

def clear_cuda_cache(self):
gc.collect()
torch.cuda.empty_cache()


class Base_Chat_Langchain_Client(BaseLLMModel):
def __init__(self, model_name, user_name=""):
Expand Down
8 changes: 7 additions & 1 deletion modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,14 @@ def get_model(
from .DALLE3 import OpenAI_DALLE3_Client
access_key = os.environ.get("OPENAI_API_KEY", access_key)
model = OpenAI_DALLE3_Client(model_name, api_key=access_key, user_name=user_name)
elif model_type == ModelType.GoogleGemma:
from .GoogleGemma import GoogleGemmaClient
model = GoogleGemmaClient(
model_name, access_key, user_name=user_name)
elif model_type == ModelType.Unknown:
raise ValueError(f"未知模型: {model_name}")
raise ValueError(f"Unknown model: {model_name}")
else:
raise ValueError(f"Unimplemented model type: {model_type}")
logging.info(msg)
except Exception as e:
import traceback
Expand Down
21 changes: 20 additions & 1 deletion modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
CHATGLM_TOKENIZER = None
LLAMA_MODEL = None
LLAMA_INFERENCER = None
GEMMA_MODEL = None
GEMMA_TOKENIZER = None

# ChatGPT 设置
INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
Expand Down Expand Up @@ -67,6 +69,8 @@
"Gemini Pro",
"Gemini Pro Vision",
"GooglePaLM",
"Gemma 2B",
"Gemma 7B",
"xmchat",
"Azure OpenAI",
"yuanai-1.0-base_10B",
Expand Down Expand Up @@ -178,6 +182,16 @@
"Gemini Pro Vision": {
"model_name": "gemini-pro-vision",
"token_limit": 30720,
},
"Gemma 2B": {
"repo_id": "google/gemma-2b-it",
"model_name": "gemma-2b-it",
"token_limit": 8192,
},
"Gemma 7B": {
"repo_id": "google/gemma-7b-it",
"model_name": "gemma-7b-it",
"token_limit": 8192,
}
}

Expand All @@ -193,7 +207,12 @@
os.makedirs("history", exist_ok=True)
for dir_name in os.listdir("models"):
if os.path.isdir(os.path.join("models", dir_name)):
if dir_name not in MODELS:
display_name = None
for model_name, metadata in MODEL_METADATA.items():
if "model_name" in metadata and metadata["model_name"] == dir_name:
display_name = model_name
break
if display_name is None:
MODELS.append(dir_name)

TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。
Expand Down

0 comments on commit d3fb4a3

Please sign in to comment.