-
Notifications
You must be signed in to change notification settings - Fork 576
[JAX] Better error message when Q, K, V are sharded differently #2440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] Better error message when Q, K, V are sharded differently #2440
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Greptile OverviewGreptile SummaryAdded early validation to check that Q, K, and V have identical sharding specifications when using separate QKV layouts ( Key Changes:
Impact:
Confidence Score: 5/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant JAX
participant FusedAttnFwdPrimitive
participant infer_sharding_from_operands
participant get_padded_spec
User->>JAX: Call DotProductAttention with mismatched Q,K,V sharding
JAX->>FusedAttnFwdPrimitive: Process attention operation
FusedAttnFwdPrimitive->>infer_sharding_from_operands: Infer sharding for operands
infer_sharding_from_operands->>get_padded_spec: Get Q sharding spec (arg_infos[0])
get_padded_spec-->>infer_sharding_from_operands: Return q_spec
infer_sharding_from_operands->>get_padded_spec: Get K sharding spec (arg_infos[1])
get_padded_spec-->>infer_sharding_from_operands: Return k_spec
infer_sharding_from_operands->>get_padded_spec: Get V sharding spec (arg_infos[2])
get_padded_spec-->>infer_sharding_from_operands: Return v_spec
alt QKV layout is separate (BSHD_BSHD_BSHD or THD_THD_THD)
infer_sharding_from_operands->>infer_sharding_from_operands: Check if q_spec == k_spec == v_spec
alt Sharding specs do NOT match
infer_sharding_from_operands->>JAX: Raise AssertionError with clear message
JAX->>User: JaxRuntimeError: Q, K, and V sharding specs must be identical
else Sharding specs match
infer_sharding_from_operands-->>FusedAttnFwdPrimitive: Continue with valid sharding
FusedAttnFwdPrimitive-->>User: Success
end
else QKV layout is packed
infer_sharding_from_operands-->>FusedAttnFwdPrimitive: Skip check (not applicable)
FusedAttnFwdPrimitive-->>User: Success
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci L1 jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
| if config.qkv_layout.is_separate(): | ||
| q_spec = get_padded_spec(arg_infos[0]) | ||
| k_spec = get_padded_spec(arg_infos[1]) | ||
| v_spec = get_padded_spec(arg_infos[2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@KshitijLakhani does this assertion make sense? For the case I'm testing above, where the batch dimension should always have the same sharding, I think it does make sense. But I'm not yet familiar enough with other parallelism techniques like CP to know if this assertion is valid on the non-batch axes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can think of MQA/GQA needing different sharding for Q and K along the head dimension .
For e.g. if we have 8 query heads and 1 k/v head then we could have a different PartitionSpec for query (let's say split across 4 devices so that 2 query heads per device) and a different PartitionSpec for key (repeating across devices, so None) so I would not be as restrictive.
I think along the batch might be fine and probably even along seq dim as I do not think any of the CP strategies may require a different PartitionSpec for QKV. Should be okay along the hidden dimension, too
cc: @mgoldfarb-nvidia @huanghua1994 @mingxu1067 to chime in
|
/te-ci L1 jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
Description
When using TE/JAX's DPA and Q, K, or V are sharded differently, the downstream error message only occurs in the inner abstract and reports a shape issue.
For example, when sharded along the batch dim, the previous error message was:
The new error message is:
The newly added test
test_mismatching_qkv_sharding_separate_qkvonly takes ~1 second.Type of change
Changes
qkv_layout.is_separate()Checklist: