Skip to content

Commit

Permalink
[Model] Add emu3-chat and emu3-gen (#738)
Browse files Browse the repository at this point in the history
* add emu3-chat and emu3-gen

* update emu

---------

Co-authored-by: kennymckormick <[email protected]>
  • Loading branch information
OliverLeeXZ and kennymckormick authored Jan 23, 2025
1 parent ddd3252 commit 982c34b
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 3 deletions.
8 changes: 6 additions & 2 deletions vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
'VisualGLM_6b': partial(VisualGLM, model_path='THUDM/visualglm-6b'),
'mPLUG-Owl2': partial(mPLUG_Owl2, model_path='MAGAer13/mplug-owl2-llama2-7b'),
'mPLUG-Owl3': partial(mPLUG_Owl3, model_path='mPLUG/mPLUG-Owl3-7B-240728'),
'emu2_chat': partial(Emu, model_path='BAAI/Emu2-Chat'),
'OmniLMM_12B': partial(OmniLMM12B, model_path='openbmb/OmniLMM-12B', root=OmniLMM_ROOT),
'MGM_7B': partial(Mini_Gemini, model_path='YanweiLi/MGM-7B-HD', root=Mini_Gemini_ROOT),
'Bunny-llama3-8B': partial(BunnyLLama3, model_path='BAAI/Bunny-v1_1-Llama-3-8B-V'),
Expand Down Expand Up @@ -114,6 +113,11 @@
'Taichu-VL-2B': partial(TaichuVLAPI, model='Taichu-VL-2B', url='https://platform.wair.ac.cn/api/v1/infer/10381/v1/chat/completions'),
}

emu_series = {
'emu2_chat': partial(Emu, model_path='BAAI/Emu2-Chat'),
'Emu3_chat': partial(Emu3_chat, model_path='BAAI/Emu3-Chat'),
'Emu3_gen': partial(Emu3_gen, model_path='BAAI/Emu3-Gen')
}
mmalaya_series = {
'MMAlaya': partial(MMAlaya, model_path='DataCanvas/MMAlaya'),
'MMAlaya2': partial(MMAlaya2, model_path='DataCanvas/MMAlaya2'),
Expand Down Expand Up @@ -462,7 +466,7 @@
mantis_series, mmalaya_series, phi3_series, xgen_mm_series, qwen2vl_series,
slime_series, eagle_series, moondream_series, llama_series, molmo_series,
kosmos_series, points_series, nvlm_series, vintern_series, h2ovl_series, aria_series,
smolvlm_series, sail_series, valley_series, vita_series, ross_series
smolvlm_series, sail_series, valley_series, vita_series, ross_series, emu_series
]

for grp in model_groups:
Expand Down
2 changes: 1 addition & 1 deletion vlmeval/vlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .aria import Aria
from .base import BaseModel
from .cogvlm import CogVlm, GLM4v
from .emu import Emu
from .emu import Emu, Emu3_chat, Emu3_gen
from .eagle_x import Eagle
from .idefics import IDEFICS, IDEFICS2
from .instructblip import InstructBLIP
Expand Down
185 changes: 185 additions & 0 deletions vlmeval/vlm/emu.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
import os
import sys
import torch
from PIL import Image
import os.path as osp
from .base import BaseModel
from ..smp import *
from huggingface_hub import snapshot_download


def get_local_root(repo_id):
if osp.exists(repo_id) and osp.isdir(repo_id):
return repo_id

cache_path = get_cache_path(repo_id, repo_type='models')
if cache_path is None:
cache_path = snapshot_download(repo_id=repo_id)
assert osp.exists(cache_path) and osp.isdir(cache_path)
return cache_path


class Emu(BaseModel):
Expand Down Expand Up @@ -87,3 +100,175 @@ def generate_inner(self, message, dataset=None):

output_text = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return output_text[0]


class Emu3_chat(BaseModel):
INSTALL_REQ = True
INTERLEAVE = False

def __init__(self, model_path='BAAI/Emu3-Chat', tokenizer_path='BAAI/Emu3-VisionTokenizer', **kwargs):
assert model_path is not None
assert tokenizer_path is not None
try:
from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
local_root = get_local_root(model_path)
sys.path.append(local_root)
from processing_emu3 import Emu3Processor
except Exception as err:
raise err

# load model wights
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map='cuda',
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True)
model.eval()
self.model = model
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="left")
self.image_processor = AutoImageProcessor.from_pretrained(tokenizer_path, trust_remote_code=True)
self.image_tokenizer = AutoModel.from_pretrained(
tokenizer_path, device_map='cuda', trust_remote_code=True).eval()
self.processor = Emu3Processor(self.image_processor, self.image_tokenizer, self.tokenizer)
self.kwargs = kwargs
self.cuda = cuda

def generate_inner(self, message, dataset=None):
query, images = '', []
for item in message:
if item['type'] == 'image':
images.append(Image.open(item['value']).convert('RGB'))
elif item['type'] == 'text':
query += item['value']

inputs = self.processor(
text=[query],
image=images,
mode='U',
return_tensors="pt",
padding="longest",
)
from transformers.generation.configuration_utils import GenerationConfig
# prepare hyper parameters
GENERATION_CONFIG = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
max_new_tokens=1024,
)
# generate
outputs = self.model.generate(
inputs.input_ids.to(self.cuda),
GENERATION_CONFIG,
attention_mask=inputs.attention_mask.to(self.cuda),
)

outputs = outputs[:, inputs.input_ids.shape[-1]:]
response = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
return response


class Emu3_gen(BaseModel):
INSTALL_REQ = True
INTERLEAVE = False

def __init__(self,
model_path='BAAI/Emu3-Gen',
tokenizer_path='BAAI/Emu3-VisionTokenizer',
output_path='',
**kwargs):

assert model_path is not None
assert tokenizer_path is not None
try:
from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
local_root = get_local_root(model_path)
sys.path.append(local_root)
from processing_emu3 import Emu3Processor
except Exception as err:
raise err

# load model wights
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map='cuda',
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True)
model.eval()
self.model = model

self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="left")
self.image_processor = AutoImageProcessor.from_pretrained(tokenizer_path, trust_remote_code=True)
self.image_tokenizer = AutoModel.from_pretrained(
tokenizer_path,
device_map='cuda',
trust_remote_code=True).eval()
self.processor = Emu3Processor(self.image_processor, self.image_tokenizer, self.tokenizer)
self.kwargs = kwargs
self.cuda = cuda
self.output_path = output_path

def generate_inner(self, message, dataset=None):
query = ''
for item in message:
if item['type'] == 'text':
query += item['value']
else:
raise ValueError('Please input the text in generation stage.')

# prepare input
POSITIVE_PROMPT = " masterpiece, film grained, best quality."
NEGATIVE_PROMPT = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry." # noqa: E501

classifier_free_guidance = 3.0
prompt = "a portrait of young girl."
prompt += POSITIVE_PROMPT

kwargs = dict(
mode='G',
ratio="1:1",
image_area=self.model.config.image_area,
return_tensors="pt",
padding="longest") # noqa: E501

pos_inputs = self.processor(text=prompt, **kwargs)
neg_inputs = self.processor(text=NEGATIVE_PROMPT, **kwargs)
from transformers.generation.configuration_utils import GenerationConfig
# prepare hyper parameters
GENERATION_CONFIG = GenerationConfig(
use_cache=True,
eos_token_id=self.model.config.eos_token_id,
pad_token_id=self.model.config.pad_token_id,
max_new_tokens=40960,
do_sample=True,
top_k=2048,
)

h = pos_inputs.image_size[:, 0]
w = pos_inputs.image_size[:, 1]
constrained_fn = self.processor.build_prefix_constrained_fn(h, w)
from transformers.generation import LogitsProcessorList, PrefixConstrainedLogitsProcessor, UnbatchedClassifierFreeGuidanceLogitsProcessor # noqa: E501
logits_processor = LogitsProcessorList([
UnbatchedClassifierFreeGuidanceLogitsProcessor(
classifier_free_guidance,
self.model,
unconditional_ids=neg_inputs.input_ids.to("cuda:0"),
),
PrefixConstrainedLogitsProcessor(constrained_fn, num_beams=1),
])

# generate
outputs = self.model.generate(
pos_inputs.input_ids.to("cuda:0"),
GENERATION_CONFIG,
logits_processor=logits_processor,
attention_mask=pos_inputs.attention_mask.to("cuda:0"),
)

mm_list = self.processor.decode(outputs[0])
for idx, im in enumerate(mm_list):
if not isinstance(im, Image.Image):
continue
im.save(os.path.join(self.output_path, f"result_{idx}.png"))

0 comments on commit 982c34b

Please sign in to comment.