Skip to content

Add MXFP8 support with cuBLASMp#3145

Open
almogsegal wants to merge 2 commits into
NVIDIA:mainfrom
almogsegal:add-mxfp8-and-nvfp4-with-cublasmp
Open

Add MXFP8 support with cuBLASMp#3145
almogsegal wants to merge 2 commits into
NVIDIA:mainfrom
almogsegal:add-mxfp8-and-nvfp4-with-cublasmp

Conversation

@almogsegal

@almogsegal almogsegal commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add MXFP8 support in comm_gemm.cpp (cuBLASMp path).
  • Add cuBLASMp + MXFP8 tests in test_comm_gemm_overlap.py.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 25, 2026
@almogsegal

almogsegal commented Jun 25, 2026

Copy link
Copy Markdown
Contributor Author

MXFP8 comparison of cuBLASMp vs UB on DGX-B200:

GPUs Op UB time cuBLASMp time Speedup Faster
2 AG 0.0969 ms 0.0802 ms 1.21x cuBLASMp
2 RS 0.0867 ms 0.0730 ms 1.19x cuBLASMp
4 AG 0.1215 ms 0.1255 ms 1.03x UB
4 RS 0.1357 ms 0.1049 ms 1.29x cuBLASMp
8 AG 0.2380 ms 0.2146 ms 1.11x cuBLASMp
8 RS 0.2191 ms 0.2141 ms 1.02x cuBLASMp

@almogsegal almogsegal force-pushed the add-mxfp8-and-nvfp4-with-cublasmp branch from c6517dd to 0c4c53a Compare June 25, 2026 10:19
@greptile-apps

greptile-apps Bot commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR enables MXFP8 (block scaling) support for the cuBLASMp comm+GEMM overlap path by extending cublasmp_gemm to canonicalize MXFP8 inputs and configure the appropriate VEC32_UE8M0 scale mode. The corresponding test skip conditions are removed so the feature is now exercised in CI.

  • comm_gemm.cpp: Adds an MXFP8 branch inside canonicalize_input that selects rowwise vs. columnwise data/scales based on the transpose direction (keeping the transpose flag unchanged, since MXFP8 columnwise data preserves logical shape unlike tensor-FP8). Adds a compile-time guard (CUBLASMP_VERSION < 801) that injects a runtime error for unsupported library versions, and uses CUBLASMP_MATMUL_MATRIX_SCALE_VEC32_UE8M0 for scale mode setup.
  • test_comm_gemm_overlap.py: Removes the two pytest.skip calls that blocked cuBLASMp+MXFP8 combinations and introduces an explicit COMM_GEMM_QUANTIZATION_PARAMS list with human-readable test IDs to replace the two separate parametrize decorators.

Confidence Score: 5/5

Safe to merge; the MXFP8 canonicalization path follows the established FP8 pattern and is protected by both a compile-time version guard and runtime assertions on pointer validity.

The new MXFP8 branch in canonicalize_input correctly keeps the transpose flag unchanged (MXFP8 columnwise data is not a transposed view), mirrors the cublaslt_gemm.cu reference path, and guards the entire path with CUBLASMP_VERSION checks. Validation is thorough: mixed scaling modes are rejected, swizzled-scale format is asserted, and per-direction null checks are in place. Test infrastructure change is mechanical and correct.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/common/comm_gemm/comm_gemm.cpp Extends cublasmp_gemm with MXFP8 block-scaling support: adds validation, a new canonicalize_input MXFP8 branch that picks rowwise/columnwise data without flipping the transpose flag, and a VEC32_UE8M0 scale mode with a compile-time CUBLASMP_VERSION guard. Logic is consistent with the cublaslt_gemm.cu reference path.
tests/pytorch/distributed/test_comm_gemm_overlap.py Removes two pytest.skip guards that blocked cuBLASMp+MXFP8 test cases, and consolidates the separate use_cublasmp / quantization parametrize decorators into a single explicit COMM_GEMM_QUANTIZATION_PARAMS list with readable IDs.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[cublasmp_gemm called] --> B{scaling_mode A & B}
    B -->|both tensor_scaling| C[FP8 / BF16 path]
    B -->|both mxfp8_scaling| D[MXFP8 path]
    B -->|mixed / unknown| E[NVTE_CHECK fails]

    D --> F{CUBLASMP_VERSION < 801?}
    F -->|yes| G[NVTE_ERROR - runtime throw]
    F -->|no| H[Check with_gemm_swizzled_scales]

    H --> I[canonicalize_input A]
    I --> J{transa?}
    J -->|yes| K[use row-wise data, keep trans flag]
    J -->|no| L[use columnwise data, keep trans flag]

    H --> M[canonicalize_input B]
    M --> N{transb?}
    N -->|yes| O[use columnwise data, keep trans flag]
    N -->|no| P[use row-wise data, keep trans flag]

    K --> Q[Set A scale mode VEC32_UE8M0]
    L --> Q
    O --> R[Set B scale mode VEC32_UE8M0]
    P --> R

    Q --> U[cublasMpMatmul]
    R --> U
    C --> T[Set scale mode SCALAR_FP32]
    T --> U
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A[cublasmp_gemm called] --> B{scaling_mode A & B}
    B -->|both tensor_scaling| C[FP8 / BF16 path]
    B -->|both mxfp8_scaling| D[MXFP8 path]
    B -->|mixed / unknown| E[NVTE_CHECK fails]

    D --> F{CUBLASMP_VERSION < 801?}
    F -->|yes| G[NVTE_ERROR - runtime throw]
    F -->|no| H[Check with_gemm_swizzled_scales]

    H --> I[canonicalize_input A]
    I --> J{transa?}
    J -->|yes| K[use row-wise data, keep trans flag]
    J -->|no| L[use columnwise data, keep trans flag]

    H --> M[canonicalize_input B]
    M --> N{transb?}
    N -->|yes| O[use columnwise data, keep trans flag]
    N -->|no| P[use row-wise data, keep trans flag]

    K --> Q[Set A scale mode VEC32_UE8M0]
    L --> Q
    O --> R[Set B scale mode VEC32_UE8M0]
    P --> R

    Q --> U[cublasMpMatmul]
    R --> U
    C --> T[Set scale mode SCALAR_FP32]
    T --> U
Loading

Reviews (2): Last reviewed commit: "Enable cuBLASMp MXFP8 overlap tests" | Re-trigger Greptile

Comment thread tests/pytorch/distributed/run_gemm_with_overlap.py Outdated
Comment thread transformer_engine/common/comm_gemm/comm_gemm.cpp Outdated
Comment thread transformer_engine/common/comm_gemm/comm_gemm.cpp Outdated
@almogsegal almogsegal force-pushed the add-mxfp8-and-nvfp4-with-cublasmp branch 2 times, most recently from 85280d5 to d0559ca Compare June 25, 2026 10:54
@almogsegal almogsegal changed the title Add MXFP8 and NVFP4 support with cuBLASMp Add MXFP8 support with cuBLASMp Jun 25, 2026
@almogsegal almogsegal force-pushed the add-mxfp8-and-nvfp4-with-cublasmp branch 2 times, most recently from c0e28c9 to b7ce126 Compare June 25, 2026 11:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant