Skip to content

Conversation

@jberchtold-nvidia
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Dec 2, 2025

Description

When using TE/JAX's DPA and Q, K, or V are sharded differently, the downstream error message only occurs in the inner abstract and reports a shape issue.

For example, when sharded along the batch dim, the previous error message was:

AssertionError: Mismatched qkv batch size for q_batch_shape: [1], k_batch_shape: [1] and v_batch_shape: [4]

The new error message is:

AssertionError: Q, K, and V sharding specs must be identical but received q_spec=PartitionSpec('fsdp', None, None, None), k_spec=PartitionSpec('fsdp', None, None, None), v_spec=PartitionSpec(None, None, None, None)

The newly added test test_mismatching_qkv_sharding_separate_qkv only takes ~1 second.

================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test                                                         |  12x |    2.14s | avg:   0.18s
test_context_parallel_allgather_attn                         | 320x | 1134.35s | avg:   3.54s
test_context_parallel_allgather_attn_shardy                  |  40x |  155.82s | avg:   3.90s
test_context_parallel_ring_attn                              | 1280x | 1753.03s | avg:   1.37s
test_context_parallel_ring_attn_shardy                       |  40x |   50.99s | avg:   1.27s
test_cross_attn                                              |  18x |   30.76s | avg:   1.71s
test_mismatching_qkv_sharding_separate_qkv                   |   1x |    1.16s | avg:   1.16s
test_self_attn                                               |  54x |  123.97s | avg:   2.30s
test_self_attn_shardy                                        |  18x |   16.04s | avg:   0.89s
================================================================================
TOTAL RUNTIME                                                |      | 3268.24s |
================================================================================

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add assertion that Q, K, and V shardings match when qkv_layout.is_separate()
  • Add test to run with mismatched sharding (Q and K are sharded fsdp in batch dim, V is replicated in batch dim) and check that the new error message is reported as expected

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 2, 2025

Greptile Overview

Greptile Summary

Added early validation to check that Q, K, and V have identical sharding specifications when using separate QKV layouts (BSHD_BSHD_BSHD or THD_THD_THD), replacing confusing downstream shape mismatch errors with clear, actionable error messages.

Key Changes:

  • Added assertion in infer_sharding_from_operands at transformer_engine/jax/cpp_extensions/attention.py:691-698 to validate Q/K/V sharding specs match
  • Comprehensive test coverage in TestMismatchingQKVSharding class that validates error message format
  • Error now caught early during sharding inference rather than later during abstract evaluation

Impact:

  • Significantly improves developer experience by providing clear diagnostics about sharding mismatches
  • Prevents confusing "batch size mismatch" errors that obscured the root cause

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk - it adds defensive validation that improves error messages without changing core logic
  • The change is a simple assertion that validates sharding specifications early in the pipeline. The fix is well-targeted (only applies when qkv_layout.is_separate()), includes comprehensive test coverage that validates both the error path and expected error message, and does not modify any computation logic.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/cpp_extensions/attention.py 5/5 Added assertion to validate Q, K, V sharding specs match when using separate QKV layout, improving error messages for mismatched sharding
tests/jax/test_distributed_fused_attn.py 5/5 Added comprehensive test class to verify the new error message appears when Q, K, V have mismatched sharding specifications

Sequence Diagram

sequenceDiagram
    participant User
    participant JAX
    participant FusedAttnFwdPrimitive
    participant infer_sharding_from_operands
    participant get_padded_spec

    User->>JAX: Call DotProductAttention with mismatched Q,K,V sharding
    JAX->>FusedAttnFwdPrimitive: Process attention operation
    FusedAttnFwdPrimitive->>infer_sharding_from_operands: Infer sharding for operands
    
    infer_sharding_from_operands->>get_padded_spec: Get Q sharding spec (arg_infos[0])
    get_padded_spec-->>infer_sharding_from_operands: Return q_spec
    
    infer_sharding_from_operands->>get_padded_spec: Get K sharding spec (arg_infos[1])
    get_padded_spec-->>infer_sharding_from_operands: Return k_spec
    
    infer_sharding_from_operands->>get_padded_spec: Get V sharding spec (arg_infos[2])
    get_padded_spec-->>infer_sharding_from_operands: Return v_spec
    
    alt QKV layout is separate (BSHD_BSHD_BSHD or THD_THD_THD)
        infer_sharding_from_operands->>infer_sharding_from_operands: Check if q_spec == k_spec == v_spec
        
        alt Sharding specs do NOT match
            infer_sharding_from_operands->>JAX: Raise AssertionError with clear message
            JAX->>User: JaxRuntimeError: Q, K, and V sharding specs must be identical
        else Sharding specs match
            infer_sharding_from_operands-->>FusedAttnFwdPrimitive: Continue with valid sharding
            FusedAttnFwdPrimitive-->>User: Success
        end
    else QKV layout is packed
        infer_sharding_from_operands-->>FusedAttnFwdPrimitive: Skip check (not applicable)
        FusedAttnFwdPrimitive-->>User: Success
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L1 jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

if config.qkv_layout.is_separate():
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KshitijLakhani does this assertion make sense? For the case I'm testing above, where the batch dimension should always have the same sharding, I think it does make sense. But I'm not yet familiar enough with other parallelism techniques like CP to know if this assertion is valid on the non-batch axes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can think of MQA/GQA needing different sharding for Q and K along the head dimension .
For e.g. if we have 8 query heads and 1 k/v head then we could have a different PartitionSpec for query (let's say split across 4 devices so that 2 query heads per device) and a different PartitionSpec for key (repeating across devices, so None) so I would not be as restrictive.

I think along the batch might be fine and probably even along seq dim as I do not think any of the CP strategies may require a different PartitionSpec for QKV. Should be okay along the hidden dimension, too

cc: @mgoldfarb-nvidia @huanghua1994 @mingxu1067 to chime in

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L1 jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants