Skip to content

Commit

Permalink
is causal hints for transformer (pytorch#106143)
Browse files Browse the repository at this point in the history
Summary:
make is_causal hint flags available for the top level transformer module.

It's debatable whether this is useful -- at present we autodetect causal masks for src and tgt masks in transformer encoder and decoder, respectively. is_causal flags available woul enable users to short-cut this check by asserting whether they mask is causal, or not.

I am putting this diff up for discussion, not as a solution.  Not doing anything may be the right solution, unless there is strong (data-driven) user demand. -- it appears the consensus is to move ahead with this, as per discussions below.

@cpuhrsch @mikaylagawarecki @jbschlosser @janEbert

Test Plan: sandcastle

Differential Revision: D47373260

Pull Request resolved: pytorch#106143
Approved by: https://github.com/mikaylagawarecki
  • Loading branch information
Michael Gschwind authored and pytorchmergebot committed Aug 4, 2023
1 parent e421edf commit 63d4527
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 28 deletions.
2 changes: 1 addition & 1 deletion test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2960,7 +2960,7 @@ def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op):
# of a tracing tensor with aten._local_scalar_dense.default -
# erroring out! It's likely that this is caused by data-dependent
# control flow or similar.
torch.nn.TransformerEncoder, # DataDependentOutputException: aten.equal compares a mask input
torch.nn.TransformerEncoder, # DataDependentOutputException: aten.eq compares a mask input
# to a causal mask tensor, to see if Boolean is_causal should be set
# for TrnasformerEncoder layers, MHA and sdp custom kernels
torch.nn.Transformer, # DataDependentOutputException: aten.equal compares a mask input
Expand Down
37 changes: 15 additions & 22 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3839,7 +3839,9 @@ def test(encoder_input_shape, decoder_input_shape,
src_mask_len=None, tgt_mask_len=None, memory_mask_size=None,
src_key_padding_mask_size=None, tgt_key_padding_mask_size=None,
memory_key_padding_mask_size=None,
raises=False):
src_is_causal=False, tgt_is_causal=False,
memory_is_causal=False):

