Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add emu3-chat and emu3-gen #738

Merged
merged 2 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
'VisualGLM_6b': partial(VisualGLM, model_path='THUDM/visualglm-6b'),
'mPLUG-Owl2': partial(mPLUG_Owl2, model_path='MAGAer13/mplug-owl2-llama2-7b'),
'mPLUG-Owl3': partial(mPLUG_Owl3, model_path='mPLUG/mPLUG-Owl3-7B-240728'),
'emu2_chat': partial(Emu, model_path='BAAI/Emu2-Chat'),
'OmniLMM_12B': partial(OmniLMM12B, model_path='openbmb/OmniLMM-12B', root=OmniLMM_ROOT),
'MGM_7B': partial(Mini_Gemini, model_path='YanweiLi/MGM-7B-HD', root=Mini_Gemini_ROOT),
'Bunny-llama3-8B': partial(BunnyLLama3, model_path='BAAI/Bunny-v1_1-Llama-3-8B-V'),
Expand Down Expand Up @@ -113,6 +112,11 @@
'Taichu-VL-2B': partial(TaichuVLAPI, model='Taichu-VL-2B', url='https://platform.wair.ac.cn/api/v1/infer/10381/v1/chat/completions'),
}

emu_series = {
'emu2_chat': partial(Emu, model_path='BAAI/Emu2-Chat'),
'Emu3_chat': partial(Emu3_chat, model_path='BAAI/Emu3-Chat'),
'Emu3_gen': partial(Emu3_gen, model_path='BAAI/Emu3-Gen')
}
mmalaya_series = {
'MMAlaya': partial(MMAlaya, model_path='DataCanvas/MMAlaya'),
'MMAlaya2': partial(MMAlaya2, model_path='DataCanvas/MMAlaya2'),
Expand Down Expand Up @@ -428,7 +432,7 @@
mantis_series, mmalaya_series, phi3_series, xgen_mm_series, qwen2vl_series,
slime_series, eagle_series, moondream_series, llama_series, molmo_series,
kosmos_series, points_series, nvlm_series, vintern_series, h2ovl_series, aria_series,
smolvlm_series, sail_series, valley_series, vita_series, ross_series
smolvlm_series, sail_series, valley_series, vita_series, ross_series, emu_series
]

for grp in model_groups:
Expand Down
2 changes: 1 addition & 1 deletion vlmeval/vlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .aria import Aria
from .base import BaseModel
from .cogvlm import CogVlm, GLM4v
from .emu import Emu
from .emu import Emu, Emu3_chat, Emu3_gen
from .eagle_x import Eagle
from .idefics import IDEFICS, IDEFICS2
from .instructblip import InstructBLIP
Expand Down
185 changes: 185 additions & 0 deletions vlmeval/vlm/emu.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
import os
import sys
import torch
from PIL import Image
import os.path as osp
from .base import BaseModel
from ..smp import *
from huggingface_hub import snapshot_download


def get_local_root(repo_id):
if osp.exists(repo_id) and osp.isdir(repo_id):
return repo_id

cache_path = get_cache_path(repo_id, repo_type='models')
if cache_path is None:
cache_path = snapshot_download(repo_id=repo_id)
assert osp.exists(cache_path) and osp.isdir(cache_path)
return cache_path


class Emu(BaseModel):
Expand Down Expand Up @@ -87,3 +100,175 @@ def generate_inner(self, message, dataset=None):

output_text = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return output_text[0]


class Emu3_chat(BaseModel):
INSTALL_REQ = True
INTERLEAVE = False

def __init__(self, model_path='BAAI/Emu3-Chat', tokenizer_path='BAAI/Emu3-VisionTokenizer', **kwargs):
assert model_path is not None
assert tokenizer_path is not None
try:
from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
local_root = get_local_root(model_path)
sys.path.append(local_root)
from processing_emu3 import Emu3Processor
except Exception as err:
raise err

# load model wights
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map='cuda',
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True)
model.eval()
self.model = model
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="left")
self.image_processor = AutoImageProcessor.from_pretrained(tokenizer_path, trust_remote_code=True)
self.image_tokenizer = AutoModel.from_pretrained(
tokenizer_path, device_map='cuda', trust_remote_code=True).eval()
self.processor = Emu3Processor(self.image_processor, self.image_tokenizer, self.tokenizer)
self.kwargs = kwargs
self.cuda = cuda

