Skip to content

Commit

Permalink
Allow an arbitrary mask to be used in the self attention (#8235)
Browse files Browse the repository at this point in the history
### Description

The aim of this PR is to enable the use of an arbitrary mask in the self
attention module, which is very useful in the case of missing data or
masked modeling.

Official torch implementations allow the use of an arbitrary mask, and
in MONAI the use of a mask is also made possible with the `causal`
argument. Here, it's just a generalization directly in the forward pass.

In the `SABlock` and `TransformerBlock`, it is now possible to input a
boolean mask of size `(BS, Seq_length)`.
Only the columns of the masked token are set to `-inf` and not the rows,
as is rarely the case in common implementations. Masked tokens don't
contribute to the gradient anyway.
In cases where causal attention is required, inputting a mask is not
supported to avoid masks overlapping.

I haven't implemented the addition mask to the attention matrix, which
allows you to use values other than `-inf` in certain cases, as may be
the case here:
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

If you think it's relevant, it could be added.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [ ] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
Co-authored-by: YunLiu <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
  • Loading branch information
3 people authored Nov 26, 2024
1 parent 3ee4cd2 commit 649c7c8
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 6 deletions.
22 changes: 18 additions & 4 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from __future__ import annotations

from typing import Tuple, Union
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -154,10 +154,12 @@ def __init__(
)
self.input_size = input_size

def forward(self, x):
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
"""
Args:
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
attn_mask (torch.Tensor, optional): mask to apply to the attention matrix.
B x (s_dim_1 * ... * s_dim_n). Defaults to None.
Return:
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
Expand All @@ -176,7 +178,13 @@ def forward(self, x):

if self.use_flash_attention:
x = F.scaled_dot_product_attention(
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
query=q,
key=k,
value=v,
attn_mask=attn_mask,
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
Expand All @@ -186,10 +194,16 @@ def forward(self, x):
att_mat = self.rel_positional_embedding(x, att_mat, q)

if self.causal:
if attn_mask is not None:
raise ValueError("Causal attention does not support attention masks.")
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1)
att_mat = att_mat.masked_fill(attn_mask == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
Expand Down
6 changes: 4 additions & 2 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ def __init__(
use_flash_attention=use_flash_attention,
)

def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + self.attn(self.norm1(x))
def forward(
self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
x = x + self.attn(self.norm1(x), attn_mask=attn_mask)
if self.with_cross_attention:
x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
x = x + self.mlp(self.norm2(x))
Expand Down
18 changes: 18 additions & 0 deletions tests/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,24 @@ def test_causal(self):
# check upper triangular part of the attention matrix is zero
assert torch.triu(block.att_mat, diagonal=1).sum() == 0

def test_masked_selfattention(self):
n = 64
block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, sequence_length=16, save_attn=True)
input_shape = (1, n, 128)
# generate a mask randomly with zeros and ones of shape (1, n)
mask = torch.randint(0, 2, (1, n)).bool()
block(torch.randn(input_shape), attn_mask=mask)
att_mat = block.att_mat.squeeze()
# ensure all masked columns are zeros
assert torch.allclose(att_mat[:, ~mask.squeeze(0)], torch.zeros_like(att_mat[:, ~mask.squeeze(0)]))

def test_causal_and_mask(self):
with self.assertRaises(ValueError):
block = SABlock(hidden_size=128, num_heads=1, causal=True, sequence_length=64)
inputs = torch.randn(2, 64, 128)
mask = torch.randint(0, 2, (2, 64)).bool()
block(inputs, attn_mask=mask)

@skipUnless(has_einops, "Requires einops")
def test_access_attn_matrix(self):
# input format
Expand Down

0 comments on commit 649c7c8

Please sign in to comment.