Skip to content

Conversation

@sudhakarsingh27
Copy link
Collaborator

Description

FusedAttention supports "right" side sliding window attention for some time now. This adds support for SWA (left, right) with FusedAttention backend in TE.
(changes cherry-picked from original PR: #1369)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

transformer_engine

  • common

    • fused_attn
      • fused_attn.cpp
        • add bottom_right_diagonal parameter to the API
        • Edit the filters to allow sliding window config to pick arbitrary seqlen fused attn backend
      • fused_attn_f16_arbitrary_seqlen.cu: add bottom_right_diagonal parameter to the API
      • fused_attn_fp8.cu: add bottom_right_diagonal parameter to the FADescriptor_v1 API
      • utils.h: add bottom_right_diagonal parameter to FADescriptor_v1 API
  • pytorch

    • transformer.py
      • plumb bottom_right_diagonal through the call stack: TransformerLayer --> SelfAttention/CrossAttention
    • attention
      • dot_product_attention
        • backends.py:
          • UnfusedDotProductAttention
            • add bottom_right_diagonal parameter to the forward API
              • why is it not used in the forward?
                • bottom_right_alignment is being used in the Alibi call, perhaps this should be corrected
          • FusedAttn custom module
            • add bottom_right_diagonal parameter to the forward API
          • FusedAttention module
            • plumb bottom_right_diagonal through the call stack
        • dot_product_attention.py
          • DotProductAttention
            • Plumb bottom_right_diagonal through the call stack
            • Add calculation of bottom_right_diagonal if it's None
        • utils.py
          • AttentionParams
            • [x]
          • get_attention_backend
            • update sliding window filter section
            • update attention bias filter section
      • multi_head_attention.py
        • Add bottom_right_diagonal to forward API and call
        • Add calculation of bottom_right_diagonal if it's None
    • cpp_extentions
      • fused_attn.py
        • plumb bottom_right_diagonal in fused_attn_fwd/fused_attn_bwd
    • csrc
      • extension
        • attention.cpp
          • plumb bottom_right_diagonal through the call stack: fused_attn_fwd --> nvte_fused_attn_fwd
          • same as above for bwd
      • extensions.h
        • add bottom_right_diagonal to fused_attn_fwd and fused_attn_bwd API definitions

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…IA#1369

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 4, 2025

Greptile Overview

Greptile Summary

This PR extends FusedAttention's sliding window attention (SWA) support from right-side-only windows to arbitrary bidirectional window configurations by adding a bottom_right_diagonal parameter throughout the entire call stack. Previously, FusedAttention only supported configurations like window_size = (left, 0), but this change enables arbitrary configurations like window_size = (left, right).

The parameter controls diagonal alignment in the attention matrix - when False, the sliding window and ALiBi diagonal align to the top-left corner (typical for causal masks); when True, they align to the bottom-right corner (for other mask types). The change leverages cuDNN 9.6+ features that added support for arbitrary sliding window configurations.

The implementation threads the parameter through all layers: from high-level PyTorch APIs (TransformerLayer, MultiheadAttention, DotProductAttention) down through C++ extensions to the underlying CUDA kernels. The change includes automatic default value logic based on mask types to maintain backward compatibility while enabling the new functionality.

Important Files Changed

Filename Score Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py 2/5 Adds parameter support but has critical bug - UnfusedDotProductAttention doesn't use the passed parameter and still uses hardcoded logic
transformer_engine/pytorch/transformer.py 2/5 Adds complex conditional logic for setting defaults that appears overly complicated and potentially error-prone
transformer_engine/common/fused_attn/fused_attn.cpp 4/5 Core backend selection with complex cuDNN version-dependent logic that needs careful validation
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py 4/5 Main attention class changes with comprehensive parameter threading and automatic defaults
transformer_engine/pytorch/attention/dot_product_attention/utils.py 4/5 Backend selection logic updates with important filtering condition changes
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu 5/5 Clean CUDA kernel implementation with proper parameter threading
transformer_engine/common/fused_attn/fused_attn_fp8.cu 5/5 FP8 implementation with intentionally hardcoded values for incomplete feature
transformer_engine/common/include/transformer_engine/fused_attn.h 5/5 Comprehensive API extension with proper documentation across all function signatures
transformer_engine/pytorch/csrc/extensions/attention.cpp 5/5 Clean C++ extension implementation properly threading parameter through call stack
transformer_engine/pytorch/attention/multi_head_attention.py 5/5 Well-implemented parameter addition with sensible defaults based on mask types
transformer_engine/pytorch/cpp_extensions/fused_attn.py 5/5 Clean Python wrapper implementation with proper documentation and parameter threading
tests/pytorch/attention/test_attention.py 5/5 Simple test coverage extension for multiple QKV layouts
transformer_engine/common/fused_attn/utils.h 5/5 Straightforward struct extension with correct comparison operator updates
transformer_engine/pytorch/csrc/extensions.h 5/5 Clean header file extension maintaining API consistency
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h 5/5 Simple header extension with consistent parameter addition

Confidence score: 3/5

  • This PR adds significant functionality but has critical implementation issues that need to be addressed before merging
  • Score lowered due to unused parameter in UnfusedDotProductAttention (critical bug), overly complex conditional logic in transformer.py, and complex version-dependent backend selection logic that increases maintenance burden
  • Pay close attention to transformer_engine/pytorch/attention/dot_product_attention/backends.py where the parameter is added but not actually used in the ALiBi computation

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1281 (link)

    logic: Trailing comma creates single-element tuple instead of boolean - should this be just bottom_right_alignment = attn_mask_type not in ["causal", "padding_causal"]?

  2. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1482 (link)

    style: Uses hardcoded mask type check instead of the new bottom_right_diagonal parameter for ALiBi alignment. Should this use bottom_right_diagonal parameter for consistency?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

15 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +453 to 455
# (This should be replaced with `bottom_right_diagonal` which is passed from the arguments)
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The TODO comment indicates bottom_right_diagonal should replace bottom_right_alignment but the replacement wasn't implemented. This could cause incorrect behavior in alibi attention. Should bottom_right_alignment be replaced with bottom_right_diagonal parameter as the comment suggests?

Comment on lines +886 to +888
# (cyang: Why is window_size is being modified but then its value ignored
# in the following else block?)
# else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Commented-out code with developer question should be removed

Suggested change
# (cyang: Why is window_size is being modified but then its value ignored
# in the following else block?)
# else:
if window_size is None:
window_size = check_set_window_size(attn_mask_type, window_size)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

"Disabling FlashAttention as it only supports sliding window with bottom right"
" diagonal alignment for cross-attention"
)
use_flash_attention = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Variable name mismatch: use_flash_attention is being set but should be use_flash_attention_2