def generate_inner(self, message, dataset=None):
query, images = '', []
for item in message:
if item['type'] == 'image':
images.append(Image.open(item['value']).convert('RGB'))
elif item['type'] == 'text':
query += item['value']

inputs = self.processor(
text=[query],
image=images,
mode='U',
return_tensors="pt",
padding="longest",
)
from transformers.generation.configuration_utils import GenerationConfig
# prepare hyper parameters
GENERATION_CONFIG = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
max_new_tokens=1024,
)
# generate
outputs = self.model.generate(
inputs.input_ids.to(self.cuda),
GENERATION_CONFIG,
attention_mask=inputs.attention_mask.to(self.cuda),
)

outputs = outputs[:, inputs.input_ids.shape[-1]:]
response = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
return response


class Emu3_gen(BaseModel):
INSTALL_REQ = True
INTERLEAVE = False

def __init__(self,
model_path='BAAI/Emu3-Gen',
tokenizer_path='BAAI/Emu3-VisionTokenizer',
output_path='',
**kwargs):

assert model_path is not None
assert tokenizer_path is not None
try:
from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
local_root = get_local_root(model_path)
sys.path.append(local_root)
from processing_emu3 import Emu3Processor
except Exception as err:
raise err

# load model wights
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map='cuda',
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True)
model.eval()
self.model = model

self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="left")
self.image_processor = AutoImageProcessor.from_pretrained(tokenizer_path, trust_remote_code=True)
self.image_tokenizer = AutoModel.from_pretrained(
tokenizer_path,
device_map='cuda',
trust_remote_code=True).eval()
self.processor = Emu3Processor(self.image_processor, self.image_tokenizer, self.tokenizer)
self.kwargs = kwargs
self.cuda = cuda
self.output_path = output_path

def generate_inner(self, message, dataset=None):
query = ''
for item in message:
if item['type'] == 'text':
query += item['value']
else:
raise ValueError('Please input the text in generation stage.')

# prepare input
POSITIVE_PROMPT = " masterpiece, film grained, best quality."
NEGATIVE_PROMPT = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry." # noqa: E501

classifier_free_guidance = 3.0
prompt = "a portrait of young girl."
prompt += POSITIVE_PROMPT

kwargs = dict(
mode='G',
ratio="1:1",
image_area=self.model.config.image_area,
return_tensors="pt",
padding="longest") # noqa: E501

pos_inputs = self.processor(text=prompt, **kwargs)
neg_inputs = self.processor(text=NEGATIVE_PROMPT, **kwargs)
from transformers.generation.configuration_utils import GenerationConfig
# prepare hyper parameters
GENERATION_CONFIG = GenerationConfig(
use_cache=True,
eos_token_id=self.model.config.eos_token_id,
pad_token_id=self.model.config.pad_token_id,
max_new_tokens=40960,
do_sample=True,
top_k=2048,
)

h = pos_inputs.image_size[:, 0]
w = pos_inputs.image_size[:, 1]
constrained_fn = self.processor.build_prefix_constrained_fn(h, w)
from transformers.generation import LogitsProcessorList, PrefixConstrainedLogitsProcessor, UnbatchedClassifierFreeGuidanceLogitsProcessor # noqa: E501
logits_processor = LogitsProcessorList([
UnbatchedClassifierFreeGuidanceLogitsProcessor(
classifier_free_guidance,
self.model,
unconditional_ids=neg_inputs.input_ids.to("cuda:0"),
),
PrefixConstrainedLogitsProcessor(constrained_fn, num_beams=1),
])

# generate
outputs = self.model.generate(
pos_inputs.input_ids.to("cuda:0"),
GENERATION_CONFIG,
logits_processor=logits_processor,
attention_mask=pos_inputs.attention_mask.to("cuda:0"),
)

mm_list = self.processor.decode(outputs[0])
for idx, im in enumerate(mm_list):
if not isinstance(im, Image.Image):
continue
im.save(os.path.join(self.output_path, f"result_{idx}.png"))
Loading