Skip to content

Commit

Permalink
Merge pull request #69 from SkywardAI/fix/load_model
Browse files Browse the repository at this point in the history
fix download model issue
  • Loading branch information
Micost authored Apr 17, 2024
2 parents 37c1f01 + d8d2ba2 commit 1413acb
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 21 deletions.
4 changes: 0 additions & 4 deletions examples/download_examples.py

This file was deleted.

26 changes: 11 additions & 15 deletions src/kimchima/pkg/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,20 @@ def __init__(self):
@classmethod
def auto_model(cls, *args, **kwargs)-> AutoModel:
r"""
It is used to get the model from the Hugging Face Transformers AutoModel.
Args:
pretrained_model_name_or_path: pretrained model name or path
Here we will use AutoModel from Huggingface to load the model form local.
It support a wider range of models beyond causal language models,
like BERT, RoBERTa, BART, T5 and more.
It returns the base model without a specific head, it does not directly
perform tasks like text generation or translation.
"""
pretrained_model_name_or_path=kwargs.pop("pretrained_model_name_or_path", None)
if pretrained_model_name_or_path is None:
raise ValueError("pretrained_model_name_or_path cannot be None")

quantization_config=kwargs.pop("quantization_config", None)
model = AutoModel.from_pretrained(
pretrained_model_name_or_path,
quantization_config,
**kwargs
)
logger.debug(f"Loaded model: {pretrained_model_name_or_path}")
Expand All @@ -58,21 +58,17 @@ def auto_model(cls, *args, **kwargs)-> AutoModel:
@classmethod
def auto_model_for_causal_lm(cls, *args, **kwargs)-> AutoModelForCausalLM:
r"""
It is used to get the model from the Hugging Face Transformers AutoModelForCausalLM.
Args:
pretrained_model_name_or_path: pretrained model name or path
Here we will use AutoModelForCausalLM to load the model from local,
Like GPT-2 XLNet etc.
It return a language modeling head which can be used to generate text,
translate text, write content, answer questions in a informative way.
"""
pretrained_model_name_or_path=kwargs.pop("pretrained_model_name_or_path", None)
if pretrained_model_name_or_path is None:
raise ValueError("pretrained_model_name_or_path cannot be None")

quantization_config=kwargs.pop("quantization_config", None)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
quantization_config=quantization_config,
device_map='auto',
pretrained_model_name_or_path,
**kwargs
)
logger.debug(f"Loaded model: {pretrained_model_name_or_path}")
Expand Down
94 changes: 94 additions & 0 deletions src/kimchima/tests/test_downloadr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# coding=utf-8
# Copyright [2024] [SkywardAI]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from kimchima.utils import Downloader
from kimchima.pipelines import PipelinesFactory
from kimchima.pkg import ModelFactory
from kimchima.pkg import TokenizerFactory


class TestDownloader(unittest.TestCase):

# prepare test data
def setUp(self):
self.model_name="gpt2"
self.folder_name="gpt2"
self.model_name_auto="sentence-transformers/all-MiniLM-L6-v2"
self.folder_name_auto="encoder"

@unittest.skip("skip test_model_downloader")
def test_model_downloader(self):
"""
Test model_downloader method
"""
Downloader.model_downloader(model_name=self.model_name, folder_name=self.folder_name)

# load it from the folder
pipe=PipelinesFactory.customized_pipe(model=self.folder_name, device_map='auto')

# pipe is not None
self.assertIsNotNone(pipe)
self.assertEqual(pipe.model.name_or_path, self.folder_name)


# @unittest.skip("skip test_auto_downloader")
def test_auto_downloader(self):
"""
Test auto_downloader method
"""
Downloader.auto_downloader(model_name=self.model_name_auto, folder_name=self.folder_name_auto)
Downloader.auto_token_downloader(model_name=self.model_name_auto, folder_name=self.folder_name_auto)

# load it from the folder
model=ModelFactory.auto_model(pretrained_model_name_or_path=self.folder_name_auto)


# load it from the local dolder
tokenizer=TokenizerFactory.auto_tokenizer(pretrained_model_name_or_path=self.folder_name_auto)

