Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Dec 10, 2025

Description

Adds nvte_grouped_gemm API using cuBLASLt grouped matmul for batched GEMM on tensors with varying shapes. A GPU kernel (setup_grouped_gemm_kernel) converts NVTEGroupedTensor format (contiguous buffer + offsets) to cuBLAS requirements (pointer arrays + per-matrix M/N/K).

New API

void nvte_grouped_gemm(int transa, int transb, 
                       const NVTETensor alpha, 
                       const NVTEGroupedTensor A,
                       const NVTEGroupedTensor B, 
                       const NVTETensor beta, 
                       const NVTEGroupedTensor C,
                       NVTEGroupedTensor D, 
                       NVTETensor workspace_setup, 
                       NVTETensor workspace_cublas,
                       NVTEMatmulConfig config, 
                       cudaStream_t stream, 
                       const int64_t *avg_m,
                       const int64_t *avg_n, 
                       const int64_t *avg_k);

Computes D = alpha * op(A) @ op(B) + beta * C for groups of matrices with potentially different shapes.

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

  • GPU setup kernel computing pointers/dims from grouped tensor metadata
  • FP8 support with scale_inv handling and TN layout selection on Hopper
  • GroupedGemmSetupWorkspace struct for cuBLAS workspace layout
  • Tests in test_grouped_gemm.cu comparing against nvte_multi_tensor_gemm (FP8/BF16, various shapes and transpose layouts)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL changed the title [common] Add support for cublasLt GEMM for GroupedTensor [common] Add support for cuBLASLt GEMM for GroupedTensor Dec 10, 2025
pre-commit-ci bot and others added 3 commits December 10, 2025 14:32
- Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM
- Fix random padding in tests to ensure 16-byte alignment for all dtypes
- Reorder GroupedGemmSetupWorkspace members for natural alignment
- Remove debug prints

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@ptrendx ptrendx added the MoE label Dec 10, 2025
@ptrendx ptrendx linked an issue Dec 10, 2025 that may be closed by this pull request
pggPL and others added 2 commits December 10, 2025 22:34
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Dec 10, 2025

/te-ci L0

@pggPL pggPL marked this pull request as ready for review December 10, 2025 21:43
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 10, 2025

Greptile Overview

Greptile Summary

This PR adds nvte_grouped_gemm, a new API for batched matrix multiplication on tensors with varying shapes using cuBLASLt grouped matmul. The implementation efficiently converts the NVTEGroupedTensor format (contiguous buffer with offsets and metadata) into cuBLASLt's required format (pointer arrays and per-matrix dimensions) using a GPU setup kernel.

Key changes:

  • New API in gemm.h with comprehensive documentation for per-matrix alpha/beta scaling and optional dimension hints
  • Implementation handles FP8 with scale_inv, enforces TN-only layout on Hopper, and supports both row-wise and column-wise data
  • GPU kernel (setup_grouped_gemm_kernel) efficiently computes pointer arrays and M/N/K dimensions for each matrix in the group
  • Comprehensive tests validate against nvte_multi_tensor_gemm with FP8/BF16, various transpose configurations, and shape variations

Quality assessment:

  • Implementation correctly handles dimension computation, offset calculations, and workspace management
  • FP8 layout selection logic properly mirrors the non-grouped GEMM behavior
  • Tests provide good coverage across input types, transpose modes, and shape cases including uniform and varying dimensions
  • One minor suggestion: consider adding explicit stream synchronization in debug builds after the setup kernel for better error detection during development

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk
  • The implementation is well-structured with proper validation, comprehensive tests, and correct logic for dimension computation and workspace management. Score of 4 (not 5) reflects that this is a new feature with complex GPU kernel and cuBLAS integration that would benefit from additional runtime testing on various hardware configurations, though no actual bugs were found in the code review.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/include/transformer_engine/gemm.h 5/5 Added well-documented API for nvte_grouped_gemm with clear parameter descriptions and requirements. No issues found.
transformer_engine/common/gemm/cublaslt_gemm.cu 4/5 Implements grouped GEMM with proper workspace management, FP8 support, and layout selection logic. One minor issue: missing kernel error synchronization after setup kernel launch.
tests/cpp/operator/test_grouped_gemm.cu 5/5 Comprehensive test coverage with multiple input types (FP8, BF16), transpose configurations, and shape variations. Tests correctly validate against nvte_multi_tensor_gemm.

Sequence Diagram

