Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions transformer_engine/pytorch/tensor/_quantization_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,42 @@ def _stride_from_shape(shape: list[int]):
for d in reversed(shape[1:]):
rstride.append(rstride[-1] * d)
return list(reversed(rstride))


def safe_quantized_repr(obj, cls_name, extras=None, error=None):
"""Metadata-only repr fallback for quantized tensors whose data cannot be
materialized for any reason.

Each attribute access is guarded so that ``__repr__`` never raises.

Parameters
----------
extras : dict, optional
Additional plain-Python (non-tensor) attributes to include, e.g.
``{"is_2D_scaled": self._is_2D_scaled}``. Values are inserted after
``fp8_dtype`` and before ``shape``.
error : BaseException, optional
The exception that triggered the fallback. When given, its type and
message are included in the ``data=`` field so that it is visible *why*
the data could not be materialized.
"""
parts = []
fp8_dtype = getattr(obj, "_fp8_dtype", None)
if fp8_dtype is not None:
parts.append(f"fp8_dtype={fp8_dtype}")
if extras:
for key, value in extras.items():
parts.append(f"{key}={value}")
try:
parts.append(f"shape={tuple(obj.shape)}")
except Exception: # pylint: disable=broad-except
pass
try:
parts.append(f"dtype={obj.dtype}")
except Exception: # pylint: disable=broad-except
pass
if error is not None:
parts.append(f"data=<unmaterialized: {type(error).__name__}: {error}>")
else:
parts.append("data=<unmaterialized>")
return f"{cls_name}({', '.join(parts)})"
20 changes: 14 additions & 6 deletions transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..quantized_tensor import QuantizedTensor, Quantizer
from ._quantization_helpers import _IdentityFunc
from ._quantization_helpers import _IdentityFunc, safe_quantized_repr
from ..constants import DType
from ..utils import devices_match, round_up_to_nearest_multiple

Expand Down Expand Up @@ -267,11 +267,19 @@ def __new__(
return instance

def __repr__(self, *, tensor_contents=None):
return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
f" is_2D_scaled={self._is_2D_scaled},"
f" data={self.dequantize()})"
)
try:
return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
f" is_2D_scaled={self._is_2D_scaled},"
f" data={self.dequantize()})"
)
except Exception as exc: # pylint: disable=broad-except
return safe_quantized_repr(
self,
"Float8BlockwiseQTensor",
extras={"is_2D_scaled": self._is_2D_scaled},
error=exc,
)

def quantize_(
self,
Expand Down
19 changes: 11 additions & 8 deletions transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..utils import canonicalize_process_group, devices_match
from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func
from ..quantized_tensor import QuantizedTensor, Quantizer
from ._quantization_helpers import _IdentityFunc
from ._quantization_helpers import _IdentityFunc, safe_quantized_repr
from ..constants import dist_group_type, DType

aten = torch.ops.aten
Expand Down Expand Up @@ -412,13 +412,16 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
"""

def __repr__(self, *, tensor_contents=None):
return (
"Float8Tensor("
f"fp8_dtype={self._fp8_dtype}, "
f"scale_inv={self._scale_inv.item()}, "
f"data={self.dequantize()}"
")"
)
try:
return (
"Float8Tensor("
f"fp8_dtype={self._fp8_dtype}, "
f"scale_inv={self._scale_inv.item()}, "
f"data={self.dequantize()}"
")"
)
except Exception as exc: # pylint: disable=broad-except
return safe_quantized_repr(self, "Float8Tensor", error=exc)

def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Comment thread
greptile-apps[bot] marked this conversation as resolved.
Expand Down
7 changes: 5 additions & 2 deletions transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..utils import devices_match, round_up_to_nearest_multiple
from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func
from ..quantized_tensor import QuantizedTensor, Quantizer
from ._quantization_helpers import _IdentityFunc
from ._quantization_helpers import _IdentityFunc, safe_quantized_repr

aten = torch.ops.aten

Expand Down Expand Up @@ -233,7 +233,10 @@ def __new__(
)

def __repr__(self, *, tensor_contents=None):
return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize()})"
try:
return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize()})"
except Exception as exc: # pylint: disable=broad-except
return safe_quantized_repr(self, "MXFP8Tensor", error=exc)

def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Expand Down
7 changes: 5 additions & 2 deletions transformer_engine/pytorch/tensor/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from .storage.nvfp4_tensor_storage import NVFP4TensorStorage, _FromNVFP4Func
from ..quantized_tensor import QuantizedTensor, Quantizer
from ._quantization_helpers import _IdentityFunc
from ._quantization_helpers import _IdentityFunc, safe_quantized_repr

aten = torch.ops.aten

Expand Down Expand Up @@ -398,7 +398,10 @@ def __new__(
return instance

def __repr__(self, *, tensor_contents=None):
return f"NVFP4Tensor, data={self.dequantize()})"
try:
return f"NVFP4Tensor, data={self.dequantize()})"
except Exception as exc: # pylint: disable=broad-except
return safe_quantized_repr(self, "NVFP4Tensor", error=exc)

def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import transformer_engine_torch as tex

from ...quantized_tensor import QuantizedTensorStorage, Quantizer
from .._quantization_helpers import safe_quantized_repr

from ...constants import TE_DType_To_Torch, DType

Expand Down Expand Up @@ -354,17 +355,25 @@ def _transpose_columnwise_data(self):
del _old_data

def __repr__(self):
if self._rowwise_data is not None:
data = self.dequantize()
descriptor = "rowwise"
else:
data = self.dequantize()
descriptor = "columnwise"
return (
"Float8BlockwiseQTensorStorage("
f"fp8_dtype={self._fp8_dtype}, "
f"{descriptor}_scaled_data={data})"
)
try:
if self._rowwise_data is not None:
data = self.dequantize()
descriptor = "rowwise"
else:
data = self.dequantize()
descriptor = "columnwise"
return (
"Float8BlockwiseQTensorStorage("
f"fp8_dtype={self._fp8_dtype}, "
f"{descriptor}_scaled_data={data})"
)
except Exception as exc: # pylint: disable=broad-except
return safe_quantized_repr(
self,
"Float8BlockwiseQTensorStorage",
extras={"is_2D_scaled": self._is_2D_scaled},
error=exc,
)

def update_usage(
self, rowwise_usage: Optional[bool] = None, columnwise_usage: Optional[bool] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import transformer_engine_torch as tex

from ...quantized_tensor import QuantizedTensorStorage, Quantizer
from .._quantization_helpers import safe_quantized_repr

from ...constants import TE_DType as torch_to_transformer_engine_dtype, TE_DType_To_Torch, DType

Expand Down Expand Up @@ -209,13 +210,16 @@ def view(self, shape: torch.Size):
)

def __repr__(self):
return (
"Float8TensorStorage("
f"fp8_dtype={self._fp8_dtype}, "
f"scale_inv={self._scale_inv.item()}, "
f"data={self.dequantize()}"
")"
)
try:
return (
"Float8TensorStorage("
f"fp8_dtype={self._fp8_dtype}, "
f"scale_inv={self._scale_inv.item()}, "
f"data={self.dequantize()}"
")"
)
except Exception as exc: # pylint: disable=broad-except
return safe_quantized_repr(self, "Float8TensorStorage", error=exc)

def _create_transpose(self):
"""Update FP8 transpose cache"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import transformer_engine_torch as tex

from ...quantized_tensor import QuantizedTensorStorage, Quantizer
from .._quantization_helpers import safe_quantized_repr

from ...constants import TE_DType as torch_to_transformer_engine_dtype, DType

Expand Down Expand Up @@ -257,15 +258,18 @@ def view(self, shape: torch.Size):
)

def __repr__(self):
data_rowwise = self.dequantize()

return (
"MXFP8TensorStorage("
f"fp8_dtype={self._fp8_dtype}, "
f"rowwise_scaled_data={data_rowwise}"
f"rowwise_scale_inv={self._rowwise_scale_inv}, "
")"
)
try:
data_rowwise = self.dequantize()

return (
"MXFP8TensorStorage("
f"fp8_dtype={self._fp8_dtype}, "
f"rowwise_scaled_data={data_rowwise}"
f"rowwise_scale_inv={self._rowwise_scale_inv}, "
")"
)
except Exception as exc: # pylint: disable=broad-except
return safe_quantized_repr(self, "MXFP8TensorStorage", error=exc)

def update_usage(
self,
Expand Down
24 changes: 14 additions & 10 deletions transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import transformer_engine_torch as tex

from ...quantized_tensor import QuantizedTensorStorage, Quantizer
from .._quantization_helpers import safe_quantized_repr

from ...constants import TE_DType as torch_to_transformer_engine_dtype, DType
from ...utils import _empty_tensor
Expand Down Expand Up @@ -340,16 +341,19 @@ def view(self, shape: torch.Size):
)

def __repr__(self):
data_rowwise = self.dequantize()

return (
"NVFP4TensorStorage("
f"rowwise_scaled_data={data_rowwise},"
f"rowwise_scale_inv={self._rowwise_scale_inv},"
f"amax_rowwise={self._amax_rowwise},"
f"amax_columnwise={self._amax_columnwise},"
")"
)
try:
data_rowwise = self.dequantize()

return (
"NVFP4TensorStorage("
f"rowwise_scaled_data={data_rowwise},"
f"rowwise_scale_inv={self._rowwise_scale_inv},"
f"amax_rowwise={self._amax_rowwise},"
f"amax_columnwise={self._amax_columnwise},"
")"
)
except Exception as exc: # pylint: disable=broad-except
return safe_quantized_repr(self, "NVFP4TensorStorage", error=exc)

def update_usage(
self,
Expand Down
Loading