Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_flex_attention.xml $TE_PATH/tests/pytorch/attention/test_flex_attention.py || test_fail "test_flex_attention.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_linear_mxfp8_attention.xml $TE_PATH/tests/pytorch/attention/test_linear_mxfp8_attention.py || test_fail "test_linear_mxfp8_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
Expand Down
15 changes: 13 additions & 2 deletions tests/pytorch/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import transformer_engine.pytorch as te

from utils import make_recipe
from transformer_engine.pytorch._extra_state import UNSAFE_PICKLE_EXTRA_STATE_ENV

# Check supported quantization schemes
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
Expand Down Expand Up @@ -131,8 +132,18 @@ def test_module(self, name: str) -> None:
raise FileNotFoundError(f"Could not find checkpoint file at {checkpoint_file}")
state_dict = torch.load(checkpoint_file, weights_only=False)

# Update module from checkpoint
module.load_state_dict(state_dict, strict=True)
# Update module from checkpoint. Delayed-scaling legacy extra state is unsafe by
# default and requires an explicit opt-in for trusted compatibility artifacts.
old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV)
if quantization == "fp8":
os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = "1"
try:
module.load_state_dict(state_dict, strict=True)
finally:
if old_unsafe_extra_state is None:
os.environ.pop(UNSAFE_PICKLE_EXTRA_STATE_ENV, None)
else:
os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = old_unsafe_extra_state


def main() -> None:
Expand Down
13 changes: 12 additions & 1 deletion tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from collections.abc import Iterable, Sequence
import io
import os
import math
import random
from typing import Optional
Expand All @@ -18,6 +19,7 @@
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch._extra_state import UNSAFE_PICKLE_EXTRA_STATE_ENV

from transformer_engine.pytorch.ops.fused import (
BackwardActivationBias,
Expand Down Expand Up @@ -3217,7 +3219,16 @@ def test_linear(
)
optim_load = torch.optim.SGD(model_load.parameters(), lr=0.25)
state_dict = torch.load(io.BytesIO(checkpoint_bytes), weights_only=False)
model_load.load_state_dict(state_dict["model"])
old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV)
if quantization in ("fp8", "fp8_delayed_scaling"):
os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = "1"
try:
model_load.load_state_dict(state_dict["model"])
finally:
if old_unsafe_extra_state is None:
os.environ.pop(UNSAFE_PICKLE_EXTRA_STATE_ENV, None)
else:
os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = old_unsafe_extra_state
optim_load.load_state_dict(state_dict["optim"])

# Training steps with loaded model
Expand Down
11 changes: 10 additions & 1 deletion tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformer_engine.pytorch.quantization import (
FP8GlobalStateManager,
)
from transformer_engine.pytorch._extra_state import UNSAFE_PICKLE_EXTRA_STATE_ENV
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
Expand Down Expand Up @@ -847,7 +848,15 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=

del block
block = _test_e2e_checkpointing_get_model(config, dtype)
block.load_state_dict(torch.load(path, weights_only=False))
loaded_state_dict = torch.load(path, weights_only=False)
old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV)
try:
block.load_state_dict(loaded_state_dict)
finally:
if old_unsafe_extra_state is None:
os.environ.pop(UNSAFE_PICKLE_EXTRA_STATE_ENV, None)
else:
os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = old_unsafe_extra_state
Comment on lines +852 to +859

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing env-var set makes the save/restore a no-op

The try/finally block saves and unconditionally restores UNSAFE_PICKLE_EXTRA_STATE_ENV, but never actually sets it to "1" before calling load_state_dict. Every other test fixed in this PR (test_checkpoint.py line 136, test_fusible_ops.py line 3222) follows the pattern: save → conditionally set to "1" → try/finally restore. Here the "set" step is absent, so the entire save/restore is a no-op.

_test_e2e_checkpointing_get_model creates a plain TransformerLayer without FP8, so fp8_checkpoint is False and the extra state is an empty tensor today, which avoids the runtime error. If this helper is ever extended with FP8 delayed-scaling (a natural step), load_state_dict will raise a RuntimeError because the env var will never be set.

torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)

Expand Down
89 changes: 89 additions & 0 deletions tests/pytorch/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from typing import Optional

import pickle

import pytest
import torch
import warnings
Expand Down Expand Up @@ -31,10 +33,19 @@
)
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common.recipe import (
CustomRecipe,
DelayedScaling,
Float8CurrentScaling,
Float8BlockScaling,
MXFP8BlockScaling,
NVFP4BlockScaling,
Recipe,
)
from transformer_engine.pytorch._extra_state import (
CheckpointExtraStatePolicy,
UNSAFE_PICKLE_EXTRA_STATE_ENV,
_RECIPE_POLICIES,
should_load_extra_state_pickle,
)

# Check if FP8 is supported
Expand Down Expand Up @@ -691,3 +702,81 @@ def test_fp4_dequantize(dtype, row_scaled_nvfp4, use_4over6, M, N):
)
new_dequantized_tensor = new_tensor.dequantize()
torch.testing.assert_close(dequantized_tensor, new_dequantized_tensor)


def _custom_recipe_qfactory(_role):
return None


def _recipe_subclasses(cls):
for subcls in cls.__subclasses__():
yield subcls
yield from _recipe_subclasses(subcls)


def _pickled_extra_state_payload(recipe_obj, *, include_delayed_state=False):
state = {"recipe": recipe_obj, "extra_fp8_variables": {}}
if include_delayed_state:
state.update(
{
"scale_fwd": torch.ones(1),
"amax_history_fwd": torch.zeros(1, 1),
"scale_bwd": torch.ones(1),
"amax_history_bwd": torch.zeros(1, 1),
}
)
return pickle.dumps(state)


def test_checkpoint_extra_state_policy_classifier_map_covers_all_recipes():
for cls in _recipe_subclasses(Recipe):
key = ("transformer_engine.common.recipe", cls.__name__)
assert key in _RECIPE_POLICIES
assert _RECIPE_POLICIES[key] in CheckpointExtraStatePolicy


@pytest.mark.parametrize(
"recipe_obj",
[
Float8CurrentScaling(),
MXFP8BlockScaling(),
Float8BlockScaling(),
NVFP4BlockScaling(),
],
)
def test_stateless_pickled_extra_state_is_ignored(recipe_obj):
payload = _pickled_extra_state_payload(recipe_obj)
assert not should_load_extra_state_pickle(payload, "test")


def test_stateless_custom_pickled_extra_state_is_ignored():
payload = _pickled_extra_state_payload(CustomRecipe(qfactory=_custom_recipe_qfactory))
assert not should_load_extra_state_pickle(payload, "test")


@pytest.mark.parametrize("payload", [pickle.dumps({}), pickle.dumps({"extra_fp8_variables": {}})])
def test_global_free_pickled_extra_state_is_ignored(payload):
# Older stateless checkpoints serialized an empty dict. Such a payload
# resolves no globals and cannot execute code, so it must load without the
# unsafe opt-in.
assert not should_load_extra_state_pickle(payload, "test")


@pytest.mark.parametrize(
"payload",
[
_pickled_extra_state_payload(DelayedScaling(), include_delayed_state=True),
_pickled_extra_state_payload(
CustomRecipe(qfactory=_custom_recipe_qfactory), include_delayed_state=True
),
pickle.dumps({"scale_inv_fwd": torch.ones(1), "extra_fp8_variables": {}}),
pickle.dumps({"recipe": object(), "extra_fp8_variables": {}}),
b"not a pickle",
],
)
def test_stateful_unknown_or_malformed_pickled_extra_state_requires_opt_in(payload, monkeypatch):
with pytest.raises(RuntimeError, match=UNSAFE_PICKLE_EXTRA_STATE_ENV):
should_load_extra_state_pickle(payload, "test")

monkeypatch.setenv(UNSAFE_PICKLE_EXTRA_STATE_ENV, "1")
assert should_load_extra_state_pickle(payload, "test")
Loading
Loading