Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""PyTorch related extensions."""

import importlib.util
import os
from pathlib import Path
from importlib import metadata
Expand Down Expand Up @@ -58,6 +59,25 @@ def setup_pytorch_extension(
]
)

# apache-tvm-ffi: headers for the C++ API (Module / Function / TensorView)
# and libtvm_ffi.so for symbol resolution. Used by tvm_ffi_bridge.h /
# applyTVMFunction. Python registers AOT-compiled CuTeDSL kernels into
# the global registry; TE C++ looks them up via Function::GetGlobalRequired.
tvm_ffi_spec = importlib.util.find_spec("tvm_ffi")
if tvm_ffi_spec is None or not tvm_ffi_spec.submodule_search_locations:
raise RuntimeError(
"apache-tvm-ffi package not found; install it (e.g. "
"`pip install apache-tvm-ffi`) — required for the TVM FFI bridge."
)
tvm_ffi_root = Path(tvm_ffi_spec.submodule_search_locations[0])
tvm_ffi_include = tvm_ffi_root / "include"
tvm_ffi_lib_dir = tvm_ffi_root / "lib"
if not tvm_ffi_include.is_dir() or not (tvm_ffi_lib_dir / "libtvm_ffi.so").exists():
raise RuntimeError(
f"apache-tvm-ffi assets missing at {tvm_ffi_root} (need include/ and lib/libtvm_ffi.so)"
)
include_dirs.append(tvm_ffi_include)

# Compiler flags
cxx_flags = ["-O3", "-fvisibility=hidden"]
if debug_build_enabled():
Expand All @@ -77,8 +97,11 @@ def setup_pytorch_extension(

setup_mpi_flags(include_dirs, cxx_flags)

library_dirs = []
libraries = []
library_dirs = [tvm_ffi_lib_dir]
libraries = ["tvm_ffi"]
# rpath pinned to the pip install dir so the loader finds libtvm_ffi.so
# without LD_LIBRARY_PATH at runtime.
extra_link_args = [f"-Wl,-rpath,{tvm_ffi_lib_dir}"]
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
assert (
os.getenv("NVSHMEM_HOME") is not None
Expand All @@ -102,6 +125,7 @@ def setup_pytorch_extension(
sources=[str(src) for src in sources],
include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={"cxx": cxx_flags},
extra_link_args=extra_link_args,
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
)
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
# See LICENSE for license information.

[build-system]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
# apache-tvm-ffi is required at configure/compile/link time: the common C++
# library finds it via find_package(tvm_ffi) and links libtvm_ffi.so (the
# CuTeDSL quant backend bridge). It is also a runtime dependency (see setup.py).
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1", "apache-tvm-ffi>=0.1.12"]

# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def setup_requirements() -> Tuple[List[str], List[str]]:
"importlib-metadata>=1.0",
"packaging",
cusolvermp_pypi_package_name(),
# The core C++ library links libtvm_ffi.so (CuTeDSL quant backend bridge),
# so apache-tvm-ffi is required at runtime by every TE install.
"apache-tvm-ffi>=0.1.12",
]
test_reqs: List[str] = ["pytest>=8.2.1"]

Expand Down
34 changes: 34 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,24 @@ set(CUTLASS_TOOLS_INCLUDE_DIR
# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)

# tvm-ffi: the quantize dispatch layer bridges to JIT-compiled CuTeDSL kernels
# through tvm-ffi (see common/tvm_ffi_bridge.h). Locate the tvm_ffi package that
# ships with the Python install and use its exported CMake config (provides the
# tvm_ffi::shared imported target with headers + libtvm_ffi.so).
execute_process(
COMMAND ${Python_EXECUTABLE} -c "import tvm_ffi.libinfo as li; print(li.find_cmake_path())"
OUTPUT_VARIABLE TVM_FFI_CMAKE_DIR
OUTPUT_STRIP_TRAILING_WHITESPACE
RESULT_VARIABLE TVM_FFI_CMAKE_QUERY)
if(NOT TVM_FFI_CMAKE_QUERY EQUAL 0)
message(FATAL_ERROR
"Could not import the tvm_ffi Python package (with '${Python_EXECUTABLE}'), "
"which Transformer Engine requires to build the CuTeDSL quantize backend "
"bridge (common/tvm_ffi_bridge.h). Install it into this Python environment: "
"`pip install apache-tvm-ffi`.")
endif()
find_package(tvm_ffi CONFIG REQUIRED PATHS "${TVM_FFI_CMAKE_DIR}")

function(find_nccl_version OUT_VERSION OUT_INCLUDE_DIR)
find_path(_nvte_nccl_include_dir
NAMES nccl.h
Expand Down Expand Up @@ -360,6 +378,22 @@ target_link_libraries(transformer_engine PUBLIC
CUDA::cudart
CUDNN::cudnn_all)

# CuTeDSL quantize backend bridge. PRIVATE: tvm_ffi_bridge.h is an internal
# header (not in the installed public include dir), so the symbols and headers
# are only needed to compile transformer_engine itself, not by downstream
# consumers. The INTERFACE include dirs of tvm_ffi::shared still apply to our
# own TUs, which is what fixes the <tvm/ffi/*.h> not-found error.
target_link_libraries(transformer_engine PRIVATE tvm_ffi::shared)

# libtvm_ffi.so ships inside the tvm_ffi Python package (not a system lib dir),
# so add its directory to the RPATH; otherwise the runtime loader can't satisfy
# the DT_NEEDED on libtvm_ffi.so and dlopen fails with "cannot open shared
# object file". Applied to both the build tree and the installed library.
get_target_property(TVM_FFI_SHARED_LOCATION tvm_ffi::shared IMPORTED_LOCATION)
get_filename_component(TVM_FFI_LIB_DIR "${TVM_FFI_SHARED_LOCATION}" DIRECTORY)
set_property(TARGET transformer_engine APPEND PROPERTY BUILD_RPATH "${TVM_FFI_LIB_DIR}")
set_property(TARGET transformer_engine APPEND PROPERTY INSTALL_RPATH "${TVM_FFI_LIB_DIR}")

target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine SYSTEM PRIVATE
Expand Down
19 changes: 19 additions & 0 deletions transformer_engine/common/CuTeDSL/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""CuTeDSL kernels for Transformer Engine.

Importing this package has a side effect: it registers the CuTeDSL kernel
entrypoints (e.g. ``get_mxfp8_quantization_function``) as TVM-FFI global
functions. The C++ dispatcher probes for those names via
``tvm::ffi::Function::GetGlobal`` — finding one means the process is running
inside a Python environment with the CuTeDSL toolchain available, so the kernel
may be compiled on demand; not finding it means a plain C++ environment, and
the dispatcher falls back to the CUDA C++ kernel.

Importing requires the optional CuTeDSL toolchain (cutlass, tvm_ffi). Callers
that want graceful degradation should guard the import in a try/except.
"""

from . import cast # noqa: F401 (import side effect: registers global funcs)
8 changes: 8 additions & 0 deletions transformer_engine/common/CuTeDSL/cast/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""CuTeDSL cast/quantization kernels. Importing pulls in each kernel module so
its TVM-FFI entrypoint is registered."""

from . import mxfp8 # noqa: F401 (import side effect: registers global funcs)
8 changes: 8 additions & 0 deletions transformer_engine/common/CuTeDSL/cast/mxfp8/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""MXFP8 CuTeDSL kernels. Importing ``quantize_mxfp8`` runs its module body,
which registers the ``get_mxfp8_quantization_function`` TVM-FFI global func."""

from . import quantize_mxfp8 # noqa: F401 (import side effect: registers the global func)
Loading
Loading