self.assertIsNotNone(model)
self.assertEqual(model.name_or_path, self.folder_name_auto)

self.assertIsNotNone(tokenizer)

promt="test"
input=tokenizer(promt, return_tensors="pt")
output=model(**input)

self.assertIsNotNone(output[0])

@unittest.skip("skip test_casual_downloader")
def test_casual_downloader(self):
"""
Test casual_downloader method
"""
Downloader.casual_downloader(model_name=self.model_name_auto, folder_name=self.folder_name_auto)
Downloader.auto_token_downloader(model_name=self.model_name_auto, folder_name=self.folder_name_auto)

model=ModelFactory.auto_model_for_causal_lm(pretrained_model_name_or_path=self.folder_name_auto)
tokenizer=TokenizerFactory.auto_tokenizer(pretrained_model_name_or_path=self.folder_name_auto)

self.assertIsNotNone(model)
self.assertEqual(model.name_or_path, self.folder_name_auto)

self.assertIsNotNone(tokenizer)

prompt="test"
input=tokenizer(prompt, return_tensors="pt")
output=model(**input)
self.assertIsNotNone(output[0])


68 changes: 66 additions & 2 deletions src/kimchima/utils/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from __future__ import annotations

from kimchima.pkg import logging
from transformers import pipeline
from transformers import (
pipeline,
AutoModel,
AutoTokenizer,
AutoModelForCausalLM,
)

logger=logging.get_logger(__name__)

Expand All @@ -31,8 +36,10 @@ def __init__(self):
)

@classmethod
def model_downloader(cls, *args, **kwargs)->str:
def model_downloader(cls, *args, **kwargs):
r"""
Here we will use pipeline from Huggingface to download the model.
And save the model to the specified folder.
"""
model_name=kwargs.pop("model_name", None)
if model_name is None:
Expand All @@ -41,3 +48,60 @@ def model_downloader(cls, *args, **kwargs)->str:
folder_name=kwargs.pop("folder_name", None)
pipe=pipeline(model=model_name)
pipe.save_pretrained(folder_name if folder_name is not None else model_name)
logger.info(f"Model {model_name} has been downloaded successfully")


@classmethod
def auto_downloader(cls, *args, **kwargs):
r"""
Here we will use AutoModel from Huggingface to download the model.
It support a wider range of models beyond causal language models,
like BERT, RoBERTa, BART, T5 and more.
It returns the base model without a specific head, it does not directly
perform tasks like text generation or translation.
"""

model_name=kwargs.pop("model_name", None)
if model_name is None:
raise ValueError("model_name is required")
folder_name=kwargs.pop("folder_name", None)

model=AutoModel.from_pretrained(model_name)
model.save_pretrained(folder_name if folder_name is not None else model_name)
logger.info(f"Model {model_name} has been downloaded successfully")


@classmethod
def casual_downloader(cls, *args, **kwargs):
r"""
Here we will use AutoModelForCausalLM from Huggingface to download the model
Like GPT-2 XLNet etc.
It return a language modeling head which can be used to generate text,
translate text, write content, answer questions in a informative way.
"""
model_name=kwargs.pop("model_name", None)
if model_name is None:
raise ValueError("model_name is required")

folder_name=kwargs.pop("folder_name", None)
# https://github.com/huggingface/transformers/issues/25296
# https://github.com/huggingface/accelerate/issues/661
model=AutoModelForCausalLM.from_pretrained(model_name)
model.save_pretrained(folder_name if folder_name is not None else model_name)
logger.info(f"Model {model_name} has been downloaded successfully")

@classmethod
def auto_token_downloader(cls, *args, **kwargs):
r"""
Here we will use AutoTokenizer from Huggingface to download the tokenizer congifuration.
"""
model_name=kwargs.pop("model_name", None)
if model_name is None:
raise ValueError("model_name is required")

folder_name=kwargs.pop("folder_name", None)

tokenizer=AutoTokenizer.from_pretrained(model_name)
tokenizer.save_pretrained(folder_name if folder_name is not None else model_name)
logger.info(f"Tokenizer {model_name} has been downloaded successfully")

0 comments on commit 1413acb

Please sign in to comment.