Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for 4D custom attention masks in GPT-2 #35517

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

sambhavnoobcoder
Copy link

Problem Statement

Currently, GPT-2's attention mechanism only supports 2D attention masks, limiting its flexibility for advanced use cases like packed sequence processing. When users attempt to use 4D attention masks (shape [batch_size, num_heads, seq_length, seq_length]), the model raises dimension mismatch errors.

Issue #35290 demonstrates this limitation when trying to process packed sequences with custom attention patterns.

Proposed Solution

Extend GPT-2's attention mechanism to properly handle both 2D and 4D attention masks while maintaining backward compatibility. This allows for:

  • Direct support for packed sequence processing
  • More flexible attention patterns
  • Compatibility with existing 2D mask implementations

Implementation Details

The changes focus on the GPT2Attention class, specifically:

  1. Updated attention mask handling in the forward pass
  2. Maintained compatibility with existing 2D attention masks
  3. Preserved causal attention behavior when required

Testing Strategy

Added comprehensive test suite (test_modeling_4D_attention_mask.py) that verifies:

  • Shape compatibility with 4D masks
  • Correctness of attention patterns
  • Consistency between 2D and 4D mask results
  • Causal attention preservation
  • Batch processing consistency
  • Edge cases (empty sequences, single tokens, maximum length)

Test Results

All tests passed successfully. Screenshot of test results:
Screenshot 2025-01-06 at 3 49 52 AM

Impact and Benefits

This enhancement:

  1. Enables efficient packed sequence processing
  2. Provides more flexibility in attention pattern design
  3. Maintains backward compatibility
  4. Improves model versatility without performance overhead

Validation

  • ✅ New test suite validates 4D mask functionality
  • ✅ Backward compatible with existing 2D masks
  • ✅ No performance regression

Related Issues

Closes #35290 - Support for 4D attention masks in GPT-2

Additional Notes

  • No breaking changes introduced
  • Existing model weights remain compatible
  • Performance impact is negligible

requested reviewers - @ArthurZucker

@sambhavnoobcoder sambhavnoobcoder changed the title Add support for 4D attention masks in GPT-2 Add support for 4D custom attention masks in GPT-2 Jan 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Custom 4D tensor caused shape mismatch error
1 participant