encoder_input = torch.randn(encoder_input_shape)
decoder_input = torch.randn(decoder_input_shape)
model = getattr(nn, model_name)(d_model, nhead, num_encoder_layers,
Expand Down Expand Up @@ -3875,23 +3877,17 @@ def test(encoder_input_shape, decoder_input_shape,
else:
memory_key_padding_mask = None

if raises:
with self.assertRaises(RuntimeError):
model(encoder_input, decoder_input,
src_mask=src_mask,
tgt_mask=tgt_mask,
memory_mask=memory_task,
src_key_padding_mask=src_key_padding_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
else:
with self.assertRaises(RuntimeError):
model(encoder_input, decoder_input,
src_mask=src_mask,
tgt_mask=tgt_mask,
memory_mask=memory_task,
src_key_padding_mask=src_key_padding_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
memory_key_padding_mask=memory_key_padding_mask,
src_is_causal=src_is_causal,
tgt_is_causal=tgt_is_causal,
memory_is_causal=memory_is_causal)


correct_encoder_input_shape = (seq_len, bsz, d_model)
Expand All @@ -3905,22 +3901,22 @@ def update_shape(shape, dim, new_dim_size):
# Incorrect encoder_input batch size
encoder_input_shape = update_shape(correct_encoder_input_shape, 1, wrong_bsz)
decoder_input_shape = correct_decoder_input_shape
test(encoder_input_shape, decoder_input_shape, raises=True)
test(encoder_input_shape, decoder_input_shape)

# Incorrect decoder_input batch size
encoder_input_shape = correct_encoder_input_shape
decoder_input_shape = update_shape(correct_decoder_input_shape, 1, wrong_bsz)
test(encoder_input_shape, decoder_input_shape, raises=True)
test(encoder_input_shape, decoder_input_shape)

# Incorrect encoder_input input size
encoder_input_shape = update_shape(correct_encoder_input_shape, 2, wrong_d_model)
decoder_input_shape = correct_decoder_input_shape
test(encoder_input_shape, decoder_input_shape, raises=True)
test(encoder_input_shape, decoder_input_shape)

# Incorrect decoder_input input size
encoder_input_shape = correct_encoder_input_shape
decoder_input_shape = update_shape(correct_decoder_input_shape, 2, wrong_d_model)
test(encoder_input_shape, decoder_input_shape, raises=True)
test(encoder_input_shape, decoder_input_shape)

# Incorrect nhead
encoder_input_shape = correct_encoder_input_shape
Expand All @@ -3933,23 +3929,20 @@ def update_shape(shape, dim, new_dim_size):
encoder_input_shape = correct_encoder_input_shape
decoder_input_shape = correct_decoder_input_shape
wrong_src_mask_size = seq_len + 1
test(encoder_input_shape, decoder_input_shape, src_mask_len=wrong_src_mask_size,
raises=True)
test(encoder_input_shape, decoder_input_shape, src_mask_len=wrong_src_mask_size)

# Incorrect tgt_mask
encoder_input_shape = correct_encoder_input_shape
decoder_input_shape = correct_decoder_input_shape
wrong_tgt_mask_size = tgt_len + 1
test(encoder_input_shape, decoder_input_shape, tgt_mask_len=wrong_tgt_mask_size,
raises=True)
test(encoder_input_shape, decoder_input_shape, tgt_mask_len=wrong_tgt_mask_size)

# Incorrect memory_mask
encoder_input_shape = correct_encoder_input_shape
decoder_input_shape = correct_decoder_input_shape
wrong_tgt_mask_size = tgt_len + 1
test(encoder_input_shape, decoder_input_shape,
memory_mask_size=(wrong_tgt_mask_size, wrong_src_mask_size),
raises=True)
memory_mask_size=(wrong_tgt_mask_size, wrong_src_mask_size))

# Incorrect src_key_padding_mask
encoder_input_shape = correct_encoder_input_shape
Expand Down
36 changes: 31 additions & 5 deletions torch/nn/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int =

def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
src_is_causal: Optional[bool] = None, tgt_is_causal: Optional[bool] = None,
memory_is_causal: bool = False) -> Tensor:
r"""Take in and process masked source/target sequences.
Args:
Expand All @@ -132,6 +134,28 @@ def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, t
src_key_padding_mask: the Tensor mask for src keys per batch (optional).
tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional).
memory_key_padding_mask: the Tensor mask for memory keys per batch (optional).
src_is_causal: If specified, applies a causal mask as ``src_mask``.
Default: ``None``; try to detect a causal mask.
Warning:
``src_is_causal`` provides a hint that ``src_mask`` is
the causal mask. Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
tgt_is_causal: If specified, applies a causal mask as ``tgt_mask``.
Default: ``None``; try to detect a causal mask.
Warning:
``tgt_is_causal`` provides a hint that ``tgt_mask`` is
the causal mask. Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
memory_is_causal: If specified, applies a causal mask as
``memory_mask``.
Default: ``False``.
Warning:
``memory_is_causal`` provides a hint that
``memory_mask`` is the causal mask. Providing incorrect
hints can result in incorrect execution, including
forward and backward compatibility.
Shape:
- src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
Expand Down Expand Up @@ -177,10 +201,12 @@ def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, t
if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
raise RuntimeError("the feature number of src and tgt must be equal to d_model")

memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask,
is_causal=src_is_causal)
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
memory_key_padding_mask=memory_key_padding_mask,
tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal)
return output

@staticmethod
Expand Down Expand Up @@ -865,8 +891,8 @@ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:

def _detect_is_causal_mask(
mask: Optional[Tensor],
is_causal: Optional[bool],
size: Optional[int]
is_causal: Optional[bool] = None,
size: Optional[int] = None,
) -> bool:
"""Return whether the given attention mask is causal.
Expand Down

0 comments on commit 63d4527

Please sign in to comment.