Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ola Model #752

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,10 @@
'valley_eagle': partial(ValleyEagleChat, model_path='bytedance-research/Valley-Eagle-7B'),
}

ola_series = {
'ola': partial(Ola, model_path='THUdyh/Ola-7b'),
}

ross_series = {
'ross-qwen2-7b': partial(Ross, model_path='HaochenWang/ross-qwen2-7b'),
}
Expand All @@ -477,7 +481,7 @@
mantis_series, mmalaya_series, phi3_series, xgen_mm_series, qwen2vl_series,
slime_series, eagle_series, moondream_series, llama_series, molmo_series,
kosmos_series, points_series, nvlm_series, vintern_series, h2ovl_series, aria_series,
smolvlm_series, sail_series, valley_series, vita_series, ross_series, emu_series, ursa_series
smolvlm_series, sail_series, valley_series, vita_series, ross_series, emu_series, ola_series, ursa_series
]

for grp in model_groups:
Expand Down
3 changes: 2 additions & 1 deletion vlmeval/vlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,5 @@
from .sail_vl import SailVL
from .valley import ValleyEagleChat
from .ross import Ross
from .ursa import UrsaChat
from .ola import Ola
from .ursa import UrsaChat
1 change: 1 addition & 0 deletions vlmeval/vlm/ola/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .ola_model import Ola
65 changes: 65 additions & 0 deletions vlmeval/vlm/ola/ola/arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import transformers

from dataclasses import dataclass, field
from typing import Optional


@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
version: Optional[str] = field(default="v0")
freeze_backbone: bool = field(default=False)
tune_speech_projector: bool = field(default=False)
tune_speech_encoder: bool = field(default=False)
tune_speech_generator_only: bool = field(default=False)
speech_encoder_type: Optional[str] = field(default=None)
speech_encoder: Optional[str] = field(default=None)
pretrain_speech_projector: Optional[str] = field(default=None)
speech_projector_type: Optional[str] = field(default='linear')
speech_encoder_ds_rate: int = 5
speech_encoder_hidden_size: int = 1280


@dataclass
class DataArguments:
data_path: str = field(default=None,
metadata={"help": "Path to the training data."})
is_multimodal: bool = False
input_type: str = field(default="mel")
speech_normalize: bool = False
mel_size: int = 128
has_tgt_units: bool = False


@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
freeze_speech_projector: bool = field(default=False)
model_max_length: int = field(
default=512,
metadata={
"help":
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
double_quant: bool = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."}
)
quant_type: str = field(
default="nf4",
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
)
bits: int = field(
default=16,
metadata={"help": "How many bits to use."}
)
lora_enable: bool = False
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
speech_projector_lr: Optional[float] = None
group_by_modality_length: bool = field(default=False)
14 changes: 14 additions & 0 deletions vlmeval/vlm/ola/ola/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15

LOGDIR = "."

# Model Constants
IGNORE_INDEX = -100
SPEECH_TOKEN_INDEX = -200
DEFAULT_SPEECH_TOKEN = "<speech>"
IMAGE_TOKEN_INDEX= -300
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
137 changes: 137 additions & 0 deletions vlmeval/vlm/ola/ola/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import dataclasses
from enum import auto, Enum
from typing import List, Any, Union, Tuple
import base64
from io import BytesIO
from PIL import Image


class SeparatorStyle(Enum):
"""Different separator style."""
TWO = auto()
PLAIN = auto()
CHATML = auto()
LLAMA_2 = auto()
LLAMA_3 = auto()
QWEN2 = auto()


@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.PLAIN
sep: str = "###"
sep2: str = None
version: str = "Unknown"

tokenizer_id: str = ""
tokenizer: Any = None
# Stop criteria (the default one is EOS token)
stop_str: Union[str, List[str]] = None
# Stops generation if meeting any token in this list
stop_token_ids: List[int] = None

skip_next: bool = False

def get_prompt(self):
messages = self.messages

if self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message = message[0]
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.QWEN2:
start = '<|im_start|>'
end = '<|im_end|>\n'
ret = start + 'system\n' + self.system + end
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message

if message.endswith('<|endoftext|>'):
message = message.replace('<|endoftext|>', '')
ret += start + role + "\n" + message + end + '<|endoftext|>'
else:
assert not '<|endoftext|>' in message, f"Invalid message: {message}"
ret += start + role + "\n" + message + end
else:
ret += start + role + "\n"
else:
raise ValueError(f"Invalid style: {self.sep_style}")

return ret

def append_message(self, role, message):
self.messages.append([role, message])

def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
msg, speech = msg
ret.append([msg, None])
else:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret

def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
version=self.version)

def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}

conv_qwen_v1 = Conversation(
system="You are a helpful assistant.",
roles=("user", "assistant"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.QWEN2,
)

default_conversation = conv_qwen_v1
conv_templates = {
'v1_qwen2': conv_qwen_v1,
}


if __name__ == "__main__":
print(default_conversation.get_prompt())
Empty file.
Loading