Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing device mismatch for InternVL2_5-78B rotary embeddings #35312

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

MorenoLaQuatra
Copy link
Contributor

@MorenoLaQuatra MorenoLaQuatra commented Dec 17, 2024

Fixing problem with Multi-GPU management of InternVL2_5-78B (https://huggingface.co/OpenGVLab/InternVL2_5-78B)

What does this PR do?

Fixes # (issue)

No specific open issue fixing. I was working on inference using the documentation provided by the official model card of InternVL2_5-78B for multiple GPUs here. I got the error of mismatching devices GPU:0 and cpu, I traced back the error to this line.

It may happen to other models, maybe to newer llama vision models (3.2) but I've no access to these models in Europe (see circleci "copies" error).

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts, @qubvel, @ArthurZucker (being Text+Vision, I mentioned all the related ones)

Fixing problem with Multi-GPU management of InternVL2_5-78B (https://huggingface.co/OpenGVLab/InternVL2_5-78B)
@MorenoLaQuatra MorenoLaQuatra changed the title Update modeling_qwen2.py Fixing device mismatch for InternVL2_5-78B rotary embeddings Dec 17, 2024
qubvel
qubvel previously approved these changes Dec 17, 2024
Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @MorenoLaQuatra, thanks for the fix! Please run make fix-copies to update the code for other models.

@qubvel qubvel dismissed their stale review December 17, 2024 22:08

Test fails

@MorenoLaQuatra
Copy link
Contributor Author

MorenoLaQuatra commented Dec 18, 2024

I think that make fix-copies will "override" the changes I did in modeling_qwen2.py rather than propagating the changes from there to the other models. Does it make sense?

Here some context:

python utils/check_copies.py --fix_and_overwrite
Detected changes, rewriting src/transformers/models/qwen2/modeling_qwen2.py.
python utils/check_modular_conversion.py  --fix_and_overwrite
No differences found for src/transformers/models/gemma/configuration_gemma.py.
No differences found for src/transformers/models/gemma/tokenization_gemma.py.
No differences found for src/transformers/models/gemma/modeling_gemma.py.

I'm not sure the same problem happen when dividing llama 3.2 other kind of models on multiple GPUs and I want to avoid breaking other models. Someone is able to check maybe?

@qubvel
Copy link
Member

qubvel commented Dec 18, 2024

I think that make fix-copies will "override" the changes I did in modeling_qwen2.py rather than propagating the changes from there to the other models.

You are right, so you have to make changes to the origin. I suppose it should be safe for other models as well. We can also run slow tests in CI to make sure it hasn't broken. To make this, please, push an empty commit with [run-slow] qwen2, llama3, <changed_model_names> message (should be the last in a row, so I can approve the run).

@qubvel qubvel added Big Model Inference Problems related to the Big Model Inference capabilities provided by Accelerate run-slow labels Dec 18, 2024
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! would you mind providing a reproducer, the exact same function is used for llama, for which we did not have issues! 🤗

@qubvel
Copy link
Member

qubvel commented Jan 7, 2025

Similar issues/PRs

It seems like some issue exists, reproducing example would be great

@MorenoLaQuatra
Copy link
Contributor Author

I reproduced the issue with "minimal" code:

import torch
import torchvision.transforms as T
from PIL import Image
import requests
from io import BytesIO
from transformers import AutoModel, AutoTokenizer
from torchvision.transforms.functional import InterpolationMode
import math
from typing import Dict, List, Tuple

class ImageProcessor:
    def __init__(self):
        self.input_size = 448
        self.max_num_patches = 12
        self.min_num_patches = 1
        self.use_thumbnail = True
        self.transform = self._build_transform()

    def _build_transform(self):
        return T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize((self.input_size, self.input_size), interpolation=InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

    def _find_closest_aspect_ratio(
        self,
        aspect_ratio: float,
        width: int,
        height: int,
        image_size: int
    ) -> Tuple[int, int]:
        target_ratios = set(
            (i, j) 
            for n in range(self.min_num_patches, self.max_num_patches + 1)
            for i in range(1, n + 1) 
            for j in range(1, n + 1)
            if i * j <= self.max_num_patches and i * j >= self.min_num_patches
        )
        target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

        best_ratio_diff = float('inf')
        best_ratio = (1, 1)
        area = width * height

        for ratio in target_ratios:
            target_aspect_ratio = ratio[0] / ratio[1]
            ratio_diff = abs(aspect_ratio - target_aspect_ratio)
            if ratio_diff < best_ratio_diff:
                best_ratio_diff = ratio_diff
                best_ratio = ratio
            elif ratio_diff == best_ratio_diff:
                if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                    best_ratio = ratio
        return best_ratio

    def _dynamic_preprocess(self, image: Image.Image) -> List[Image.Image]:
        orig_width, orig_height = image.size
        aspect_ratio = orig_width / orig_height
        image_size = self.input_size

        target_aspect_ratio = self._find_closest_aspect_ratio(
            aspect_ratio, orig_width, orig_height, image_size
        )

        target_width = image_size * target_aspect_ratio[0]
        target_height = image_size * target_aspect_ratio[1]
        blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

        resized_img = image.resize((target_width, target_height))
        processed_images = []
        
        for i in range(blocks):
            box = (
                (i % (target_width // image_size)) * image_size,
                (i // (target_width // image_size)) * image_size,
                ((i % (target_width // image_size)) + 1) * image_size,
                ((i // (target_width // image_size)) + 1) * image_size
            )
            split_img = resized_img.crop(box)
            processed_images.append(split_img)

        if self.use_thumbnail and len(processed_images) != 1:
            thumbnail_img = image.resize((image_size, image_size))
            processed_images.append(thumbnail_img)

        return processed_images

    def process_image(self, image_url: str) -> torch.Tensor:
        response = requests.get(image_url)
        image = Image.open(BytesIO(response.content)).convert('RGB')
        processed_images = self._dynamic_preprocess(image)
        pixel_values = [self.transform(img) for img in processed_images]
        return torch.stack(pixel_values).to(torch.bfloat16).cuda()

class ModelManager:
    def __init__(self):
        self.model_name = "OpenGVLab/InternVL2_5-78B"
        self.model = None
        self.tokenizer = None
        self.generation_config = {
            'max_new_tokens': 1024,
            'do_sample': True
        }

    def _split_model(self) -> Dict[str, int]:
        device_map = {}
        world_size = torch.cuda.device_count()
        print(f"Found {world_size} GPUs")
        
        num_layers = 80  # InternVL2_5-78B has 80 layers
        
        num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
        num_layers_per_gpu = [num_layers_per_gpu] * world_size
        num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
        
        layer_cnt = 0
        for i, num_layer in enumerate(num_layers_per_gpu):
            for _ in range(num_layer):
                device_map[f'language_model.model.layers.{layer_cnt}'] = i
                layer_cnt += 1

        base_components = {
            'vision_model': 0,
            'mlp1': 0,
            'language_model.model.tok_embeddings': 0,
            'language_model.model.embed_tokens': 0,
            'language_model.output': 0,
            'language_model.model.norm': 0,
            'language_model.lm_head': 0
        }
        device_map.update(base_components)
        device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
        return device_map

    def initialize_model(self):
        device_map = self._split_model()
        self.model = AutoModel.from_pretrained(
            self.model_name,
            torch_dtype=torch.bfloat16,
            load_in_8bit=False,
            low_cpu_mem_usage=True,
            use_flash_attn=True,
            trust_remote_code=True,
            device_map=device_map
        ).eval()
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name,
            trust_remote_code=True,
            use_fast=False
        )
        self.model.generation_config.pad_token_id = self.tokenizer.eos_token_id

    def generate_caption(self, pixel_values: torch.Tensor) -> str:
        prompt = "Please describe this image."
        return self.model.chat(
            tokenizer=self.tokenizer,
            pixel_values=pixel_values,
            question=prompt,
            generation_config=self.generation_config
        )

def main():
    # Initialize components
    image_processor = ImageProcessor()
    model_manager = ModelManager()
    
    # Initialize model
    print("Initializing model...")
    model_manager.initialize_model()
    
    # Process image
    image_url = "https://images.unsplash.com/photo-1507146426996-ef05306b995a?fm=jpg&q=60&w=3000&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxzZWFyY2h8M3x8cHVwcHl8ZW58MHx8MHx8fDA%3D"
    print("Processing image...")
    pixel_values = image_processor.process_image(image_url)
    
    # Generate caption
    print("Generating caption...")
    caption = model_manager.generate_caption(pixel_values)
    print(f"\nGenerated caption: {caption}")

if __name__ == "__main__":
    main()

I run this on a machine with 4xA100 80GB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Big Model Inference Problems related to the Big Model Inference capabilities provided by Accelerate run-slow
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants