From 99c62660c67fd2567c094bb49ec7d108e5670730 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 5 Dec 2024 10:37:42 +0000 Subject: [PATCH] support qwen2vl train proj only --- examples/train_full/qwen2vl_full_sft.yaml | 1 + src/llamafactory/model/model_utils/visual.py | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/train_full/qwen2vl_full_sft.yaml b/examples/train_full/qwen2vl_full_sft.yaml index d3f8d2b30..3cf299bd6 100644 --- a/examples/train_full/qwen2vl_full_sft.yaml +++ b/examples/train_full/qwen2vl_full_sft.yaml @@ -6,6 +6,7 @@ stage: sft do_train: true finetuning_type: full freeze_vision_tower: true # choices: [true, false] +train_mm_proj_only: false # choices: [true, false] deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] ### dataset diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 659261979..246b90287 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -138,11 +138,10 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni forbidden_modules.add("language_model") elif model_type == "qwen2_vl": - if finetuning_args.freeze_vision_tower: - forbidden_modules.add("visual") - if finetuning_args.train_mm_proj_only: - raise ValueError("Qwen2-VL models do not support `train_mm_proj_only`.") + forbidden_modules.update({"visual.patch_embed", "visual.blocks", "model", "lm_head"}) + elif finetuning_args.freeze_vision_tower: + forbidden_modules.add("visual") return forbidden_modules