-
Notifications
You must be signed in to change notification settings - Fork 576
Add support for SWA (left, right) with FusedAttention #2477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add support for SWA (left, right) with FusedAttention #2477
Conversation
…IA#1369 Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L0 |
Greptile OverviewGreptile SummaryThis PR extends FusedAttention's sliding window attention (SWA) support from right-side-only windows to arbitrary bidirectional window configurations by adding a The parameter controls diagonal alignment in the attention matrix - when The implementation threads the parameter through all layers: from high-level PyTorch APIs ( Important Files Changed
Confidence score: 3/5
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1281 (link)logic: Trailing comma creates single-element tuple instead of boolean - should this be just
bottom_right_alignment = attn_mask_type not in ["causal", "padding_causal"]? -
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1482 (link)style: Uses hardcoded mask type check instead of the new
bottom_right_diagonalparameter for ALiBi alignment. Should this usebottom_right_diagonalparameter for consistency?Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
15 files reviewed, 8 comments
| # (This should be replaced with `bottom_right_diagonal` which is passed from the arguments) | ||
| bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The TODO comment indicates bottom_right_diagonal should replace bottom_right_alignment but the replacement wasn't implemented. This could cause incorrect behavior in alibi attention. Should bottom_right_alignment be replaced with bottom_right_diagonal parameter as the comment suggests?
| # (cyang: Why is window_size is being modified but then its value ignored | ||
| # in the following else block?) | ||
| # else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Commented-out code with developer question should be removed
| # (cyang: Why is window_size is being modified but then its value ignored | |
| # in the following else block?) | |
| # else: | |
| if window_size is None: | |
| window_size = check_set_window_size(attn_mask_type, window_size) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| "Disabling FlashAttention as it only supports sliding window with bottom right" | ||
| " diagonal alignment for cross-attention" | ||
| ) | ||
| use_flash_attention = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Variable name mismatch: use_flash_attention is being set but should be use_flash_attention_2
| use_flash_attention = False | |
| use_flash_attention_2 = False |
| "Disabling FlashAttention as it only supports ALiBi with bottom right diagonal" | ||
| " alignment for cross-attention" | ||
| ) | ||
| use_flash_attention = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Variable name mismatch: use_flash_attention is being set but should be use_flash_attention_2
| use_flash_attention = False | |
| use_flash_attention_2 = False |
| if self_attn_mask_type in {"causal", "padding_causal"}: | ||
| bottom_right_diagonal = False | ||
| if bottom_right_diagonal is None or self_attn_mask_type in { | ||
| "causal_bottom_right", | ||
| "padding_causal_bottom_right", | ||
| }: | ||
| bottom_right_diagonal = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Logic overrides the instance variable even when explicitly set in forward call - should preserve user's explicit choice. Should the mask type check override an explicitly passed bottom_right_diagonal parameter, or only apply when it's None?
| if enc_dec_attn_mask_type in {"causal", "padding_causal"}: | ||
| enc_dec_bottom_right_diagonal = False | ||
| if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in { | ||
| "causal_bottom_right", | ||
| "padding_causal_bottom_right", | ||
| }: | ||
| enc_dec_bottom_right_diagonal = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Same logic issue as above - mask type check overrides explicit parameter values
| qkv_format == NVTE_QKV_Format::NVTE_SBHD)))) || | ||
| // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} | ||
| (cudnn_runtime_version >= 90600 && | ||
| ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably remove this line? Because it's covered by the next line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a couple of SWA tests to the CP tests as well? I think it's just a matter of replacing (left,0) with (left, right) and test them out. Thanks!
|
|
||
| // NVTE fused attention FWD with packed QKV | ||
| // DEPRECATED: This API is deprecated. | ||
| // DEPRECATED: This API is deprecated. (Should there be a version by which this is going to be removed? @cyang) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made some changes in #2272, but will see if I can make the 2.11 deadline.
| sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); | ||
| } | ||
| if (cudnn_runtime_version >= 90600 && window_size_right != -1) { | ||
| // (remove comment when reviewed) Should it be `window_size_right + 1` instead? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SWA(left, right) should describe an attention window of left + 1 + right elements, but cuDNN understands it as left - 1 + 1 + right elements, so we need to add the 1 here to left to make all three backends (FlashAttention, FusedAttention, UnfusedDPA) equivalent in terms of the SWA operation.
| sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); | ||
| } | ||
| if (cudnn_runtime_version >= 90600 && window_size_right != -1) { | ||
| // (remove comment when reviewed) Should it be `window_size_right + 1` instead? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
| actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None, | ||
| actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None, | ||
| alibi_slopes=alibi_slopes, | ||
| # (This should be replaced with `bottom_right_diagonal` which is passed from the arguments) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if the logic should be something like this: bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"] if bottom_right_alignment is None else bottom_right_alignment.
| `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} | ||
| window_size : Tuple[int, int], default = None | ||
| Sliding window attention size. | ||
| bottom_right_diagonal: bool, default = `True` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the default True here?
| use_flash_attention_2 = False | ||
| # (cyang: Why is window_size is being modified but then its value ignored | ||
| # in the following else block?) | ||
| # else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right - the code under else should come out of else because we need to factor in the changes in check_set_window_size. Thanks - well spotted!
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
15 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (4)
-
transformer_engine/pytorch/transformer.py, line 777-783 (link)logic: Conditional logic overwrites explicitly passed parameter. When
self_attn_mask_type in {"causal", "padding_causal"}at line 777, this unconditionally setsbottom_right_diagonal = False, even if the user explicitly passed a different value in the forward call. The check at line 779 (if bottom_right_diagonal is None) becomes unreachable in those cases.Consider restructuring to only apply defaults when
bottom_right_diagonal is None: -
transformer_engine/pytorch/transformer.py, line 787-793 (link)logic: Same conditional logic issue for encoder-decoder attention. The mask type check at line 787 unconditionally overwrites
enc_dec_bottom_right_diagonal, ignoring explicitly passed values. -
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 920 (link)logic: Variable name mismatch: setting
use_flash_attentioninstead ofuse_flash_attention_2 -
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 947 (link)logic: Variable name mismatch: setting
use_flash_attentioninstead ofuse_flash_attention_2
15 files reviewed, 4 comments
|
/te-ci pytorch L0 |
|
/te-ci pytorch L0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1283 (link)syntax: Tuple has trailing comma causing it to be a single-element tuple instead of boolean
15 files reviewed, 1 comment
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…ansformerEngine into swa_padding_brcm_try2
|
/te-ci pytorch L0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
15 files reviewed, no comments
|
/te-ci pytorch L0 |
Description
FusedAttention supports "right" side sliding window attention for some time now. This adds support for SWA (left, right) with FusedAttention backend in TE.
(changes cherry-picked from original PR: #1369)
Type of change
Changes
Please list the changes introduced in this PR:
transformer_enginecommonfused_attnfused_attn.cppbottom_right_diagonalparameter to the APIfused_attn_f16_arbitrary_seqlen.cu: addbottom_right_diagonalparameter to the APIfused_attn_fp8.cu: addbottom_right_diagonalparameter to theFADescriptor_v1APIutils.h: addbottom_right_diagonalparameter toFADescriptor_v1APIpytorchtransformer.pybottom_right_diagonalthrough the call stack:TransformerLayer-->SelfAttention/CrossAttentionattentiondot_product_attentionbackends.py:UnfusedDotProductAttentionbottom_right_diagonalparameter to theforwardAPIforward?bottom_right_alignmentis being used in the Alibi call, perhaps this should be correctedFusedAttncustom modulebottom_right_diagonalparameter to theforwardAPIFusedAttentionmodulebottom_right_diagonalthrough the call stackdot_product_attention.pyDotProductAttentionbottom_right_diagonalthrough the call stackbottom_right_diagonalif it'sNoneutils.pyAttentionParamsget_attention_backendmulti_head_attention.pybottom_right_diagonalto forward API and callbottom_right_diagonalif it'sNonecpp_extentionsfused_attn.pybottom_right_diagonalinfused_attn_fwd/fused_attn_bwdcsrcextensionattention.cppbottom_right_diagonalthrough the call stack:fused_attn_fwd-->nvte_fused_attn_fwdextensions.hbottom_right_diagonaltofused_attn_fwdandfused_attn_bwdAPI definitionsChecklist: