Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow an arbitrary mask to be used in the self attention (#8235)
### Description The aim of this PR is to enable the use of an arbitrary mask in the self attention module, which is very useful in the case of missing data or masked modeling. Official torch implementations allow the use of an arbitrary mask, and in MONAI the use of a mask is also made possible with the `causal` argument. Here, it's just a generalization directly in the forward pass. In the `SABlock` and `TransformerBlock`, it is now possible to input a boolean mask of size `(BS, Seq_length)`. Only the columns of the masked token are set to `-inf` and not the rows, as is rarely the case in common implementations. Masked tokens don't contribute to the gradient anyway. In cases where causal attention is required, inputting a mask is not supported to avoid masks overlapping. I haven't implemented the addition mask to the attention matrix, which allows you to use values other than `-inf` in certain cases, as may be the case here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html If you think it's relevant, it could be added. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Lucas Robinet <[email protected]> Signed-off-by: Lucas Robinet <[email protected]> Co-authored-by: YunLiu <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
- Loading branch information