diff --git a/scripts/stat_utils/cal_lr.py b/scripts/stat_utils/cal_lr.py index 23e16d831..a76d5827b 100644 --- a/scripts/stat_utils/cal_lr.py +++ b/scripts/stat_utils/cal_lr.py @@ -24,7 +24,7 @@ from tqdm import tqdm from transformers import DataCollatorForLanguageModeling -from llamafactory.data import get_dataset, get_template_and_fix_tokenizer, MultiModalDataCollatorForSeq2Seq +from llamafactory.data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.hparams import get_train_args from llamafactory.model import load_tokenizer @@ -71,7 +71,9 @@ def calculate_lr( if stage == "pt": data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) elif stage == "sft": - data_collator = MultiModalDataCollatorForSeq2Seq(template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX) + data_collator = MultiModalDataCollatorForSeq2Seq( + template=template, tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX + ) else: raise NotImplementedError(f"Stage does not supported: {stage}.") diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py index 1eae6842e..063f457c6 100644 --- a/scripts/vllm_infer.py +++ b/scripts/vllm_infer.py @@ -16,16 +16,25 @@ import fire from transformers import Seq2SeqTrainingArguments -from vllm import LLM, SamplingParams -from vllm.lora.request import LoRARequest from llamafactory.data import get_dataset, get_template_and_fix_tokenizer from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.misc import get_device_count +from llamafactory.extras.packages import is_pillow_available, is_vllm_available from llamafactory.hparams import get_infer_args from llamafactory.model import load_tokenizer +if is_pillow_available(): + from PIL import Image + from PIL.Image import Image as ImageObject + + +if is_vllm_available(): + from vllm import LLM, SamplingParams + from vllm.lora.request import LoRARequest + + def vllm_infer( model_name_or_path: str, adapter_name_or_path: str = None, @@ -64,15 +73,29 @@ def vllm_infer( ) ) - training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir", predict_with_generate=True) + training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir") tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - template = get_template_and_fix_tokenizer(tokenizer, data_args) - dataset = get_dataset(template, model_args, data_args, training_args, "ppo", **tokenizer_module)["train_dataset"] + template_obj = get_template_and_fix_tokenizer(tokenizer, data_args) + template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate + dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module) inputs, prompts, labels = [], [], [] - for sample in dataset: - inputs.append({"prompt_token_ids": sample["input_ids"]}) + for sample in dataset_module["train_dataset"]: + if sample["images"]: + multi_modal_data = {"image": []} + for image in sample["images"]: + if not isinstance(image, (str, ImageObject)): + raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.") + + if isinstance(image, str): + image = Image.open(image).convert("RGB") + + multi_modal_data["image"].append(image) + else: + multi_modal_data = None + + inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data}) prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False)) labels.append( tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False) @@ -100,6 +123,9 @@ def vllm_infer( "disable_log_stats": True, "enable_lora": model_args.adapter_name_or_path is not None, } + if template_obj.mm_plugin.__class__.__name__ != "BasePlugin": + engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2} + if isinstance(model_args.vllm_config, dict): engine_args.update(model_args.vllm_config) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index cf567d5f5..a8f12faa3 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -19,7 +19,7 @@ from ..data import get_template_and_fix_tokenizer from ..extras import logging -from ..extras.constants import IMAGE_PLACEHOLDER +from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.misc import get_device_count from ..extras.packages import is_pillow_available, is_vllm_available from ..model import load_config, load_tokenizer @@ -67,6 +67,7 @@ def __init__( self.processor = tokenizer_module["processor"] self.tokenizer.padding_side = "left" self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) + self.template.mm_plugin.expand_mm_tokens = False # for vllm generate self.generating_args = generating_args.to_dict() engine_args = { @@ -83,6 +84,9 @@ def __init__( "enable_lora": model_args.adapter_name_or_path is not None, "max_lora_rank": model_args.vllm_max_lora_rank, } + if self.template.mm_plugin.__class__.__name__ != "BasePlugin": + engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2} + if isinstance(model_args.vllm_config, dict): engine_args.update(model_args.vllm_config) @@ -108,19 +112,21 @@ async def _generate( **input_kwargs, ) -> AsyncIterator["RequestOutput"]: request_id = f"chatcmpl-{uuid.uuid4().hex}" + mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]} if images is not None: + mm_input_dict.update({"images": images, "imglens": [len(images)]}) if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] - if self.template.mm_plugin.__class__.__name__ == "Qwen2vlPlugin": # temporary solution - image_str = f"<|vision_start|>{self.template.mm_plugin.image_token}<|vision_end|>" - else: - image_str = self.template.mm_plugin.image_token or "" + if videos is not None: + mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]}) + if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] - paired_messages = [ - {"role": message["role"], "content": message["content"].replace(IMAGE_PLACEHOLDER, image_str)} - for message in messages - ] + [{"role": "assistant", "content": ""}] + messages = self.template.mm_plugin.process_messages( + messages, mm_input_dict["images"], mm_input_dict["videos"], self.processor + ) + paired_messages = messages + [{"role": "assistant", "content": ""}] system = system or self.generating_args["default_system"] prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) prompt_length = len(prompt_ids) @@ -168,7 +174,7 @@ async def _generate( ) if images is not None: # add image features - image_data = [] + multi_modal_data = {"image": []} for image in images: if not isinstance(image, (str, ImageObject)): raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.") @@ -176,9 +182,7 @@ async def _generate( if isinstance(image, str): image = Image.open(image).convert("RGB") - image_data.append(image) - - multi_modal_data = {"image": image_data} + multi_modal_data["image"].append(image) else: multi_modal_data = None diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 389e27b16..4e1f418af 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -62,6 +62,7 @@ class BasePlugin: def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None: self.image_token = image_token self.video_token = video_token + self.expand_mm_tokens = True def _validate_input( self, @@ -259,7 +260,7 @@ def process_messages( ) -> List[Dict[str, str]]: self._validate_input(images, videos) num_image_tokens = 0 - image_seqlen = getattr(processor, "image_seqlen") + image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1 messages = deepcopy(messages) for message in messages: content = message["content"] @@ -310,11 +311,13 @@ def process_messages( for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: - image_size = next(image_sizes) - orig_height, orig_width = image_size - image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) - if getattr(processor, "vision_feature_select_strategy") == "default": - image_seqlen -= 1 + if self.expand_mm_tokens: + orig_height, orig_width = next(image_sizes) + image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) + if getattr(processor, "vision_feature_select_strategy") == "default": + image_seqlen -= 1 + else: + image_seqlen = 1 num_image_tokens += 1 content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) @@ -359,11 +362,13 @@ def process_messages( for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: - image_size = next(image_sizes) - orig_height, orig_width = image_size - image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) - if getattr(processor, "vision_feature_select_strategy") == "default": - image_seqlen -= 1 + if self.expand_mm_tokens: + orig_height, orig_width = next(image_sizes) + image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) + if getattr(processor, "vision_feature_select_strategy") == "default": + image_seqlen -= 1 + else: + image_seqlen = 1 num_image_tokens += 1 content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) @@ -376,6 +381,7 @@ def process_messages( num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer + video_seqlen = video_seqlen if self.expand_mm_tokens else 1 for message in messages: content = message["content"] while VIDEO_PLACEHOLDER in content: @@ -443,7 +449,7 @@ def process_token_ids( ) -> Tuple[List[int], Optional[List[int]]]: self._validate_input(images, videos) num_images = len(images) - image_seqlen = num_images * getattr(processor, "image_seqlen") + image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) input_ids = [image_token_id] * image_seqlen + input_ids if labels is not None: @@ -493,14 +499,18 @@ def process_messages( if image_input_sizes is None: raise ValueError("Cannot get image input sizes.") - image_size = image_input_sizes[0][num_image_tokens] - height, width = image_size - num_height_tokens = height // patch_size - num_width_tokens = width // patch_size - replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens - replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list - replace_tokens[-1] = image_end_token - replace_str = "".join(replace_tokens) + if self.expand_mm_tokens: + image_size = image_input_sizes[0][num_image_tokens] + height, width = image_size + num_height_tokens = height // patch_size + num_width_tokens = width // patch_size + replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens + replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list + replace_tokens[-1] = image_end_token + replace_str = "".join(replace_tokens) + else: + replace_str = image_token + content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1) num_image_tokens += 1 @@ -549,10 +559,27 @@ def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": return image @override - def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int: - sample_frames = super()._get_video_sample_frames(video_stream, **kwargs) - sample_frames = sample_frames // 2 * 2 - return sample_frames + def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]: + results = [] + for video in videos: + container = av.open(video, "r") + video_stream = next(stream for stream in container.streams if stream.type == "video") + total_frames = video_stream.frames + sample_frames = self._get_video_sample_frames(video_stream, **kwargs) + sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) + frames: List["ImageObject"] = [] + container.seek(0) + for frame_idx, frame in enumerate(container.decode(video_stream)): + if frame_idx in sample_indices: + frames.append(frame.to_image()) + + if len(frames) % 2 != 0: # qwen2-vl requires even number of frames + frames.append(frames[-1]) + + frames = self._regularize_images(frames, **kwargs) + results.append(frames) + + return results @override def process_messages( @@ -577,12 +604,9 @@ def process_messages( if num_image_tokens >= len(image_grid_thw): raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") + image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 content = content.replace( - IMAGE_PLACEHOLDER, - "<|vision_start|>{}<|vision_end|>".format( - self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length) - ), - 1, + IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1 ) num_image_tokens += 1 @@ -590,12 +614,9 @@ def process_messages( if num_video_tokens >= len(video_grid_thw): raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") + video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1 content = content.replace( - VIDEO_PLACEHOLDER, - "<|vision_start|>{}<|vision_end|>".format( - self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length) - ), - 1, + VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1 ) num_video_tokens += 1 @@ -640,19 +661,22 @@ def process_messages( has_images = "pixel_values_images" in mm_inputs has_videos = "pixel_values_videos" in mm_inputs if has_images or has_videos: - if has_images: - height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0])) - num_frames = 1 - - if has_videos: - pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) - height, width = get_image_size(pixel_values_video[0]) - num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim - - image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1 - video_seqlen = image_seqlen * num_frames - if getattr(processor, "vision_feature_select_strategy") == "default": - image_seqlen -= 1 + if self.expand_mm_tokens: + if has_images: + height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0])) + num_frames = 1 + + if has_videos: + pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) + height, width = get_image_size(pixel_values_video[0]) + num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim + + image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1 + video_seqlen = image_seqlen * num_frames + if getattr(processor, "vision_feature_select_strategy") == "default": + image_seqlen -= 1 + else: + image_seqlen, video_seqlen = 1, 1 for message in messages: content = message["content"]