sequenceDiagram
    participant User
    participant API as nvte_grouped_gemm
    participant Validator
    participant OpSelector as select_grouped_operand
    participant SetupKernel as setup_grouped_gemm_kernel
    participant cuBLASLt
    
    User->>API: Call with A, B, C, D, alpha, beta
    API->>Validator: validate_grouped_gemm_inputs()
    Validator->>Validator: Check num_tensors match
    Validator->>Validator: Check alpha/beta dimensions
    Validator->>Validator: Check data types (FP8/BF16/FP16)
    
    API->>OpSelector: select_grouped_operand(A, transa)
    OpSelector->>OpSelector: Check row-wise vs column-wise data
    OpSelector->>OpSelector: Handle FP8 TN-only layout on Hopper
    OpSelector-->>API: Return A_sel (base ptr, dtype, trans)
    
    API->>OpSelector: select_grouped_operand(B, transb)
    OpSelector-->>API: Return B_sel (base ptr, dtype, trans)
    
    API->>API: Allocate setup workspace (pointer arrays)
    API->>API: Allocate cuBLAS workspace (32 MiB)
    
    API->>SetupKernel: Launch setup kernel on GPU
    SetupKernel->>SetupKernel: For each tensor in group:
    SetupKernel->>SetupKernel: - Compute A/B/C/D pointers from offsets
    SetupKernel->>SetupKernel: - Compute M, N, K dimensions
    SetupKernel->>SetupKernel: - Set alpha/beta pointers
    SetupKernel-->>API: Workspace arrays populated
    
    API->>cuBLASLt: Create matrix layout descriptors
    API->>cuBLASLt: Create matmul descriptor
    API->>cuBLASLt: Set FP8 scale_inv pointers (if FP8)
    API->>cuBLASLt: Get algorithm heuristic
    API->>cuBLASLt: cublasLtMatmul (grouped GEMM)
    cuBLASLt->>cuBLASLt: Execute batched matmul
    cuBLASLt-->>API: Results in D
    API-->>User: Return
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. tests/cpp/operator/test_grouped_gemm.cu, line 75 (link)

    logic: missing columnwise_data in move assignment

  2. tests/cpp/operator/test_grouped_gemm.cu, line 336-337 (link)

    logic: kSameFirst and kSameLast test cases use identical shape tuples. Should kSameFirst and kSameLast have different shape patterns to properly test the respective scenarios?

  3. tests/cpp/operator/test_grouped_gemm.cu, line 367-378 (link)

    logic: missing case for InputCase::kFP8Delayed

  4. transformer_engine/common/include/transformer_engine/gemm.h, line 265-266 (link)

    style: The avg_m, avg_n, avg_k parameters are not documented in the function comment

    What do these average dimension parameters represent and how should they be computed?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx requested a review from timmoon10 December 10, 2025 22:35
pggPL and others added 2 commits December 11, 2025 11:56
- Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers
- Simplify select_grouped_operand by removing dead code branches
- Add GroupedOperandSelection.tensor field to avoid passing tensor separately
- Extract set_fp8_scale_pointers and init_matrix_layouts helpers
- Add safety check for FP8 on Hopper column-wise fallback
- Support NULL C tensor when beta=0 (uses D as placeholder)
- Remove unused get_scale_inv() from test
- Add use_null_c test parameter and test case
- Fix documentation: alpha/beta are single element tensors only

Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/common/include/transformer_engine/gemm.h, line 266 (link)

    syntax: Documentation incorrectly states alpha[i] and beta[i]. The implementation uses a single alpha and beta value for all matrices in the group (batch stride is 1 in init_matmul_desc at cublaslt_gemm.cu:1404), not per-matrix scaling.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

pggPL and others added 4 commits December 11, 2025 12:16
- Change alpha/beta from single values to per-matrix arrays
- Validate alpha/beta have exactly num_tensors elements
- Update kernel to index alpha_ptr[idx] and beta_ptr[idx]
- Move alpha/beta validation to validate_grouped_gemm_inputs
- Update tests to use per-matrix alpha/beta arrays
- Update documentation

Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Dec 11, 2025

/te-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/common/gemm/cublaslt_gemm.cu, line 1565 (link)

    style: consider adding explicit synchronization for better error detection

    The kernel launch uses cudaGetLastError() to check launch errors, but kernel execution errors won't be caught until later. Consider adding cudaStreamSynchronize(stream) in debug builds after the kernel launch to catch execution errors early during development.

    #ifndef NDEBUG
      NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
    #endif

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GroupedGemm: FP8 per-tensor via cuBLAS

2 participants