Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions tests/pytorch/test_grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,7 @@ def test_fp8_grouped_gemm(shape, accumulate):

_FUSED_GROUPED_GEMM_ENV = "NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM"
_ALL_BOOLEAN = all_boolean
_fp8_available, _reason_for_no_fp8 = fp8_available, reason_for_no_fp8
_mxfp8_available, _reason_for_no_mxfp8 = mxfp8_available, reason_for_no_mxfp8
_nvfp4_available, _reason_for_no_nvfp4 = nvfp4_available, reason_for_no_nvfp4

Expand Down Expand Up @@ -1577,6 +1578,10 @@ def _run_grouped_linear_path(
"fp8_recipe",
[
None,
pytest.param(
recipe.Float8CurrentScaling(),
marks=pytest.mark.skipif(not _fp8_available, reason=_reason_for_no_fp8),
),
pytest.param(
recipe.MXFP8BlockScaling(),
marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8),
Expand All @@ -1586,7 +1591,7 @@ def _run_grouped_linear_path(
marks=pytest.mark.skipif(not _nvfp4_available, reason=_reason_for_no_nvfp4),
),
],
ids=["bf16", "mxfp8", "nvfp4"],
ids=["bf16", "fp8_current_scaling", "mxfp8", "nvfp4"],
)
@pytest.mark.parametrize("bias", _ALL_BOOLEAN)
@pytest.mark.parametrize("fp8_model_params", _ALL_BOOLEAN)
Expand All @@ -1600,8 +1605,13 @@ def test_grouped_linear_grouped_tensor_path_matches_legacy(
pytest.skip(
"GroupedTensor grouped GEMM path requires Hopper (SM90) or Blackwell (SM10x and SM110)."
)
if use_fp8 and device_capability < (10, 0):
pytest.skip("Quantized GroupedTensor grouped GEMM path requires Blackwell (SM100+).")
# MXFP8/NVFP4 grouped quantization kernels require Blackwell, but FP8 per-tensor
# current scaling also runs on the Hopper grouped GEMM path.
is_current_scaling = use_fp8 and fp8_recipe.float8_current_scaling()
if use_fp8 and not is_current_scaling and device_capability < (10, 0):
pytest.skip(
"Quantized GroupedTensor grouped GEMM path (MXFP8/NVFP4) requires Blackwell (SM100+)."
)
cublaslt_version = tex.get_cublasLt_version()
if device_capability < (10, 0) and cublaslt_version < 130400:
pytest.skip("Grouped GEMM on Hopper requires cuBLAS 13.4+.")
Expand Down Expand Up @@ -1786,6 +1796,10 @@ def test_grouped_linear_grouped_tensor_path_skips_non_rht_nvfp4(monkeypatch):
"fp8_recipe",
[
None,
pytest.param(
recipe.Float8CurrentScaling(),
marks=pytest.mark.skipif(not _fp8_available, reason=_reason_for_no_fp8),
),
pytest.param(
recipe.MXFP8BlockScaling(),
marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8),
Expand All @@ -1795,7 +1809,7 @@ def test_grouped_linear_grouped_tensor_path_skips_non_rht_nvfp4(monkeypatch):
marks=pytest.mark.skipif(not _nvfp4_available, reason=_reason_for_no_nvfp4),
),
],
ids=["bf16", "mxfp8", "nvfp4"],
ids=["bf16", "fp8_current_scaling", "mxfp8", "nvfp4"],
)
@pytest.mark.parametrize("bias", _ALL_BOOLEAN)
def test_grouped_linear_fused_path_cuda_graph_safe(fp8_recipe, bias, monkeypatch):
Expand All @@ -1806,8 +1820,13 @@ def test_grouped_linear_fused_path_cuda_graph_safe(fp8_recipe, bias, monkeypatch
pytest.skip(
"GroupedTensor grouped GEMM path requires Hopper (SM90) or Blackwell (SM10x and SM110)."
)
if use_fp8 and device_capability < (10, 0):
pytest.skip("Quantized GroupedTensor grouped GEMM path requires Blackwell (SM100+).")
# MXFP8/NVFP4 grouped quantization kernels require Blackwell, but FP8 per-tensor
# current scaling also runs on the Hopper grouped GEMM path.
is_current_scaling = use_fp8 and fp8_recipe.float8_current_scaling()
if use_fp8 and not is_current_scaling and device_capability < (10, 0):
pytest.skip(
"Quantized GroupedTensor grouped GEMM path (MXFP8/NVFP4) requires Blackwell (SM100+)."
)
cublaslt_version = tex.get_cublasLt_version()
if device_capability < (10, 0) and cublaslt_version < 130400:
pytest.skip("Grouped GEMM on Hopper requires cuBLAS 13.4+.")
Expand Down
11 changes: 10 additions & 1 deletion tests/pytorch/test_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,10 @@ def test_grouped_linear(
@pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16))
@pytest.mark.parametrize(
"quantization",
[None] + (["mxfp8"] if mxfp8_available else []),
[None]
+ (["fp8_current_scaling"] if fp8_available else [])
+ (["mxfp8"] if mxfp8_available else [])
+ (["nvfp4_rht"] if nvfp4_available else []),
)
@pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("bias", (False, True))
Expand Down Expand Up @@ -479,6 +482,12 @@ def test_grouped_linear_cuda_graph_safe(
pytest.skip("Grouped GEMM CUDA-graph-safe path requires SM100+ (Blackwell)")
if quantization is None and quantized_weight:
pytest.skip("quantized_weight requires a quantization recipe")
if (
quantization is not None
and quantization.startswith("nvfp4")
and dtype != torch.bfloat16
):
pytest.skip("NVFP4 grouped GEMM only supports BF16 output")

single_grouped_bias = bias and single_grouped_weight

Expand Down
17 changes: 11 additions & 6 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,10 @@ def _is_grouped_tensor_path_supported(
and be incompatible with CUDA Graphs.

Supported Compute Capability (CC) and precisions:
* Hopper (CC 9.0): BF16/FP16.
* Blackwell (CC 10.x and 11.0): BF16/FP16/MXFP8/NVFP4 with RHT.
FP8 delayed / current scaling, and FP8 block scaling are not supported because the
* Hopper (CC 9.0): BF16/FP16 and FP8 per-tensor current scaling.
* Blackwell (CC 10.x and 11.0): BF16/FP16/MXFP8/NVFP4 with RHT and FP8
per-tensor current scaling.
FP8 delayed scaling and FP8 block scaling are not supported because the
corresponding grouped quantization kernels are missing.
Non-RHT NVFP4 falls back to the legacy path because graph-safe grouped quantization
currently requires RHT.
Expand Down Expand Up @@ -133,6 +134,9 @@ def _is_grouped_tensor_path_supported(
return False
# 5. Filter by quantization recipes.
if fp8:
if all(isinstance(q, Float8CurrentScalingQuantizer) for q in input_quantizers):
return True
# MXFP8 and NVFP4 grouped quantization kernels require Blackwell.
if not (10, 0) <= get_device_compute_capability() <= (11, 0):
return False
return all(isinstance(q, MXFP8Quantizer) for q in input_quantizers) or all(
Expand Down Expand Up @@ -328,9 +332,10 @@ def _forward_grouped_tensor(

if is_grad_enabled:
if weight_requires_grad:
if fp8:
grouped_x.rowwise_data = None
grouped_x.scale_inv = None
# Free Rowwise Data if columnwise data is available for backward pass
# (For FP8 per tensor current scaling on Hopper) if fp8 and grouped_x.columnwise_data is not None:
grouped_x.rowwise_data = None
grouped_x.scale_inv = None
Comment on lines +335 to +338

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Conditional guard accidentally embedded in comment — rowwise_data cleared unconditionally

The if fp8 and grouped_x.columnwise_data is not None: guard was intended to precede the two assignments on lines 337–338, but it was appended to the end of the preceding comment on line 336. As Python ignores everything after #, both grouped_x.rowwise_data = None and grouped_x.scale_inv = None now execute unconditionally whenever is_grad_enabled and weight_requires_grad are both True.

For the non-FP8 (BF16/FP16) grouped-tensor path, grouped_x.rowwise_data holds the packed activation buffer that is saved for backward and used to compute the weight gradient. Clearing it to None before ctx.save_for_backward destroys the activation data, causing the wgrad computation to operate on None — resulting in a crash or silently incorrect gradients.

The equivalent change in ops/basic/grouped_linear.py (line 1335) correctly places the condition on its own line: if with_quantized_compute and grouped_x.columnwise_data is not None:.

else:
grouped_x = None

Expand Down
25 changes: 20 additions & 5 deletions transformer_engine/pytorch/ops/basic/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe
from ...quantized_tensor import QuantizedTensorStorage
from ...tensor import MXFP8Quantizer, MXFP8Tensor, NVFP4Quantizer, Quantizer
from ...tensor import (
Float8CurrentScalingQuantizer,
MXFP8Quantizer,
MXFP8Tensor,
NVFP4Quantizer,
Quantizer,
)
from ...utils import (
canonicalize_device,
canonicalize_dtype,
Expand Down Expand Up @@ -768,8 +774,11 @@ def _is_graph_safe_path_supported(
requirement without duplicating its cuBLAS version checks.
* Quantized compute supports MXFP8 and NVFP4 on Blackwell GPUs with Compute Capability (CC)
10.x and 11.0. NVFP4 requires RHT because graph-safe grouped quantization currently
requires RHT;
Every other quantization recipe (fp8 delayed / current scaling, fp8 block scaling, ...)
requires RHT.
* FP8 per-tensor current scaling is backed by grouped current-scaling quantization
(``tex.group_quantize``) and cuBLASLt grouped GEMM with per-batch scalar FP8 scaling,
which are supported on Hopper (CC 9.0) and Blackwell (CC 10.x and 11.0).
Every other quantization recipe (fp8 delayed scaling, fp8 block scaling, ...)
falls back to the legacy flow because the corresponding grouped quantization kernels are
missing.
* Unquantized compute supports BF16/FP16 on Hopper (CC 9.0) and Blackwell (CC 10.x and 11.0)
Expand All @@ -780,6 +789,11 @@ def _is_graph_safe_path_supported(
if not (9, 0) <= get_device_compute_capability() <= (11, 0):
return False
if with_quantized_compute:
# FP8 per-tensor current scaling runs on the Hopper and Blackwell grouped GEMM
# path; the compute-capability range was already checked above.
if all(isinstance(q, Float8CurrentScalingQuantizer) for q in input_quantizers):
return True
# MXFP8 and NVFP4 grouped quantization kernels require Blackwell.
if not (10, 0) <= get_device_compute_capability() <= (11, 0):
return False
return all(isinstance(q, MXFP8Quantizer) for q in input_quantizers) or all(
Expand Down Expand Up @@ -1318,8 +1332,9 @@ def _fuser_forward_grouped_tensor(
# [split_sizes, base_split_offsets, split_points,
# (scales if _scale_bias), grouped_x, *weights]
if grouped_x is not None:
if with_quantized_compute:
# only columnwise data is needed for wgrad
# Free Rowwise Data if columnwise data is available for backward pass
# (For FP8 per tensor current scaling on Hopper)
if with_quantized_compute and grouped_x.columnwise_data is not None:
grouped_x.rowwise_data = None
grouped_x.scale_inv = None
saved: list[Optional[torch.Tensor]] = [split_sizes, base_split_offsets, split_points]
Expand Down
Loading