Skip to content

Commit

Permalink
try to get backends running on mac - part I
Browse files Browse the repository at this point in the history
Signed-off-by: julianbollig <[email protected]>
  • Loading branch information
julianbollig committed Jan 31, 2025
1 parent d7f4d98 commit d8969fd
Show file tree
Hide file tree
Showing 21 changed files with 73 additions and 104 deletions.
2 changes: 1 addition & 1 deletion LlamaCPP/llama_web_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
os.environ['PATH'] = os.path.abspath('../llama-cpp-env/Library/bin') + os.pathsep + os.environ['PATH']
os.environ['PATH'] = os.path.abspath('../llama-cpp-env-2/Library/bin') + os.pathsep + os.environ['PATH']

from apiflask import APIFlask
from flask import jsonify, request, Response, stream_with_context
Expand Down
2 changes: 1 addition & 1 deletion LlamaCPP/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
"ggufLLM": "../service/models/llm/ggufLLM",
}

device = "xpu"
device = "mps"
2 changes: 1 addition & 1 deletion WebUI/electron/subprocesses/aiBackendService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export class AiBackendService extends LongLivedPythonApiService {
// lsLevelZero will ensure uv and pip are installed
await this.lsLevelZero.ensureInstalled()

const deviceArch: string = 'mac'
const deviceArch: string = 'mps'
yield {
serviceName: self.name,
step: `Detecting intel device`,
Expand Down
2 changes: 1 addition & 1 deletion WebUI/electron/subprocesses/comfyUIBackendService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ export class ComfyUiBackendService extends LongLivedPythonApiService {
status: 'executing',
debugMessage: `Trying to identify intel hardware`,
}
const deviceArch: string = 'mac'
const deviceArch: string = 'mps'
yield {
serviceName: self.name,
step: `Detecting intel device`,
Expand Down
2 changes: 1 addition & 1 deletion WebUI/electron/subprocesses/llamaCppBackendService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ export class LlamaCppBackendService extends LongLivedPythonApiService {
await this.lsLevelZero.ensureInstalled()
await this.uvPip.ensureInstalled()

const deviceArch: string = 'mac'
const deviceArch: string = 'mps'
yield {
serviceName: self.name,
step: `Detecting intel device`,
Expand Down
32 changes: 11 additions & 21 deletions WebUI/src/assets/js/store/globalSetup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,27 +101,20 @@ export const useGlobalSetup = defineStore('globalSetup', () => {
models.value.scheduler.push(...(await initWebSettings(postJson)))
models.value.scheduler.unshift('None')
break
} catch (error) {
} catch (_error: unknown) {
await util.delay(delay)
}
}
await reloadGraphics()
// if (graphicsList.value.length == 0) {
// await window.electronAPI.showMessageBoxSync({ message: useI18N().state.ERROR_UNFOUND_GRAPHICS, title: "error", icon: "error" });
// window.electronAPI.exitApp();
// }
await loadUserSettings()

// isComfyUiInstalled.value = await isComfyUIDownloaded()
// if (isComfyUiInstalled.value) {
// window.electronAPI.wakeupComfyUIService()
// setTimeout(() => {
// //requires proper feedback on server startup...
// useComfyUi().updateComfyState()
// loadingState.value = "running";
// }, 10000);
// } else {
// loadingState.value = "running";
if (graphicsList.value.length == 0) {
await window.electronAPI.showMessageBoxSync({
message: useI18N().state.ERROR_UNFOUND_GRAPHICS,
title: 'error',
icon: 'error',
})
window.electronAPI.exitApp()
}
loadUserSettings()
}

async function initWebSettings(postJson: string) {
Expand Down Expand Up @@ -266,10 +259,7 @@ export const useGlobalSetup = defineStore('globalSetup', () => {
modelSettings.lora = models.value.lora[0]
changeUserSetup = true
}
if (
!graphicsList.value.find((item) => item.index == modelSettings.graphics) &&
graphicsList.value.length != 0
) {
if (!graphicsList.value.find((item) => item.index == modelSettings.graphics)) {
modelSettings.graphics = graphicsList.value[0].index
}
if (changeUserSetup) {
Expand Down
16 changes: 3 additions & 13 deletions service/aipg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,21 +234,11 @@ def get_ESRGAN_size():
return int(response.headers.get("Content-Length"))


def get_support_graphics(env_type: str):

device_count = torch.xpu.device_count()
def get_support_graphics():
device_count = torch.mps.device_count()
graphics = list()
for i in range(device_count):
device_name = torch.xpu.get_device_name(i)
print('device_name', device_name)
if device_name == "Intel(R) Arc(TM) Graphics" or re.search("Intel\(R\) Arc\(TM\)", device_name) is not None:
graphics.append({"index": i, "name": device_name})
device_count = torch.cuda.device_count()
print('cuda device_count:', device_count)
service_config.env_type = env_type
for i in range(device_count):
device_name = torch.cuda.get_device_name(i)
print('device_name', device_name)
device_name = "mps"
graphics.append({"index": i, "name": device_name})
return graphics

Expand Down
4 changes: 2 additions & 2 deletions service/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
hidden_states[start_idx:end_idx] = original_torch_bmm(
input[start_idx:end_idx], mat2[start_idx:end_idx], out=out
)
torch.xpu.synchronize(input.device)
torch.mps.synchronize(input.device)
else:
return original_torch_bmm(input, mat2, out=out)
return hidden_states
Expand Down Expand Up @@ -314,7 +314,7 @@ def scaled_dot_product_attention_32_bit(
**kwargs,
)
)
torch.xpu.synchronize(query.device)
torch.mps.synchronize(query.device)
else:
return original_scaled_dot_product_attention(
query,
Expand Down
2 changes: 1 addition & 1 deletion service/lama.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def scale_image(img, factor, interpolation=cv2.INTER_AREA):

class SimpleLama:
def __init__(self):
self.device = "xpu"
self.device = "mps"
model_path = "C:\\Users\\X\\Downloads\\big-lama.pt"
self.model = torch.jit.load(model_path)
self.model.eval()
Expand Down
23 changes: 7 additions & 16 deletions service/llm_biz.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
AutoModelForCausalLM
)

#from ipex_llm.transformers import AutoModelForCausalLM
Expand All @@ -26,16 +27,6 @@
)
import service_config


# import ipex_llm.transformers.models.mistral

# W/A for https://github.com/intel/AI-Playground/issues/94
# Disable decoding_fast_path to avoid calling forward_qkv() which is not supported by bigdl-core-xe-*-23
# ipex_llm.transformers.models.mistral.use_decoding_fast_path = (
# lambda *args, **kwargs: False
# )


class LLMParams:
prompt: List[Dict[str, str]]
device: int
Expand Down Expand Up @@ -182,8 +173,8 @@ def chat(
# if prev genera not finish, stop it
stop_generate()

torch.cuda.set_device(params.device)
service_config.device = f"cuda:{params.device}"
#torch.mps.set_device(params.device)
service_config.device = f"mps:{params.device}"
prompt = params.prompt
enable_rag = params.enable_rag
model_repo_id = params.model_repo_id
Expand All @@ -198,7 +189,7 @@ def chat(
if _model is not None:
del _model
gc.collect()
torch.cuda.empty_cache()
torch.mps.empty_cache()

model_base_path = service_config.service_model_paths.get("llm")
model_name = model_repo_id.replace("/", "---")
Expand All @@ -215,7 +206,7 @@ def chat(
model_path,
torch_dtype=torch.float16,
trust_remote_code=True,
load_in_low_bit=load_in_low_bit,
# load_in_low_bit=load_in_low_bit,
# load_in_4bit=True,
)

Expand Down Expand Up @@ -264,7 +255,7 @@ def chat(
text_out_callback(stream_output, 1)

last_token_time = time.time()
torch.xpu.empty_cache()
torch.mps.empty_cache()
if params.print_metrics:
logging.info(f"""
----------inference finish----------
Expand Down Expand Up @@ -302,7 +293,7 @@ def dispose():
del _model
_model = None
gc.collect()
torch.cuda.empty_cache()
torch.mps.empty_cache()


class StopGenerateException(Exception):
Expand Down
8 changes: 4 additions & 4 deletions service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def stream_chat_generate(model: PreTrainedModel, args: dict):
},
]
pipe.model.eval()
pipe.model.to("xpu")
model = ipex.optimize(pipe.model, dtype=torch.bfloat16)
pipe.model.to("mps")
#model = ipex.optimize(pipe.model, dtype=torch.bfloat16)
prompt = pipe.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, return_tensors="pt"
)
encoding = pipe.tokenizer.encode_plus(prompt, return_tensors="pt").to("xpu")
encoding = pipe.tokenizer.encode_plus(prompt, return_tensors="pt").to("mps")
tensor: torch.Tensor = encoding.get("input_ids")
streamer = TextIteratorStreamer(
pipe.tokenizer,
Expand All @@ -60,7 +60,7 @@ def stream_chat_generate(model: PreTrainedModel, args: dict):
top_k=50,
top_p=0.95,
)
torch.xpu.synchronize()
torch.mps.synchronize()
Thread(target=stream_chat_generate, args=(pipe.model, generate_kwargs)).start()

for stream_output in streamer:
Expand Down
19 changes: 8 additions & 11 deletions service/paint_biz.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@
import schedulers_util
from compel import Compel
from threading import Event
# from cuda_hijacks import ipex_hijacks

# ipex_hijacks()
print("workarounds applied")
# print("workarounds applied")


# region class define
Expand Down Expand Up @@ -242,7 +241,7 @@ def get_ext_pipe(params: TextImageParams, pipe_classes: List, init_class: any):
return _ext_model_pipe
del _ext_model_pipe
gc.collect()
torch.cuda.empty_cache()
torch.mps.empty_cache()

basic_model_pipe = get_basic_model(params.model_name)
_ext_model_pipe = init_class.from_pipe(basic_model_pipe)
Expand Down Expand Up @@ -484,7 +483,7 @@ def convet_compel_prompt(
# }
# )
else:
compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, device="mps")
compel_prompt = convert_prompt_to_compel_format(prompt)
prompt_embeds = compel_proc(compel_prompt)
custom_inputs.update(
Expand Down Expand Up @@ -831,8 +830,6 @@ def generate(params: TextImageParams):

try:
stop_generate()
torch.cuda.set_device(params.device)
# service_config.device = f"cuda:{params.device}"
if _last_model_name != params.model_name:
# hange model dispose basic model
if _basic_model_pipe is not None:
Expand All @@ -857,7 +854,7 @@ def generate(params: TextImageParams):
text_to_image(params)
_last_mode = params.mode

torch.cuda.empty_cache()
torch.mps.empty_cache()
finally:
_generating = False

Expand Down Expand Up @@ -895,15 +892,15 @@ def dispose_basic_model():
_last_mode = None

gc.collect()
torch.cuda.empty_cache()
torch.mps.empty_cache()


def dispose_ext_model():
global _ext_model_pipe
del _ext_model_pipe
_ext_model_pipe = None
gc.collect()
torch.cuda.empty_cache()
torch.mps.empty_cache()


def dispose():
Expand Down Expand Up @@ -931,5 +928,5 @@ def assert_stop_generate():
raise StopGenerateException()


def clear_xpu_cache():
torch.xpu.empty_cache()
def clear_mps_cache():
torch.mps.empty_cache()
20 changes: 10 additions & 10 deletions service/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, List, Dict

# from sentence_transformers import SentenceTransformer
import intel_extension_for_pytorch as ipex # noqa: F401
# import intel_extension_for_pytorch as ipex # noqa: F401
import torch
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders.markdown import UnstructuredMarkdownLoader
Expand Down Expand Up @@ -53,14 +53,14 @@ def to(self, device: str):
self.model.to(device)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
torch.xpu.synchronize()
torch.mps.synchronize()
t0 = time.time()
embeddings = [
self.model.encode(text, normalize_embeddings=True) for text in texts
]
# Convert embeddings from NumPy arrays to lists for serialization
embeddings_as_lists = [embedding.tolist() for embedding in embeddings]
torch.xpu.synchronize()
torch.mps.synchronize()
t1 = time.time()
print("-----------SentenceTransformer--embedding cost time(s): ", t1 - t0)
return embeddings_as_lists
Expand Down Expand Up @@ -231,10 +231,10 @@ def delete_index(self, md5: str):
def add_index_file(file: str):
global embedding_database
if re.search(".(txt|docx?|pptx?|md|pdf)$", file, re.IGNORECASE) is not None:
torch.xpu.synchronize()
torch.mps.synchronize()
start = time.time()
result = embedding_database.add_index_file(file)
torch.xpu.synchronize()
torch.mps.synchronize()
end = time.time()
print(f"add index file cost {end-start}s")
else:
Expand All @@ -249,12 +249,12 @@ def to(device: str):

def query(query: str):
global embedding_database
torch.xpu.synchronize()
torch.mps.synchronize()
start = time.time()
success, context, source_file = embedding_database.query_database(query)
end = time.time()
print(f'query by keyword "{query}" cost {end-start}s')
torch.xpu.synchronize()
torch.mps.synchronize()
return success, context, source_file


Expand All @@ -275,8 +275,8 @@ def get_index_list():

def init(repo_id: str, device: int):
global embedding_database, embedding_wrapper, Is_Inited
torch.xpu.set_device(device)
service_config.device = f"xpu:{device}"
# torch.mps.set_device(device)
service_config.device = f"mps:{device}"
embedding_wrapper = EmbeddingWrapper(repo_id)
embedding_database = EmbeddingDatabase(embedding_wrapper)
Is_Inited = True
Expand All @@ -293,4 +293,4 @@ def dispose():
embedding_database = None
Is_Inited = False
gc.collect()
torch.xpu.empty_cache()
torch.mps.empty_cache()
Loading

0 comments on commit d8969fd

Please sign in to comment.