Suggested change
use_flash_attention = False
use_flash_attention_2 = False

"Disabling FlashAttention as it only supports ALiBi with bottom right diagonal"
" alignment for cross-attention"
)
use_flash_attention = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Variable name mismatch: use_flash_attention is being set but should be use_flash_attention_2

Suggested change
use_flash_attention = False
use_flash_attention_2 = False

Comment on lines +777 to +783
if self_attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or self_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Logic overrides the instance variable even when explicitly set in forward call - should preserve user's explicit choice. Should the mask type check override an explicitly passed bottom_right_diagonal parameter, or only apply when it's None?

Comment on lines +787 to +793
if enc_dec_attn_mask_type in {"causal", "padding_causal"}:
enc_dec_bottom_right_diagonal = False
if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
enc_dec_bottom_right_diagonal = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Same logic issue as above - mask type check overrides explicit parameter values

qkv_format == NVTE_QKV_Format::NVTE_SBHD)))) ||
// 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd}
(cudnn_runtime_version >= 90600 &&
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably remove this line? Because it's covered by the next line?

Copy link
Collaborator

@cyanguwa cyanguwa Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a couple of SWA tests to the CP tests as well? I think it's just a matter of replacing (left,0) with (left, right) and test them out. Thanks!


// NVTE fused attention FWD with packed QKV
// DEPRECATED: This API is deprecated.
// DEPRECATED: This API is deprecated. (Should there be a version by which this is going to be removed? @cyang)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made some changes in #2272, but will see if I can make the 2.11 deadline.

sdpa_options.set_diagonal_band_left_bound(window_size_left + 1);
}
if (cudnn_runtime_version >= 90600 && window_size_right != -1) {
// (remove comment when reviewed) Should it be `window_size_right + 1` instead?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SWA(left, right) should describe an attention window of left + 1 + right elements, but cuDNN understands it as left - 1 + 1 + right elements, so we need to add the 1 here to left to make all three backends (FlashAttention, FusedAttention, UnfusedDPA) equivalent in terms of the SWA operation.

sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1);
}
if (cudnn_runtime_version >= 90600 && window_size_right != -1) {
// (remove comment when reviewed) Should it be `window_size_right + 1` instead?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None,
actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None,
alibi_slopes=alibi_slopes,
# (This should be replaced with `bottom_right_diagonal` which is passed from the arguments)
Copy link
Collaborator

@cyanguwa cyanguwa Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if the logic should be something like this: bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"] if bottom_right_alignment is None else bottom_right_alignment.

`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
window_size : Tuple[int, int], default = None
Sliding window attention size.
bottom_right_diagonal: bool, default = `True`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the default True here?

use_flash_attention_2 = False
# (cyang: Why is window_size is being modified but then its value ignored
# in the following else block?)
# else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right - the code under else should come out of else because we need to factor in the changes in check_set_window_size. Thanks - well spotted!

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. transformer_engine/pytorch/transformer.py, line 777-783 (link)

    logic: Conditional logic overwrites explicitly passed parameter. When self_attn_mask_type in {"causal", "padding_causal"} at line 777, this unconditionally sets bottom_right_diagonal = False, even if the user explicitly passed a different value in the forward call. The check at line 779 (if bottom_right_diagonal is None) becomes unreachable in those cases.

    Consider restructuring to only apply defaults when bottom_right_diagonal is None:

  2. transformer_engine/pytorch/transformer.py, line 787-793 (link)

    logic: Same conditional logic issue for encoder-decoder attention. The mask type check at line 787 unconditionally overwrites enc_dec_bottom_right_diagonal, ignoring explicitly passed values.

  3. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 920 (link)

    logic: Variable name mismatch: setting use_flash_attention instead of use_flash_attention_2

  4. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 947 (link)

    logic: Variable name mismatch: setting use_flash_attention instead of use_flash_attention_2

15 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1283 (link)

    syntax: Tuple has trailing comma causing it to be a single-element tuple instead of boolean

15 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants