-
Notifications
You must be signed in to change notification settings - Fork 576
[common] Add support for cuBLASLt GEMM for GroupedTensor #2502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
- Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM - Fix random padding in tests to ensure 16-byte alignment for all dtypes - Reorder GroupedGemmSetupWorkspace members for natural alignment - Remove debug prints Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
|
/te-ci L0 |
Greptile OverviewGreptile SummaryThis PR adds Key changes:
Quality assessment:
Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant API as nvte_grouped_gemm
participant Validator
participant OpSelector as select_grouped_operand
participant SetupKernel as setup_grouped_gemm_kernel
participant cuBLASLt
User->>API: Call with A, B, C, D, alpha, beta
API->>Validator: validate_grouped_gemm_inputs()
Validator->>Validator: Check num_tensors match
Validator->>Validator: Check alpha/beta dimensions
Validator->>Validator: Check data types (FP8/BF16/FP16)
API->>OpSelector: select_grouped_operand(A, transa)
OpSelector->>OpSelector: Check row-wise vs column-wise data
OpSelector->>OpSelector: Handle FP8 TN-only layout on Hopper
OpSelector-->>API: Return A_sel (base ptr, dtype, trans)
API->>OpSelector: select_grouped_operand(B, transb)
OpSelector-->>API: Return B_sel (base ptr, dtype, trans)
API->>API: Allocate setup workspace (pointer arrays)
API->>API: Allocate cuBLAS workspace (32 MiB)
API->>SetupKernel: Launch setup kernel on GPU
SetupKernel->>SetupKernel: For each tensor in group:
SetupKernel->>SetupKernel: - Compute A/B/C/D pointers from offsets
SetupKernel->>SetupKernel: - Compute M, N, K dimensions
SetupKernel->>SetupKernel: - Set alpha/beta pointers
SetupKernel-->>API: Workspace arrays populated
API->>cuBLASLt: Create matrix layout descriptors
API->>cuBLASLt: Create matmul descriptor
API->>cuBLASLt: Set FP8 scale_inv pointers (if FP8)
API->>cuBLASLt: Get algorithm heuristic
API->>cuBLASLt: cublasLtMatmul (grouped GEMM)
cuBLASLt->>cuBLASLt: Execute batched matmul
cuBLASLt-->>API: Results in D
API-->>User: Return
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (4)
-
tests/cpp/operator/test_grouped_gemm.cu, line 75 (link)logic: missing columnwise_data in move assignment
-
tests/cpp/operator/test_grouped_gemm.cu, line 336-337 (link)logic: kSameFirst and kSameLast test cases use identical shape tuples. Should kSameFirst and kSameLast have different shape patterns to properly test the respective scenarios?
-
tests/cpp/operator/test_grouped_gemm.cu, line 367-378 (link)logic: missing case for InputCase::kFP8Delayed
-
transformer_engine/common/include/transformer_engine/gemm.h, line 265-266 (link)style: The
avg_m,avg_n,avg_kparameters are not documented in the function commentWhat do these average dimension parameters represent and how should they be computed?
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
4 files reviewed, 4 comments
- Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers - Simplify select_grouped_operand by removing dead code branches - Add GroupedOperandSelection.tensor field to avoid passing tensor separately - Extract set_fp8_scale_pointers and init_matrix_layouts helpers - Add safety check for FP8 on Hopper column-wise fallback - Support NULL C tensor when beta=0 (uses D as placeholder) - Remove unused get_scale_inv() from test - Add use_null_c test parameter and test case - Fix documentation: alpha/beta are single element tensors only Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/include/transformer_engine/gemm.h, line 266 (link)syntax: Documentation incorrectly states
alpha[i]andbeta[i]. The implementation uses a singlealphaandbetavalue for all matrices in the group (batch stride is 1 ininit_matmul_descat cublaslt_gemm.cu:1404), not per-matrix scaling.
4 files reviewed, 1 comment
- Change alpha/beta from single values to per-matrix arrays - Validate alpha/beta have exactly num_tensors elements - Update kernel to index alpha_ptr[idx] and beta_ptr[idx] - Move alpha/beta validation to validate_grouped_gemm_inputs - Update tests to use per-matrix alpha/beta arrays - Update documentation Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/gemm/cublaslt_gemm.cu, line 1565 (link)style: consider adding explicit synchronization for better error detection
The kernel launch uses
cudaGetLastError()to check launch errors, but kernel execution errors won't be caught until later. Consider addingcudaStreamSynchronize(stream)in debug builds after the kernel launch to catch execution errors early during development.#ifndef NDEBUG NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); #endif
4 files reviewed, 1 comment
Description
Adds
nvte_grouped_gemmAPI using cuBLASLt grouped matmul for batched GEMM on tensors with varying shapes. A GPU kernel (setup_grouped_gemm_kernel) convertsNVTEGroupedTensorformat (contiguous buffer + offsets) to cuBLAS requirements (pointer arrays + per-matrix M/N/K).New API
Computes
D = alpha * op(A) @ op(B) + beta * Cfor groups of matrices with potentially different shapes.Type of change
Changes
GroupedGemmSetupWorkspacestruct for cuBLAS workspace layouttest_grouped_gemm.cucomparing againstnvte_multi_tensor_gemm(FP8/BF16, various shapes and transpose layouts)Checklist: