From 3ee9599b6d75852426d1ca55c2daa7c5b851a751 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 8 Jun 2026 13:29:33 -0700 Subject: [PATCH 01/22] jax: add EP bindings on pointer-keyed cache with EpLayerConfig and bf16 max_token_dtype Signed-off-by: Phuong Nguyen --- examples/jax/ep/ep_moe.py | 393 +++++++ examples/jax/ep/run_test_ep.sh | 85 ++ tests/jax/multi_process_launch_ep.sh | 67 ++ tests/jax/test_multi_process_ep.py | 748 ++++++++++++ .../jax/cpp_extensions/__init__.py | 1 + transformer_engine/jax/cpp_extensions/base.py | 11 + transformer_engine/jax/cpp_extensions/ep.py | 1017 +++++++++++++++++ transformer_engine/jax/csrc/extensions.h | 22 + transformer_engine/jax/csrc/extensions/ep.cpp | 541 +++++++++ .../jax/csrc/extensions/pybind.cpp | 31 + transformer_engine/jax/ep.py | 311 +++++ transformer_engine/jax/sharding.py | 12 +- 12 files changed, 3238 insertions(+), 1 deletion(-) create mode 100644 examples/jax/ep/ep_moe.py create mode 100755 examples/jax/ep/run_test_ep.sh create mode 100755 tests/jax/multi_process_launch_ep.sh create mode 100644 tests/jax/test_multi_process_ep.py create mode 100644 transformer_engine/jax/cpp_extensions/ep.py create mode 100644 transformer_engine/jax/csrc/extensions/ep.cpp create mode 100644 transformer_engine/jax/ep.py diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py new file mode 100644 index 0000000000..7b3601fb60 --- /dev/null +++ b/examples/jax/ep/ep_moe.py @@ -0,0 +1,393 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""End-to-end MoE example: dispatch -> batched expert linear -> combine, fwd + bwd. + +One process per GPU. Run via run_test_ep.sh. +""" + +import argparse +import sys + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from transformer_engine.jax.ep import EpLayerConfig, ep_bootstrap, ep_dispatch, ep_combine +from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + +# ── Setup ─────────────────────────────────────────────────────────────────── + + +def _parse_args(): + p = argparse.ArgumentParser(description="TE-JAX EP MoE example (fwd + bwd)") + p.add_argument("--coordinator-address", required=True) + p.add_argument("--process-id", type=int, required=True) + p.add_argument("--num-processes", type=int, required=True) + p.add_argument("--num-tokens", type=int, default=8, help="Per-rank token count.") + p.add_argument("--top-k", type=int, default=2) + p.add_argument("--hidden", type=int, default=32) + p.add_argument("--hidden-out", type=int, default=32) + p.add_argument( + "--num-experts", + type=int, + default=None, + help="Total experts across the EP group. Default: num_processes.", + ) + p.add_argument("--dp-size", type=int, default=None, help="Default: num_procs // ep_size.") + p.add_argument( + "--check", + action="store_true", + default=True, + help="Verify fwd+bwd against a single-rank numpy reference.", + ) + return p.parse_args() + + +def _distributed_init(args): + jax.distributed.initialize( + coordinator_address=args.coordinator_address, + num_processes=args.num_processes, + process_id=args.process_id, + local_device_ids=[args.process_id], + ) + assert ( + jax.local_device_count() == 1 + ), f"EP example requires 1 GPU per process; got {jax.local_device_count()}" + + +def _build_mesh_and_resource(args): + """Pick a (2, 2) mesh by default. Override via --dp-size.""" + n = args.num_processes + if n < 4: + raise ValueError(f"num_processes ({n}) must be >= 4 for NCCL EP") + if args.dp_size is None: + if n != 4: + raise ValueError( + f"default mesh expects exactly 4 ranks (got {n}); pass --dp-size to override" + ) + args.dp_size = 2 + assert n % args.dp_size == 0, f"num_processes={n} not divisible by dp_size={args.dp_size}" + args.ep_size = n // args.dp_size + if args.num_experts is None: + args.num_experts = args.num_processes + assert args.num_experts % args.ep_size == 0 + args.num_local_experts = args.num_experts // args.ep_size + args.recv_capacity_per_rank = args.ep_size * args.num_tokens * args.top_k + + devs = np.asarray(jax.devices()).reshape(args.dp_size, args.ep_size) + mesh = Mesh(devs, ("dp", "ep")) + mr = MeshResource(dp_resource="dp", ep_resource="ep") + return mesh, mr + + +def _make_routing(dp_color, num_tokens, top_k, num_experts, num_local_experts): + """Deterministic routing: topk_idx[t, k] = (dp_color*NLE + t*K + k) % E.""" + topk_idx = np.empty((num_tokens, top_k), dtype=np.int32) + for t in range(num_tokens): + for k in range(top_k): + topk_idx[t, k] = (dp_color * num_local_experts + t * top_k + k) % num_experts + return topk_idx + + +def _make_inputs(args): + """Build 3D ``[B, S, H]`` arrays sharded ``(("dp","ep"), None, None)``. + + B = num_processes (sharded across the compound (dp,ep) axis so each rank + holds one slot); S = args.num_tokens. Global numpy views (rank-0 + reference) are kept 2D for the legacy reference implementation. + """ + T, K, H, H_out = args.num_tokens, args.top_k, args.hidden, args.hidden_out + E = args.num_experts + dp_size = args.dp_size + ep_size = args.ep_size + num_procs = args.num_processes + dp_color = args.process_id // ep_size + + rng_dp = np.random.default_rng(seed=42 + dp_color) + tokens_np = (rng_dp.standard_normal((T, H), dtype=np.float32) * 0.5).astype(np.float32) + topk_idx_np = _make_routing(dp_color, T, K, E, args.num_local_experts) + w_np = np.full((T, K), 1.0 / K, dtype=np.float32) + + tokens_global_np = np.concatenate( + [ + ( + np.random.default_rng(seed=42 + c).standard_normal((T, H), dtype=np.float32) * 0.5 + ).astype(np.float32) + for c in range(dp_size) + ], + axis=0, + ) + topk_idx_global_np = np.concatenate( + [_make_routing(c, T, K, E, args.num_local_experts) for c in range(dp_size)], axis=0 + ) + w_global_np = np.full((dp_size * T, K), 1.0 / K, dtype=np.float32) + + # Same seed on every rank → identical kernel array everywhere. + rng = np.random.default_rng(seed=42) + kernels_np = (rng.standard_normal((E, H, H_out), dtype=np.float32) * (1.0 / np.sqrt(H))).astype( + np.float32 + ) + + # Each rank contributes one [1, T, ...] slab; the global shape is + # [num_procs, T, ...] sharded on the first dim across (dp, ep). + mesh = args.mesh + dpep_spec = NamedSharding(mesh, PartitionSpec(("dp", "ep"), None, None)) + tokens = jax.make_array_from_process_local_data( + dpep_spec, tokens_np[None, :, :].astype(np.float32), (num_procs, T, H) + ).astype(jnp.bfloat16) + topk_idx = jax.make_array_from_process_local_data( + dpep_spec, topk_idx_np[None, :, :], (num_procs, T, K) + ) + topk_w = jax.make_array_from_process_local_data(dpep_spec, w_np[None, :, :], (num_procs, T, K)) + kernels = jnp.asarray(kernels_np, dtype=jnp.bfloat16) + return ( + tokens_global_np, + topk_idx_global_np, + w_global_np, + kernels_np, + tokens, + topk_idx, + topk_w, + kernels, + ) + + +# ── MoE step ──────────────────────────────────────────────────────────────── + + +def _batched_expert_linear(recv_tokens, kernels, num_local_experts, dp_size, ep_size): + """Per-expert linear. ``recv_tokens`` is 3D ``[num_procs, recv_pr, H]`` + (compound (dp,ep) leading); ``kernels`` is 4D ``[ep_size, NLE, H, H_out]``, + broadcast over the dp axis. Output matches ``recv_tokens``' 3D layout + with ``H_out`` in place of ``H``.""" + num_procs, recv_pr, H = recv_tokens.shape + H_out = kernels.shape[-1] + slots_per_expert = recv_pr // num_local_experts + # [num_procs, recv_pr, H] -> [dp, ep, NLE, slots, H] + grouped = recv_tokens.reshape(dp_size, ep_size, num_local_experts, slots_per_expert, H) + # Contract H; batch over (ep, NLE) which are present on both sides. + out = jax.lax.dot_general( + grouped, + kernels.astype(grouped.dtype), + dimension_numbers=(((4,), (2,)), ((1, 2), (0, 1))), + ) + # Output dim order from dot_general: batch dims first, then remaining lhs, rhs. + # batch=(ep,NLE), lhs_remaining=(dp,slots), rhs_remaining=(H_out,) + # → shape [ep, NLE, dp, slots, H_out]. Permute to [dp, ep, NLE, slots, H_out]. + out = jnp.transpose(out, (2, 0, 1, 3, 4)) + return out.reshape(num_procs, recv_pr, H_out) + + +def _moe_step(args, topk_idx, tokens, topk_w, kernels): + """Jit'd MoE step: dispatch -> batched per-expert linear -> combine. + + Inputs are 3D ``[B, S, H]`` with the first dim compound-sharded across + ``("dp","ep")``. Combine returns the same 3D shape. + """ + B = args.num_processes + S = args.num_tokens + NLE = args.num_local_experts + dp_size, ep_size = args.dp_size, args.ep_size + mesh = args.mesh + in_spec = PartitionSpec(("dp", "ep"), None, None) # [B, S, ...] + ep3 = PartitionSpec(("dp", "ep"), None, None) # [num_procs, recv_pr, H] + ep2 = PartitionSpec(("dp", "ep"), None) # [num_procs, recv_pr] + # Kernels are EP-replicated across dp colors; shard only the ep-rank axis. + kernel_spec = PartitionSpec("ep", None, None, None) + + kernels = kernels.reshape(ep_size, NLE, *kernels.shape[1:]) + ep_handle = EpLayerConfig(top_k=args.top_k, dispatch_output_per_expert_alignment=16) + + @jax.jit + def step(topk_idx, tokens, topk_w, local_kernels): + topk_idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(mesh, in_spec)) + tokens = jax.lax.with_sharding_constraint(tokens, NamedSharding(mesh, in_spec)) + topk_w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(mesh, in_spec)) + local_kernels = jax.lax.with_sharding_constraint( + local_kernels, NamedSharding(mesh, kernel_spec) + ) + recv_tokens, recv_topk_w, handle_mem, _tc = ep_dispatch( + ep_handle, topk_idx, tokens, topk_w, args.recv_capacity_per_rank + ) + recv_tokens = jax.lax.with_sharding_constraint(recv_tokens, NamedSharding(mesh, ep3)) + recv_topk_w = jax.lax.with_sharding_constraint(recv_topk_w, NamedSharding(mesh, ep2)) + expert_out = _batched_expert_linear(recv_tokens, local_kernels, NLE, dp_size, ep_size) + expert_out = jax.lax.with_sharding_constraint(expert_out, NamedSharding(mesh, ep3)) + return ep_combine( + ep_handle, + handle_mem, + _tc, + expert_out, + recv_topk_w, + num_local_tokens=(B, S), + out_sharding=(("dp", "ep"), None, None), + ) + + return step(topk_idx, tokens, topk_w, kernels) + + +# ── Reference (numerical check) ───────────────────────────────────────────── + + +def _reference_moe(tokens, topk_idx, topk_w, kernels): + """Single-rank dense MoE reference. tokens [T, H], output [T, H_out].""" + T, K = topk_idx.shape + H_out = kernels.shape[-1] + out = np.zeros((T, H_out), dtype=np.float32) + for t in range(T): + tok = tokens[t].astype(np.float32) + for k in range(K): + e = int(topk_idx[t, k]) + out[t] += float(topk_w[t, k]) * (tok @ kernels[e].astype(np.float32)) + return out + + +def _reference_grad(tokens, topk_idx, topk_w, kernels): + """d/dtokens of 0.5 * sum(ref_out**2) — used by --check to validate bwd.""" + T, K = topk_idx.shape + H = tokens.shape[-1] + ref_out = _reference_moe(tokens, topk_idx, topk_w, kernels) + grad = np.zeros((T, H), dtype=np.float32) + for t in range(T): + mixed = np.zeros_like(kernels[0]) + for k in range(K): + mixed = mixed + float(topk_w[t, k]) * kernels[int(topk_idx[t, k])] + grad[t] = ref_out[t] @ mixed.T + return ref_out, grad + + +# ── Main ──────────────────────────────────────────────────────────────────── + + +def main(): + args = _parse_args() + _distributed_init(args) + + dev = jax.local_devices()[0] + cap = getattr(dev, "compute_capability", None) + if cap is not None: + major, minor = (int(x) for x in str(cap).split(".")) + if major * 10 + minor < 90: + print(f"[ep_moe] SKIPPED: NCCL EP requires SM>=90 (got SM{major}{minor})") + return + + args.mesh, args.mr = _build_mesh_and_resource(args) + + with args.mesh, global_shard_guard(args.mr): + ep_bootstrap( + world_size=args.num_processes, + rank=args.process_id, + ep_size=args.ep_size, + num_experts=args.num_experts, + max_tokens_per_rank=args.num_tokens, + recv_capacity_per_rank=args.recv_capacity_per_rank, + hidden_dim=args.hidden, + ) + + ( + tokens_global_np, + topk_idx_global_np, + w_global_np, + kernels_np, + tokens, + topk_idx, + topk_w, + kernels, + ) = _make_inputs(args) + + def loss_fn(toks, idx, w, kern): + out = _moe_step(args, idx, toks, w, kern) + return 0.5 * (out.astype(jnp.float32) ** 2).sum(), out + + (loss, out_fwd), grad_tokens = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))( + tokens, topk_idx, topk_w, kernels + ) + grad_tokens.block_until_ready() + out_fwd.block_until_ready() + + if args.process_id == 0: + print( + f"[ep_moe] loss={float(loss):.4f} grad_tokens.shape={grad_tokens.shape} " + f"dp={args.dp_size} ep={args.ep_size} " + f"num_experts={args.num_experts} recv_pr={args.recv_capacity_per_rank}" + ) + + if args.check: + + def _norm(spec, ndim): + return tuple(spec) + (None,) * (ndim - len(spec)) + + # JAX may collapse a size-1 mesh axis: when dp_size==1 the spec can + # appear as ``(("dp","ep"),...)`` or ``("ep",...)``. Accept both. + if args.dp_size > 1: + acceptable_specs = ((("dp", "ep"), None, None),) + else: + acceptable_specs = ((("dp", "ep"), None, None), ("ep", None, None)) + assert ( + _norm(out_fwd.sharding.spec, out_fwd.ndim) in acceptable_specs + ), f"out_fwd.sharding.spec={out_fwd.sharding.spec} (expected one of {acceptable_specs})" + assert _norm(grad_tokens.sharding.spec, grad_tokens.ndim) in acceptable_specs, ( + f"grad_tokens.sharding.spec={grad_tokens.sharding.spec}" + f" (expected one of {acceptable_specs})" + ) + + replicated = NamedSharding(args.mesh, jax.sharding.PartitionSpec()) + out_global = jax.jit(lambda x: jax.lax.with_sharding_constraint(x, replicated))(out_fwd) + grad_global = jax.jit(lambda x: jax.lax.with_sharding_constraint(x, replicated))( + grad_tokens + ) + out_global.block_until_ready() + grad_global.block_until_ready() + + ref_out, ref_grad = _reference_grad( + tokens_global_np, topk_idx_global_np, w_global_np, kernels_np + ) + ref_loss = 0.5 * float((ref_out.astype(np.float32) ** 2).sum()) + # 3D global ``[num_procs, S, H]`` with num_procs = dp * ep. Each EP + # column in a DP color sees identical inputs (and produces identical + # outputs), so collapse the ep dim to one replica before flattening + # to 2D against the dp-only reference. + dp_size, ep_size = args.dp_size, args.ep_size + global_out = ( + np.asarray(out_global.addressable_shards[0].data.astype(jnp.float32)) + .reshape(dp_size, ep_size, -1, ref_out.shape[-1])[:, 0] + .reshape(-1, ref_out.shape[-1]) + ) + global_grad = ( + np.asarray(grad_global.addressable_shards[0].data.astype(jnp.float32)) + .reshape(dp_size, ep_size, -1, ref_grad.shape[-1])[:, 0] + .reshape(-1, ref_grad.shape[-1]) + ) + if args.process_id == 0: + fwd_diff = np.abs(global_out - ref_out) + grad_diff = np.abs(global_grad - ref_grad) + print( + f"[ep_moe] DEBUG loss={float(loss):.4f} ref_loss(global)={ref_loss:.4f} " + f"ratio={float(loss) / max(ref_loss, 1e-9):.4f} (expected ~1.0)" + ) + print(f"[ep_moe] DEBUG fwd max-abs-diff per row: {fwd_diff.max(axis=1)}") + print(f"[ep_moe] DEBUG grad max-abs-diff per row: {grad_diff.max(axis=1)}") + np.testing.assert_allclose( + global_out, + ref_out, + rtol=5e-2, + atol=5e-2, + err_msg=f"rank {args.process_id}: fwd mismatch", + ) + np.testing.assert_allclose( + global_grad, + ref_grad, + rtol=5e-2, + atol=5e-2, + err_msg=f"rank {args.process_id}: bwd mismatch", + ) + if args.process_id == 0: + print(f"[ep_moe] --check PASSED (ref_out.sum()={float(ref_out.sum()):.4f})") + + +if __name__ == "__main__": + main() + sys.exit(0) diff --git a/examples/jax/ep/run_test_ep.sh b/examples/jax/ep/run_test_ep.sh new file mode 100755 index 0000000000..55b958f146 --- /dev/null +++ b/examples/jax/ep/run_test_ep.sh @@ -0,0 +1,85 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +#!/bin/bash + +NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} + +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "NCCL EP requires at least 4 GPUs (found ${NUM_GPUS}); SKIPPING." + exit 0 +fi +# Default mesh is (2, 2); use exactly 4 ranks even on larger boxes. +NUM_GPUS="${NVTE_EP_NUM_RANKS:-4}" + +: ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + +# NCCL EP requires NVLink P2P among ranks on the node. +echo "*** Checking NVLINK support ***" +NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1) +NVLINK_EXIT_CODE=$? +if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] \ + || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then + echo "NVLINK is not supported on this platform — EP example requires NVLINK; SKIPPING" + exit 0 +fi +echo "NVLINK support detected" + +SCRIPT="$TE_PATH/examples/jax/ep/ep_moe.py" +export PYTHONPATH="${TE_PATH}${PYTHONPATH:+:${PYTHONPATH}}" +COORD="${COORD:-127.0.0.1:12345}" +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-300}" + +XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_graph_min_graph_size=1" +export XLA_FLAGS="${XLA_BASE_FLAGS}" + +# Stage NCCL EP JIT cubins on tmpfs to keep build/iteration fast. +: ${NCCL_EP_JIT_CACHE_DIR:="${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)"} +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "$NCCL_EP_JIT_CACHE_DIR" + +echo +echo "*** Executing ep_moe.py across $NUM_GPUS GPUs ***" + +PIDS=() +cleanup() { + for pid in "${PIDS[@]}"; do + kill -0 "$pid" 2>/dev/null && kill -KILL "$pid" 2>/dev/null || true + done +} +trap cleanup EXIT INT TERM + +EXTRA_ARGS=${EXTRA_ARGS:-"--check"} + +for ((i=1; i "stdout_rank_${i}.txt" 2>&1 & + PIDS+=($!) +done +timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + python -u "$SCRIPT" \ + --coordinator-address "$COORD" --process-id "0" --num-processes "$NUM_GPUS" \ + $EXTRA_ARGS 2>&1 | tee stdout_rank_0.txt +wait + +HAS_FAILURE=0 +if grep -qE "FAILED|Traceback|ERROR" stdout_rank_0.txt; then + echo "... ep_moe FAILED" + HAS_FAILURE=1 +elif ! grep -qE "\[ep_moe\]" stdout_rank_0.txt; then + echo "... ep_moe INVALID (rank 0 produced no summary line)" + for ((i=1; i/dev/null + done + HAS_FAILURE=1 +else + echo "... ep_moe PASSED" +fi +rm -f stdout_rank_*.txt +exit $HAS_FAILURE diff --git a/tests/jax/multi_process_launch_ep.sh b/tests/jax/multi_process_launch_ep.sh new file mode 100755 index 0000000000..a37ffc2952 --- /dev/null +++ b/tests/jax/multi_process_launch_ep.sh @@ -0,0 +1,67 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +#!/bin/bash + +SCRIPT_NAMES="${SCRIPT_NAMES:-test_multi_process_ep.py}" +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-180}" + + +XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_graph_min_graph_size=1" + +export XLA_FLAGS="${XLA_BASE_FLAGS}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +NUM_RUNS=$(nvidia-smi -L | wc -l) + +if [ "${NUM_RUNS}" -lt 4 ]; then + echo "NCCL EP requires at least 4 GPUs (found ${NUM_RUNS}); SKIPPING." + exit 0 +fi +# Default test mesh is (2, 2); use exactly 4 ranks even on larger boxes. +NUM_RUNS="${NVTE_TEST_EP_NUM_RANKS:-4}" + +OVERALL_RET=0 + +for SCRIPT_NAME in $SCRIPT_NAMES; do + echo "=== Running ${SCRIPT_NAME} ===" + for ((i=1; i stdout_rank_${i}.txt 2>&1 & + done + + timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS 2>&1 | tee stdout_multi_process.txt + + wait + + RET=0 + if grep -q "FAILED" stdout_multi_process.txt; then + RET=1 + fi + # Treat missing test summary on rank 0 as hang/crash rather than silent success. + if ! grep -qE "Ran [0-9]+ test|^OK$|PASSED" stdout_multi_process.txt; then + echo "ERROR: rank 0 produced no test summary for ${SCRIPT_NAME} — likely a hang or early crash." + echo " NCCL EP requires NVLS multicast; check NCCL_DEBUG=INFO output." + RET=1 + fi + if [ "$RET" -ne 0 ]; then + for ((i=1; i/dev/null || echo "(no log)" + done + fi + + rm -f stdout_multi_process.txt stdout_rank_*.txt + if [ "$RET" -ne 0 ]; then + OVERALL_RET=1 + fi +done + +exit "$OVERALL_RET" diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py new file mode 100644 index 0000000000..edfac0f82c --- /dev/null +++ b/tests/jax/test_multi_process_ep.py @@ -0,0 +1,748 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Multi-process unit tests for the TE-JAX Expert Parallelism (EP) primitives. + +Default mesh is (dp=2, ep=2); override via ``NVTE_TEST_EP_MESH=DPxEP``. +Coverage: + + - ``ep_bootstrap`` rejects when ``ep_resource`` is unset. + - Individual primitives (``ep_prepare``, ``ep_dispatch_fwd``, ``ep_combine_fwd``) + round-trip an identity expert → output ≈ tokens. + - ``ep_dispatch`` custom_vjp: ``grad_tokens ≈ TOP_K · tokens`` (closed form). + - ``ep_combine`` custom_vjp: ``max|grad_eo| ≈ eo_const / TOP_K`` (closed form). + - ``ep_dispatch`` custom_vjp: exact per-(t, k) ``grad_topk_weights`` under + skewed upstream gradients (no k-axis averaging). + - HLO reshard guard: compile-only, no XLA collectives outside the EP FFI. + +Launch via tests/jax/multi_process_launch_ep.sh (one process per GPU). +""" + +import os +import sys +import unittest + +import jax +import jax.experimental.multihost_utils as jmu +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.jax.ep import EpLayerConfig, ep_bootstrap, ep_dispatch, ep_combine +from transformer_engine.jax.cpp_extensions.ep import ( + ep_prepare, + ep_dispatch_fwd, + ep_combine_fwd, +) + + +# ── Test config ───────────────────────────────────────────────────────────── +# NCCL EP requires NUM_LOCAL_EXPERTS*ep % 4 == 0 (TMA alignment in +# device/hybridep_adapter.cu:511). With NUM_LOCAL_EXPERTS=2, ep must be even. + +NUM_LOCAL_EXPERTS = 2 # per-rank → num_experts = NLE * EP +HIDDEN_DIM = 32 +TOP_K = 2 +TOKENS_PER_DP_SHARD = 4 # per device along dp + + +def _factor_dp_ep(num_procs): + """Default to a (2, 2) mesh. Override via ``NVTE_TEST_EP_MESH=DPxEP``. + + NUM_LOCAL_EXPERTS*ep must be a multiple of 4 for NCCL EP's TMA alignment. + """ + override = os.environ.get("NVTE_TEST_EP_MESH") + if override: + dp_str, ep_str = override.lower().split("x") + dp, ep = int(dp_str), int(ep_str) + if dp * ep != num_procs: + raise ValueError( + f"NVTE_TEST_EP_MESH={override!r} does not multiply to num_procs={num_procs}" + ) + if (NUM_LOCAL_EXPERTS * ep) % 4 != 0: + raise ValueError( + f"NUM_LOCAL_EXPERTS*ep ({NUM_LOCAL_EXPERTS}*{ep}) must be a multiple of 4 " + "for NCCL EP TMA alignment" + ) + return dp, ep + if num_procs != 4: + raise ValueError( + f"default mesh expects exactly 4 ranks (got {num_procs}); set " + "NVTE_TEST_EP_MESH=DPxEP to override" + ) + return 2, 2 + + +def _build_mesh(dp, ep): + devs = np.asarray(jax.devices()).reshape(dp, ep) + return Mesh(devs, ("dp", "ep")) + + +def _local_device_sm(): + """Return SM major*10+minor of the first local CUDA device, or None.""" + try: + dev = jax.local_devices()[0] + cap = getattr(dev, "compute_capability", None) + if cap is None: + return None + major, minor = (int(x) for x in str(cap).split(".")) + return major * 10 + minor + except Exception: + return None + + +class TestEP(unittest.TestCase): + @classmethod + def setUpClass(cls): + sm = _local_device_sm() + if sm is not None and sm < 90: + raise unittest.SkipTest(f"NCCL EP requires SM>=90 (got SM{sm})") + cls.num_procs = jax.process_count() + cls.rank = jax.process_index() + cls.dp, cls.ep = _factor_dp_ep(cls.num_procs) + cls.num_experts = NUM_LOCAL_EXPERTS * cls.ep + # recv_capacity is per-DP-group (NCCL EP comms isolated per DP color). + # Under PartitionSpec(("dp","ep"), None) each EP group sees + # T_global/dp = TOKENS_PER_DP_SHARD tokens total; pad for routing skew. + T_per_ep_group = TOKENS_PER_DP_SHARD + active_experts = min(cls.num_experts, T_per_ep_group * TOP_K) + overconc = cls.num_experts // active_experts + cls.recv_capacity_per_rank = ( + NUM_LOCAL_EXPERTS * max(T_per_ep_group * TOP_K, 16) * overconc * 2 + ) + cls.mesh = _build_mesh(cls.dp, cls.ep) + cls.mr = MeshResource(dp_resource="dp", ep_resource="ep") + with cls.mesh, global_shard_guard(cls.mr): + ep_bootstrap( + world_size=cls.num_procs, + rank=cls.rank, + ep_size=cls.ep, + num_experts=cls.num_experts, + max_tokens_per_rank=TOKENS_PER_DP_SHARD, + recv_capacity_per_rank=cls.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + ) + # One layer config shared by all single-layer tests below; non-zero + # alignment exercises dispatch_output_per_expert_alignment end-to-end. + cls.hk = EpLayerConfig(top_k=TOP_K, dispatch_output_per_expert_alignment=16) + + # ── Bootstrap precondition ──────────────────────────────────────────── + + def test_bootstrap_rejects_missing_ep_axis(self): + """ep_bootstrap raises when MeshResource has no ep_resource.""" + with self.mesh, global_shard_guard(MeshResource()): + with self.assertRaisesRegex(ValueError, "ep_resource"): + ep_bootstrap( + world_size=self.num_procs, + rank=self.rank, + ep_size=self.ep, + num_experts=self.num_experts, + max_tokens_per_rank=TOKENS_PER_DP_SHARD, + recv_capacity_per_rank=self.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + ) + + # ── Helpers ─────────────────────────────────────────────────────────── + + def _make_identity_inputs(self, nonuniform=False): + """Identity routing + uniform weights — combined output ≈ tokens. + + ``nonuniform=False``: ``(t*TOP_K+k) % E`` (round-robin, near-balanced). + ``nonuniform=True``: ``top1=0`` for every token, ``top2=1+(t%(E-1))`` — + expert 0 absorbs the entire batch while the others split the second + slot evenly. Exercises a skewed per-expert load. + """ + T_global = TOKENS_PER_DP_SHARD * self.dp + E = self.num_experts + topk_idx = np.empty((T_global, TOP_K), dtype=np.int32) + if nonuniform: + assert TOP_K == 2, "non-uniform pattern assumes top_k=2" + for t in range(T_global): + topk_idx[t, 0] = 0 + topk_idx[t, 1] = 1 + (t % (E - 1)) + else: + for t in range(T_global): + for k in range(TOP_K): + topk_idx[t, k] = (t * TOP_K + k) % E + topk_idx = jnp.asarray(topk_idx) + topk_weights = jnp.full((T_global, TOP_K), 1.0 / TOP_K, dtype=jnp.float32) + tokens = jnp.asarray( + np.linspace(0.1, 0.9, T_global * HIDDEN_DIM, dtype=np.float32).reshape( + T_global, HIDDEN_DIM + ), + dtype=jnp.bfloat16, + ) + return T_global, topk_idx, tokens, topk_weights + + def _make_random_inputs(self, seed=42, nonuniform=True): + """Random tokens + skewed top-2 routing (top1=0 always; top2 varies). + + Non-uniform load by default — guarantees expert 0 receives every token + while the rest of the experts split the second slot. Use + ``nonuniform=False`` for a balanced (t%E, (t+1)%E) pattern. + """ + T_dp = TOKENS_PER_DP_SHARD * self.dp + E = self.num_experts + rng = np.random.default_rng(seed=seed) + tokens = jnp.asarray( + rng.standard_normal((T_dp, HIDDEN_DIM), dtype=np.float32) * 0.5, + dtype=jnp.bfloat16, + ) + topk_idx_np = np.empty((T_dp, TOP_K), dtype=np.int32) + if nonuniform: + assert TOP_K == 2, "non-uniform pattern assumes top_k=2" + for t in range(T_dp): + topk_idx_np[t, 0] = 0 + topk_idx_np[t, 1] = 1 + (t % (E - 1)) + else: + for t in range(T_dp): + a, b = t % E, (t + 1) % E + topk_idx_np[t, 0], topk_idx_np[t, 1] = (a, b) if a < b else (b, a) + topk_idx = jnp.asarray(topk_idx_np) + topk_weights = jnp.asarray(np.full((T_dp, TOP_K), 1.0 / TOP_K, dtype=np.float32)) + return T_dp, tokens, topk_idx, topk_weights + + # ── Individual primitives (cpp_extensions level) ────────────────────── + + def test_two_handle_mems_no_aliasing(self): + """Two ``ep_prepare`` calls in one jit must produce distinct handle_mem + buffers; the pointer-keyed C++ cache must not alias HandleEntries + across distinct logical layers.""" + _T, topk_idx, _tokens, _w = self._make_identity_inputs() + ka, kb = ( + EpLayerConfig(top_k=TOP_K, dispatch_output_per_expert_alignment=16), + EpLayerConfig(top_k=TOP_K, dispatch_output_per_expert_alignment=16), + ) + dp_spec = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + + @jax.jit + def run(idx): + _tc_a, ha = ep_prepare(ka, idx) + _tc_b, hb = ep_prepare(kb, idx) + return ha, hb + + hm_a, hm_b = run(idx_s) + hm_a.block_until_ready() + hm_b.block_until_ready() + self.assertNotEqual(hm_a.unsafe_buffer_pointer(), hm_b.unsafe_buffer_pointer()) + + def test_two_layer_dispatch_no_handle_aliasing(self): + """Two ep_dispatch calls in one jit with distinct ``EpLayerConfig``s must + not clobber each other's routing state. Different inputs per layer with + identity routing + uniform weights => both recv buffers must independently + identity-round-trip via ep_combine.""" + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + tokens_b = (tokens.astype(jnp.float32) * -1.0 + 0.25).astype(tokens.dtype) + ka, kb = ( + EpLayerConfig(top_k=TOP_K, dispatch_output_per_expert_alignment=16), + EpLayerConfig(top_k=TOP_K, dispatch_output_per_expert_alignment=16), + ) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + ta = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec)) + tb = jax.lax.with_sharding_constraint(tokens_b, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + + def one_layer(hk, idx, toks, w_): + recv_t, recv_w, hm, tc = ep_dispatch( + hk, idx, toks, w_, self.recv_capacity_per_rank + ) + recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_spec_3d)) + recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_spec_2d)) + return ep_combine( + hk, hm, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None) + ) + + @jax.jit + def run(idx, ta_, tb_, w_): + return one_layer(ka, idx, ta_, w_), one_layer(kb, idx, tb_, w_) + + out_a, out_b = run(idx_s, ta, tb, w) + out_a.block_until_ready() + out_b.block_until_ready() + out_a_g = jmu.process_allgather(out_a, tiled=True) + out_b_g = jmu.process_allgather(out_b, tiled=True) + + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(out_a_g.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, rtol=5e-2, + ) + np.testing.assert_allclose( + np.asarray(out_b_g.astype(jnp.float32)), + np.asarray(tokens_b.astype(jnp.float32)), + atol=5e-2, rtol=5e-2, + ) + + def test_primitive_prepare(self): + """ep_prepare returns token_counts and handle_mem of the expected shapes.""" + T_global, topk_idx, _tokens, _w = self._make_identity_inputs() + del T_global + dp_spec = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + + @jax.jit + def run(idx): + tc, hm = ep_prepare(self.hk, idx) + return tc, hm + + tc, hm = run(idx_s) + tc.block_until_ready() + self.assertEqual(tc.shape, (self.dp * self.ep, NUM_LOCAL_EXPERTS)) + self.assertEqual(hm.shape[0], self.dp * self.ep) + self.assertGreater(hm.shape[1], 0) + + def _run_identity_round_trip(self, nonuniform): + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=nonuniform) + dp_spec = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + tok_s = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec)) + w_s = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + @jax.jit + def run(idx, toks, w): + _tc, hm = ep_prepare(self.hk, idx) + recv_t, recv_w = ep_dispatch_fwd( + self.hk, hm, idx, toks, w, self.recv_capacity_per_rank + ) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) + ) + # Apply the weighted hadamard inline (combine FFI is unweighted). + mask = (recv_w != 0).astype(jnp.float32)[..., None] + weighted = (recv_t.astype(jnp.float32) * recv_w[..., None] * mask).astype( + recv_t.dtype + ) + weighted = jax.lax.with_sharding_constraint( + weighted, NamedSharding(self.mesh, ep_spec_3d) + ) + out = ep_combine_fwd( + self.hk, hm, weighted, T_global, + out_partition_spec=(("dp", "ep"), None), + ) + return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) + + out = run(idx_s, tok_s, w_s) + out.block_until_ready() + # Allgather so the rank-0 numpy comparison sees the full global tensor. + out_global = jmu.process_allgather(out, tiled=True) + + # Identity expert + uniform weights → out ≈ tokens (rank-0 check). + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(out_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + def test_primitive_dispatch_combine_identity_uniform(self): + """Round-robin routing → identity round-trip via the primitive layer.""" + self._run_identity_round_trip(nonuniform=False) + + def test_primitive_dispatch_combine_identity_nonuniform(self): + """Skewed routing (top1=0 always) → identity round-trip via the primitive layer.""" + self._run_identity_round_trip(nonuniform=True) + + def test_primitive_dispatch_combine_identity_bwd_uniform(self): + """Bwd through identity round-trip: ∇(0.5 ||out||²) w.r.t. tokens ≈ tokens. + + Identity routing + uniform top-k weights ⇒ dispatch∘combine is the + identity, so loss = 0.5||tokens||² and ∇_tokens loss = tokens. + """ + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(toks): + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + recv_t, recv_w, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) + ) + out = ep_combine( + self.hk, hm, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None) + ) + return 0.5 * (out.astype(jnp.float32) ** 2).sum() + + grad = jax.jit(jax.grad(loss_fn))(tokens) + grad.block_until_ready() + grad_global = jmu.process_allgather(grad, tiled=True) + + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(grad_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + def test_dispatch_combine_3d_input_output(self): + """3D input ``[B, S, H]`` sharded on the first dim only — + ``(("dp","ep"), None, None)`` here — dispatch accepts the rank-3 shape + and combine returns a matching 3D ``[B, S, H]`` output. End-to-end + round trip recovers the original tokens under identity routing + + uniform top-k weights.""" + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + # B is sharded across all (dp*ep) ranks; S held in one piece per rank. + B, S, H = T_global, 1, tokens.shape[-1] + tokens_3d = tokens.reshape(B, S, H) + topk_idx_3d = topk_idx.reshape(B, S, -1) + topk_w_3d = topk_w.reshape(B, S, -1) + spec_3d = PartitionSpec(("dp", "ep"), None, None) + out_spec_3d = (("dp", "ep"), None, None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx_3d, NamedSharding(self.mesh, spec_3d)) + tok_s = jax.lax.with_sharding_constraint(tokens_3d, NamedSharding(self.mesh, spec_3d)) + w_s = jax.lax.with_sharding_constraint(topk_w_3d, NamedSharding(self.mesh, spec_3d)) + + ep_t = PartitionSpec(("dp", "ep"), None, None) + ep_w = PartitionSpec(("dp", "ep"), None) + + @jax.jit + def run(idx, toks, w): + recv_t, recv_w, hm, _tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) + recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) + out = ep_combine( + self.hk, + hm, + _tc, + recv_t, + recv_w, + num_local_tokens=(B, S), + out_sharding=out_spec_3d, + ) + return out + + out = run(idx_s, tok_s, w_s) + out.block_until_ready() + out_global = jmu.process_allgather(out, tiled=True) + + if self.rank == 0: + self.assertEqual(out_global.shape, (B, S, H)) + np.testing.assert_allclose( + np.asarray(out_global.astype(jnp.float32)), + np.asarray(tokens_3d.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + def test_dispatch_combine_dp_only_first_dim(self): + """Input sharded ``("dp", None)`` (no ep on leading) — dispatch must + accept it. JAX SPMD slices the missing ep axis locally so the kernel + still sees ``T/(dp*ep)`` tokens per rank.""" + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + dp_only = PartitionSpec("dp", None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_only)) + tok_s = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_only)) + w_s = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_only)) + + ep_t = PartitionSpec(("dp", "ep"), None, None) + ep_w = PartitionSpec(("dp", "ep"), None) + + @jax.jit + def run(idx, toks, w): + recv_t, recv_w, hm, _tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) + recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) + out = ep_combine( + self.hk, + hm, + _tc, + recv_t, + recv_w, + num_local_tokens=T_global, + out_sharding=(("dp", "ep"), None), + ) + return out + + out = run(idx_s, tok_s, w_s) + out.block_until_ready() + out_global = jmu.process_allgather(out, tiled=True) + + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(out_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + # ── Custom-VJP tests ───────────────────────────────────────────────── + + def test_dispatch_vjp_fwd_bwd(self): + """ep_dispatch fwd + jax.grad w.r.t. tokens. + + Identity routing + loss = 0.5||recv_tokens||² ⇒ each token appears + TOP_K times in recv_tokens (all routes fit recv_capacity), so + grad_tokens = TOP_K * tokens (closed form). + """ + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs() + del T_global + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(toks): + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + recv_tokens, _recv_w, _hm, _tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank + ) + recv_tokens = jax.lax.with_sharding_constraint( + recv_tokens, NamedSharding(self.mesh, ep_spec_3d) + ) + return 0.5 * (recv_tokens.astype(jnp.float32) ** 2).sum() + + loss, grad_tokens = jax.jit(jax.value_and_grad(loss_fn))(tokens) + grad_tokens.block_until_ready() + grad_global = jmu.process_allgather(grad_tokens, tiled=True) + + self.assertTrue(np.isfinite(float(loss))) + self.assertEqual(grad_tokens.shape, tokens.shape) + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(grad_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)) * float(TOP_K), + atol=5e-2, + rtol=5e-2, + ) + + def test_combine_vjp_fwd_bwd(self): + """ep_combine fwd + jax.grad w.r.t. expert_out. + + Identity routing + constant eo=c + uniform topk_w ⇒ combined[t] = c + (sum_k topk_w = 1) and grad_eo[e, s, h] = recv_w[e, s] * c at filled + slots — so max|grad_eo| ≈ c / TOP_K. + """ + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs() + eo_const = 0.5 + expert_out = jnp.full( + (self.dp * self.ep, self.recv_capacity_per_rank, HIDDEN_DIM), + eo_const, + dtype=jnp.bfloat16, + ) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(eo): + eo = jax.lax.with_sharding_constraint(eo, NamedSharding(self.mesh, ep_spec_3d)) + toks = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + _recv_tokens, recv_w, hm, tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, PartitionSpec(("dp", "ep"), None)) + ) + combined = ep_combine(self.hk, hm, tc, eo, recv_w, T_global) + # Pin combined to dp-sharded so autodiff transpose feeds + # ep_combine_bwd a per-shard cotangent. + combined = jax.lax.with_sharding_constraint( + combined, NamedSharding(self.mesh, dp_spec) + ) + return 0.5 * (combined.astype(jnp.float32) ** 2).sum() + + loss, grad_eo = jax.jit(jax.value_and_grad(loss_fn))(expert_out) + grad_eo.block_until_ready() + + self.assertTrue(np.isfinite(float(loss))) + self.assertEqual(grad_eo.shape, expert_out.shape) + for shard in grad_eo.addressable_shards: + arr = np.asarray(shard.data.astype(jnp.float32)) + self.assertTrue(np.all(np.isfinite(arr))) + self.assertGreater(arr.max(), 0.0, "grad_eo has no positive entry on filled slots") + np.testing.assert_allclose( + arr.max(), + eo_const / float(TOP_K), + atol=5e-2, + rtol=5e-2, + ) + + def test_dispatch_bwd_exact_per_k_topk_weights(self): + """Distinct per-(t, k) upstream grads ⇒ grad[t, 0] != grad[t, 1] for all t. + + Guards against a regression where the bwd would average across the k + axis (per-token mean instead of per-slot exact recovery). + """ + T_dp, tokens, topk_idx, topk_w = self._make_random_inputs() + dp_spec = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(idx_in, tok_in, w_in): + idx_in = jax.lax.with_sharding_constraint(idx_in, NamedSharding(self.mesh, dp_spec)) + tok_in = jax.lax.with_sharding_constraint(tok_in, NamedSharding(self.mesh, dp_spec)) + w_in = jax.lax.with_sharding_constraint(w_in, NamedSharding(self.mesh, dp_spec)) + _recv_t, recv_w, _h, _tc = ep_dispatch( + self.hk, idx_in, tok_in, w_in, self.recv_capacity_per_rank + ) + # Per-slot index scale ⇒ each slot's contribution differs. + scale = jnp.asarray( + np.arange(recv_w.size, dtype=np.float32).reshape(recv_w.shape) + 1.0 + ) + return jnp.sum(recv_w * scale) + + grad_topk_w = jax.jit(jax.grad(loss_fn, argnums=2))(topk_idx, tokens, topk_w) + grad_topk_w.block_until_ready() + grad_global = jmu.process_allgather(grad_topk_w, tiled=True) + + if self.rank == 0: + grad_np = np.asarray(grad_global).astype(np.float32) + mismatch = sum(int(abs(grad_np[t, 0] - grad_np[t, 1]) < 1e-6) for t in range(T_dp)) + self.assertEqual( + mismatch, + 0, + f"Expected grad[t, 0] != grad[t, 1] for all {T_dp} tokens under skewed " + f"upstream scaling; got {mismatch} tokens with grad[t, 0] == grad[t, 1].", + ) + + # ── HLO reshard guard ──────────────────────────────────────────────── + # Compile-only: assert XLA inserts no cross-device collectives outside + # the EP FFI. EP-axis flux is carried by the FFI itself. + + def test_z_no_unexpected_reshard_in_hlo_fwd(self): + """Compiled fwd HLO must not insert XLA collectives outside the EP FFI.""" + T_dp, tokens, topk_idx, topk_w = self._make_random_inputs() + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + @jax.jit + def run(idx, toks, w): + idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec)) + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) + recv_t, recv_w, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) + ) + out = ep_combine( + self.hk, hm, tc, recv_t, recv_w, T_dp, out_sharding=(("dp", "ep"), None) + ) + return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) + + compiled = run.lower(topk_idx, tokens, topk_w).compile() + hlo = compiled.as_text() + # Match instruction names; "all-gather-start" and "all-gather-done" + # bracket a single async all-gather. + for op in ("all-gather-start", "all-to-all", "collective-permute"): + self.assertEqual(hlo.count(op), 0, f"unexpected XLA {op} in fwd HLO:\n{hlo}") + # XLA drops trailing-None entries from the spec; compare as a tuple. + # JAX collapses size-1 mesh axes, so dp=1 reduces ("dp","ep") to "ep". + expected = (("dp", "ep"),) if self.dp > 1 else ("ep",) + self.assertEqual(tuple(compiled.output_shardings.spec), expected) + + def test_z_no_unexpected_reshard_in_hlo_bwd(self): + """Compiled bwd HLO must not insert XLA collectives outside the EP FFI.""" + T_dp, tokens, topk_idx, topk_w = self._make_random_inputs() + rng = np.random.default_rng(seed=44) + expert_out = jnp.asarray( + rng.standard_normal( + (self.dp * self.ep, self.recv_capacity_per_rank, HIDDEN_DIM), dtype=np.float32 + ) + * 0.5, + dtype=jnp.bfloat16, + ) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + def fwd(eo, toks, idx, w): + eo = jax.lax.with_sharding_constraint(eo, NamedSharding(self.mesh, ep_spec_3d)) + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) + _rt, rw, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) + rw = jax.lax.with_sharding_constraint(rw, NamedSharding(self.mesh, ep_spec_2d)) + combined = ep_combine(self.hk, hm, tc, eo, rw, T_dp, out_sharding=(("dp", "ep"), None)) + return jax.lax.with_sharding_constraint(combined, NamedSharding(self.mesh, dp_spec)) + + # jax.vjp + pinned cotangent feeds ep_combine_bwd/ep_dispatch_bwd + # the expected sharding without relying on XLA-transpose propagation. + def bwd_only(eo, toks, idx, w, g): + _y, vjp_fn = jax.vjp(fwd, eo, toks, idx, w) + g = jax.lax.with_sharding_constraint(g, NamedSharding(self.mesh, dp_spec)) + grads = vjp_fn(g) + return ( + jax.lax.with_sharding_constraint( + grads[0], NamedSharding(self.mesh, ep_spec_3d) + ), + jax.lax.with_sharding_constraint(grads[1], NamedSharding(self.mesh, dp_spec)), + ) + + g_seed = jnp.ones((T_dp, HIDDEN_DIM), dtype=jnp.bfloat16) + compiled = ( + jax.jit(bwd_only).lower(expert_out, tokens, topk_idx, topk_w, g_seed).compile() + ) + hlo = compiled.as_text() + for op in ("all-gather-start", "all-to-all", "collective-permute"): + self.assertEqual(hlo.count(op), 0, f"unexpected XLA {op} in bwd HLO:\n{hlo}") + + +# ── Entry point ────────────────────────────────────────────────────────────── + + +if __name__ == "__main__": + if len(sys.argv) < 4: + print("Usage: python test_multi_process_ep.py ") + sys.exit(1) + + coord_addr = sys.argv[1] + proc_id = int(sys.argv[2]) + num_procs = int(sys.argv[3]) + + jax.distributed.initialize( + coordinator_address=coord_addr, + num_processes=num_procs, + process_id=proc_id, + local_device_ids=[proc_id], + ) + + loader = unittest.TestLoader() + target = os.environ.get("TARGET_TEST") + if target: + name = target.split(".")[-1] + suite = loader.loadTestsFromName(name, TestEP) + else: + suite = loader.loadTestsFromTestCase(TestEP) + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + sys.exit(0 if result.wasSuccessful() else 1) diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index b42c909740..c9647afb82 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -11,4 +11,5 @@ from .softmax import * from .gemm import * from .router import * +from .ep import * from .topk import * diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 6eb588c849..2cdef4bfe7 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -266,6 +266,17 @@ def _gspmd_wrapper(*args, **kwargs): for _name, _value in transformer_engine_jax.registrations().items(): ffi.register_ffi_target(_name, _value, platform="CUDA") +# Register EpInstanceState (no-op when TE is built without NCCL EP). +if hasattr(transformer_engine_jax, "get_ep_instance_state_type_id"): + ffi.register_ffi_type( + "EpInstanceState", + { + "type_id": transformer_engine_jax.get_ep_instance_state_type_id(), + "type_info": transformer_engine_jax.get_ep_instance_state_type_info(), + }, + platform="CUDA", + ) + def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False): """ diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py new file mode 100644 index 0000000000..5263b33ba9 --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -0,0 +1,1017 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX/TE custom ops for Expert Parallelism (EP). + +Sharding model: + - EpPrepare / EpDispatch outputs carry a single leading ``num_procs`` dim. + Sharded compound ``(dp_resource, ep_resource)`` when DP is set, else + ``ep_resource`` alone. + - EpDispatch inputs are 2D ``[T, H]`` or 3D ``[B, S, H]``; only the first + dim may be sharded, with axis in {ep, (dp, ep), dp, None}. Trailing dims + must be replicated. ``dp`` alone gets ``ep`` folded in locally. + - EpCombine output sharding comes from ``out_sharding`` or defaults to the + compound ``(dp, ep)`` axis on the leading dim. +""" + +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +from jax import dtypes, ffi +from jax.sharding import NamedSharding, PartitionSpec + +import transformer_engine_jax +from .base import BasePrimitive, register_primitive +from ..sharding import global_mesh_resource + +__all__ = [ + "EpConfig", + "EpLayerConfig", + "set_ep_config", + "get_ep_config", + "get_ep_num_local_experts", + "ep_handle_mem_size", + "ep_prepare", + "ep_dispatch_fwd", + "ep_combine_fwd", + "ep_dispatch_bwd", + "ep_combine_bwd", +] + + +# ── Module-level EP config ────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class EpConfig: + """Immutable Python view of the EP bootstrap config (see ep_bootstrap).""" + + world_size: int + rank: int + ep_size: int + num_experts: int + num_local_experts: int + max_tokens_per_rank: int + recv_capacity_per_rank: int + hidden_dim: int + + +_ep_config: EpConfig = None + + +def set_ep_config(config: EpConfig) -> None: + """Cache the EP config for abstract-eval / sharding helpers. Call once.""" + global _ep_config + _ep_config = config + + +def get_ep_config() -> EpConfig: + if _ep_config is None: + raise RuntimeError("EpConfig has not been set. Did you call ep_bootstrap()?") + return _ep_config + + +def get_ep_num_local_experts() -> int: + return get_ep_config().num_local_experts + + +@dataclass(frozen=True) +class EpLayerConfig: + """Per-layer EP config; mirrors C ``NVTEEpLayerConfig``. + + Threaded through every per-step op so the pointer-keyed C++ cache can + validate consistency across a handle_mem's prepare / dispatch / combine. + Reserved for future per-call fields (fp8 scale, overflow policy, ...). + """ + + top_k: int + dispatch_output_per_expert_alignment: int = 0 + + +def ep_handle_mem_size(cfg: EpLayerConfig) -> int: + """Return the handle_mem byte size for ``cfg``. Host-only; cheap.""" + return int( + transformer_engine_jax.ep_handle_mem_size( + int(cfg.top_k), int(cfg.dispatch_output_per_expert_alignment) + ) + ) + + +def _leading_axis_ok(spec, ep_axis, outer_axes=()): + # Only the first dim may carry sharding; remaining dims must be replicated. + # The first dim's axis must be one of: + # ``ep_axis`` alone, + # a tuple of dp/fsdp axes (no ep — ep gets sliced in locally), + # a tuple ending in ``ep_axis`` with dp/fsdp axes before it. + # Examples on a (dp, ep) mesh: 2D ``(ep, None)``, ``(("dp","ep"), None)``, + # ``("dp", None)``; 3D ``(ep, None, None)``, ``(("dp","ep"), None, None)``, + # ``("dp", None, None)``. + if len(spec) < 2 or ep_axis is None: + return False + if any(ax is not None for ax in spec[1:]): + return False # only first dim sharded + leading = spec[0] + allowed_outers = {a for a in outer_axes if a is not None} + allowed = allowed_outers | {ep_axis, None} + elts = leading if isinstance(leading, tuple) else (leading,) + return all(a in allowed for a in elts) + + +def _canonical_input_spec(spec, ndim): + """Canonical input PartitionSpec the primitive demands JAX deliver. + + Sharding lives entirely on the first dim. If ``spec[0]`` already includes + ``ep_resource``, returned unchanged. Otherwise ``ep_resource`` is folded + into the first-dim axis tuple, e.g. ``"dp"`` → ``("dp","ep")``. The added + ep axis is a local slice (the missing dim was replicated), no cross-device + comm. + """ + gsr = global_mesh_resource() + ep = gsr.ep_resource + leading = spec[0] + present = leading if isinstance(leading, tuple) else (leading,) if leading is not None else () + if ep in present: + return PartitionSpec(*spec) + if leading is None: + new_leading = ep + elif isinstance(leading, tuple): + new_leading = (*leading, ep) + else: + new_leading = (leading, ep) + return PartitionSpec(new_leading, *([None] * (ndim - 1))) + + +def _dispatch_input_outer_axes(): + """dp/fsdp axes allowed as outer companions to ep_resource on dispatch input.""" + gsr = global_mesh_resource() + return tuple(a for a in (gsr.dp_resource, gsr.fsdp_resource) if a is not None) + + +def _ep_outer_axis(): + """The single dp/fsdp axis (if any) sitting outside ep on EP-output tensors. + + When set, EP-output globals carry an extra leading ``dp_size`` dim so SPMD + sees each DP color's slab as distinct (rather than replicated across DP). + """ + gsr = global_mesh_resource() + return gsr.dp_resource or gsr.fsdp_resource + + +def _ep_leading_dims(is_outer): + """Single leading dim of an EP-output tensor: ``(dp*ep,)`` (or ``(ep,)`` when + DP is unset) globally; ``(1,)`` per shard.""" + cfg = get_ep_config() + outer = _ep_outer_axis() + if not is_outer: + return (1,) + return (cfg.world_size,) if outer is not None else (cfg.ep_size,) + + +def _ep_output_spec(*trailing): + """PartitionSpec for an EP-output tensor: ``(("dp","ep"), *trailing)`` when + DP is set (compound leading axis on a single dim), else ``("ep",*trailing)``.""" + gsr = global_mesh_resource() + outer = _ep_outer_axis() + if outer is None: + return PartitionSpec(gsr.ep_resource, *trailing) + return PartitionSpec((outer, gsr.ep_resource), *trailing) + + +def _ep_spec_ok(spec, trailing_count): + """Accept ``(ep, *[None])`` (no DP) or ``((dp,ep), *[None])`` / + ``(("dp",), *[None])`` / ``("dp", *[None])`` / ``(None, *[None])`` (with DP) + on an EP-output tensor's single leading dim. JAX may collapse a size-1 + mesh axis to ``None`` (matters for dp_size=1 like 1x4).""" + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer = _ep_outer_axis() + expected_len = 1 + trailing_count + if len(spec) != expected_len: + return False + if any(ax is not None for ax in spec[1:]): + return False + leading = spec[0] + if outer is None: + return leading == ep_axis + allowed = {ep_axis, outer, None} + elts = leading if isinstance(leading, tuple) else (leading,) + return all(a in allowed for a in elts) + + +# ── ep_prepare ────────────────────────────────────────────────────────────── + + +class EpPreparePrimitive(BasePrimitive): + name = "te_ep_prepare_ffi" + multiple_results = True + impl_static_args = (1, 2, 3) # top_k, dispatch_output_per_expert_alignment, is_outer + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(topk_idx_aval, *, top_k, dispatch_output_per_expert_alignment, is_outer): + # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with + # no DP); False: per-shard = (1,). + cfg = get_ep_config() + num_local_experts = cfg.num_local_experts + assert ( + len(topk_idx_aval.shape) >= 2 + ), f"topk_idx must be at least 2D [..., top_k], got shape {topk_idx_aval.shape}" + handle_mem_size = int( + transformer_engine_jax.ep_handle_mem_size( + int(top_k), int(dispatch_output_per_expert_alignment) + ) + ) + leading = _ep_leading_dims(is_outer) + token_counts_aval = jax.core.ShapedArray(leading + (num_local_experts,), jnp.int32) + handle_mem_aval = jax.core.ShapedArray(leading + (handle_mem_size,), jnp.uint8) + # FFI scratch for the int32 -> int64 topk_idx upcast. int32 with last + # dim doubled to keep the int64 byte count without JAX_ENABLE_X64. + # TODO(phuong): drop once NCCL EP supports int32 topk_idx. + workspace_shape = topk_idx_aval.shape[:-1] + (topk_idx_aval.shape[-1] * 2,) + workspace_aval = jax.core.ShapedArray(workspace_shape, jnp.int32) + return token_counts_aval, handle_mem_aval, workspace_aval + + @staticmethod + def outer_abstract(topk_idx_aval, *, top_k, dispatch_output_per_expert_alignment, is_outer): + del is_outer + avals = EpPreparePrimitive.abstract( + topk_idx_aval, + top_k=top_k, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + is_outer=True, + ) + return avals[:2] + + @staticmethod + def lowering(ctx, topk_idx, *, top_k, dispatch_output_per_expert_alignment, is_outer): + del is_outer + return ffi.ffi_lowering(EpPreparePrimitive.name)( + ctx, + topk_idx, + top_k=int(top_k), + dispatch_output_per_expert_alignment=int(dispatch_output_per_expert_alignment), + ) + + @staticmethod + def impl(topk_idx, top_k, dispatch_output_per_expert_alignment, is_outer): + assert EpPreparePrimitive.inner_primitive is not None + token_counts, handle_mem, _workspace = EpPreparePrimitive.inner_primitive.bind( + topk_idx, + top_k=top_k, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + is_outer=is_outer, + ) + return token_counts, handle_mem + + @staticmethod + def batcher( + batched_args, batch_dims, *, top_k, dispatch_output_per_expert_alignment, is_outer + ): + raise NotImplementedError("EpPreparePrimitive does not support vmap") + + @staticmethod + def partition( + top_k, dispatch_output_per_expert_alignment, is_outer, mesh, arg_infos, result_infos + ): + del is_outer, result_infos + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer_axes = _dispatch_input_outer_axes() + idx_spec = arg_infos[0].sharding.spec + if not _leading_axis_ok(idx_spec, ep_axis, outer_axes): + raise NotImplementedError( + "EpPrepare: topk_idx leading dims must shard on ep_resource" + f" ('{ep_axis}') and/or {outer_axes}, with the topk dim replicated;" + f" got spec={idx_spec}." + ) + idx_ndim = len(arg_infos[0].shape) + arg_shardings = (NamedSharding(mesh, _canonical_input_spec(idx_spec, idx_ndim)),) + tc_sharding = NamedSharding(mesh, _ep_output_spec(None)) + hm_sharding = NamedSharding(mesh, _ep_output_spec(None)) + + def sharded_impl(topk_idx): + return EpPreparePrimitive.impl( + topk_idx, top_k, dispatch_output_per_expert_alignment, False + ) + + return mesh, sharded_impl, (tc_sharding, hm_sharding), arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Static args + # for this primitive are (top_k, dispatch_alignment, is_outer). + value_types = args[-2] + topk_idx_rank = len(value_types[0].shape) + in_axes = " ".join(f"L{i}" for i in range(topk_idx_rank - 1)) + " topk" + return f"{in_axes} -> EPL nle, EPL hm" + + +register_primitive(EpPreparePrimitive) + + +# ── ep_dispatch ───────────────────────────────────────────────────────────── + + +class EpDispatchPrimitive(BasePrimitive): + name = "te_ep_dispatch_ffi" + multiple_results = True + impl_static_args = (4, 5, 6, 7) # top_k, dispatch_output_per_expert_alignment, + # recv_capacity_per_rank, is_outer + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + topk_idx_aval, + tokens_aval, + topk_weights_aval, + *, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + ): + # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with + # no DP); False: per-shard = (1,). + del topk_weights_aval, top_k, dispatch_output_per_expert_alignment, handle_mem_aval + assert ( + len(tokens_aval.shape) >= 2 + ), f"tokens must be at least 2D [..., H], got shape {tokens_aval.shape}" + recv_pr = recv_capacity_per_rank + tok_dtype = dtypes.canonicalize_dtype(tokens_aval.dtype) + hidden_dim = tokens_aval.shape[-1] + leading = _ep_leading_dims(is_outer) + recv_tokens_aval = jax.core.ShapedArray(leading + (recv_pr, hidden_dim), tok_dtype) + recv_topk_weights_aval = jax.core.ShapedArray(leading + (recv_pr,), jnp.float32) + # int32 with last dim doubled to keep the int64 byte count without JAX_ENABLE_X64. + workspace_shape = topk_idx_aval.shape[:-1] + (topk_idx_aval.shape[-1] * 2,) + workspace_aval = jax.core.ShapedArray(workspace_shape, jnp.int32) + return (recv_tokens_aval, recv_topk_weights_aval, workspace_aval) + + @staticmethod + def outer_abstract(*args, **kwargs): + kwargs = dict(kwargs) + kwargs["is_outer"] = True + avals = EpDispatchPrimitive.abstract(*args, **kwargs) + return avals[:2] + + @staticmethod + def lowering( + ctx, + handle_mem, + topk_idx, + tokens, + topk_weights, + *, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + ): + del recv_capacity_per_rank, is_outer + return ffi.ffi_lowering(EpDispatchPrimitive.name)( + ctx, + handle_mem, + topk_idx, + tokens, + topk_weights, + top_k=int(top_k), + dispatch_output_per_expert_alignment=int(dispatch_output_per_expert_alignment), + ) + + @staticmethod + def impl( + handle_mem, + topk_idx, + tokens, + topk_weights, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + ): + assert EpDispatchPrimitive.inner_primitive is not None + recv_tokens, recv_topk_weights, _workspace = EpDispatchPrimitive.inner_primitive.bind( + handle_mem, + topk_idx, + tokens, + topk_weights, + top_k=top_k, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + recv_capacity_per_rank=recv_capacity_per_rank, + is_outer=is_outer, + ) + return recv_tokens, recv_topk_weights + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + ): + raise NotImplementedError("EpDispatchPrimitive does not support vmap") + + @staticmethod + def partition( + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + mesh, + arg_infos, + result_infos, + ): + del is_outer, result_infos + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer_axes = _dispatch_input_outer_axes() + tokens_spec = arg_infos[2].sharding.spec + if not _leading_axis_ok(tokens_spec, ep_axis, outer_axes): + raise NotImplementedError( + "EpDispatch: tokens leading dims must shard on ep_resource" + f" ('{ep_axis}') and/or {outer_axes}, hidden dim replicated;" + f" got spec={tokens_spec}." + ) + idx_spec = arg_infos[1].sharding.spec + tw_spec = arg_infos[3].sharding.spec + arg_shardings = ( + arg_infos[0].sharding, + NamedSharding(mesh, _canonical_input_spec(idx_spec, len(arg_infos[1].shape))), + NamedSharding(mesh, _canonical_input_spec(tokens_spec, len(arg_infos[2].shape))), + NamedSharding(mesh, _canonical_input_spec(tw_spec, len(arg_infos[3].shape))), + ) + out_shardings = ( + NamedSharding(mesh, _ep_output_spec(None, None)), + NamedSharding(mesh, _ep_output_spec(None)), + ) + + def sharded_impl(handle_mem, topk_idx, tokens, topk_weights): + return EpDispatchPrimitive.impl( + handle_mem, + topk_idx, + tokens, + topk_weights, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + False, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Static args + # for this primitive are (top_k, dispatch_alignment, recv_capacity_per_rank, is_outer). + value_types = args[-2] + # Inputs: handle_mem, topk_idx, tokens, topk_weights. + idx_rank = len(value_types[1].shape) + tok_rank = len(value_types[2].shape) + tw_rank = len(value_types[3].shape) + idx_axes = " ".join(f"I{i}" for i in range(idx_rank - 1)) + " topk_in" + tok_axes = " ".join(f"T{i}" for i in range(tok_rank - 1)) + " H" + tw_axes = " ".join(f"W{i}" for i in range(tw_rank - 1)) + " topk" + return f"EPL hm, {idx_axes}, {tok_axes}, {tw_axes} -> EPL recv_pr H, EPL recv_pr" + + +register_primitive(EpDispatchPrimitive) + + +# ── ep_combine ────────────────────────────────────────────────────────────── +# `expert_out` here is the post-weight buffer; ep.ep_combine applies the +# hadamard before calling. + + +def _normalize_leading_shape(s): + return s if isinstance(s, tuple) else (int(s),) + + +def _prod(seq): + p = 1 + for x in seq: + p *= int(x) + return p + + +def _resolve_out_partition_spec(out_partition_spec, num_leading): + """Pick the combine output PartitionSpec. + + Defaults to a compound leading axis ``(dp_resource, ep_resource)`` when a + DP/FSDP axis is set on the active MeshResource, else just ``ep_resource``. + This matches the input sharding so XLA does not need collective-permutes + in the bwd path. + """ + if out_partition_spec is not None: + assert len(out_partition_spec) == num_leading + 1, ( + f"out_partition_spec length {len(out_partition_spec)} must equal num_leading" + f" + 1 ({num_leading + 1})" + ) + return tuple(out_partition_spec) + gsr = global_mesh_resource() + if gsr.ep_resource is None: + raise ValueError( + "ep_combine: ep_resource is not set on the active MeshResource;" + " pass out_sharding=... explicitly." + ) + outer = gsr.dp_resource or gsr.fsdp_resource + leading = (outer, gsr.ep_resource) if outer is not None else gsr.ep_resource + return (leading,) + (None,) * num_leading + + +def _per_shard_leading(out_leading_shape, resolved_spec, mesh): + """Per-shard leading shape given resolved partition spec and mesh.""" + per_shard = list(out_leading_shape) + for i, ax in enumerate(resolved_spec[: len(out_leading_shape)]): + if ax is None: + continue + axes = ax if isinstance(ax, tuple) else (ax,) + factor = 1 + for a in axes: + factor *= mesh.shape[a] + assert ( + per_shard[i] % factor == 0 + ), f"leading dim {per_shard[i]} not divisible by shard factor {factor} on axes {axes}" + per_shard[i] //= factor + return tuple(per_shard) + + +class EpCombinePrimitive(BasePrimitive): + name = "te_ep_combine_ffi" + multiple_results = False + impl_static_args = (2, 3, 4, 5) # top_k, dispatch_output_per_expert_alignment, + # out_leading_shape, out_partition_spec + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + expert_out_aval, + *, + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + ): + del top_k, dispatch_output_per_expert_alignment, out_partition_spec, handle_mem_aval + assert ( + len(expert_out_aval.shape) == 3 + ), f"expert_out must be 3D [num_procs, recv_pr, H], got shape {expert_out_aval.shape}" + eo_dtype = dtypes.canonicalize_dtype(expert_out_aval.dtype) + hidden_dim = expert_out_aval.shape[-1] + out_shape = tuple(out_leading_shape) + (hidden_dim,) + return jax.core.ShapedArray(out_shape, eo_dtype) + + @staticmethod + def lowering( + ctx, + handle_mem, + expert_out, + *, + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + ): + del out_partition_spec + return ffi.ffi_lowering(EpCombinePrimitive.name)( + ctx, + handle_mem, + expert_out, + top_k=int(top_k), + dispatch_output_per_expert_alignment=int(dispatch_output_per_expert_alignment), + num_local_tokens=_prod(out_leading_shape), + ) + + @staticmethod + def impl( + handle_mem, + expert_out, + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + ): + assert EpCombinePrimitive.inner_primitive is not None + return EpCombinePrimitive.inner_primitive.bind( + handle_mem, + expert_out, + top_k=top_k, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + out_leading_shape=out_leading_shape, + out_partition_spec=out_partition_spec, + ) + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + ): + raise NotImplementedError("EpCombinePrimitive does not support vmap") + + @staticmethod + def partition( + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + mesh, + arg_infos, + result_infos, + ): + del result_infos + eo_spec = arg_infos[1].sharding.spec + if not _ep_spec_ok(eo_spec, trailing_count=2): + raise NotImplementedError( + "EpCombine: expert_out must be sharded as PartitionSpec(ep_resource," + " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)" + f" over [num_procs, recv_pr, H]; got spec={eo_spec}." + ) + resolved = _resolve_out_partition_spec(out_partition_spec, len(out_leading_shape)) + per_shard_leading = _per_shard_leading(out_leading_shape, resolved, mesh) + arg_shardings = tuple(a.sharding for a in arg_infos) + out_sharding = NamedSharding(mesh, PartitionSpec(*resolved)) + + def sharded_impl(handle_mem, expert_out): + return EpCombinePrimitive.impl( + handle_mem, + expert_out, + top_k, + dispatch_output_per_expert_alignment, + per_shard_leading, + out_partition_spec, + ) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Static args: + # (top_k, dispatch_alignment, out_leading_shape, out_partition_spec). + result_types = args[-1] + out_rank = len(result_types[0].shape) + out_axes = " ".join(f"O{i}" for i in range(out_rank - 1)) + " H" + return f"EPL hm, EPL recv_pr H -> {out_axes}" + + +register_primitive(EpCombinePrimitive) + + +# ── ep_dispatch_bwd ───────────────────────────────────────────────────────── + + +class EpDispatchBwdPrimitive(BasePrimitive): + name = "te_ep_dispatch_bwd_ffi" + multiple_results = True + impl_static_args = (3, 4, 5, 6) # top_k, dispatch_output_per_expert_alignment, + # out_leading_shape, out_partition_spec + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + grad_aval, + g_recv_topk_weights_aval, + *, + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + ): + del dispatch_output_per_expert_alignment + del g_recv_topk_weights_aval, out_partition_spec, handle_mem_aval + assert ( + len(grad_aval.shape) == 3 + ), f"grad must be 3D [num_procs, recv_pr, H], got shape {grad_aval.shape}" + g_dtype = dtypes.canonicalize_dtype(grad_aval.dtype) + hidden_dim = grad_aval.shape[-1] + result_aval = jax.core.ShapedArray(tuple(out_leading_shape) + (hidden_dim,), g_dtype) + grad_topk_weights_aval = jax.core.ShapedArray( + tuple(out_leading_shape) + (top_k,), jnp.float32 + ) + return result_aval, grad_topk_weights_aval + + @staticmethod + def lowering( + ctx, + handle_mem, + grad, + g_recv_topk_weights, + *, + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + ): + del out_partition_spec + return ffi.ffi_lowering(EpDispatchBwdPrimitive.name)( + ctx, + handle_mem, + grad, + g_recv_topk_weights, + top_k=int(top_k), + dispatch_output_per_expert_alignment=int(dispatch_output_per_expert_alignment), + num_local_tokens=_prod(out_leading_shape), + ) + + @staticmethod + def impl( + handle_mem, + grad, + g_recv_topk_weights, + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + ): + assert EpDispatchBwdPrimitive.inner_primitive is not None + return EpDispatchBwdPrimitive.inner_primitive.bind( + handle_mem, + grad, + g_recv_topk_weights, + top_k=top_k, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + out_leading_shape=out_leading_shape, + out_partition_spec=out_partition_spec, + ) + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + ): + raise NotImplementedError("EpDispatchBwdPrimitive does not support vmap") + + @staticmethod + def partition( + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + mesh, + arg_infos, + result_infos, + ): + del result_infos + g_spec = arg_infos[1].sharding.spec + if not _ep_spec_ok(g_spec, trailing_count=2): + raise NotImplementedError( + "EpDispatchBwd: grad must be sharded as PartitionSpec(ep_resource," + " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)" + f" over [num_procs, recv_pr, H]; got spec={g_spec}." + ) + gw_spec = arg_infos[2].sharding.spec + if not _ep_spec_ok(gw_spec, trailing_count=1): + raise NotImplementedError( + "EpDispatchBwd: g_recv_topk_weights must be sharded as" + " PartitionSpec(ep_resource, None) (or ((dp, ep), None) when dp/fsdp is set)" + f" over [num_procs, recv_pr]; got spec={gw_spec}." + ) + resolved = _resolve_out_partition_spec(out_partition_spec, len(out_leading_shape)) + per_shard_leading = _per_shard_leading(out_leading_shape, resolved, mesh) + arg_shardings = tuple(a.sharding for a in arg_infos) + out_shardings = [ + NamedSharding(mesh, PartitionSpec(*resolved)), + NamedSharding(mesh, PartitionSpec(*resolved, None)), + ] + + def sharded_impl(handle_mem, grad, g_recv_topk_weights): + return EpDispatchBwdPrimitive.impl( + handle_mem, + grad, + g_recv_topk_weights, + top_k, + dispatch_output_per_expert_alignment, + per_shard_leading, + out_partition_spec, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Result rank + # follows out_leading_shape (static arg #2): rank = len(out_leading) + 1. + result_types = args[-1] + out_rank = len(result_types[0].shape) + out_axes = " ".join(f"O{i}" for i in range(out_rank - 1)) + return f"EPL hm, EPL recv_pr H, EPL recv_pr -> {out_axes} H, {out_axes} k" + + +register_primitive(EpDispatchBwdPrimitive) + + +# ── ep_combine_bwd ────────────────────────────────────────────────────────── + + +class EpCombineBwdPrimitive(BasePrimitive): + name = "te_ep_combine_bwd_ffi" + multiple_results = False + impl_static_args = (2, 3, 4, 5) # top_k, dispatch_output_per_expert_alignment, + # recv_capacity_per_rank, is_outer + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + grad_aval, + *, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + ): + # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with + # no DP); False: per-shard = (1,). + del top_k, dispatch_output_per_expert_alignment, handle_mem_aval + assert ( + len(grad_aval.shape) >= 2 + ), f"grad must be at least 2D [..., H], got shape {grad_aval.shape}" + g_dtype = dtypes.canonicalize_dtype(grad_aval.dtype) + hidden_dim = grad_aval.shape[-1] + leading = _ep_leading_dims(is_outer) + return jax.core.ShapedArray(leading + (recv_capacity_per_rank, hidden_dim), g_dtype) + + @staticmethod + def outer_abstract(*args, **kwargs): + kwargs = dict(kwargs) + kwargs["is_outer"] = True + return EpCombineBwdPrimitive.abstract(*args, **kwargs) + + @staticmethod + def lowering( + ctx, + handle_mem, + grad, + *, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + ): + del recv_capacity_per_rank, is_outer + return ffi.ffi_lowering(EpCombineBwdPrimitive.name)( + ctx, + handle_mem, + grad, + top_k=int(top_k), + dispatch_output_per_expert_alignment=int(dispatch_output_per_expert_alignment), + ) + + @staticmethod + def impl( + handle_mem, + grad, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + ): + assert EpCombineBwdPrimitive.inner_primitive is not None + return EpCombineBwdPrimitive.inner_primitive.bind( + handle_mem, + grad, + top_k=top_k, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + recv_capacity_per_rank=recv_capacity_per_rank, + is_outer=is_outer, + ) + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + ): + raise NotImplementedError("EpCombineBwdPrimitive does not support vmap") + + @staticmethod + def partition( + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + mesh, + arg_infos, + result_infos, + ): + del is_outer, result_infos + arg_shardings = tuple(a.sharding for a in arg_infos) + out_sharding = NamedSharding(mesh, _ep_output_spec(None, None)) + + def sharded_impl(handle_mem, grad): + return EpCombineBwdPrimitive.impl( + handle_mem, + grad, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + False, + ) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # T axes are dynamic-rank based on the actual cotangent shape. + value_types = args[-2] + g_rank = len(value_types[1].shape) + g_axes = " ".join(f"T{i}" for i in range(g_rank - 1)) + " H" + return f"EPL hm, {g_axes} -> EPL recv_pr H" + + +register_primitive(EpCombineBwdPrimitive) + + +# ── Public-ish helpers (used by jax/ep.py) ────────────────────────────────── + + +def ep_prepare(cfg: EpLayerConfig, topk_idx): + """Exchange routing metadata for ``cfg``; return ``(token_counts, handle_mem)``.""" + return EpPreparePrimitive.outer_primitive.bind( + topk_idx, + top_k=int(cfg.top_k), + dispatch_output_per_expert_alignment=int(cfg.dispatch_output_per_expert_alignment), + is_outer=True, + ) + + +def ep_dispatch_fwd(cfg: EpLayerConfig, handle_mem, topk_idx, tokens, topk_weights, + recv_capacity_per_rank): + """Scatter tokens and weights to expert ranks; returns (recv_tokens, recv_topk_weights).""" + return EpDispatchPrimitive.outer_primitive.bind( + handle_mem, + topk_idx, + tokens, + topk_weights, + top_k=int(cfg.top_k), + dispatch_output_per_expert_alignment=int(cfg.dispatch_output_per_expert_alignment), + recv_capacity_per_rank=recv_capacity_per_rank, + is_outer=True, + ) + + +def ep_combine_fwd(cfg: EpLayerConfig, handle_mem, expert_out, num_local_tokens, + out_partition_spec=None): + """Gather expert outputs back to home ranks. expert_out is pre-weighted.""" + out_leading = _normalize_leading_shape(num_local_tokens) + return EpCombinePrimitive.outer_primitive.bind( + handle_mem, + expert_out, + top_k=int(cfg.top_k), + dispatch_output_per_expert_alignment=int(cfg.dispatch_output_per_expert_alignment), + out_leading_shape=out_leading, + out_partition_spec=out_partition_spec, + ) + + +def ep_dispatch_bwd( + cfg: EpLayerConfig, handle_mem, grad, g_recv_topk_weights, num_local_tokens, + out_partition_spec=None, +): + """Backward of dispatch; returns (grad_tokens, grad_topk_weights).""" + out_leading = _normalize_leading_shape(num_local_tokens) + return EpDispatchBwdPrimitive.outer_primitive.bind( + handle_mem, + grad, + g_recv_topk_weights, + top_k=int(cfg.top_k), + dispatch_output_per_expert_alignment=int(cfg.dispatch_output_per_expert_alignment), + out_leading_shape=out_leading, + out_partition_spec=out_partition_spec, + ) + + +def ep_combine_bwd(cfg: EpLayerConfig, handle_mem, grad, recv_capacity_per_rank): + """Backward of combine; returns grad_expert_out [num_procs, recv_capacity_per_rank, H].""" + return EpCombineBwdPrimitive.outer_primitive.bind( + handle_mem, + grad, + top_k=int(cfg.top_k), + dispatch_output_per_expert_alignment=int(cfg.dispatch_output_per_expert_alignment), + recv_capacity_per_rank=recv_capacity_per_rank, + is_outer=True, + ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index c0fa3acaeb..b9c7c849f2 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -204,6 +204,28 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); +// Bootstrap EP (eager NCCL comm init); anchor released by ReleaseEpResources. +// max_token_dtype is the NVTEDType enum value (int) for the widest token dtype +// the group will dispatch. +void SetEpBootstrapParams(pybind11::bytes unique_id_bytes, int ep_size, int rank_within_group, + int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, + int hidden_dim, int max_num_sms, int max_token_dtype); +void ReleaseEpResources(); +// Return the handle_mem byte size for a layer config. +size_t EpHandleMemSize(int top_k, size_t dispatch_output_per_expert_alignment); + +// EpInstanceState type_id / type_info capsules for jax.ffi.register_ffi_type. +pybind11::capsule GetEpInstanceStateTypeIdCapsule(); +pybind11::capsule GetEpInstanceStateTypeInfoCapsule(); + +// EP FFI handlers +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpInstantiateHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpPrepareHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpDispatchHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpCombineHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpDispatchBwdHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpCombineBwdHandler); + // TopK XLA_FFI_DECLARE_HANDLER_SYMBOL(TopkHandler); pybind11::tuple GetTopkWorkspaceSizes(int batch_size, int seq_len, int k); diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp new file mode 100644 index 0000000000..e727eadce9 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -0,0 +1,541 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifdef NVTE_WITH_NCCL_EP + +#include "transformer_engine/ep.h" + +#include + +#include +#include +#include +#include +#include + +#include "../extensions.h" +#include "common.h" +#include "transformer_engine/gemm.h" + +namespace transformer_engine { +namespace jax { + +// NCCL comm + EPBackend lifetime tracks live JAX executables via XLA stateful FFI. + +struct EpBootstrapParams { + std::array uid_bytes{}; + int ep_size = 0; + int rank_within_group = 0; + int num_experts = 0; + int max_tokens_per_rank = 0; + int max_recv_tokens_per_rank = 0; + int hidden_dim = 0; + int max_num_sms = 0; + NVTEDType max_token_dtype = kNVTEBFloat16; +}; + +class EpResources { + public: + explicit EpResources(const EpBootstrapParams& p) { + ncclUniqueId uid; + std::memcpy(&uid, p.uid_bytes.data(), sizeof(uid)); + NVTE_CHECK_NCCL(ncclCommInitRank(&comm_, p.ep_size, uid, p.rank_within_group)); + NVTEEpGroupConfig cfg{.ep_size = p.ep_size, + .num_experts = p.num_experts, + .max_tokens_per_rank = p.max_tokens_per_rank, + .max_recv_tokens_per_rank = p.max_recv_tokens_per_rank, + .hidden_dim = p.hidden_dim, + .max_num_sms = p.max_num_sms, + .max_token_dtype = p.max_token_dtype}; + try { + nvte_ep_initialize(static_cast(comm_), cfg); + } catch (...) { + ncclCommDestroy(comm_); + comm_ = nullptr; + throw; + } + } + + ~EpResources() { + if (comm_ == nullptr) return; + nvte_ep_shutdown(); + ncclCommDestroy(comm_); + } + + EpResources(const EpResources&) = delete; + EpResources& operator=(const EpResources&) = delete; + + ncclComm_t comm() const { return comm_; } + + private: + ncclComm_t comm_{nullptr}; +}; + +struct EpInstanceState { + static ::xla::ffi::TypeId id; + static ::xla::ffi::TypeInfo info; + std::shared_ptr resources; +}; + +::xla::ffi::TypeId EpInstanceState::id = {}; +::xla::ffi::TypeInfo EpInstanceState::info = ::xla::ffi::MakeTypeInfo(); + +namespace { + +std::mutex g_ep_mu; +EpBootstrapParams g_ep_params; +bool g_ep_params_set = false; +std::weak_ptr g_ep_resources_weak; +// Python-held anchor so trace-time handle_mem allocs find EPBackend ready. +std::shared_ptr g_ep_resources_anchor; + +std::shared_ptr AcquireEpResources() { + std::lock_guard lock(g_ep_mu); + NVTE_CHECK(g_ep_params_set, + "EP bootstrap params not set; call transformer_engine_jax." + "set_ep_bootstrap_params() (typically via ep_bootstrap) first."); + auto sp = g_ep_resources_weak.lock(); + if (sp) return sp; + sp = std::make_shared(g_ep_params); + g_ep_resources_weak = sp; + return sp; +} + +} // namespace + +// top_k and dispatch_output_per_expert_alignment are baked as static FFI +// attributes; prepare passes them to the C API as NVTEEpLayerConfig, and the +// per-step ops carry top_k only to validate the topk_idx last dim. + +struct EpPrepareConfig { + int64_t top_k; + int64_t dispatch_output_per_expert_alignment; +}; + +struct EpDispatchConfig { + int64_t top_k; + int64_t dispatch_output_per_expert_alignment; +}; + +struct EpCombineConfig { + int64_t top_k; + int64_t dispatch_output_per_expert_alignment; + int64_t num_local_tokens; +}; + +struct EpDispatchBwdConfig { + int64_t top_k; + int64_t dispatch_output_per_expert_alignment; + int64_t num_local_tokens; +}; + +struct EpCombineBwdConfig { + int64_t top_k; + int64_t dispatch_output_per_expert_alignment; +}; + +// ── Bootstrap helpers ───────────────────────────────────────────────────────── + +// Caches uid + group config and eagerly creates the NCCL comm (ranks +// synchronize via the UID broadcast). +void SetEpBootstrapParams(pybind11::bytes unique_id_bytes_obj, int ep_size, int rank_within_group, + int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, + int hidden_dim, int max_num_sms, int max_token_dtype) { + std::string uid_str = unique_id_bytes_obj; + NVTE_CHECK(static_cast(uid_str.size()) >= 128, + "unique_id_bytes must be at least 128 bytes (ncclUniqueId size)."); + std::shared_ptr anchor; + { + std::lock_guard lock(g_ep_mu); + NVTE_CHECK(!g_ep_resources_anchor, + "EP bootstrap already initialized; call release_ep_resources() before re-init."); + std::memcpy(g_ep_params.uid_bytes.data(), uid_str.data(), 128); + g_ep_params.ep_size = ep_size; + g_ep_params.rank_within_group = rank_within_group; + g_ep_params.num_experts = num_experts; + g_ep_params.max_tokens_per_rank = max_tokens_per_rank; + g_ep_params.max_recv_tokens_per_rank = max_recv_tokens_per_rank; + g_ep_params.hidden_dim = hidden_dim; + g_ep_params.max_num_sms = max_num_sms; + g_ep_params.max_token_dtype = static_cast(max_token_dtype); + g_ep_params_set = true; + } + // Acquire outside the lock: EpResources ctor runs ncclCommInitRank which is + // a collective and may block on peer ranks. + anchor = AcquireEpResources(); + std::lock_guard lock(g_ep_mu); + g_ep_resources_anchor = std::move(anchor); +} + +// Drops the anchor; comm tears down once the last executable also releases. +void ReleaseEpResources() { + std::shared_ptr to_drop; + { + std::lock_guard lock(g_ep_mu); + to_drop = std::move(g_ep_resources_anchor); + } + // to_drop dtor runs outside the lock. +} + +size_t EpHandleMemSize(int top_k, size_t dispatch_output_per_expert_alignment) { + NVTEEpLayerConfig layer_cfg{top_k, dispatch_output_per_expert_alignment}; + return nvte_ep_handle_mem_size(layer_cfg); +} + +pybind11::capsule GetEpInstanceStateTypeIdCapsule() { + return pybind11::capsule(static_cast(&EpInstanceState::id), "xla.ffi.type_id"); +} + +pybind11::capsule GetEpInstanceStateTypeInfoCapsule() { + return pybind11::capsule(static_cast(&EpInstanceState::info), "xla.ffi.type_info"); +} + +// ── Instantiate handler ───────────────────────────────────────────────────── + +static ::xla::ffi::ErrorOr> EpInstantiateImpl() { + auto state = std::make_unique(); + try { + state->resources = AcquireEpResources(); + } catch (const std::exception& e) { + return ::xla::ffi::Unexpected( + ::xla::ffi::Error::Internal(std::string("EP instantiate failed: ") + e.what())); + } + return state; +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpInstantiateHandler, EpInstantiateImpl, FFI::BindInstantiate()); + +// ── ep_prepare ──────────────────────────────────────────────────────────────── + +Error_Type EpPrepareFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type topk_idx, + Result_Type token_counts, Result_Type handle_mem, Result_Type workspace, + EpPrepareConfig config) { + (void)ep_state; // lifetime only. + auto topk_dims = topk_idx.dimensions(); + NVTE_CHECK(topk_dims.size() >= 2, + "topk_idx must be at least 2D [..., top_k], got ndim=", topk_dims.size()); + auto idx_etype = topk_idx.element_type(); + NVTE_CHECK(idx_etype == ::xla::ffi::DataType::S64 || idx_etype == ::xla::ffi::DataType::S32, + "topk_idx must be int32 or int64; got element_type=", static_cast(idx_etype)); + + std::vector topk_shape = {product(topk_dims, 0, topk_dims.size() - 1), + static_cast(topk_dims.back())}; + // NCCL EP currently requires int64 topk_idx; upcast int32 on-stream. + // TODO(phuong): drop once NCCL EP accepts int32. + void* topk_idx_data = topk_idx.untyped_data(); + if (idx_etype == ::xla::ffi::DataType::S32) { + const size_t n = topk_shape[0] * topk_shape[1]; + NVTE_CHECK(static_cast(workspace->element_count()) >= n, + "workspace too small for int32 → int64 upcast: element_count=", + workspace->element_count(), " < required ", n); + int64_t* ws = reinterpret_cast(workspace->untyped_data()); + nvte_convert_int32_to_int64(reinterpret_cast(topk_idx_data), ws, n, stream); + topk_idx_data = ws; + } + auto topk_idx_ = TensorWrapper(topk_idx_data, topk_shape, DType::kInt64); + + std::vector tc_shape = {static_cast(token_counts->element_count())}; + auto token_counts_ = TensorWrapper(token_counts->untyped_data(), tc_shape, DType::kInt32); + + std::vector hm_shape = {static_cast(handle_mem->element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem->untyped_data(), hm_shape, DType::kByte); + + NVTEEpLayerConfig layer_cfg{static_cast(config.top_k), + static_cast(config.dispatch_output_per_expert_alignment)}; + nvte_ep_prepare(handle_mem_.data(), topk_idx_.data(), token_counts_.data(), layer_cfg, stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpPrepareHandler, EpPrepareFFI, + FFI::Bind() + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // topk_idx + .Ret() // token_counts + .Ret() // handle_mem + .Ret() // workspace (FFI scratch) + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_dispatch ─────────────────────────────────────────────────────────────── + +Error_Type EpDispatchFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type topk_idx, Buffer_Type tokens, Buffer_Type topk_weights, + Result_Type recv_tokens, Result_Type recv_topk_weights, + Result_Type workspace, EpDispatchConfig config) { + (void)ep_state; + auto token_dims = tokens.dimensions(); + NVTE_CHECK(token_dims.size() >= 2, + "tokens must be at least 2D [..., H], got ndim=", token_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + auto idx_dims = topk_idx.dimensions(); + NVTE_CHECK(idx_dims.size() >= 2, + "topk_idx must be at least 2D [..., top_k], got ndim=", idx_dims.size()); + auto idx_etype = topk_idx.element_type(); + NVTE_CHECK(idx_etype == ::xla::ffi::DataType::S64 || idx_etype == ::xla::ffi::DataType::S32, + "topk_idx must be int32 or int64; got element_type=", static_cast(idx_etype)); + NVTE_CHECK(static_cast(idx_dims.back()) == config.top_k, "top_k attr (", config.top_k, + ") must match topk_idx last dim (", idx_dims.back(), ")"); + std::vector idx_shape = {product(idx_dims, 0, idx_dims.size() - 1), + static_cast(idx_dims.back())}; + // NCCL EP currently requires int64 topk_idx; upcast int32 on-stream. + // TODO(phuong): drop once NCCL EP accepts int32. + void* topk_idx_data = topk_idx.untyped_data(); + if (idx_etype == ::xla::ffi::DataType::S32) { + const size_t n = idx_shape[0] * idx_shape[1]; + NVTE_CHECK(static_cast(workspace->element_count()) >= n, + "workspace too small for int32 → int64 upcast: element_count=", + workspace->element_count(), " < required ", n); + int64_t* ws = reinterpret_cast(workspace->untyped_data()); + nvte_convert_int32_to_int64(reinterpret_cast(topk_idx_data), ws, n, stream); + topk_idx_data = ws; + } + auto topk_idx_ = TensorWrapper(topk_idx_data, idx_shape, DType::kInt64); + + const size_t T_flat = product(token_dims, 0, token_dims.size() - 1); + const size_t H = static_cast(token_dims.back()); + std::vector tok_shape = {T_flat, H}; + auto token_dtype = convert_ffi_datatype_to_te_dtype(tokens.element_type()); + auto tokens_ = TensorWrapper(tokens.untyped_data(), tok_shape, token_dtype); + + auto tw_dims = topk_weights.dimensions(); + NVTE_CHECK(tw_dims.size() >= 2, + "topk_weights must be at least 2D [..., top_k], got ndim=", tw_dims.size()); + std::vector tw_shape = {product(tw_dims, 0, tw_dims.size() - 1), + static_cast(tw_dims.back())}; + auto topk_weights_ = TensorWrapper(topk_weights.untyped_data(), tw_shape, DType::kFloat32); + + // recv_tokens: flatten any leading dims into recv_capacity_per_rank. + auto recv_dims = recv_tokens->dimensions(); + NVTE_CHECK(recv_dims.size() >= 2, + "recv_tokens must be at least 2D [..., recv_pr, H]; got ndim=", recv_dims.size()); + const size_t recv_capacity_per_rank = product(recv_dims, 0, recv_dims.size() - 1); + std::vector recv_shape = {recv_capacity_per_rank, H}; + auto recv_tokens_ = TensorWrapper(recv_tokens->untyped_data(), recv_shape, token_dtype); + + auto recv_w_dims = recv_topk_weights->dimensions(); + NVTE_CHECK(recv_w_dims.size() >= 1, + "recv_topk_weights must be at least 1D; got ndim=", recv_w_dims.size()); + const size_t recv_w_total = product(recv_w_dims, 0, recv_w_dims.size()); + NVTE_CHECK(recv_w_total == recv_capacity_per_rank, "recv_topk_weights total (", recv_w_total, + ") must match recv_tokens recv_pr (", recv_capacity_per_rank, ")"); + std::vector recv_w_shape = {recv_capacity_per_rank}; + auto recv_topk_weights_ = + TensorWrapper(recv_topk_weights->untyped_data(), recv_w_shape, DType::kFloat32); + + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_dispatch(handle_mem_.data(), topk_idx_.data(), tokens_.data(), no_win, + topk_weights_.data(), no_win, recv_tokens_.data(), no_win, + recv_topk_weights_.data(), no_win, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchHandler, EpDispatchFFI, + FFI::Bind() + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // topk_idx + .Arg() // tokens + .Arg() // topk_weights + .Ret() // recv_tokens + .Ret() // recv_topk_weights + .Ret() // workspace (FFI scratch) + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_combine ──────────────────────────────────────────────────────────────── + +Error_Type EpCombineFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type expert_out, Result_Type result, EpCombineConfig config) { + (void)ep_state; + auto eo_dims = expert_out.dimensions(); + NVTE_CHECK(eo_dims.size() >= 2, + "expert_out must be at least 2D [..., recv_pr, H]; got ndim=", eo_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + const size_t recv_capacity_per_rank = product(eo_dims, 0, eo_dims.size() - 1); + const size_t H = static_cast(eo_dims.back()); + std::vector eo_shape = {recv_capacity_per_rank, H}; + auto eo_dtype = convert_ffi_datatype_to_te_dtype(expert_out.element_type()); + auto expert_out_ = TensorWrapper(expert_out.untyped_data(), eo_shape, eo_dtype); + + auto res_dims = result->dimensions(); + NVTE_CHECK(res_dims.size() >= 2, + "result must be at least 2D [..., H]; got ndim=", res_dims.size()); + const size_t res_T_flat = product(res_dims, 0, res_dims.size() - 1); + NVTE_CHECK(static_cast(res_T_flat) == config.num_local_tokens, + "result leading-dim product (", res_T_flat, ") must equal num_local_tokens (", + config.num_local_tokens, ")"); + std::vector res_shape = {res_T_flat, H}; + auto result_ = TensorWrapper(result->untyped_data(), res_shape, eo_dtype); + + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_combine(handle_mem_.data(), expert_out_.data(), no_win, result_.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineHandler, EpCombineFFI, + FFI::Bind() + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // expert_out + .Ret() // result + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_dispatch_bwd ─────────────────────────────────────────────────────────── + +Error_Type EpDispatchBwdFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type grad, Buffer_Type g_recv_topk_weights, + Result_Type grad_tokens, Result_Type grad_topk_weights, + EpDispatchBwdConfig config) { + (void)ep_state; + auto grad_dims = grad.dimensions(); + NVTE_CHECK(grad_dims.size() >= 2, + "grad must be at least 2D [..., recv_pr, H]; got ndim=", grad_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + const size_t recv_capacity_per_rank = product(grad_dims, 0, grad_dims.size() - 1); + const size_t H = static_cast(grad_dims.back()); + std::vector g_shape = {recv_capacity_per_rank, H}; + auto g_dtype = convert_ffi_datatype_to_te_dtype(grad.element_type()); + auto grad_ = TensorWrapper(grad.untyped_data(), g_shape, g_dtype); + + auto gw_dims = g_recv_topk_weights.dimensions(); + NVTE_CHECK( + gw_dims.size() >= 1, + "g_recv_topk_weights rank must flatten to recv_capacity_per_rank; got ndim=", gw_dims.size()); + const size_t gw_total = product(gw_dims, 0, gw_dims.size()); + NVTE_CHECK(gw_total == recv_capacity_per_rank, "g_recv_topk_weights total (", gw_total, + ") must match grad recv_pr (", recv_capacity_per_rank, ")"); + std::vector gw_shape = {recv_capacity_per_rank}; + auto g_recv_topk_weights_ = + TensorWrapper(g_recv_topk_weights.untyped_data(), gw_shape, DType::kFloat32); + + auto out_dims = grad_tokens->dimensions(); + NVTE_CHECK(out_dims.size() >= 2, + "grad_tokens must be at least 2D [..., H], got ndim=", out_dims.size()); + const size_t T_flat = product(out_dims, 0, out_dims.size() - 1); + NVTE_CHECK(static_cast(T_flat) == config.num_local_tokens, + "grad_tokens leading-dim product (", T_flat, ") must equal num_local_tokens (", + config.num_local_tokens, ")"); + std::vector out_shape = {T_flat, H}; + auto grad_tokens_ = TensorWrapper(grad_tokens->untyped_data(), out_shape, g_dtype); + + auto gtw_dims = grad_topk_weights->dimensions(); + NVTE_CHECK(gtw_dims.size() >= 2, + "grad_topk_weights must be at least 2D [..., top_k]; got ndim=", gtw_dims.size()); + const size_t gtw_T_flat = product(gtw_dims, 0, gtw_dims.size() - 1); + NVTE_CHECK(gtw_T_flat == T_flat, "grad_topk_weights leading-dim product (", gtw_T_flat, + ") must equal grad_tokens leading-dim product (", T_flat, ")"); + const size_t top_k = static_cast(gtw_dims.back()); + NVTE_CHECK(static_cast(top_k) == config.top_k, "top_k attr (", config.top_k, + ") must match grad_topk_weights last dim (", top_k, ")"); + std::vector gtw_shape = {T_flat, top_k}; + auto grad_topk_weights_ = + TensorWrapper(grad_topk_weights->untyped_data(), gtw_shape, DType::kFloat32); + + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_dispatch_bwd(handle_mem_.data(), grad_.data(), no_win, g_recv_topk_weights_.data(), + no_win, grad_tokens_.data(), grad_topk_weights_.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchBwdHandler, EpDispatchBwdFFI, + FFI::Bind() + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // grad (w.r.t. recv_tokens) + .Arg() // g_recv_topk_weights + .Ret() // grad_tokens + .Ret() // grad_topk_weights + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_combine_bwd ──────────────────────────────────────────────────────────── + +Error_Type EpCombineBwdFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type grad, Result_Type grad_expert_out, + EpCombineBwdConfig config) { + (void)ep_state; + auto grad_dims = grad.dimensions(); + NVTE_CHECK(grad_dims.size() >= 2, + "grad must be at least 2D [..., H], got ndim=", grad_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + const size_t T_flat = product(grad_dims, 0, grad_dims.size() - 1); + const size_t H = static_cast(grad_dims.back()); + std::vector g_shape = {T_flat, H}; + auto g_dtype = convert_ffi_datatype_to_te_dtype(grad.element_type()); + auto grad_ = TensorWrapper(grad.untyped_data(), g_shape, g_dtype); + + auto out_dims = grad_expert_out->dimensions(); + NVTE_CHECK(out_dims.size() >= 2, + "grad_expert_out must be at least 2D [..., recv_pr, H]; got ndim=", out_dims.size()); + const size_t recv_capacity_per_rank = product(out_dims, 0, out_dims.size() - 1); + const size_t out_H = static_cast(out_dims.back()); + NVTE_CHECK(out_H == H, "grad_expert_out hidden dim (", out_H, ") must match grad H (", H, ")"); + std::vector out_shape = {recv_capacity_per_rank, H}; + auto grad_expert_out_ = TensorWrapper(grad_expert_out->untyped_data(), out_shape, g_dtype); + + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_combine_bwd(handle_mem_.data(), grad_.data(), no_win, grad_expert_out_.data(), no_win, + stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineBwdHandler, EpCombineBwdFFI, + FFI::Bind() + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // grad (w.r.t. result) + .Ret() // grad_expert_out + .Attrs(), + FFI_CudaGraph_Traits); + +} // namespace jax +} // namespace transformer_engine + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::EpPrepareConfig, ::xla::ffi::StructMember("top_k"), + ::xla::ffi::StructMember("dispatch_output_per_expert_alignment")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::EpDispatchConfig, ::xla::ffi::StructMember("top_k"), + ::xla::ffi::StructMember("dispatch_output_per_expert_alignment")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::EpCombineConfig, ::xla::ffi::StructMember("top_k"), + ::xla::ffi::StructMember("dispatch_output_per_expert_alignment"), + ::xla::ffi::StructMember("num_local_tokens")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::EpDispatchBwdConfig, ::xla::ffi::StructMember("top_k"), + ::xla::ffi::StructMember("dispatch_output_per_expert_alignment"), + ::xla::ffi::StructMember("num_local_tokens")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::EpCombineBwdConfig, ::xla::ffi::StructMember("top_k"), + ::xla::ffi::StructMember("dispatch_output_per_expert_alignment")); + +#endif // NVTE_WITH_NCCL_EP diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 2432f65005..db5468afe6 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -107,6 +107,25 @@ pybind11::dict Registrations() { dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler); dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); +#ifdef NVTE_WITH_NCCL_EP + // Expert Parallelism (instantiate handler pins NCCL comm to executable lifetime). + dict["te_ep_prepare_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpPrepareHandler)); + dict["te_ep_dispatch_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpDispatchHandler)); + dict["te_ep_combine_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpCombineHandler)); + dict["te_ep_dispatch_bwd_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpDispatchBwdHandler)); + dict["te_ep_combine_bwd_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpCombineBwdHandler)); +#endif // NVTE_WITH_NCCL_EP + // TopK dict["te_topk_ffi"] = EncapsulateFFI(TopkHandler); @@ -136,6 +155,18 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("is_collective_gemm_with_cublasmp", &IsCollectiveGemmWithCublasmp); m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); m.def("get_grouped_gemm_setup_workspace_size", &nvte_get_grouped_gemm_setup_workspace_size); +#ifdef NVTE_WITH_NCCL_EP + m.def("set_ep_bootstrap_params", &SetEpBootstrapParams, pybind11::arg("unique_id_bytes"), + pybind11::arg("ep_size"), pybind11::arg("rank_within_group"), pybind11::arg("num_experts"), + pybind11::arg("max_tokens_per_rank"), pybind11::arg("max_recv_tokens_per_rank"), + pybind11::arg("hidden_dim"), pybind11::arg("max_num_sms"), + pybind11::arg("max_token_dtype")); + m.def("release_ep_resources", &ReleaseEpResources); + m.def("ep_handle_mem_size", &EpHandleMemSize, pybind11::arg("top_k"), + pybind11::arg("dispatch_output_per_expert_alignment") = 0); + m.def("get_ep_instance_state_type_id", &GetEpInstanceStateTypeIdCapsule); + m.def("get_ep_instance_state_type_info", &GetEpInstanceStateTypeInfoCapsule); +#endif // NVTE_WITH_NCCL_EP pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py new file mode 100644 index 0000000000..7b8f638ceb --- /dev/null +++ b/transformer_engine/jax/ep.py @@ -0,0 +1,311 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX Expert Parallelism (EP) API.""" + +import atexit +import ctypes +from functools import partial + +import jax +import jax.numpy as jnp +import jax.experimental.multihost_utils as jmu +import numpy as np + +import transformer_engine_jax +import transformer_engine.jax.cpp_extensions as tex +from transformer_engine.jax.cpp_extensions.misc import jax_dtype_to_te_dtype +from transformer_engine.jax.sharding import global_mesh_resource, get_mesh_axis_size + +ep_prepare = tex.ep_prepare +EpLayerConfig = tex.EpLayerConfig +ep_handle_mem_size = tex.ep_handle_mem_size + +__all__ = [ + "EpLayerConfig", + "ep_bootstrap", + "ep_handle_mem_size", + "ep_prepare", + "ep_dispatch", + "ep_combine", +] + +_atexit_registered = False + + +def _allgather_uid(uid_arr, world_size, uid_size): + """Allgather UID bytes across all processes. + + Tries ``jax.experimental.multihost_utils.process_allgather`` first; + falls back to an XLA collective (process-local sharded global array + replicated via ``jax.jit``) when the multihost helper returns a + short buffer, which has been observed under some launchers. + """ + try: + gathered = jmu.process_allgather(uid_arr, tiled=True) + if gathered.size == world_size * uid_size: + return np.asarray(gathered).reshape(world_size, uid_size) + except Exception: # pylint: disable=broad-except + pass + devices = np.asarray(jax.devices()) + if devices.size != world_size: + raise RuntimeError( + f"_allgather_uid fallback expected {world_size} global devices," + f" got {devices.size}." + ) + mesh = jax.sharding.Mesh(devices, ("_uid_all",)) + sharded = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("_uid_all", None)) + replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + local = np.asarray(uid_arr).reshape(1, uid_size) + g_in = jax.make_array_from_process_local_data(sharded, local, (world_size, uid_size)) + g_out = jax.jit(lambda x: x, out_shardings=replicated)(g_in) + return np.asarray(g_out).reshape(world_size, uid_size) + + +# ── Bootstrap ──────────────────────────────────────────────────────────────── + + +def ep_bootstrap( + world_size, + rank, + ep_size, + num_experts, + max_tokens_per_rank, + recv_capacity_per_rank, + hidden_dim, + max_token_dtype=jnp.bfloat16, + max_num_sms=0, +): + """Initialize the EP communicator. Call once per process before any EP op. + + max_token_dtype is the widest jnp dtype the group will dispatch; tensors + passed to ``ep_dispatch`` may use any narrower dtype. + max_num_sms caps the SMs allotted to EP kernels (0 = auto). + """ + if jnp.dtype(max_token_dtype) != jnp.bfloat16: + raise NotImplementedError( + f"ep_bootstrap: only max_token_dtype=jnp.bfloat16 is supported today, got" + f" {jnp.dtype(max_token_dtype)}." + ) + if world_size < 2: + raise ValueError( + f"ep_bootstrap requires world_size >= 2 (got {world_size}); NCCL EP needs" + " at least 2 ranks to form a group." + ) + if world_size % ep_size != 0: + raise ValueError( + f"world_size ({world_size}) must be divisible by ep_size ({ep_size}); otherwise" + " some EP groups would have fewer than ep_size ranks and ncclCommInitRank would hang." + ) + if num_experts % ep_size != 0: + raise ValueError(f"num_experts ({num_experts}) must be divisible by ep_size ({ep_size}).") + if jax.local_device_count() != 1: + raise ValueError( + "ep_bootstrap requires one local device per process (got" + f" jax.local_device_count() = {jax.local_device_count()}); NCCL EP does not" + " support single-process multi-device setups." + ) + UID_SIZE = 128 + dp_color = rank // ep_size + rank_within_group = rank % ep_size + is_color_root = rank_within_group == 0 + if is_color_root: + try: + from nccl import get_unique_id + + uid_bytes = bytes(get_unique_id())[:UID_SIZE] + except ImportError: + libnccl = ctypes.CDLL("libnccl.so.2", use_errno=True) + uid_arr = (ctypes.c_uint8 * UID_SIZE)() + ret = libnccl.ncclGetUniqueId(ctypes.cast(uid_arr, ctypes.c_void_p)) + assert ret == 0, f"ncclGetUniqueId failed with code {ret}" + uid_bytes = bytes(uid_arr) + else: + uid_bytes = bytes(UID_SIZE) + + uid_arr = jnp.frombuffer(uid_bytes, dtype=jnp.uint8) + all_uids = _allgather_uid(uid_arr, world_size, UID_SIZE) + uid_bytes = bytes(np.asarray(all_uids[dp_color * ep_size]).tolist()) + + ep_resource = global_mesh_resource().ep_resource + if ep_resource is None: + raise ValueError( + "ep_bootstrap requires MeshResource.ep_resource to be set; enter a" + " global_shard_guard(MeshResource(..., ep_resource=)) before bootstrap." + ) + mesh_ep_size = get_mesh_axis_size(ep_resource) + if mesh_ep_size != ep_size: + raise ValueError( + f"ep_bootstrap: EpConfig.ep_size ({ep_size}) does not match mesh axis" + f" '{ep_resource}' size ({mesh_ep_size})." + ) + + # Eager NCCL init while ranks are barrier-synced by the UID broadcast above. + transformer_engine_jax.set_ep_bootstrap_params( + uid_bytes, + ep_size, + rank_within_group, + num_experts, + max_tokens_per_rank, + recv_capacity_per_rank, + hidden_dim, + max_num_sms=int(max_num_sms), + max_token_dtype=int(jax_dtype_to_te_dtype(max_token_dtype)), + ) + + # Release the C++ anchor at interpreter shutdown so RAII can tear down NCCL. + global _atexit_registered + if not _atexit_registered: + atexit.register(transformer_engine_jax.release_ep_resources) + _atexit_registered = True + + tex.ep.set_ep_config( + tex.ep.EpConfig( + world_size=world_size, + rank=rank, + ep_size=ep_size, + num_experts=num_experts, + num_local_experts=num_experts // ep_size, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=hidden_dim, + ) + ) + + +# ── ep_dispatch (custom_vjp) ───────────────────────────────────────────────── + + +@partial(jax.custom_vjp, nondiff_argnums=(0, 4)) +def ep_dispatch(cfg, topk_idx, tokens, topk_weights, recv_capacity_per_rank): + """Scatter tokens and weights to expert ranks. + + ``cfg`` is a per-layer ``EpLayerConfig``; distinct layers may share a + ``cfg`` (the pointer-keyed C++ cache keys on handle_mem, not on cfg). + Inputs are 2D ``[T, H]`` or 3D ``[B, S, H]`` with only the leading dim + sharded (axis in {ep, (dp, ep), dp, None}). Returns + ``(recv_tokens, recv_topk_weights, handle_mem, token_counts)``; pass + ``handle_mem`` and ``token_counts`` to the matching ``ep_combine``. + """ + return _dispatch_fwd(cfg, topk_idx, tokens, topk_weights, recv_capacity_per_rank)[0] + + +def _dispatch_fwd(cfg, topk_idx, tokens, topk_weights, recv_capacity_per_rank): + token_counts, handle_mem = tex.ep_prepare(cfg, topk_idx) + recv_tokens, recv_topk_weights = tex.ep_dispatch_fwd( + cfg, handle_mem, topk_idx, tokens, topk_weights, recv_capacity_per_rank + ) + out_leading = tuple(tokens.shape[:-1]) + primal = (recv_tokens, recv_topk_weights, handle_mem, token_counts) + return primal, (handle_mem, out_leading) + + +def _dispatch_bwd(cfg, recv_capacity_per_rank, res, g_outputs): + del recv_capacity_per_rank + handle_mem, out_leading = res + # Re-pin cotangent sharding: XLA transpose can drop the EP axis on a + # single-fwd-output cotangent, landing a global tensor in the FFI. + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer = gsr.dp_resource or gsr.fsdp_resource + leading = (outer, ep_axis) if outer is not None else ep_axis + g_recv_tokens = jax.lax.with_sharding_constraint( + g_outputs[0], jax.sharding.PartitionSpec(leading, None, None) + ) + g_recv_topk_weights = jax.lax.with_sharding_constraint( + g_outputs[1], jax.sharding.PartitionSpec(leading, None) + ) + grad_tokens, grad_topk_weights = tex.ep_dispatch_bwd( + cfg, handle_mem, g_recv_tokens, g_recv_topk_weights, out_leading + ) + return (None, grad_tokens, grad_topk_weights) + + +ep_dispatch.defvjp(_dispatch_fwd, _dispatch_bwd) + + +# ── ep_combine (custom_vjp) ────────────────────────────────────────────────── + + +@partial(jax.custom_vjp, nondiff_argnums=(0, 5, 6)) +def ep_combine( + cfg, handle_mem, token_counts, expert_out, recv_topk_weights, + num_local_tokens, out_sharding=None, +): + """Reduce weighted expert outputs back to source ranks. + + Args: + cfg: ``EpLayerConfig`` matching the ``ep_dispatch`` call. + handle_mem: Routing-state buffer returned by ``ep_dispatch``. + token_counts: ``[num_procs, num_local_experts]`` int32 (passed through). + expert_out: ``[num_procs, recv_capacity_per_rank, H]`` post-FFN activations. + recv_topk_weights: ``[num_procs, recv_capacity_per_rank]`` float32 weights + returned by ``ep_dispatch``. + num_local_tokens: STATIC int or tuple. int -> 2D output ``[T, H]``; + tuple -> N-D output ``[*tuple, H]``. + out_sharding: STATIC optional ``PartitionSpec`` tuple for the + output. Defaults to ``(("dp","ep"), *None)`` when + DP is set, else ``("ep", *None)``. Only the leading + dim may be sharded. + + Returns: + ``[..., H]`` combined output shaped per ``num_local_tokens``. + """ + return _combine_fwd( + cfg, handle_mem, token_counts, expert_out, recv_topk_weights, + num_local_tokens, out_sharding, + )[0] + + +def _make_valid_mask(recv_topk_weights, dtype): + # recv_topk_weights == 0 marks a padded slot. + return (recv_topk_weights != 0).astype(dtype)[..., None] + + +def _combine_fwd( + cfg, handle_mem, token_counts, expert_out, recv_topk_weights, + num_local_tokens, out_sharding, +): + del token_counts + w = recv_topk_weights[..., None] + mask = _make_valid_mask(recv_topk_weights, jnp.float32) + weighted = (expert_out.astype(jnp.float32) * w * mask).astype(expert_out.dtype) + result = tex.ep_combine_fwd( + cfg, handle_mem, weighted, num_local_tokens, out_partition_spec=out_sharding + ) + return result, (handle_mem, recv_topk_weights, expert_out) + + +def _combine_bwd(cfg, _num_local_tokens, _out_sharding, res, g_result): + handle_mem, recv_topk_weights, expert_out = res + # expert_out is [..., recv_pr, H]; pull recv_pr from the second-to-last dim. + recv_capacity_per_rank = expert_out.shape[-2] + # Re-pin cotangent sharding: same XLA-transpose workaround as _dispatch_bwd. + gsr = global_mesh_resource() + if _out_sharding is not None: + spec = jax.sharding.PartitionSpec(*_out_sharding) + else: + ep_axis = gsr.ep_resource + outer = gsr.dp_resource or gsr.fsdp_resource + leading = (outer, ep_axis) if outer is not None and ep_axis is not None else ep_axis + spec = ( + jax.sharding.PartitionSpec(leading, *([None] * (g_result.ndim - 1))) + if leading is not None + else None + ) + if spec is not None: + g_result = jax.lax.with_sharding_constraint(g_result, spec) + grad_weighted = tex.ep_combine_bwd(cfg, handle_mem, g_result, recv_capacity_per_rank) + w = recv_topk_weights[..., None] + mask = _make_valid_mask(recv_topk_weights, jnp.float32) + grad_weighted_f32 = grad_weighted.astype(jnp.float32) + grad_expert_out = (grad_weighted_f32 * w * mask).astype(grad_weighted.dtype) + grad_recv_topk_weights = ( + (grad_weighted_f32 * expert_out.astype(jnp.float32) * mask) + .sum(axis=-1) + .astype(recv_topk_weights.dtype) + ) + return (None, None, grad_expert_out, grad_recv_topk_weights) + + +ep_combine.defvjp(_combine_fwd, _combine_bwd) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index da527fdf18..2e8e611fa3 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -331,7 +331,12 @@ class MeshResource: fsdp_resource: Axis name for full-sharded data parallelism, default is None pp_resource: Axis name for pipeline parallelism (layer sharding), default is None cp_resource: Axis name for context parallelism (sequence sharding), default is None - ep_resource: Axis name for expert parallelism (MoE expert sharding), default is None + ep_resource: Axis name for expert parallelism. Dispatch input tokens + must be sharded on their leading dim by ``ep_resource`` (alone or + compound with ``dp_resource`` / ``fsdp_resource`` as outer, e.g. + ``PartitionSpec(("dp", "ep"), None, None)``). Dispatch output + ``[ep_size, recv_capacity, H]`` is always sharded by ``ep_resource`` + on the leading ``ep_size`` dim. """ dp_resource: str = None @@ -474,3 +479,8 @@ def dp_or_fsdp_axis_size(): dp_size = get_mesh_axis_size(global_mesh_resource().dp_resource) fsdp_size = get_mesh_axis_size(global_mesh_resource().fsdp_resource) return dp_size if dp_size > 1 else fsdp_size + + +def ep_axis_size(): + """Get the size of the dispatch/EP axis (ep_resource). Returns 1 if unset.""" + return get_mesh_axis_size(global_mesh_resource().ep_resource) From fb3300fe724d157ed760fa1394315f6dd01bb0d2 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 8 Jun 2026 14:31:50 -0700 Subject: [PATCH 02/22] jax/ep: drop topk_weights from ep_combine; caller must pre-multiply Signed-off-by: Phuong Nguyen --- examples/jax/ep/ep_moe.py | 10 ++++-- tests/jax/test_multi_process_ep.py | 30 ++++++++++++------ transformer_engine/jax/ep.py | 50 +++++++++++------------------- 3 files changed, 47 insertions(+), 43 deletions(-) diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py index 7b3601fb60..8a81ccb788 100644 --- a/examples/jax/ep/ep_moe.py +++ b/examples/jax/ep/ep_moe.py @@ -216,12 +216,18 @@ def step(topk_idx, tokens, topk_w, local_kernels): recv_topk_w = jax.lax.with_sharding_constraint(recv_topk_w, NamedSharding(mesh, ep2)) expert_out = _batched_expert_linear(recv_tokens, local_kernels, NLE, dp_size, ep_size) expert_out = jax.lax.with_sharding_constraint(expert_out, NamedSharding(mesh, ep3)) + # ep_combine is unweighted: pre-multiply by recv_topk_w and zero + # padded slots (recv_topk_w == 0) before the scatter-sum. + mask = (recv_topk_w != 0).astype(jnp.float32)[..., None] + weighted = ( + expert_out.astype(jnp.float32) * recv_topk_w[..., None] * mask + ).astype(expert_out.dtype) + weighted = jax.lax.with_sharding_constraint(weighted, NamedSharding(mesh, ep3)) return ep_combine( ep_handle, handle_mem, _tc, - expert_out, - recv_topk_w, + weighted, num_local_tokens=(B, S), out_sharding=(("dp", "ep"), None, None), ) diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py index edfac0f82c..5500ae13a7 100644 --- a/tests/jax/test_multi_process_ep.py +++ b/tests/jax/test_multi_process_ep.py @@ -203,6 +203,13 @@ def _make_random_inputs(self, seed=42, nonuniform=True): topk_weights = jnp.asarray(np.full((T_dp, TOP_K), 1.0 / TOP_K, dtype=np.float32)) return T_dp, tokens, topk_idx, topk_weights + @staticmethod + def _preweight_expert_out(expert_out, recv_topk_weights): + """ep_combine is unweighted; mirror the caller-side weighting + mask.""" + mask = (recv_topk_weights != 0).astype(jnp.float32)[..., None] + w = recv_topk_weights[..., None] + return (expert_out.astype(jnp.float32) * w * mask).astype(expert_out.dtype) + # ── Individual primitives (cpp_extensions level) ────────────────────── def test_two_handle_mems_no_aliasing(self): @@ -255,8 +262,9 @@ def one_layer(hk, idx, toks, w_): ) recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_spec_3d)) recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_spec_2d)) + weighted = self._preweight_expert_out(recv_t, recv_w) return ep_combine( - hk, hm, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None) + hk, hm, tc, weighted, T_global, out_sharding=(("dp", "ep"), None) ) @jax.jit @@ -383,8 +391,9 @@ def loss_fn(toks): recv_w = jax.lax.with_sharding_constraint( recv_w, NamedSharding(self.mesh, ep_spec_2d) ) + weighted = self._preweight_expert_out(recv_t, recv_w) out = ep_combine( - self.hk, hm, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None) + self.hk, hm, tc, weighted, T_global, out_sharding=(("dp", "ep"), None) ) return 0.5 * (out.astype(jnp.float32) ** 2).sum() @@ -427,12 +436,12 @@ def run(idx, toks, w): recv_t, recv_w, hm, _tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) + weighted = self._preweight_expert_out(recv_t, recv_w) out = ep_combine( self.hk, hm, _tc, - recv_t, - recv_w, + weighted, num_local_tokens=(B, S), out_sharding=out_spec_3d, ) @@ -470,12 +479,12 @@ def run(idx, toks, w): recv_t, recv_w, hm, _tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) + weighted = self._preweight_expert_out(recv_t, recv_w) out = ep_combine( self.hk, hm, _tc, - recv_t, - recv_w, + weighted, num_local_tokens=T_global, out_sharding=(("dp", "ep"), None), ) @@ -565,7 +574,8 @@ def loss_fn(eo): recv_w = jax.lax.with_sharding_constraint( recv_w, NamedSharding(self.mesh, PartitionSpec(("dp", "ep"), None)) ) - combined = ep_combine(self.hk, hm, tc, eo, recv_w, T_global) + weighted = self._preweight_expert_out(eo, recv_w) + combined = ep_combine(self.hk, hm, tc, weighted, T_global) # Pin combined to dp-sharded so autodiff transpose feeds # ep_combine_bwd a per-shard cotangent. combined = jax.lax.with_sharding_constraint( @@ -652,8 +662,9 @@ def run(idx, toks, w): recv_w = jax.lax.with_sharding_constraint( recv_w, NamedSharding(self.mesh, ep_spec_2d) ) + weighted = self._preweight_expert_out(recv_t, recv_w) out = ep_combine( - self.hk, hm, tc, recv_t, recv_w, T_dp, out_sharding=(("dp", "ep"), None) + self.hk, hm, tc, weighted, T_dp, out_sharding=(("dp", "ep"), None) ) return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) @@ -692,7 +703,8 @@ def fwd(eo, toks, idx, w): w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) _rt, rw, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) rw = jax.lax.with_sharding_constraint(rw, NamedSharding(self.mesh, ep_spec_2d)) - combined = ep_combine(self.hk, hm, tc, eo, rw, T_dp, out_sharding=(("dp", "ep"), None)) + weighted = self._preweight_expert_out(eo, rw) + combined = ep_combine(self.hk, hm, tc, weighted, T_dp, out_sharding=(("dp", "ep"), None)) return jax.lax.with_sharding_constraint(combined, NamedSharding(self.mesh, dp_spec)) # jax.vjp + pinned cotangent feeds ep_combine_bwd/ep_dispatch_bwd diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index 7b8f638ceb..47ef4d89ed 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -227,20 +227,25 @@ def _dispatch_bwd(cfg, recv_capacity_per_rank, res, g_outputs): # ── ep_combine (custom_vjp) ────────────────────────────────────────────────── -@partial(jax.custom_vjp, nondiff_argnums=(0, 5, 6)) +@partial(jax.custom_vjp, nondiff_argnums=(0, 4, 5)) def ep_combine( - cfg, handle_mem, token_counts, expert_out, recv_topk_weights, + cfg, handle_mem, token_counts, expert_out, num_local_tokens, out_sharding=None, ): - """Reduce weighted expert outputs back to source ranks. + """Scatter-sum expert outputs back to source ranks. **Unweighted.** + + ``ep_combine`` does not apply ``recv_topk_weights`` or any padded-slot + mask. The caller must pre-multiply ``expert_out`` by the dispatched + weights (and zero padded slots) before calling. Gradients w.r.t. + ``recv_topk_weights`` therefore flow through the caller's hadamard, not + through this op. Args: cfg: ``EpLayerConfig`` matching the ``ep_dispatch`` call. handle_mem: Routing-state buffer returned by ``ep_dispatch``. token_counts: ``[num_procs, num_local_experts]`` int32 (passed through). - expert_out: ``[num_procs, recv_capacity_per_rank, H]`` post-FFN activations. - recv_topk_weights: ``[num_procs, recv_capacity_per_rank]`` float32 weights - returned by ``ep_dispatch``. + expert_out: ``[num_procs, recv_capacity_per_rank, H]`` pre-weighted + post-FFN activations. num_local_tokens: STATIC int or tuple. int -> 2D output ``[T, H]``; tuple -> N-D output ``[*tuple, H]``. out_sharding: STATIC optional ``PartitionSpec`` tuple for the @@ -252,34 +257,24 @@ def ep_combine( ``[..., H]`` combined output shaped per ``num_local_tokens``. """ return _combine_fwd( - cfg, handle_mem, token_counts, expert_out, recv_topk_weights, + cfg, handle_mem, token_counts, expert_out, num_local_tokens, out_sharding, )[0] -def _make_valid_mask(recv_topk_weights, dtype): - # recv_topk_weights == 0 marks a padded slot. - return (recv_topk_weights != 0).astype(dtype)[..., None] - - def _combine_fwd( - cfg, handle_mem, token_counts, expert_out, recv_topk_weights, + cfg, handle_mem, token_counts, expert_out, num_local_tokens, out_sharding, ): del token_counts - w = recv_topk_weights[..., None] - mask = _make_valid_mask(recv_topk_weights, jnp.float32) - weighted = (expert_out.astype(jnp.float32) * w * mask).astype(expert_out.dtype) result = tex.ep_combine_fwd( - cfg, handle_mem, weighted, num_local_tokens, out_partition_spec=out_sharding + cfg, handle_mem, expert_out, num_local_tokens, out_partition_spec=out_sharding ) - return result, (handle_mem, recv_topk_weights, expert_out) + return result, (handle_mem, expert_out.shape[-2]) def _combine_bwd(cfg, _num_local_tokens, _out_sharding, res, g_result): - handle_mem, recv_topk_weights, expert_out = res - # expert_out is [..., recv_pr, H]; pull recv_pr from the second-to-last dim. - recv_capacity_per_rank = expert_out.shape[-2] + handle_mem, recv_capacity_per_rank = res # Re-pin cotangent sharding: same XLA-transpose workaround as _dispatch_bwd. gsr = global_mesh_resource() if _out_sharding is not None: @@ -295,17 +290,8 @@ def _combine_bwd(cfg, _num_local_tokens, _out_sharding, res, g_result): ) if spec is not None: g_result = jax.lax.with_sharding_constraint(g_result, spec) - grad_weighted = tex.ep_combine_bwd(cfg, handle_mem, g_result, recv_capacity_per_rank) - w = recv_topk_weights[..., None] - mask = _make_valid_mask(recv_topk_weights, jnp.float32) - grad_weighted_f32 = grad_weighted.astype(jnp.float32) - grad_expert_out = (grad_weighted_f32 * w * mask).astype(grad_weighted.dtype) - grad_recv_topk_weights = ( - (grad_weighted_f32 * expert_out.astype(jnp.float32) * mask) - .sum(axis=-1) - .astype(recv_topk_weights.dtype) - ) - return (None, None, grad_expert_out, grad_recv_topk_weights) + grad_expert_out = tex.ep_combine_bwd(cfg, handle_mem, g_result, recv_capacity_per_rank) + return (None, None, grad_expert_out) ep_combine.defvjp(_combine_fwd, _combine_bwd) From fa907ca3f15d9f5d6c6d34d9920189c67535a286 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 8 Jun 2026 14:31:59 -0700 Subject: [PATCH 03/22] tests/jax/ep: mask uninitialized recv_tokens tail in dispatch_vjp Signed-off-by: Phuong Nguyen --- tests/jax/test_multi_process_ep.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py index 5500ae13a7..bf251682de 100644 --- a/tests/jax/test_multi_process_ep.py +++ b/tests/jax/test_multi_process_ep.py @@ -518,17 +518,27 @@ def test_dispatch_vjp_fwd_bwd(self): with self.mesh, global_shard_guard(self.mr): + align = max(int(self.hk.dispatch_output_per_expert_alignment), 1) + def loss_fn(toks): toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) - recv_tokens, _recv_w, _hm, _tc = ep_dispatch( + recv_tokens, _recv_w, _hm, tc = ep_dispatch( self.hk, idx, toks, w, self.recv_capacity_per_rank ) recv_tokens = jax.lax.with_sharding_constraint( recv_tokens, NamedSharding(self.mesh, ep_spec_3d) ) - return 0.5 * (recv_tokens.astype(jnp.float32) ** 2).sum() + # ep_dispatch fills only slots [0, sum(padded_per_expert)); + # the tail is uninitialized. Mask with jnp.where (NaN-safe; + # multiply would propagate NaN*0=NaN). + padded = ((tc + align - 1) // align) * align + total_recv = jnp.sum(padded, axis=-1, keepdims=True).astype(jnp.int32) + slot_idx = jnp.arange(self.recv_capacity_per_rank, dtype=jnp.int32) + mask = slot_idx[None, :] < total_recv + rt32 = jnp.where(mask[..., None], recv_tokens.astype(jnp.float32), 0.0) + return 0.5 * (rt32 ** 2).sum() loss, grad_tokens = jax.jit(jax.value_and_grad(loss_fn))(tokens) grad_tokens.block_until_ready() From dedbf8653c37ad4a312877edeb0e9b43915aa186 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 8 Jun 2026 14:32:05 -0700 Subject: [PATCH 04/22] examples/jax/ep: add ep_bench.py + run_ep_bench.sh Signed-off-by: Phuong Nguyen --- examples/jax/ep/bench/ep_bench.py | 327 ++++++++++++++++++++++++ examples/jax/ep/bench/run_ep_bench.sh | 352 ++++++++++++++++++++++++++ 2 files changed, 679 insertions(+) create mode 100644 examples/jax/ep/bench/ep_bench.py create mode 100755 examples/jax/ep/bench/run_ep_bench.sh diff --git a/examples/jax/ep/bench/ep_bench.py b/examples/jax/ep/bench/ep_bench.py new file mode 100644 index 0000000000..01713da990 --- /dev/null +++ b/examples/jax/ep/bench/ep_bench.py @@ -0,0 +1,327 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX EP perf bench — dispatch / combine (raw fwd + custom_vjp wrapper) on a 1DP x EP mesh. + +One process per GPU. Run via run_ep_bench.sh. + +Measured per kernel (separate jits): + * tex_ep.ep_dispatch_fwd (stage: dispatch_fwd) + * ep_dispatch (stage: ep_dispatch_vjp -- custom_vjp wrapper, fwd-only) + * tex_ep.ep_combine_fwd (stage: combine_fwd) + * ep_combine (stage: ep_combine_vjp -- custom_vjp wrapper, fwd-only) +Prepare runs once outside the timed loops. + +Timing: wall-clock (perf_counter) around each iter with NVTX ranges, so +nsys can attribute kernels per stage. Rank-0 prints mean wall in us. +Per-stage kernel breakdown comes from `nsys stats --report nvtx_kern_sum`. +Profiling: if --xplane DIR is set, jax.profiler captures the timed region. +nsys profiling is driven from the shell launcher (see run_ep_bench.sh). +""" + +import argparse +import os +import sys +import time + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from transformer_engine.jax.cpp_extensions import ep as tex_ep +from transformer_engine.jax.ep import EpLayerConfig, ep_bootstrap, ep_dispatch, ep_combine +from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + +def _parse_args(): + p = argparse.ArgumentParser(description="TE-JAX EP perf bench (dispatch_fwd + combine_fwd)") + p.add_argument("--coordinator-address", required=True) + p.add_argument("--process-id", type=int, required=True) + p.add_argument("--num-processes", type=int, required=True) + p.add_argument("--tokens-per-rank", type=int, default=8192) + p.add_argument("--hidden", type=int, default=7168) + p.add_argument("--top-k", type=int, default=8) + p.add_argument("--num-experts", type=int, default=256) + p.add_argument("--dp-size", type=int, default=1) + p.add_argument("--warmup", type=int, default=2) + p.add_argument("--iters", type=int, default=10) + p.add_argument( + "--max-num-sms", + type=int, + default=0, + help="Max SMs for dispatch / combine / preprocess kernels (0 = auto).", + ) + p.add_argument( + "--mode-label", + default=None, + help="Optional label suffix for NVTX range names so nsys can partition kernels.", + ) + p.add_argument( + "--second-step", + action="store_true", + help=( + "Time only the 2nd step (1 warmup iter, 1 timed iter). Use to isolate " + "JIT-cache-warm-but-no-steady-state-batching overhead from steady-state perf." + ), + ) + p.add_argument( + "--xplane", + default=None, + help="If set, jax.profiler dumps an XPlane trace into this dir (rank 0 only).", + ) + return p.parse_args() + + +def _distributed_init(args): + jax.distributed.initialize( + coordinator_address=args.coordinator_address, + num_processes=args.num_processes, + process_id=args.process_id, + local_device_ids=[args.process_id], + ) + + +def _build_mesh(args): + n = args.num_processes + assert n % args.dp_size == 0 + ep = n // args.dp_size + devs = np.asarray(jax.devices()).reshape(args.dp_size, ep) + return Mesh(devs, ("dp", "ep")), ep + + +def _make_inputs(args, ep_size): + """Identity-style routing (round-robin), uniform top-k weights. + + Globals: ``B = num_processes`` (sharded on compound (dp,ep)), so each rank + sees ``args.tokens_per_rank`` tokens. Tokens/weights are bf16 / fp32; idx + is int32. Rank shards land via with_sharding_constraint inside the jit. + """ + n = args.num_processes + T = args.tokens_per_rank + H = args.hidden + K = args.top_k + E = args.num_experts + del ep_size + + topk_idx = np.empty((n * T, K), dtype=np.int32) + for t in range(n * T): + for k in range(K): + topk_idx[t, k] = (t * K + k) % E + topk_idx = jnp.asarray(topk_idx) + topk_w = jnp.full((n * T, K), 1.0 / K, dtype=jnp.float32) + tokens = jnp.asarray( + np.random.default_rng(0).standard_normal((n * T, H), dtype=np.float32) * 0.5, + dtype=jnp.bfloat16, + ) + return tokens, topk_idx, topk_w + + +def main(): + args = _parse_args() + _distributed_init(args) + mesh, ep_size = _build_mesh(args) + mr = MeshResource(dp_resource="dp", ep_resource="ep") + rank = args.process_id + + local_experts = args.num_experts // ep_size + recv_capacity_per_rank = args.num_processes * args.tokens_per_rank * args.top_k // 2 + + if rank == 0: + print( + f"[ep_bench] world={args.num_processes} dp={args.dp_size} ep={ep_size}" + f" T={args.tokens_per_rank} H={args.hidden} K={args.top_k}" + f" E={args.num_experts} (local={local_experts}) recv_pr={recv_capacity_per_rank}" + + (f" mode={args.mode_label}" if args.mode_label else ""), + flush=True, + ) + + nvtx_suffix = f"[{args.mode_label}]" if args.mode_label else "" + + in_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + out_spec = (("dp", "ep"), None) + T_global = args.num_processes * args.tokens_per_rank + + with mesh, global_shard_guard(mr): + ep_bootstrap( + world_size=args.num_processes, + rank=rank, + ep_size=ep_size, + num_experts=args.num_experts, + max_tokens_per_rank=args.tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=args.hidden, + max_num_sms=args.max_num_sms, + ) + + tokens, topk_idx, topk_w = _make_inputs(args, ep_size) + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(mesh, in_spec)) + tok_s = jax.lax.with_sharding_constraint(tokens, NamedSharding(mesh, in_spec)) + w_s = jax.lax.with_sharding_constraint(topk_w, NamedSharding(mesh, in_spec)) + + cfg = EpLayerConfig(top_k=args.top_k, dispatch_output_per_expert_alignment=16) + + @jax.jit + def run_prepare(idx): + tc, hm = tex_ep.ep_prepare(cfg, idx) + return tc, hm + + @jax.jit + def run_dispatch(hm, idx, toks, w): + recv_t, recv_w = tex_ep.ep_dispatch_fwd( + cfg, hm, idx, toks, w, recv_capacity_per_rank + ) + recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(mesh, ep_spec_3d)) + recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(mesh, ep_spec_2d)) + return recv_t, recv_w + + @jax.jit + def run_dispatch_vjp(idx, toks, w): + recv_t, recv_w, _hm, _tc = ep_dispatch(cfg, idx, toks, w, recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(mesh, ep_spec_3d)) + recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(mesh, ep_spec_2d)) + return recv_t, recv_w + + @jax.jit + def run_combine(hm, recv_t): + out = tex_ep.ep_combine_fwd( + cfg, + hm, + recv_t, + T_global, + out_partition_spec=out_spec, + ) + return out + + @jax.jit + def run_combine_vjp(hm, tc, recv_t): + # ep_combine is unweighted; bench feeds expert_out directly (caller + # would otherwise pre-multiply by recv_topk_weights + mask). + out = ep_combine(cfg, hm, tc, recv_t, T_global, out_sharding=out_spec) + return out + + tc, handle_mem = run_prepare(idx_s) + tc.block_until_ready() + handle_mem.block_until_ready() + + recv_t0, recv_w0 = run_dispatch(handle_mem, idx_s, tok_s, w_s) + recv_t0.block_until_ready() + recv_w0.block_until_ready() + + warmup_n = 1 if args.second_step else args.warmup + iters_n = 1 if args.second_step else args.iters + + for _ in range(warmup_n): + r, _rw = run_dispatch(handle_mem, idx_s, tok_s, w_s) + r.block_until_ready() + o = run_combine(handle_mem, r) + o.block_until_ready() + run_dispatch_vjp(idx_s, tok_s, w_s)[0].block_until_ready() + run_combine_vjp(handle_mem, tc, recv_t0).block_until_ready() + + if args.xplane and rank == 0: + os.makedirs(args.xplane, exist_ok=True) + jax.profiler.start_trace(args.xplane) + + try: + import nvtx as _nvtx + + def _push(name): + _nvtx.push_range(message=name) + + def _pop(): + _nvtx.pop_range() + + except ImportError: + + def _push(name): + pass + + def _pop(): + pass + + def _time_stage_wall_us(name, fn): + # First timed iter still carries an autotune outlier even after JIT + # warmup; run iters_n + 1, drop iter 0 from the average, and push + # the NVTX range AFTER iter 0 so nsys' nvtx_kern_sum excludes the + # outlier too. + total_ns = 0 + counted = 0 + for i in range(iters_n + 1): + if i == 1: + _push(f"{name}{nvtx_suffix}") + t0 = time.perf_counter_ns() + fn() + dt = time.perf_counter_ns() - t0 + if i == 0: + continue + total_ns += dt + counted += 1 + _pop() + return total_ns / 1e3 / counted + + def _do_dispatch(): + r, _ = run_dispatch(handle_mem, idx_s, tok_s, w_s) + r.block_until_ready() + + def _do_dispatch_vjp(): + r, _ = run_dispatch_vjp(idx_s, tok_s, w_s) + r.block_until_ready() + + def _do_combine(): + o = run_combine(handle_mem, recv_t0) + o.block_until_ready() + + def _do_combine_vjp(): + o = run_combine_vjp(handle_mem, tc, recv_t0) + o.block_until_ready() + + d_wall_us = _time_stage_wall_us("dispatch_fwd", _do_dispatch) + dv_wall_us = _time_stage_wall_us("ep_dispatch_vjp", _do_dispatch_vjp) + c_wall_us = _time_stage_wall_us("combine_fwd", _do_combine) + cv_wall_us = _time_stage_wall_us("ep_combine_vjp", _do_combine_vjp) + + if args.xplane and rank == 0: + jax.profiler.stop_trace() + + if rank == 0: + label = f" [{args.mode_label}]" if args.mode_label else "" + print("", flush=True) + print(f"| stage | mean wall (us){label} |", flush=True) + print("|-------------------|---------------:|", flush=True) + print(f"| dispatch_fwd | {d_wall_us:14.1f} |", flush=True) + print(f"| ep_dispatch_vjp | {dv_wall_us:14.1f} |", flush=True) + print(f"| combine_fwd | {c_wall_us:14.1f} |", flush=True) + print(f"| ep_combine_vjp | {cv_wall_us:14.1f} |", flush=True) + print(f"| (dispatch vjp-fwd)| {dv_wall_us - d_wall_us:14.1f} |", flush=True) + print(f"| (combine vjp-fwd)| {cv_wall_us - c_wall_us:14.1f} |", flush=True) + print("", flush=True) + print( + "[ep_bench] kernel breakout: see nsys nvtx_kern_sum output below " + "(produced by run_ep_bench.sh --nsys).", + flush=True, + ) + + # Under nsys: force cudaDeviceReset() to drain CUPTI's in-process kernel + # records into the .nsys-rep, then os._exit to skip JAX's coord-service + # watchdog. The reset crashes during NCCL EP context teardown, so we only + # take this path when the launcher opts in via EP_BENCH_FLUSH_CUPTI=1. + if os.environ.get("EP_BENCH_FLUSH_CUPTI", "0") == "1": + try: + import ctypes + + cudart = ctypes.CDLL("libcudart.so") + cudart.cudaDeviceSynchronize() + cudart.cudaDeviceReset() + except Exception: + pass + time.sleep(0.5) + sys.stdout.flush() + sys.stderr.flush() + os._exit(0) + + +if __name__ == "__main__": + main() diff --git a/examples/jax/ep/bench/run_ep_bench.sh b/examples/jax/ep/bench/run_ep_bench.sh new file mode 100755 index 0000000000..1531dfd5cf --- /dev/null +++ b/examples/jax/ep/bench/run_ep_bench.sh @@ -0,0 +1,352 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# 4-rank launcher for ep_bench.py. +# Examples: +# bash run_ep_bench.sh # plain run, stdout only +# bash run_ep_bench.sh --cuda-graph # enable XLA command-buffer (cudaGraph), min_size=1 +# bash run_ep_bench.sh --nsys # nsys on rank 0 -> results/jax_nsys.nsys-rep +# bash run_ep_bench.sh --xplane # jax.profiler on rank 0 -> results/xplane/ +# +# Notes: +# * nsys + xplane cannot be combined (both attach CUPTI -> MULTIPLE_SUBSCRIBERS). +# * nsys + --cuda-graph is rejected: cudaGraph fires kernels via cuGraphLaunch +# and detaches the host NVTX context, breaking per-stage attribution. +# * stdout per rank lands in results/stdout__rank_.txt. + +set -uo pipefail + +NSYS=0; XPLANE=0; CGRAPH=0; SECOND_STEP=0 +for a in "$@"; do + case "$a" in + --nsys) NSYS=1 ;; + --xplane) XPLANE=1 ;; + --cuda-graph) CGRAPH=1 ;; + --second-step) SECOND_STEP=1 ;; + *) echo "unknown arg: $a" >&2; exit 2 ;; + esac +done +if [ "${NSYS}" -eq 1 ] && [ "${XPLANE}" -eq 1 ]; then + echo "--nsys and --xplane both attach CUPTI; pick one." >&2; exit 2 +fi +if [ "${NSYS}" -eq 1 ] && [ "${CGRAPH}" -eq 1 ]; then + echo "--nsys and --cuda-graph cannot be combined: cudaGraph launches detach the" \ + "host NVTX context, so nvtx_kern_sum cannot attribute kernels to our ranges." >&2 + exit 2 +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../../../.." && pwd)" +RESULTS="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS}" +export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +NUM_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "EP bench requires >=4 GPUs (found ${NUM_GPUS}); SKIPPING."; exit 0 +fi +NUM=4 +COORD="${COORD:-127.0.0.1:23457}" +TIMEOUT_S="${TIMEOUT_S:-1800}" + +XLA_BASE="${XLA_BASE:---xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_graph_min_graph_size=1}" + +if [ "${CGRAPH}" -eq 1 ]; then + TAG="cudagraph" + export XLA_FLAGS="${XLA_BASE} --xla_gpu_enable_command_buffer=FUSION,CUSTOM_CALL --xla_gpu_graph_min_graph_size=1" +else + TAG="vanilla" + export XLA_FLAGS="${XLA_BASE} --xla_gpu_enable_command_buffer=" +fi +[ "${SECOND_STEP}" -eq 1 ] && TAG="${TAG}_step2" + +: "${NCCL_EP_JIT_CACHE_DIR:=${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)}" +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "${NCCL_EP_JIT_CACHE_DIR}" + +# JAX/XLA persistent compilation cache: first run pays full compile cost +# (cudaGraph capture + EP custom_calls is minutes); subsequent runs reuse it. +: "${JAX_COMPILATION_CACHE_DIR:=${TMPDIR:-/tmp}/jax_cache_$(id -u)}" +export JAX_COMPILATION_CACHE_DIR +mkdir -p "${JAX_COMPILATION_CACHE_DIR}" + +export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.2}" +export NVTE_EP_SILENCE_NONSYMM_WARN="${NVTE_EP_SILENCE_NONSYMM_WARN:-1}" + +ALL_RANKS_ARGS=() +R0_ONLY_ARGS=() +NSYS_PREFIX=() +SUFFIX="" +if [ "${SECOND_STEP}" -eq 1 ]; then + ALL_RANKS_ARGS+=(--second-step) +fi +if [ "${XPLANE}" -eq 1 ]; then + R0_ONLY_ARGS+=(--xplane "${RESULTS}/xplane_${TAG}") + SUFFIX="_xplane" +fi +if [ "${NSYS}" -eq 1 ]; then + SUFFIX="_nsys" + export EP_BENCH_FLUSH_CUPTI=1 + NSYS_PREFIX=(nsys profile + --output "${RESULTS}/jax_${TAG}_nsys" + --force-overwrite=true + --trace=cuda,nvtx + --gpu-metrics-devices=none + --cuda-um-cpu-page-faults=false + --cuda-um-gpu-page-faults=false) +fi + +OUT_PREFIX="stdout_${TAG}${SUFFIX}_rank" + +for f in "${RESULTS}/${OUT_PREFIX}_"*.txt \ + "${RESULTS}/jax_${TAG}_nsys.nsys-rep" \ + "${RESULTS}/jax_${TAG}_nsys.sqlite" \ + "${RESULTS}/jax_${TAG}_nsys_nvtx_kern_sum.csv" \ + "${RESULTS}/jax_${TAG}_nsys_kern_sum.csv" \ + "${RESULTS}/summary_${TAG}${SUFFIX}.md"; do + [ -f "$f" ] && mv -f "$f" "$f.prev" +done + +PIDS=() +cleanup() { for pid in "${PIDS[@]}"; do kill -KILL "$pid" 2>/dev/null || true; done; } +trap cleanup EXIT INT TERM + +for ((i=1; i "${RESULTS}/${OUT_PREFIX}_${i}.txt" 2>&1 & + PIDS+=($!) +done + +R0_CMD=(python -u "${SCRIPT_DIR}/ep_bench.py" + --coordinator-address "${COORD}" --process-id 0 --num-processes "${NUM}" + "${ALL_RANKS_ARGS[@]}" "${R0_ONLY_ARGS[@]}") +if [ "${NSYS}" -eq 1 ]; then + R0_CMD=("${NSYS_PREFIX[@]}" "${R0_CMD[@]}") +fi + +WATCHDOG_PID="" +if [ "${NSYS}" -eq 1 ]; then + ( while ! grep -q "kernel breakout" "${RESULTS}/${OUT_PREFIX}_0.txt" 2>/dev/null; do + sleep 2 + done + sleep 20 + pkill -INT -f "nsys profile --output ${RESULTS}/jax_${TAG}_nsys" 2>/dev/null || true + ) & + WATCHDOG_PID=$! +fi + +timeout --foreground --signal=KILL "${TIMEOUT_S}" "${R0_CMD[@]}" 2>&1 | tee "${RESULTS}/${OUT_PREFIX}_0.txt" +if [ -n "${WATCHDOG_PID}" ]; then + kill "${WATCHDOG_PID}" 2>/dev/null || true +fi +wait + +SUMMARY="${RESULTS}/summary_${TAG}${SUFFIX}.md" +RANK0_LOG="${RESULTS}/${OUT_PREFIX}_0.txt" + +{ + echo "# JAX EP bench summary — tag=${TAG}${SUFFIX}" + echo "" + echo "Generated: $(date -Iseconds)" + echo "Rank-0 log: \`${RANK0_LOG}\`" + echo "" + echo "## Per-stage runtime (rank 0)" + echo "" + echo '```' + awk '/^\| stage / {flag=1} flag {print; if (/combine[ ]+vjp-fwd/) {flag=0}}' "${RANK0_LOG}" || true + echo '```' +} > "${SUMMARY}" + +if [ "${NSYS}" -eq 1 ]; then + NSYS_REP="${RESULTS}/jax_${TAG}_nsys.nsys-rep" + NVTX_CSV="${RESULTS}/jax_${TAG}_nsys_nvtx_kern_sum.csv" + KERN_CSV="${RESULTS}/jax_${TAG}_nsys_kern_sum.csv" + if [ -f "${NSYS_REP}" ] && command -v nsys >/dev/null 2>&1; then + PROJ_CSV="${RESULTS}/jax_${TAG}_nsys_nvtx_gpu_proj_sum.csv" + echo "Extracting NVTX-range + kernel summaries from ${NSYS_REP} ..." + nsys stats --report nvtx_kern_sum --format csv \ + --output - "${NSYS_REP}" > "${NVTX_CSV}" 2>&1 || true + nsys stats --report cuda_gpu_kern_sum --format csv \ + --output - "${NSYS_REP}" > "${KERN_CSV}" 2>&1 || true + nsys stats --report nvtx_gpu_proj_sum --format csv \ + --output - "${NSYS_REP}" > "${PROJ_CSV}" 2>&1 || true + + BREAKOUT=$(python3 - "${NVTX_CSV}" "${PROJ_CSV}" <<'PYEOF' +import csv, sys, collections, re +path = sys.argv[1] + +STAGE_PATTERNS = { + "dispatch_fwd": re.compile(r"(^|:)dispatch_fwd(\[[^\]]*\])?$"), + "ep_dispatch_vjp": re.compile(r"(^|:)ep_dispatch_vjp(\[[^\]]*\])?$"), + "combine_fwd": re.compile(r"(^|:)combine_fwd(\[[^\]]*\])?$"), + "ep_combine_vjp": re.compile(r"(^|:)ep_combine_vjp(\[[^\]]*\])?$"), +} +STAGE_ORDER = ("dispatch_fwd", "ep_dispatch_vjp", "combine_fwd", "ep_combine_vjp") + +stages = collections.defaultdict(list) +try: + with open(path) as f: + lines = [ln for ln in f] + header_idx = next((i for i, ln in enumerate(lines) + if ln.lstrip().startswith("NVTX Range,")), -1) + if header_idx < 0: + print("(NVTX header not found)"); sys.exit(0) + reader = csv.reader(lines[header_idx:]) + header = next(reader, None) + def col(name): + for i, h in enumerate(header): + if h.strip().lower() == name.lower(): + return i + return -1 + i_range = col("NVTX Range") + i_total = col("Total Time (ns)") + i_inst = col("Kern Inst") + i_name = col("Kernel Name") + if min(i_range, i_total, i_inst, i_name) < 0: + print(f"(missing expected columns; got {header})"); sys.exit(0) + for row in reader: + if len(row) <= i_name: continue + rname = row[i_range].strip() + try: + total_ns = int(row[i_total].replace(',', '')) + inst = int(row[i_inst].replace(',', '')) + except ValueError: + continue + kname = row[i_name].strip() + for stage, pat in STAGE_PATTERNS.items(): + if pat.search(rname): + stages[stage].append((total_ns, inst, kname)) + break +except FileNotFoundError: + print("(nvtx_kern_sum CSV not found)"); sys.exit(0) + +if not stages: + print("(no kernels matched expected NVTX ranges)") + sys.exit(0) + +proj_csv = sys.argv[2] if len(sys.argv) > 2 else None +proj = {} +if proj_csv: + try: + with open(proj_csv) as f: + plines = list(f) + hidx = next((i for i, ln in enumerate(plines) + if ln.lstrip().startswith("Range,")), -1) + if hidx >= 0: + pr = csv.reader(plines[hidx:]) + ph = next(pr, None) + def pcol(n): + for i, h in enumerate(ph): + if h.strip().lower() == n.lower(): return i + return -1 + pi_range = pcol("Range") + pi_total = pcol("Total Proj Time (ns)") + pi_inst = pcol("Range Instances") + pi_gpuops = pcol("Total GPU Ops") + for row in pr: + if len(row) <= max(pi_range, pi_total, pi_inst): continue + rname = row[pi_range].strip() + for stage, pat in STAGE_PATTERNS.items(): + if pat.search(rname): + try: + t = int(row[pi_total].replace(',', '')) + n = int(row[pi_inst].replace(',', '')) + ops = int(row[pi_gpuops].replace(',', '')) if pi_gpuops >= 0 else 0 + except ValueError: + continue + proj[stage] = (t / 1e3, n) + break + except FileNotFoundError: + pass + +print("### Per-stage GPU activity (kernels + memops, from nvtx_gpu_proj_sum)") +print() +print("| stage | iters | GPU activity total (us) | per-iter (us) | kernel sum (us) | per-iter (us) | gap = memops+idle (us) |") +print("|------|-----:|----------------------:|------------:|--------------:|------------:|---------------------:|") +for stage in STAGE_ORDER: + rows = stages.get(stage, []) + kern_total_us = sum(r[0] for r in rows) / 1e3 + iters = max(rows, key=lambda r: r[0])[1] if rows else 0 + gpu_total_us, _ = proj.get(stage, (0.0, 0)) + per_iter_gpu = gpu_total_us / iters if iters else 0 + per_iter_kern = kern_total_us / iters if iters else 0 + gap = per_iter_gpu - per_iter_kern + print(f"| `{stage}` | {iters} | {gpu_total_us:18.1f} | {per_iter_gpu:11.1f} | {kern_total_us:13.1f} | {per_iter_kern:11.1f} | {gap:20.1f} |") +print() + +def _kern_per_iter(rows, needle): + tot_ns = 0; inst = 0 + for tns, n, kname in rows: + if needle in kname: + tot_ns += tns; inst += n + return (tot_ns / inst / 1e3) if inst else None + +KEY_KERNELS = { + "dispatch_fwd": [("dispatch", "nccl_ep_jit_ht_dispatch_kernel"), + ("permute", "nccl_ep_jit_ht_permute_kernel")], + "ep_dispatch_vjp": [("dispatch", "nccl_ep_jit_ht_dispatch_kernel"), + ("permute", "nccl_ep_jit_ht_permute_kernel")], + "combine_fwd": [("combine", "nccl_ep_jit_ht_combine_kernel"), + ("local_reduce", "nccl_ep_jit_ht_local_reduce_kernel")], + "ep_combine_vjp": [("combine", "nccl_ep_jit_ht_combine_kernel"), + ("local_reduce", "nccl_ep_jit_ht_local_reduce_kernel")], +} + +print("### Key NCCL EP kernel time per iter (us)") +print() +print("| stage | primary kernel (us/iter) | secondary kernel (us/iter) | kernel sum/iter (us) |") +print("|------|--------------------:|-----------------------:|------------------:|") +for stage in STAGE_ORDER: + rows = stages.get(stage, []) + iters = max(rows, key=lambda r: r[0])[1] if rows else 0 + per_iter_kern = (sum(r[0] for r in rows) / 1e3 / iters) if iters else 0.0 + keys = KEY_KERNELS.get(stage, []) + cells = [] + for label, needle in keys: + v = _kern_per_iter(rows, needle) + cells.append(f"{label}: {v:.1f}" if v is not None else f"{label}: -") + while len(cells) < 2: + cells.append("-") + print(f"| `{stage}` | {cells[0]:>20} | {cells[1]:>22} | {per_iter_kern:17.1f} |") +print() + +for stage in STAGE_ORDER: + rows = stages.get(stage, []) + if not rows: + print(f"### Stage `{stage}` top kernels — none"); print(); continue + agg = collections.defaultdict(lambda: [0, 0]) + for tns, inst, kname in rows: + agg[kname][0] += tns + agg[kname][1] += inst + items = sorted(([k, v[0], v[1]] for k, v in agg.items()), key=lambda x: -x[1]) + total_us = sum(v[1] for v in items) / 1e3 + print(f"### Stage `{stage}` — top 20 kernels ({len(items)} distinct; kernel-sum {total_us:.1f} us)") + print() + print("| # | total (us) | inst | avg (us) | kernel |") + print("|--:|-----------:|-----:|---------:|--------|") + for i, (kname, tns, inst) in enumerate(items[:20], 1): + avg_us = (tns / inst) / 1e3 if inst else 0 + short = kname if len(kname) <= 80 else kname[:77] + "..." + print(f"| {i} | {tns/1e3:10.1f} | {inst:4d} | {avg_us:8.2f} | `{short}` |") + print() +PYEOF +) + { + echo "" + echo "## Kernel breakout per NVTX range (rank 0)" + echo "" + echo "${BREAKOUT}" + echo "Full CSVs:" + echo "- per-range: \`${NVTX_CSV}\`" + echo "- overall: \`${KERN_CSV}\`" + } | tee -a "${RANK0_LOG}" >> "${SUMMARY}" + fi +fi + +echo "Done. Logs in ${RESULTS}/${OUT_PREFIX}_*.txt" +echo "Summary: ${SUMMARY}" From 453b39904ec5f98c46378050dc47da40fd630dae Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 8 Jun 2026 15:15:05 -0700 Subject: [PATCH 05/22] examples/jax/ep: ep_moe.py runs --iters fwd+bwd steps (default 3) Signed-off-by: Phuong Nguyen --- examples/jax/ep/ep_moe.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py index 8a81ccb788..77b9531ff8 100644 --- a/examples/jax/ep/ep_moe.py +++ b/examples/jax/ep/ep_moe.py @@ -43,6 +43,12 @@ def _parse_args(): default=True, help="Verify fwd+bwd against a single-rank numpy reference.", ) + p.add_argument( + "--iters", + type=int, + default=3, + help="Number of fwd+bwd iterations to run (same compiled jit, same handle_mem).", + ) return p.parse_args() @@ -308,18 +314,22 @@ def loss_fn(toks, idx, w, kern): out = _moe_step(args, idx, toks, w, kern) return 0.5 * (out.astype(jnp.float32) ** 2).sum(), out - (loss, out_fwd), grad_tokens = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))( - tokens, topk_idx, topk_w, kernels - ) - grad_tokens.block_until_ready() - out_fwd.block_until_ready() - - if args.process_id == 0: - print( - f"[ep_moe] loss={float(loss):.4f} grad_tokens.shape={grad_tokens.shape} " - f"dp={args.dp_size} ep={args.ep_size} " - f"num_experts={args.num_experts} recv_pr={args.recv_capacity_per_rank}" - ) + step_jit = jax.jit(jax.value_and_grad(loss_fn, has_aux=True)) + + # Run --iters fwd+bwd steps on the same compiled jit. With identical + # inputs every iter, the pointer-keyed handle_mem cache must keep + # producing identical loss/grad. + for it in range(args.iters): + (loss, out_fwd), grad_tokens = step_jit(tokens, topk_idx, topk_w, kernels) + grad_tokens.block_until_ready() + out_fwd.block_until_ready() + if args.process_id == 0: + print( + f"[ep_moe] iter={it} loss={float(loss):.4f}" + f" grad_tokens.shape={grad_tokens.shape}" + f" dp={args.dp_size} ep={args.ep_size}" + f" num_experts={args.num_experts} recv_pr={args.recv_capacity_per_rank}" + ) if args.check: From 4ee095c641bb3e193733f0b1307803cf58cd5172 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 8 Jun 2026 16:29:08 -0700 Subject: [PATCH 06/22] jax/ep: tighten sharding contract, drop helpers, route bwd through TE with_sharding_constraint Signed-off-by: Phuong Nguyen --- examples/jax/ep/bench/ep_bench.py | 28 +-- examples/jax/ep/ep_moe.py | 4 +- tests/jax/test_multi_process_ep.py | 42 ---- transformer_engine/jax/cpp_extensions/ep.py | 201 +++++++------------- transformer_engine/jax/ep.py | 95 +++++---- 5 files changed, 123 insertions(+), 247 deletions(-) diff --git a/examples/jax/ep/bench/ep_bench.py b/examples/jax/ep/bench/ep_bench.py index 01713da990..27842dc834 100644 --- a/examples/jax/ep/bench/ep_bench.py +++ b/examples/jax/ep/bench/ep_bench.py @@ -1,22 +1,11 @@ # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -"""JAX EP perf bench — dispatch / combine (raw fwd + custom_vjp wrapper) on a 1DP x EP mesh. - -One process per GPU. Run via run_ep_bench.sh. - -Measured per kernel (separate jits): - * tex_ep.ep_dispatch_fwd (stage: dispatch_fwd) - * ep_dispatch (stage: ep_dispatch_vjp -- custom_vjp wrapper, fwd-only) - * tex_ep.ep_combine_fwd (stage: combine_fwd) - * ep_combine (stage: ep_combine_vjp -- custom_vjp wrapper, fwd-only) -Prepare runs once outside the timed loops. - -Timing: wall-clock (perf_counter) around each iter with NVTX ranges, so -nsys can attribute kernels per stage. Rank-0 prints mean wall in us. -Per-stage kernel breakdown comes from `nsys stats --report nvtx_kern_sum`. -Profiling: if --xplane DIR is set, jax.profiler captures the timed region. -nsys profiling is driven from the shell launcher (see run_ep_bench.sh). +"""JAX EP perf bench — dispatch/combine (raw fwd + custom_vjp wrapper) on a 1DP x EP mesh. + +One process per GPU; launch via run_ep_bench.sh. Each stage is jitted and +timed separately with NVTX ranges (prepare runs once outside the loop). +Rank-0 prints mean wall in us; nsys / --xplane attribute kernels per stage. """ import argparse @@ -91,12 +80,7 @@ def _build_mesh(args): def _make_inputs(args, ep_size): - """Identity-style routing (round-robin), uniform top-k weights. - - Globals: ``B = num_processes`` (sharded on compound (dp,ep)), so each rank - sees ``args.tokens_per_rank`` tokens. Tokens/weights are bf16 / fp32; idx - is int32. Rank shards land via with_sharding_constraint inside the jit. - """ + """Round-robin routing, uniform top-k weights; each rank sees ``args.tokens_per_rank`` tokens.""" n = args.num_processes T = args.tokens_per_rank H = args.hidden diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py index 77b9531ff8..a6b0ba6545 100644 --- a/examples/jax/ep/ep_moe.py +++ b/examples/jax/ep/ep_moe.py @@ -316,9 +316,7 @@ def loss_fn(toks, idx, w, kern): step_jit = jax.jit(jax.value_and_grad(loss_fn, has_aux=True)) - # Run --iters fwd+bwd steps on the same compiled jit. With identical - # inputs every iter, the pointer-keyed handle_mem cache must keep - # producing identical loss/grad. + # Same jit + same inputs each iter: handle_mem cache must give identical loss/grad. for it in range(args.iters): (loss, out_fwd), grad_tokens = step_jit(tokens, topk_idx, topk_w, kernels) grad_tokens.block_until_ready() diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py index bf251682de..ad1216642b 100644 --- a/tests/jax/test_multi_process_ep.py +++ b/tests/jax/test_multi_process_ep.py @@ -460,48 +460,6 @@ def run(idx, toks, w): rtol=5e-2, ) - def test_dispatch_combine_dp_only_first_dim(self): - """Input sharded ``("dp", None)`` (no ep on leading) — dispatch must - accept it. JAX SPMD slices the missing ep axis locally so the kernel - still sees ``T/(dp*ep)`` tokens per rank.""" - T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) - dp_only = PartitionSpec("dp", None) - with self.mesh, global_shard_guard(self.mr): - idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_only)) - tok_s = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_only)) - w_s = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_only)) - - ep_t = PartitionSpec(("dp", "ep"), None, None) - ep_w = PartitionSpec(("dp", "ep"), None) - - @jax.jit - def run(idx, toks, w): - recv_t, recv_w, hm, _tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) - recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) - recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) - weighted = self._preweight_expert_out(recv_t, recv_w) - out = ep_combine( - self.hk, - hm, - _tc, - weighted, - num_local_tokens=T_global, - out_sharding=(("dp", "ep"), None), - ) - return out - - out = run(idx_s, tok_s, w_s) - out.block_until_ready() - out_global = jmu.process_allgather(out, tiled=True) - - if self.rank == 0: - np.testing.assert_allclose( - np.asarray(out_global.astype(jnp.float32)), - np.asarray(tokens.astype(jnp.float32)), - atol=5e-2, - rtol=5e-2, - ) - # ── Custom-VJP tests ───────────────────────────────────────────────── def test_dispatch_vjp_fwd_bwd(self): diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py index 5263b33ba9..55b204efdc 100644 --- a/transformer_engine/jax/cpp_extensions/ep.py +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -23,7 +23,7 @@ import transformer_engine_jax from .base import BasePrimitive, register_primitive -from ..sharding import global_mesh_resource +from ..sharding import global_mesh_resource, get_mesh_axis_size __all__ = [ "EpConfig", @@ -98,54 +98,25 @@ def ep_handle_mem_size(cfg: EpLayerConfig) -> int: ) -def _leading_axis_ok(spec, ep_axis, outer_axes=()): - # Only the first dim may carry sharding; remaining dims must be replicated. - # The first dim's axis must be one of: - # ``ep_axis`` alone, - # a tuple of dp/fsdp axes (no ep — ep gets sliced in locally), - # a tuple ending in ``ep_axis`` with dp/fsdp axes before it. - # Examples on a (dp, ep) mesh: 2D ``(ep, None)``, ``(("dp","ep"), None)``, - # ``("dp", None)``; 3D ``(ep, None, None)``, ``(("dp","ep"), None, None)``, - # ``("dp", None, None)``. - if len(spec) < 2 or ep_axis is None: - return False - if any(ax is not None for ax in spec[1:]): - return False # only first dim sharded - leading = spec[0] - allowed_outers = {a for a in outer_axes if a is not None} - allowed = allowed_outers | {ep_axis, None} - elts = leading if isinstance(leading, tuple) else (leading,) - return all(a in allowed for a in elts) - +def _leading_axis_ok(spec): + """Validate an EP input spec; return ``(ok, ep_axis, outer_axes)``. -def _canonical_input_spec(spec, ndim): - """Canonical input PartitionSpec the primitive demands JAX deliver. - - Sharding lives entirely on the first dim. If ``spec[0]`` already includes - ``ep_resource``, returned unchanged. Otherwise ``ep_resource`` is folded - into the first-dim axis tuple, e.g. ``"dp"`` → ``("dp","ep")``. The added - ep axis is a local slice (the missing dim was replicated), no cross-device - comm. + Leading dim is ``ep`` or a tuple ending in ``ep`` (outer dp/fsdp axes + first); all other dims must be replicated. """ gsr = global_mesh_resource() - ep = gsr.ep_resource + ep_axis = gsr.ep_resource + outer_axes = tuple(a for a in (gsr.dp_resource, gsr.fsdp_resource) if a is not None) + if len(spec) < 2 or ep_axis is None: + return False, ep_axis, outer_axes + if any(ax is not None for ax in spec[1:]): + return False, ep_axis, outer_axes leading = spec[0] - present = leading if isinstance(leading, tuple) else (leading,) if leading is not None else () - if ep in present: - return PartitionSpec(*spec) - if leading is None: - new_leading = ep - elif isinstance(leading, tuple): - new_leading = (*leading, ep) - else: - new_leading = (leading, ep) - return PartitionSpec(new_leading, *([None] * (ndim - 1))) - - -def _dispatch_input_outer_axes(): - """dp/fsdp axes allowed as outer companions to ep_resource on dispatch input.""" - gsr = global_mesh_resource() - return tuple(a for a in (gsr.dp_resource, gsr.fsdp_resource) if a is not None) + elts = leading if isinstance(leading, tuple) else (leading,) + if ep_axis not in elts: + return False, ep_axis, outer_axes + allowed = set(outer_axes) | {ep_axis} + return all(a in allowed for a in elts), ep_axis, outer_axes def _ep_outer_axis(): @@ -165,7 +136,9 @@ def _ep_leading_dims(is_outer): outer = _ep_outer_axis() if not is_outer: return (1,) - return (cfg.world_size,) if outer is not None else (cfg.ep_size,) + if outer is None: + return (cfg.ep_size,) + return (get_mesh_axis_size(outer) * cfg.ep_size,) def _ep_output_spec(*trailing): @@ -211,8 +184,8 @@ class EpPreparePrimitive(BasePrimitive): @staticmethod def abstract(topk_idx_aval, *, top_k, dispatch_output_per_expert_alignment, is_outer): - # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with - # no DP); False: per-shard = (1,). + # is_outer=True: global leading dim = (dp*ep,) (or (ep,) with no DP); + # False: per-shard = (1,). cfg = get_ep_config() num_local_experts = cfg.num_local_experts assert ( @@ -276,20 +249,19 @@ def partition( top_k, dispatch_output_per_expert_alignment, is_outer, mesh, arg_infos, result_infos ): del is_outer, result_infos - gsr = global_mesh_resource() - ep_axis = gsr.ep_resource - outer_axes = _dispatch_input_outer_axes() idx_spec = arg_infos[0].sharding.spec - if not _leading_axis_ok(idx_spec, ep_axis, outer_axes): + ok, ep_axis, outer_axes = _leading_axis_ok(idx_spec) + if not ok: raise NotImplementedError( - "EpPrepare: topk_idx leading dims must shard on ep_resource" - f" ('{ep_axis}') and/or {outer_axes}, with the topk dim replicated;" - f" got spec={idx_spec}." + "EpPrepare: topk_idx leading dim must include ep_resource" + f" ('{ep_axis}'), optionally tupled with {outer_axes}," + f" with the topk dim replicated; got spec={idx_spec}." ) - idx_ndim = len(arg_infos[0].shape) - arg_shardings = (NamedSharding(mesh, _canonical_input_spec(idx_spec, idx_ndim)),) - tc_sharding = NamedSharding(mesh, _ep_output_spec(None)) - hm_sharding = NamedSharding(mesh, _ep_output_spec(None)) + arg_shardings = tuple(a.sharding for a in arg_infos) + # token_counts / handle_mem inherit the input's leading axis (trailing dims auto-pad to None). + leading_spec = PartitionSpec(idx_spec[0]) + tc_sharding = NamedSharding(mesh, leading_spec) + hm_sharding = NamedSharding(mesh, leading_spec) def sharded_impl(topk_idx): return EpPreparePrimitive.impl( @@ -334,8 +306,8 @@ def abstract( recv_capacity_per_rank, is_outer, ): - # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with - # no DP); False: per-shard = (1,). + # is_outer=True: global leading dim = (dp*ep,) (or (ep,) with no DP); + # False: per-shard = (1,). del topk_weights_aval, top_k, dispatch_output_per_expert_alignment, handle_mem_aval assert ( len(tokens_aval.shape) >= 2 @@ -429,27 +401,27 @@ def partition( result_infos, ): del is_outer, result_infos - gsr = global_mesh_resource() - ep_axis = gsr.ep_resource - outer_axes = _dispatch_input_outer_axes() tokens_spec = arg_infos[2].sharding.spec - if not _leading_axis_ok(tokens_spec, ep_axis, outer_axes): + ok, ep_axis, outer_axes = _leading_axis_ok(tokens_spec) + if not ok: raise NotImplementedError( - "EpDispatch: tokens leading dims must shard on ep_resource" - f" ('{ep_axis}') and/or {outer_axes}, hidden dim replicated;" - f" got spec={tokens_spec}." + "EpDispatch: tokens leading dim must include ep_resource" + f" ('{ep_axis}'), optionally tupled with {outer_axes}," + f" hidden dim replicated; got spec={tokens_spec}." ) idx_spec = arg_infos[1].sharding.spec tw_spec = arg_infos[3].sharding.spec - arg_shardings = ( - arg_infos[0].sharding, - NamedSharding(mesh, _canonical_input_spec(idx_spec, len(arg_infos[1].shape))), - NamedSharding(mesh, _canonical_input_spec(tokens_spec, len(arg_infos[2].shape))), - NamedSharding(mesh, _canonical_input_spec(tw_spec, len(arg_infos[3].shape))), - ) + if idx_spec[0] != tokens_spec[0] or tw_spec[0] != tokens_spec[0]: + raise NotImplementedError( + "EpDispatch: topk_idx, tokens, topk_weights must share the leading" + f" axis; got topk_idx={idx_spec}, tokens={tokens_spec}, topk_weights={tw_spec}." + ) + # Recv outputs share the tokens leading-only spec (trailing dims auto-pad to None). + leading_spec = PartitionSpec(tokens_spec[0]) + arg_shardings = tuple(a.sharding for a in arg_infos) out_shardings = ( - NamedSharding(mesh, _ep_output_spec(None, None)), - NamedSharding(mesh, _ep_output_spec(None)), + NamedSharding(mesh, leading_spec), + NamedSharding(mesh, leading_spec), ) def sharded_impl(handle_mem, topk_idx, tokens, topk_weights): @@ -500,46 +472,17 @@ def _prod(seq): return p -def _resolve_out_partition_spec(out_partition_spec, num_leading): - """Pick the combine output PartitionSpec. - - Defaults to a compound leading axis ``(dp_resource, ep_resource)`` when a - DP/FSDP axis is set on the active MeshResource, else just ``ep_resource``. - This matches the input sharding so XLA does not need collective-permutes - in the bwd path. - """ - if out_partition_spec is not None: - assert len(out_partition_spec) == num_leading + 1, ( - f"out_partition_spec length {len(out_partition_spec)} must equal num_leading" - f" + 1 ({num_leading + 1})" - ) - return tuple(out_partition_spec) - gsr = global_mesh_resource() - if gsr.ep_resource is None: - raise ValueError( - "ep_combine: ep_resource is not set on the active MeshResource;" - " pass out_sharding=... explicitly." - ) - outer = gsr.dp_resource or gsr.fsdp_resource - leading = (outer, gsr.ep_resource) if outer is not None else gsr.ep_resource - return (leading,) + (None,) * num_leading - - -def _per_shard_leading(out_leading_shape, resolved_spec, mesh): - """Per-shard leading shape given resolved partition spec and mesh.""" - per_shard = list(out_leading_shape) - for i, ax in enumerate(resolved_spec[: len(out_leading_shape)]): - if ax is None: - continue - axes = ax if isinstance(ax, tuple) else (ax,) - factor = 1 - for a in axes: - factor *= mesh.shape[a] - assert ( - per_shard[i] % factor == 0 - ), f"leading dim {per_shard[i]} not divisible by shard factor {factor} on axes {axes}" - per_shard[i] //= factor - return tuple(per_shard) +def _leading_per_shard(out_leading_shape, leading_axis, mesh): + """Per-shard leading shape: divide ``out_leading_shape[0]`` by the mesh factor on ``leading_axis``.""" + axes = leading_axis if isinstance(leading_axis, tuple) else (leading_axis,) + factor = 1 + for a in axes: + factor *= mesh.shape[a] + assert out_leading_shape[0] % factor == 0, ( + f"leading dim {out_leading_shape[0]} not divisible by shard factor" + f" {factor} on axes {axes}" + ) + return (out_leading_shape[0] // factor,) + tuple(out_leading_shape[1:]) class EpCombinePrimitive(BasePrimitive): @@ -639,10 +582,9 @@ def partition( " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)" f" over [num_procs, recv_pr, H]; got spec={eo_spec}." ) - resolved = _resolve_out_partition_spec(out_partition_spec, len(out_leading_shape)) - per_shard_leading = _per_shard_leading(out_leading_shape, resolved, mesh) + per_shard_leading = _leading_per_shard(out_leading_shape, out_partition_spec[0], mesh) arg_shardings = tuple(a.sharding for a in arg_infos) - out_sharding = NamedSharding(mesh, PartitionSpec(*resolved)) + out_sharding = NamedSharding(mesh, PartitionSpec(*out_partition_spec)) def sharded_impl(handle_mem, expert_out): return EpCombinePrimitive.impl( @@ -785,13 +727,15 @@ def partition( " PartitionSpec(ep_resource, None) (or ((dp, ep), None) when dp/fsdp is set)" f" over [num_procs, recv_pr]; got spec={gw_spec}." ) - resolved = _resolve_out_partition_spec(out_partition_spec, len(out_leading_shape)) - per_shard_leading = _per_shard_leading(out_leading_shape, resolved, mesh) + if gw_spec[0] != g_spec[0]: + raise NotImplementedError( + "EpDispatchBwd: grad and g_recv_topk_weights must share the leading" + f" axis; got grad={g_spec}, g_recv_topk_weights={gw_spec}." + ) + per_shard_leading = _leading_per_shard(out_leading_shape, out_partition_spec[0], mesh) arg_shardings = tuple(a.sharding for a in arg_infos) - out_shardings = [ - NamedSharding(mesh, PartitionSpec(*resolved)), - NamedSharding(mesh, PartitionSpec(*resolved, None)), - ] + out_sharding = NamedSharding(mesh, PartitionSpec(*out_partition_spec)) + out_shardings = [out_sharding, out_sharding] def sharded_impl(handle_mem, grad, g_recv_topk_weights): return EpDispatchBwdPrimitive.impl( @@ -840,8 +784,8 @@ def abstract( recv_capacity_per_rank, is_outer, ): - # is_outer=True: global leading dim = (world_size,) (or (ep_size,) with - # no DP); False: per-shard = (1,). + # is_outer=True: global leading dim = (dp*ep,) (or (ep,) with no DP); + # False: per-shard = (1,). del top_k, dispatch_output_per_expert_alignment, handle_mem_aval assert ( len(grad_aval.shape) >= 2 @@ -920,7 +864,8 @@ def partition( ): del is_outer, result_infos arg_shardings = tuple(a.sharding for a in arg_infos) - out_sharding = NamedSharding(mesh, _ep_output_spec(None, None)) + # EP-output leading (trailing dims auto-pad to None). + out_sharding = NamedSharding(mesh, _ep_output_spec()) def sharded_impl(handle_mem, grad): return EpCombineBwdPrimitive.impl( diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index 47ef4d89ed..f9dd03f032 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -15,7 +15,11 @@ import transformer_engine_jax import transformer_engine.jax.cpp_extensions as tex from transformer_engine.jax.cpp_extensions.misc import jax_dtype_to_te_dtype -from transformer_engine.jax.sharding import global_mesh_resource, get_mesh_axis_size +from transformer_engine.jax.sharding import ( + global_mesh_resource, + get_mesh_axis_size, + with_sharding_constraint, +) ep_prepare = tex.ep_prepare EpLayerConfig = tex.EpLayerConfig @@ -173,6 +177,19 @@ def ep_bootstrap( ) +def _default_out_partition_spec(): + """Leading-axis default: ``(("dp","ep"),)`` if DP/FSDP is set, else ``("ep",)``.""" + gsr = global_mesh_resource() + if gsr.ep_resource is None: + raise ValueError( + "ep_resource is not set on the active MeshResource;" + " pass out_sharding=... explicitly." + ) + outer = gsr.dp_resource or gsr.fsdp_resource + leading = (outer, gsr.ep_resource) if outer is not None else gsr.ep_resource + return (leading,) + + # ── ep_dispatch (custom_vjp) ───────────────────────────────────────────────── @@ -182,8 +199,8 @@ def ep_dispatch(cfg, topk_idx, tokens, topk_weights, recv_capacity_per_rank): ``cfg`` is a per-layer ``EpLayerConfig``; distinct layers may share a ``cfg`` (the pointer-keyed C++ cache keys on handle_mem, not on cfg). - Inputs are 2D ``[T, H]`` or 3D ``[B, S, H]`` with only the leading dim - sharded (axis in {ep, (dp, ep), dp, None}). Returns + Inputs are ``[..., H]`` with only the leading dim sharded as ``ep`` or + ``(dp, ep)``. Returns ``(recv_tokens, recv_topk_weights, handle_mem, token_counts)``; pass ``handle_mem`` and ``token_counts`` to the matching ``ep_combine``. """ @@ -191,6 +208,10 @@ def ep_dispatch(cfg, topk_idx, tokens, topk_weights, recv_capacity_per_rank): def _dispatch_fwd(cfg, topk_idx, tokens, topk_weights, recv_capacity_per_rank): + if not jnp.issubdtype(topk_weights.dtype, jnp.floating): + raise TypeError( + f"ep_dispatch: topk_weights must be a floating dtype; got {topk_weights.dtype}." + ) token_counts, handle_mem = tex.ep_prepare(cfg, topk_idx) recv_tokens, recv_topk_weights = tex.ep_dispatch_fwd( cfg, handle_mem, topk_idx, tokens, topk_weights, recv_capacity_per_rank @@ -203,20 +224,14 @@ def _dispatch_fwd(cfg, topk_idx, tokens, topk_weights, recv_capacity_per_rank): def _dispatch_bwd(cfg, recv_capacity_per_rank, res, g_outputs): del recv_capacity_per_rank handle_mem, out_leading = res - # Re-pin cotangent sharding: XLA transpose can drop the EP axis on a - # single-fwd-output cotangent, landing a global tensor in the FFI. - gsr = global_mesh_resource() - ep_axis = gsr.ep_resource - outer = gsr.dp_resource or gsr.fsdp_resource - leading = (outer, ep_axis) if outer is not None else ep_axis - g_recv_tokens = jax.lax.with_sharding_constraint( - g_outputs[0], jax.sharding.PartitionSpec(leading, None, None) - ) - g_recv_topk_weights = jax.lax.with_sharding_constraint( - g_outputs[1], jax.sharding.PartitionSpec(leading, None) - ) + # Re-pin cotangent: XLA transpose can drop the EP axis and feed the FFI a global tensor. + out_spec = _default_out_partition_spec() + spec = jax.sharding.PartitionSpec(*out_spec) + g_recv_tokens = with_sharding_constraint(g_outputs[0], spec) + g_recv_topk_weights = with_sharding_constraint(g_outputs[1], spec) grad_tokens, grad_topk_weights = tex.ep_dispatch_bwd( - cfg, handle_mem, g_recv_tokens, g_recv_topk_weights, out_leading + cfg, handle_mem, g_recv_tokens, g_recv_topk_weights, out_leading, + out_partition_spec=out_spec, ) return (None, grad_tokens, grad_topk_weights) @@ -234,27 +249,11 @@ def ep_combine( ): """Scatter-sum expert outputs back to source ranks. **Unweighted.** - ``ep_combine`` does not apply ``recv_topk_weights`` or any padded-slot - mask. The caller must pre-multiply ``expert_out`` by the dispatched - weights (and zero padded slots) before calling. Gradients w.r.t. - ``recv_topk_weights`` therefore flow through the caller's hadamard, not - through this op. - - Args: - cfg: ``EpLayerConfig`` matching the ``ep_dispatch`` call. - handle_mem: Routing-state buffer returned by ``ep_dispatch``. - token_counts: ``[num_procs, num_local_experts]`` int32 (passed through). - expert_out: ``[num_procs, recv_capacity_per_rank, H]`` pre-weighted - post-FFN activations. - num_local_tokens: STATIC int or tuple. int -> 2D output ``[T, H]``; - tuple -> N-D output ``[*tuple, H]``. - out_sharding: STATIC optional ``PartitionSpec`` tuple for the - output. Defaults to ``(("dp","ep"), *None)`` when - DP is set, else ``("ep", *None)``. Only the leading - dim may be sharded. - - Returns: - ``[..., H]`` combined output shaped per ``num_local_tokens``. + Caller must pre-multiply ``expert_out`` by ``recv_topk_weights`` (and + zero padded slots); gradients w.r.t. weights flow through that hadamard, + not through this op. ``num_local_tokens`` is STATIC: int -> ``[T, H]``, + tuple -> ``[*tuple, H]``. ``out_sharding`` defaults via + ``_default_out_partition_spec``; only the leading dim may be sharded. """ return _combine_fwd( cfg, handle_mem, token_counts, expert_out, @@ -267,6 +266,8 @@ def _combine_fwd( num_local_tokens, out_sharding, ): del token_counts + if out_sharding is None: + out_sharding = _default_out_partition_spec() result = tex.ep_combine_fwd( cfg, handle_mem, expert_out, num_local_tokens, out_partition_spec=out_sharding ) @@ -275,21 +276,11 @@ def _combine_fwd( def _combine_bwd(cfg, _num_local_tokens, _out_sharding, res, g_result): handle_mem, recv_capacity_per_rank = res - # Re-pin cotangent sharding: same XLA-transpose workaround as _dispatch_bwd. - gsr = global_mesh_resource() - if _out_sharding is not None: - spec = jax.sharding.PartitionSpec(*_out_sharding) - else: - ep_axis = gsr.ep_resource - outer = gsr.dp_resource or gsr.fsdp_resource - leading = (outer, ep_axis) if outer is not None and ep_axis is not None else ep_axis - spec = ( - jax.sharding.PartitionSpec(leading, *([None] * (g_result.ndim - 1))) - if leading is not None - else None - ) - if spec is not None: - g_result = jax.lax.with_sharding_constraint(g_result, spec) + # Re-pin cotangent (same XLA-transpose workaround as _dispatch_bwd). + if _out_sharding is None: + _out_sharding = _default_out_partition_spec() + spec = jax.sharding.PartitionSpec(*_out_sharding) + g_result = with_sharding_constraint(g_result, spec) grad_expert_out = tex.ep_combine_bwd(cfg, handle_mem, g_result, recv_capacity_per_rank) return (None, None, grad_expert_out) From 761d5571cbc5338e444589a3f92dda258f2fd762 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 14:33:48 -0700 Subject: [PATCH 07/22] jax/ep: derive ep_size and num_ep_groups from active mesh in ep_bootstrap Signed-off-by: Phuong Nguyen --- examples/jax/ep/bench/ep_bench.py | 1 - examples/jax/ep/ep_moe.py | 1 - tests/jax/test_multi_process_ep.py | 7 ++- transformer_engine/jax/cpp_extensions/ep.py | 18 +++--- transformer_engine/jax/csrc/extensions/ep.cpp | 5 +- transformer_engine/jax/ep.py | 57 ++++++++++++------- 6 files changed, 54 insertions(+), 35 deletions(-) diff --git a/examples/jax/ep/bench/ep_bench.py b/examples/jax/ep/bench/ep_bench.py index 27842dc834..6b96cbeb9a 100644 --- a/examples/jax/ep/bench/ep_bench.py +++ b/examples/jax/ep/bench/ep_bench.py @@ -132,7 +132,6 @@ def main(): ep_bootstrap( world_size=args.num_processes, rank=rank, - ep_size=ep_size, num_experts=args.num_experts, max_tokens_per_rank=args.tokens_per_rank, recv_capacity_per_rank=recv_capacity_per_rank, diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py index a6b0ba6545..5fd705a734 100644 --- a/examples/jax/ep/ep_moe.py +++ b/examples/jax/ep/ep_moe.py @@ -292,7 +292,6 @@ def main(): ep_bootstrap( world_size=args.num_processes, rank=args.process_id, - ep_size=args.ep_size, num_experts=args.num_experts, max_tokens_per_rank=args.num_tokens, recv_capacity_per_rank=args.recv_capacity_per_rank, diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py index ad1216642b..98f8575372 100644 --- a/tests/jax/test_multi_process_ep.py +++ b/tests/jax/test_multi_process_ep.py @@ -34,6 +34,7 @@ ep_prepare, ep_dispatch_fwd, ep_combine_fwd, + get_ep_config, ) @@ -117,12 +118,15 @@ def setUpClass(cls): ep_bootstrap( world_size=cls.num_procs, rank=cls.rank, - ep_size=cls.ep, num_experts=cls.num_experts, max_tokens_per_rank=TOKENS_PER_DP_SHARD, recv_capacity_per_rank=cls.recv_capacity_per_rank, hidden_dim=HIDDEN_DIM, ) + # Bootstrap must snapshot ep_size and num_ep_groups onto EpConfig so + # abstract-eval never needs the active mesh. + assert get_ep_config().ep_size == cls.ep + assert get_ep_config().num_ep_groups == cls.dp # One layer config shared by all single-layer tests below; non-zero # alignment exercises dispatch_output_per_expert_alignment end-to-end. cls.hk = EpLayerConfig(top_k=TOP_K, dispatch_output_per_expert_alignment=16) @@ -136,7 +140,6 @@ def test_bootstrap_rejects_missing_ep_axis(self): ep_bootstrap( world_size=self.num_procs, rank=self.rank, - ep_size=self.ep, num_experts=self.num_experts, max_tokens_per_rank=TOKENS_PER_DP_SHARD, recv_capacity_per_rank=self.recv_capacity_per_rank, diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py index 55b204efdc..946f8ea2cb 100644 --- a/transformer_engine/jax/cpp_extensions/ep.py +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -23,7 +23,7 @@ import transformer_engine_jax from .base import BasePrimitive, register_primitive -from ..sharding import global_mesh_resource, get_mesh_axis_size +from ..sharding import global_mesh_resource __all__ = [ "EpConfig", @@ -45,11 +45,16 @@ @dataclass(frozen=True) class EpConfig: - """Immutable Python view of the EP bootstrap config (see ep_bootstrap).""" + """Snapshot of the EP bootstrap config (see ep_bootstrap). + + num_ep_groups is the size of the outer dp/fsdp mesh axis (1 if neither + is set), captured at bootstrap so abstract-eval never reads the mesh. + """ world_size: int rank: int ep_size: int + num_ep_groups: int num_experts: int num_local_experts: int max_tokens_per_rank: int @@ -130,15 +135,12 @@ def _ep_outer_axis(): def _ep_leading_dims(is_outer): - """Single leading dim of an EP-output tensor: ``(dp*ep,)`` (or ``(ep,)`` when - DP is unset) globally; ``(1,)`` per shard.""" + """Leading dim of an EP-output tensor: num_ep_groups*ep_size globally, + 1 per shard. Read from EpConfig so abstract-eval needs no active mesh.""" cfg = get_ep_config() - outer = _ep_outer_axis() if not is_outer: return (1,) - if outer is None: - return (cfg.ep_size,) - return (get_mesh_axis_size(outer) * cfg.ep_size,) + return (cfg.num_ep_groups * cfg.ep_size,) def _ep_output_spec(*trailing): diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp index e727eadce9..8bb9083159 100644 --- a/transformer_engine/jax/csrc/extensions/ep.cpp +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -43,13 +43,16 @@ class EpResources { ncclUniqueId uid; std::memcpy(&uid, p.uid_bytes.data(), sizeof(uid)); NVTE_CHECK_NCCL(ncclCommInitRank(&comm_, p.ep_size, uid, p.rank_within_group)); + // zero_copy=0: JAX EP path always stages payloads; the zero-copy fast path + // requires NVTECommWindow-backed tensors, which JAX bindings don't expose. NVTEEpGroupConfig cfg{.ep_size = p.ep_size, .num_experts = p.num_experts, .max_tokens_per_rank = p.max_tokens_per_rank, .max_recv_tokens_per_rank = p.max_recv_tokens_per_rank, .hidden_dim = p.hidden_dim, .max_num_sms = p.max_num_sms, - .max_token_dtype = p.max_token_dtype}; + .max_token_dtype = p.max_token_dtype, + .zero_copy = 0}; try { nvte_ep_initialize(static_cast(comm_), cfg); } catch (...) { diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index f9dd03f032..495ac0d94f 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -72,7 +72,6 @@ def _allgather_uid(uid_arr, world_size, uid_size): def ep_bootstrap( world_size, rank, - ep_size, num_experts, max_tokens_per_rank, recv_capacity_per_rank, @@ -82,8 +81,12 @@ def ep_bootstrap( ): """Initialize the EP communicator. Call once per process before any EP op. + Must run inside the active JAX Mesh and a global_shard_guard; ep_size and + num_ep_groups are read from the mesh axes named by MeshResource.ep_resource + and MeshResource.dp_resource/fsdp_resource. + max_token_dtype is the widest jnp dtype the group will dispatch; tensors - passed to ``ep_dispatch`` may use any narrower dtype. + passed to ep_dispatch may use any narrower dtype. max_num_sms caps the SMs allotted to EP kernels (0 = auto). """ if jnp.dtype(max_token_dtype) != jnp.bfloat16: @@ -96,19 +99,41 @@ def ep_bootstrap( f"ep_bootstrap requires world_size >= 2 (got {world_size}); NCCL EP needs" " at least 2 ranks to form a group." ) - if world_size % ep_size != 0: - raise ValueError( - f"world_size ({world_size}) must be divisible by ep_size ({ep_size}); otherwise" - " some EP groups would have fewer than ep_size ranks and ncclCommInitRank would hang." - ) - if num_experts % ep_size != 0: - raise ValueError(f"num_experts ({num_experts}) must be divisible by ep_size ({ep_size}).") if jax.local_device_count() != 1: raise ValueError( "ep_bootstrap requires one local device per process (got" f" jax.local_device_count() = {jax.local_device_count()}); NCCL EP does not" " support single-process multi-device setups." ) + + gsr = global_mesh_resource() + ep_resource = gsr.ep_resource + if ep_resource is None: + raise ValueError( + "ep_bootstrap requires MeshResource.ep_resource to be set; enter a" + " global_shard_guard(MeshResource(..., ep_resource=)) before bootstrap." + ) + ep_size = get_mesh_axis_size(ep_resource) + outer_axis = gsr.dp_resource or gsr.fsdp_resource + if outer_axis is None: + if world_size != ep_size: + raise ValueError( + f"ep_bootstrap: world_size ({world_size}) > ep_size ({ep_size}) but neither" + " MeshResource.dp_resource nor fsdp_resource is set; name the outer axis so" + " EP-output tensors can shard across EP groups." + ) + num_ep_groups = 1 + else: + num_ep_groups = get_mesh_axis_size(outer_axis) + if num_ep_groups * ep_size != world_size: + raise ValueError( + f"ep_bootstrap: num_ep_groups*ep_size ({num_ep_groups}*{ep_size}=" + f"{num_ep_groups * ep_size}) must equal world_size ({world_size}); check that" + f" the '{outer_axis}' and '{ep_resource}' mesh axes cover all ranks." + ) + if num_experts % ep_size != 0: + raise ValueError(f"num_experts ({num_experts}) must be divisible by ep_size ({ep_size}).") + UID_SIZE = 128 dp_color = rank // ep_size rank_within_group = rank % ep_size @@ -131,19 +156,6 @@ def ep_bootstrap( all_uids = _allgather_uid(uid_arr, world_size, UID_SIZE) uid_bytes = bytes(np.asarray(all_uids[dp_color * ep_size]).tolist()) - ep_resource = global_mesh_resource().ep_resource - if ep_resource is None: - raise ValueError( - "ep_bootstrap requires MeshResource.ep_resource to be set; enter a" - " global_shard_guard(MeshResource(..., ep_resource=)) before bootstrap." - ) - mesh_ep_size = get_mesh_axis_size(ep_resource) - if mesh_ep_size != ep_size: - raise ValueError( - f"ep_bootstrap: EpConfig.ep_size ({ep_size}) does not match mesh axis" - f" '{ep_resource}' size ({mesh_ep_size})." - ) - # Eager NCCL init while ranks are barrier-synced by the UID broadcast above. transformer_engine_jax.set_ep_bootstrap_params( uid_bytes, @@ -168,6 +180,7 @@ def ep_bootstrap( world_size=world_size, rank=rank, ep_size=ep_size, + num_ep_groups=num_ep_groups, num_experts=num_experts, num_local_experts=num_experts // ep_size, max_tokens_per_rank=max_tokens_per_rank, From 82c90637a539b92dff9685fb188a741899badc1d Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 20:23:13 -0700 Subject: [PATCH 08/22] examples/jax/ep: rename ep_handle to layer_cfg in ep_moe.py (matches EpLayerConfig type) Signed-off-by: Phuong Nguyen --- examples/jax/ep/ep_moe.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py index 5fd705a734..e25bb34fd7 100644 --- a/examples/jax/ep/ep_moe.py +++ b/examples/jax/ep/ep_moe.py @@ -205,7 +205,7 @@ def _moe_step(args, topk_idx, tokens, topk_w, kernels): kernel_spec = PartitionSpec("ep", None, None, None) kernels = kernels.reshape(ep_size, NLE, *kernels.shape[1:]) - ep_handle = EpLayerConfig(top_k=args.top_k, dispatch_output_per_expert_alignment=16) + layer_cfg = EpLayerConfig(top_k=args.top_k, dispatch_output_per_expert_alignment=16) @jax.jit def step(topk_idx, tokens, topk_w, local_kernels): @@ -216,7 +216,7 @@ def step(topk_idx, tokens, topk_w, local_kernels): local_kernels, NamedSharding(mesh, kernel_spec) ) recv_tokens, recv_topk_w, handle_mem, _tc = ep_dispatch( - ep_handle, topk_idx, tokens, topk_w, args.recv_capacity_per_rank + layer_cfg, topk_idx, tokens, topk_w, args.recv_capacity_per_rank ) recv_tokens = jax.lax.with_sharding_constraint(recv_tokens, NamedSharding(mesh, ep3)) recv_topk_w = jax.lax.with_sharding_constraint(recv_topk_w, NamedSharding(mesh, ep2)) @@ -230,7 +230,7 @@ def step(topk_idx, tokens, topk_w, local_kernels): ).astype(expert_out.dtype) weighted = jax.lax.with_sharding_constraint(weighted, NamedSharding(mesh, ep3)) return ep_combine( - ep_handle, + layer_cfg, handle_mem, _tc, weighted, @@ -358,7 +358,6 @@ def _norm(spec, ndim): ref_out, ref_grad = _reference_grad( tokens_global_np, topk_idx_global_np, w_global_np, kernels_np ) - ref_loss = 0.5 * float((ref_out.astype(np.float32) ** 2).sum()) # 3D global ``[num_procs, S, H]`` with num_procs = dp * ep. Each EP # column in a DP color sees identical inputs (and produces identical # outputs), so collapse the ep dim to one replica before flattening @@ -374,15 +373,6 @@ def _norm(spec, ndim): .reshape(dp_size, ep_size, -1, ref_grad.shape[-1])[:, 0] .reshape(-1, ref_grad.shape[-1]) ) - if args.process_id == 0: - fwd_diff = np.abs(global_out - ref_out) - grad_diff = np.abs(global_grad - ref_grad) - print( - f"[ep_moe] DEBUG loss={float(loss):.4f} ref_loss(global)={ref_loss:.4f} " - f"ratio={float(loss) / max(ref_loss, 1e-9):.4f} (expected ~1.0)" - ) - print(f"[ep_moe] DEBUG fwd max-abs-diff per row: {fwd_diff.max(axis=1)}") - print(f"[ep_moe] DEBUG grad max-abs-diff per row: {grad_diff.max(axis=1)}") np.testing.assert_allclose( global_out, ref_out, From a314a45bf80554bac936e416f701d961d2f053ce Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 20:31:26 -0700 Subject: [PATCH 09/22] jax/ep: add primitive docstrings and silence missing-kwoa false positives (lint 10.00) Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/ep.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py index 946f8ea2cb..cef88d0937 100644 --- a/transformer_engine/jax/cpp_extensions/ep.py +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -72,12 +72,14 @@ def set_ep_config(config: EpConfig) -> None: def get_ep_config() -> EpConfig: + """Return the process-wide EpConfig set by ep_bootstrap.""" if _ep_config is None: raise RuntimeError("EpConfig has not been set. Did you call ep_bootstrap()?") return _ep_config def get_ep_num_local_experts() -> int: + """Number of experts owned by this EP rank.""" return get_ep_config().num_local_experts @@ -178,6 +180,8 @@ def _ep_spec_ok(spec, trailing_count): class EpPreparePrimitive(BasePrimitive): + """FFI primitive for nvte_ep_prepare: routing setup and per-expert token counts.""" + name = "te_ep_prepare_ffi" multiple_results = True impl_static_args = (1, 2, 3) # top_k, dispatch_output_per_expert_alignment, is_outer @@ -289,6 +293,8 @@ def shardy_sharding_rule(*args): class EpDispatchPrimitive(BasePrimitive): + """FFI primitive for nvte_ep_dispatch (forward).""" + name = "te_ep_dispatch_ffi" multiple_results = True impl_static_args = (4, 5, 6, 7) # top_k, dispatch_output_per_expert_alignment, @@ -329,7 +335,7 @@ def abstract( def outer_abstract(*args, **kwargs): kwargs = dict(kwargs) kwargs["is_outer"] = True - avals = EpDispatchPrimitive.abstract(*args, **kwargs) + avals = EpDispatchPrimitive.abstract(*args, **kwargs) # pylint: disable=missing-kwoa return avals[:2] @staticmethod @@ -488,6 +494,8 @@ def _leading_per_shard(out_leading_shape, leading_axis, mesh): class EpCombinePrimitive(BasePrimitive): + """FFI primitive for nvte_ep_combine (forward).""" + name = "te_ep_combine_ffi" multiple_results = False impl_static_args = (2, 3, 4, 5) # top_k, dispatch_output_per_expert_alignment, @@ -617,6 +625,8 @@ def shardy_sharding_rule(*args): class EpDispatchBwdPrimitive(BasePrimitive): + """FFI primitive for the backward of nvte_ep_dispatch.""" + name = "te_ep_dispatch_bwd_ffi" multiple_results = True impl_static_args = (3, 4, 5, 6) # top_k, dispatch_output_per_expert_alignment, @@ -769,6 +779,8 @@ def shardy_sharding_rule(*args): class EpCombineBwdPrimitive(BasePrimitive): + """FFI primitive for the backward of nvte_ep_combine.""" + name = "te_ep_combine_bwd_ffi" multiple_results = False impl_static_args = (2, 3, 4, 5) # top_k, dispatch_output_per_expert_alignment, @@ -801,7 +813,7 @@ def abstract( def outer_abstract(*args, **kwargs): kwargs = dict(kwargs) kwargs["is_outer"] = True - return EpCombineBwdPrimitive.abstract(*args, **kwargs) + return EpCombineBwdPrimitive.abstract(*args, **kwargs) # pylint: disable=missing-kwoa @staticmethod def lowering( From c0b280d0e51217e3013f8ec6f5d1b0866c9f9fca Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 20:41:22 -0700 Subject: [PATCH 10/22] jax/ep: apply black formatting (pre-commit hook output) Signed-off-by: Phuong Nguyen --- examples/jax/ep/bench/ep_bench.py | 4 +- examples/jax/ep/ep_moe.py | 6 +-- tests/jax/test_multi_process_ep.py | 49 +++++++++++++-------- transformer_engine/jax/cpp_extensions/ep.py | 27 +++++++----- transformer_engine/jax/ep.py | 38 +++++++++++----- 5 files changed, 75 insertions(+), 49 deletions(-) diff --git a/examples/jax/ep/bench/ep_bench.py b/examples/jax/ep/bench/ep_bench.py index 6b96cbeb9a..27ad8ca146 100644 --- a/examples/jax/ep/bench/ep_bench.py +++ b/examples/jax/ep/bench/ep_bench.py @@ -153,9 +153,7 @@ def run_prepare(idx): @jax.jit def run_dispatch(hm, idx, toks, w): - recv_t, recv_w = tex_ep.ep_dispatch_fwd( - cfg, hm, idx, toks, w, recv_capacity_per_rank - ) + recv_t, recv_w = tex_ep.ep_dispatch_fwd(cfg, hm, idx, toks, w, recv_capacity_per_rank) recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(mesh, ep_spec_3d)) recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(mesh, ep_spec_2d)) return recv_t, recv_w diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py index e25bb34fd7..a23a0b33c9 100644 --- a/examples/jax/ep/ep_moe.py +++ b/examples/jax/ep/ep_moe.py @@ -225,9 +225,9 @@ def step(topk_idx, tokens, topk_w, local_kernels): # ep_combine is unweighted: pre-multiply by recv_topk_w and zero # padded slots (recv_topk_w == 0) before the scatter-sum. mask = (recv_topk_w != 0).astype(jnp.float32)[..., None] - weighted = ( - expert_out.astype(jnp.float32) * recv_topk_w[..., None] * mask - ).astype(expert_out.dtype) + weighted = (expert_out.astype(jnp.float32) * recv_topk_w[..., None] * mask).astype( + expert_out.dtype + ) weighted = jax.lax.with_sharding_constraint(weighted, NamedSharding(mesh, ep3)) return ep_combine( layer_cfg, diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py index 98f8575372..1f986adbe8 100644 --- a/tests/jax/test_multi_process_ep.py +++ b/tests/jax/test_multi_process_ep.py @@ -144,7 +144,7 @@ def test_bootstrap_rejects_missing_ep_axis(self): max_tokens_per_rank=TOKENS_PER_DP_SHARD, recv_capacity_per_rank=self.recv_capacity_per_rank, hidden_dim=HIDDEN_DIM, - ) + ) # ── Helpers ─────────────────────────────────────────────────────────── @@ -260,15 +260,15 @@ def test_two_layer_dispatch_no_handle_aliasing(self): w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) def one_layer(hk, idx, toks, w_): - recv_t, recv_w, hm, tc = ep_dispatch( - hk, idx, toks, w_, self.recv_capacity_per_rank + recv_t, recv_w, hm, tc = ep_dispatch(hk, idx, toks, w_, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) ) - recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_spec_3d)) - recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_spec_2d)) - weighted = self._preweight_expert_out(recv_t, recv_w) - return ep_combine( - hk, hm, tc, weighted, T_global, out_sharding=(("dp", "ep"), None) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) ) + weighted = self._preweight_expert_out(recv_t, recv_w) + return ep_combine(hk, hm, tc, weighted, T_global, out_sharding=(("dp", "ep"), None)) @jax.jit def run(idx, ta_, tb_, w_): @@ -284,12 +284,14 @@ def run(idx, ta_, tb_, w_): np.testing.assert_allclose( np.asarray(out_a_g.astype(jnp.float32)), np.asarray(tokens.astype(jnp.float32)), - atol=5e-2, rtol=5e-2, + atol=5e-2, + rtol=5e-2, ) np.testing.assert_allclose( np.asarray(out_b_g.astype(jnp.float32)), np.asarray(tokens_b.astype(jnp.float32)), - atol=5e-2, rtol=5e-2, + atol=5e-2, + rtol=5e-2, ) def test_primitive_prepare(self): @@ -343,7 +345,10 @@ def run(idx, toks, w): weighted, NamedSharding(self.mesh, ep_spec_3d) ) out = ep_combine_fwd( - self.hk, hm, weighted, T_global, + self.hk, + hm, + weighted, + T_global, out_partition_spec=(("dp", "ep"), None), ) return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) @@ -387,7 +392,9 @@ def loss_fn(toks): toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) - recv_t, recv_w, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) + recv_t, recv_w, hm, tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank + ) recv_t = jax.lax.with_sharding_constraint( recv_t, NamedSharding(self.mesh, ep_spec_3d) ) @@ -436,7 +443,9 @@ def test_dispatch_combine_3d_input_output(self): @jax.jit def run(idx, toks, w): - recv_t, recv_w, hm, _tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) + recv_t, recv_w, hm, _tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank + ) recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) weighted = self._preweight_expert_out(recv_t, recv_w) @@ -499,7 +508,7 @@ def loss_fn(toks): slot_idx = jnp.arange(self.recv_capacity_per_rank, dtype=jnp.int32) mask = slot_idx[None, :] < total_recv rt32 = jnp.where(mask[..., None], recv_tokens.astype(jnp.float32), 0.0) - return 0.5 * (rt32 ** 2).sum() + return 0.5 * (rt32**2).sum() loss, grad_tokens = jax.jit(jax.value_and_grad(loss_fn))(tokens) grad_tokens.block_until_ready() @@ -626,7 +635,9 @@ def run(idx, toks, w): idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec)) toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) - recv_t, recv_w, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) + recv_t, recv_w, hm, tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank + ) recv_t = jax.lax.with_sharding_constraint( recv_t, NamedSharding(self.mesh, ep_spec_3d) ) @@ -634,9 +645,7 @@ def run(idx, toks, w): recv_w, NamedSharding(self.mesh, ep_spec_2d) ) weighted = self._preweight_expert_out(recv_t, recv_w) - out = ep_combine( - self.hk, hm, tc, weighted, T_dp, out_sharding=(("dp", "ep"), None) - ) + out = ep_combine(self.hk, hm, tc, weighted, T_dp, out_sharding=(("dp", "ep"), None)) return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) compiled = run.lower(topk_idx, tokens, topk_w).compile() @@ -675,7 +684,9 @@ def fwd(eo, toks, idx, w): _rt, rw, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) rw = jax.lax.with_sharding_constraint(rw, NamedSharding(self.mesh, ep_spec_2d)) weighted = self._preweight_expert_out(eo, rw) - combined = ep_combine(self.hk, hm, tc, weighted, T_dp, out_sharding=(("dp", "ep"), None)) + combined = ep_combine( + self.hk, hm, tc, weighted, T_dp, out_sharding=(("dp", "ep"), None) + ) return jax.lax.with_sharding_constraint(combined, NamedSharding(self.mesh, dp_spec)) # jax.vjp + pinned cotangent feeds ep_combine_bwd/ep_dispatch_bwd diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py index cef88d0937..233a4f4314 100644 --- a/transformer_engine/jax/cpp_extensions/ep.py +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -245,9 +245,7 @@ def impl(topk_idx, top_k, dispatch_output_per_expert_alignment, is_outer): return token_counts, handle_mem @staticmethod - def batcher( - batched_args, batch_dims, *, top_k, dispatch_output_per_expert_alignment, is_outer - ): + def batcher(batched_args, batch_dims, *, top_k, dispatch_output_per_expert_alignment, is_outer): raise NotImplementedError("EpPreparePrimitive does not support vmap") @staticmethod @@ -486,10 +484,9 @@ def _leading_per_shard(out_leading_shape, leading_axis, mesh): factor = 1 for a in axes: factor *= mesh.shape[a] - assert out_leading_shape[0] % factor == 0, ( - f"leading dim {out_leading_shape[0]} not divisible by shard factor" - f" {factor} on axes {axes}" - ) + assert ( + out_leading_shape[0] % factor == 0 + ), f"leading dim {out_leading_shape[0]} not divisible by shard factor {factor} on axes {axes}" return (out_leading_shape[0] // factor,) + tuple(out_leading_shape[1:]) @@ -918,8 +915,9 @@ def ep_prepare(cfg: EpLayerConfig, topk_idx): ) -def ep_dispatch_fwd(cfg: EpLayerConfig, handle_mem, topk_idx, tokens, topk_weights, - recv_capacity_per_rank): +def ep_dispatch_fwd( + cfg: EpLayerConfig, handle_mem, topk_idx, tokens, topk_weights, recv_capacity_per_rank +): """Scatter tokens and weights to expert ranks; returns (recv_tokens, recv_topk_weights).""" return EpDispatchPrimitive.outer_primitive.bind( handle_mem, @@ -933,8 +931,9 @@ def ep_dispatch_fwd(cfg: EpLayerConfig, handle_mem, topk_idx, tokens, topk_weigh ) -def ep_combine_fwd(cfg: EpLayerConfig, handle_mem, expert_out, num_local_tokens, - out_partition_spec=None): +def ep_combine_fwd( + cfg: EpLayerConfig, handle_mem, expert_out, num_local_tokens, out_partition_spec=None +): """Gather expert outputs back to home ranks. expert_out is pre-weighted.""" out_leading = _normalize_leading_shape(num_local_tokens) return EpCombinePrimitive.outer_primitive.bind( @@ -948,7 +947,11 @@ def ep_combine_fwd(cfg: EpLayerConfig, handle_mem, expert_out, num_local_tokens, def ep_dispatch_bwd( - cfg: EpLayerConfig, handle_mem, grad, g_recv_topk_weights, num_local_tokens, + cfg: EpLayerConfig, + handle_mem, + grad, + g_recv_topk_weights, + num_local_tokens, out_partition_spec=None, ): """Backward of dispatch; returns (grad_tokens, grad_topk_weights).""" diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index 495ac0d94f..ed22ad5be7 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -54,8 +54,7 @@ def _allgather_uid(uid_arr, world_size, uid_size): devices = np.asarray(jax.devices()) if devices.size != world_size: raise RuntimeError( - f"_allgather_uid fallback expected {world_size} global devices," - f" got {devices.size}." + f"_allgather_uid fallback expected {world_size} global devices, got {devices.size}." ) mesh = jax.sharding.Mesh(devices, ("_uid_all",)) sharded = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("_uid_all", None)) @@ -91,7 +90,7 @@ def ep_bootstrap( """ if jnp.dtype(max_token_dtype) != jnp.bfloat16: raise NotImplementedError( - f"ep_bootstrap: only max_token_dtype=jnp.bfloat16 is supported today, got" + "ep_bootstrap: only max_token_dtype=jnp.bfloat16 is supported today, got" f" {jnp.dtype(max_token_dtype)}." ) if world_size < 2: @@ -195,8 +194,7 @@ def _default_out_partition_spec(): gsr = global_mesh_resource() if gsr.ep_resource is None: raise ValueError( - "ep_resource is not set on the active MeshResource;" - " pass out_sharding=... explicitly." + "ep_resource is not set on the active MeshResource; pass out_sharding=... explicitly." ) outer = gsr.dp_resource or gsr.fsdp_resource leading = (outer, gsr.ep_resource) if outer is not None else gsr.ep_resource @@ -243,7 +241,11 @@ def _dispatch_bwd(cfg, recv_capacity_per_rank, res, g_outputs): g_recv_tokens = with_sharding_constraint(g_outputs[0], spec) g_recv_topk_weights = with_sharding_constraint(g_outputs[1], spec) grad_tokens, grad_topk_weights = tex.ep_dispatch_bwd( - cfg, handle_mem, g_recv_tokens, g_recv_topk_weights, out_leading, + cfg, + handle_mem, + g_recv_tokens, + g_recv_topk_weights, + out_leading, out_partition_spec=out_spec, ) return (None, grad_tokens, grad_topk_weights) @@ -257,8 +259,12 @@ def _dispatch_bwd(cfg, recv_capacity_per_rank, res, g_outputs): @partial(jax.custom_vjp, nondiff_argnums=(0, 4, 5)) def ep_combine( - cfg, handle_mem, token_counts, expert_out, - num_local_tokens, out_sharding=None, + cfg, + handle_mem, + token_counts, + expert_out, + num_local_tokens, + out_sharding=None, ): """Scatter-sum expert outputs back to source ranks. **Unweighted.** @@ -269,14 +275,22 @@ def ep_combine( ``_default_out_partition_spec``; only the leading dim may be sharded. """ return _combine_fwd( - cfg, handle_mem, token_counts, expert_out, - num_local_tokens, out_sharding, + cfg, + handle_mem, + token_counts, + expert_out, + num_local_tokens, + out_sharding, )[0] def _combine_fwd( - cfg, handle_mem, token_counts, expert_out, - num_local_tokens, out_sharding, + cfg, + handle_mem, + token_counts, + expert_out, + num_local_tokens, + out_sharding, ): del token_counts if out_sharding is None: From 348cc36ddd3f72ff115c23f3fa1e795ab81a44c0 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 20:42:32 -0700 Subject: [PATCH 11/22] build_tools/jax: gate NCCL EP on NVTE_BUILD_WITH_NCCL_EP (default on); define NVTE_WITH_NCCL_EP Signed-off-by: Phuong Nguyen --- build_tools/jax.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/build_tools/jax.py b/build_tools/jax.py index 35ed62f832..4ad0442055 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -120,6 +120,9 @@ def setup_jax_extension( if bool(int(os.getenv("NVTE_WITH_CUBLASMP", 0))): cxx_flags.append("-DNVTE_WITH_CUBLASMP") + if bool(int(os.getenv("NVTE_BUILD_WITH_NCCL_EP", "1"))): + cxx_flags.append("-DNVTE_WITH_NCCL_EP") + # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension From bf811b81aa27bf0d3470263586992bb95b53b549 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 22:19:50 -0700 Subject: [PATCH 12/22] jax/ep: collapse 5 FFI attr structs into single EpConfig Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/ep.py | 2 - transformer_engine/jax/csrc/extensions/ep.cpp | 70 ++++--------------- 2 files changed, 12 insertions(+), 60 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py index 233a4f4314..54fec2045b 100644 --- a/transformer_engine/jax/cpp_extensions/ep.py +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -537,7 +537,6 @@ def lowering( expert_out, top_k=int(top_k), dispatch_output_per_expert_alignment=int(dispatch_output_per_expert_alignment), - num_local_tokens=_prod(out_leading_shape), ) @staticmethod @@ -675,7 +674,6 @@ def lowering( g_recv_topk_weights, top_k=int(top_k), dispatch_output_per_expert_alignment=int(dispatch_output_per_expert_alignment), - num_local_tokens=_prod(out_leading_shape), ) @staticmethod diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp index 8bb9083159..9cd1422d37 100644 --- a/transformer_engine/jax/csrc/extensions/ep.cpp +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -113,29 +113,7 @@ std::shared_ptr AcquireEpResources() { // attributes; prepare passes them to the C API as NVTEEpLayerConfig, and the // per-step ops carry top_k only to validate the topk_idx last dim. -struct EpPrepareConfig { - int64_t top_k; - int64_t dispatch_output_per_expert_alignment; -}; - -struct EpDispatchConfig { - int64_t top_k; - int64_t dispatch_output_per_expert_alignment; -}; - -struct EpCombineConfig { - int64_t top_k; - int64_t dispatch_output_per_expert_alignment; - int64_t num_local_tokens; -}; - -struct EpDispatchBwdConfig { - int64_t top_k; - int64_t dispatch_output_per_expert_alignment; - int64_t num_local_tokens; -}; - -struct EpCombineBwdConfig { +struct EpConfig { int64_t top_k; int64_t dispatch_output_per_expert_alignment; }; @@ -215,7 +193,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpInstantiateHandler, EpInstantiateImpl, FFI::Bind Error_Type EpPrepareFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type topk_idx, Result_Type token_counts, Result_Type handle_mem, Result_Type workspace, - EpPrepareConfig config) { + EpConfig config) { (void)ep_state; // lifetime only. auto topk_dims = topk_idx.dimensions(); NVTE_CHECK(topk_dims.size() >= 2, @@ -260,7 +238,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpPrepareHandler, EpPrepareFFI, .Ret() // token_counts .Ret() // handle_mem .Ret() // workspace (FFI scratch) - .Attrs(), + .Attrs(), FFI_CudaGraph_Traits); // ── ep_dispatch ─────────────────────────────────────────────────────────────── @@ -268,7 +246,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpPrepareHandler, EpPrepareFFI, Error_Type EpDispatchFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, Buffer_Type topk_idx, Buffer_Type tokens, Buffer_Type topk_weights, Result_Type recv_tokens, Result_Type recv_topk_weights, - Result_Type workspace, EpDispatchConfig config) { + Result_Type workspace, EpConfig config) { (void)ep_state; auto token_dims = tokens.dimensions(); NVTE_CHECK(token_dims.size() >= 2, @@ -351,13 +329,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchHandler, EpDispatchFFI, .Ret() // recv_tokens .Ret() // recv_topk_weights .Ret() // workspace (FFI scratch) - .Attrs(), + .Attrs(), FFI_CudaGraph_Traits); // ── ep_combine ──────────────────────────────────────────────────────────────── Error_Type EpCombineFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, - Buffer_Type expert_out, Result_Type result, EpCombineConfig config) { + Buffer_Type expert_out, Result_Type result, EpConfig config) { (void)ep_state; auto eo_dims = expert_out.dimensions(); NVTE_CHECK(eo_dims.size() >= 2, @@ -376,9 +354,6 @@ Error_Type EpCombineFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_T NVTE_CHECK(res_dims.size() >= 2, "result must be at least 2D [..., H]; got ndim=", res_dims.size()); const size_t res_T_flat = product(res_dims, 0, res_dims.size() - 1); - NVTE_CHECK(static_cast(res_T_flat) == config.num_local_tokens, - "result leading-dim product (", res_T_flat, ") must equal num_local_tokens (", - config.num_local_tokens, ")"); std::vector res_shape = {res_T_flat, H}; auto result_ = TensorWrapper(result->untyped_data(), res_shape, eo_dtype); @@ -395,7 +370,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineHandler, EpCombineFFI, .Arg() // handle_mem .Arg() // expert_out .Ret() // result - .Attrs(), + .Attrs(), FFI_CudaGraph_Traits); // ── ep_dispatch_bwd ─────────────────────────────────────────────────────────── @@ -403,7 +378,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineHandler, EpCombineFFI, Error_Type EpDispatchBwdFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, Buffer_Type grad, Buffer_Type g_recv_topk_weights, Result_Type grad_tokens, Result_Type grad_topk_weights, - EpDispatchBwdConfig config) { + EpConfig config) { (void)ep_state; auto grad_dims = grad.dimensions(); NVTE_CHECK(grad_dims.size() >= 2, @@ -433,9 +408,6 @@ Error_Type EpDispatchBwdFFI(cudaStream_t stream, EpInstanceState* ep_state, Buff NVTE_CHECK(out_dims.size() >= 2, "grad_tokens must be at least 2D [..., H], got ndim=", out_dims.size()); const size_t T_flat = product(out_dims, 0, out_dims.size() - 1); - NVTE_CHECK(static_cast(T_flat) == config.num_local_tokens, - "grad_tokens leading-dim product (", T_flat, ") must equal num_local_tokens (", - config.num_local_tokens, ")"); std::vector out_shape = {T_flat, H}; auto grad_tokens_ = TensorWrapper(grad_tokens->untyped_data(), out_shape, g_dtype); @@ -468,14 +440,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchBwdHandler, EpDispatchBwdFFI, .Arg() // g_recv_topk_weights .Ret() // grad_tokens .Ret() // grad_topk_weights - .Attrs(), + .Attrs(), FFI_CudaGraph_Traits); // ── ep_combine_bwd ──────────────────────────────────────────────────────────── Error_Type EpCombineBwdFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, Buffer_Type grad, Result_Type grad_expert_out, - EpCombineBwdConfig config) { + EpConfig config) { (void)ep_state; auto grad_dims = grad.dimensions(); NVTE_CHECK(grad_dims.size() >= 2, @@ -513,32 +485,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineBwdHandler, EpCombineBwdFFI, .Arg() // handle_mem .Arg() // grad (w.r.t. result) .Ret() // grad_expert_out - .Attrs(), + .Attrs(), FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( - transformer_engine::jax::EpPrepareConfig, ::xla::ffi::StructMember("top_k"), - ::xla::ffi::StructMember("dispatch_output_per_expert_alignment")); - -XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( - transformer_engine::jax::EpDispatchConfig, ::xla::ffi::StructMember("top_k"), - ::xla::ffi::StructMember("dispatch_output_per_expert_alignment")); - -XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( - transformer_engine::jax::EpCombineConfig, ::xla::ffi::StructMember("top_k"), - ::xla::ffi::StructMember("dispatch_output_per_expert_alignment"), - ::xla::ffi::StructMember("num_local_tokens")); - -XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( - transformer_engine::jax::EpDispatchBwdConfig, ::xla::ffi::StructMember("top_k"), - ::xla::ffi::StructMember("dispatch_output_per_expert_alignment"), - ::xla::ffi::StructMember("num_local_tokens")); - -XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( - transformer_engine::jax::EpCombineBwdConfig, ::xla::ffi::StructMember("top_k"), + transformer_engine::jax::EpConfig, ::xla::ffi::StructMember("top_k"), ::xla::ffi::StructMember("dispatch_output_per_expert_alignment")); #endif // NVTE_WITH_NCCL_EP From c0306183f39ac20335a7b3a3404d02ab92238eec Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 22:20:03 -0700 Subject: [PATCH 13/22] jax/ep: dedup _ep_outer_axis, normalize _ep_spec_ok, unify outer_abstract, drop dead helpers Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/ep.py | 41 ++++++--------------- transformer_engine/jax/ep.py | 20 ++++------ 2 files changed, 20 insertions(+), 41 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py index 54fec2045b..88e7c2abcd 100644 --- a/transformer_engine/jax/cpp_extensions/ep.py +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -30,7 +30,6 @@ "EpLayerConfig", "set_ep_config", "get_ep_config", - "get_ep_num_local_experts", "ep_handle_mem_size", "ep_prepare", "ep_dispatch_fwd", @@ -78,11 +77,6 @@ def get_ep_config() -> EpConfig: return _ep_config -def get_ep_num_local_experts() -> int: - """Number of experts owned by this EP rank.""" - return get_ep_config().num_local_experts - - @dataclass(frozen=True) class EpLayerConfig: """Per-layer EP config; mirrors C ``NVTEEpLayerConfig``. @@ -156,24 +150,22 @@ def _ep_output_spec(*trailing): def _ep_spec_ok(spec, trailing_count): - """Accept ``(ep, *[None])`` (no DP) or ``((dp,ep), *[None])`` / - ``(("dp",), *[None])`` / ``("dp", *[None])`` / ``(None, *[None])`` (with DP) - on an EP-output tensor's single leading dim. JAX may collapse a size-1 - mesh axis to ``None`` (matters for dp_size=1 like 1x4).""" + """Leading dim shards along ep (and outer dp/fsdp when set); trailing dims + are replicated. JAX may collapse size-1 mesh axes to ``None`` or drop them, + so the leading entry is normalized to a set of named axes before comparing. + """ gsr = global_mesh_resource() ep_axis = gsr.ep_resource outer = _ep_outer_axis() - expected_len = 1 + trailing_count - if len(spec) != expected_len: + if len(spec) != 1 + trailing_count: return False if any(ax is not None for ax in spec[1:]): return False leading = spec[0] - if outer is None: - return leading == ep_axis - allowed = {ep_axis, outer, None} elts = leading if isinstance(leading, tuple) else (leading,) - return all(a in allowed for a in elts) + actual = frozenset(a for a in elts if a is not None) + expected = {ep_axis} if outer is None else {ep_axis, outer} + return actual <= expected # ── ep_prepare ────────────────────────────────────────────────────────────── @@ -213,15 +205,9 @@ def abstract(topk_idx_aval, *, top_k, dispatch_output_per_expert_alignment, is_o return token_counts_aval, handle_mem_aval, workspace_aval @staticmethod - def outer_abstract(topk_idx_aval, *, top_k, dispatch_output_per_expert_alignment, is_outer): - del is_outer - avals = EpPreparePrimitive.abstract( - topk_idx_aval, - top_k=top_k, - dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, - is_outer=True, - ) - return avals[:2] + def outer_abstract(*args, **kwargs): + kwargs["is_outer"] = True + return EpPreparePrimitive.abstract(*args, **kwargs)[:2] # pylint: disable=missing-kwoa @staticmethod def lowering(ctx, topk_idx, *, top_k, dispatch_output_per_expert_alignment, is_outer): @@ -331,10 +317,8 @@ def abstract( @staticmethod def outer_abstract(*args, **kwargs): - kwargs = dict(kwargs) kwargs["is_outer"] = True - avals = EpDispatchPrimitive.abstract(*args, **kwargs) # pylint: disable=missing-kwoa - return avals[:2] + return EpDispatchPrimitive.abstract(*args, **kwargs)[:2] # pylint: disable=missing-kwoa @staticmethod def lowering( @@ -806,7 +790,6 @@ def abstract( @staticmethod def outer_abstract(*args, **kwargs): - kwargs = dict(kwargs) kwargs["is_outer"] = True return EpCombineBwdPrimitive.abstract(*args, **kwargs) # pylint: disable=missing-kwoa diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index ed22ad5be7..666b46f95b 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -14,6 +14,7 @@ import transformer_engine_jax import transformer_engine.jax.cpp_extensions as tex +from transformer_engine.jax.cpp_extensions.ep import _ep_outer_axis from transformer_engine.jax.cpp_extensions.misc import jax_dtype_to_te_dtype from transformer_engine.jax.sharding import ( global_mesh_resource, @@ -113,7 +114,7 @@ def ep_bootstrap( " global_shard_guard(MeshResource(..., ep_resource=)) before bootstrap." ) ep_size = get_mesh_axis_size(ep_resource) - outer_axis = gsr.dp_resource or gsr.fsdp_resource + outer_axis = _ep_outer_axis() if outer_axis is None: if world_size != ep_size: raise ValueError( @@ -138,16 +139,11 @@ def ep_bootstrap( rank_within_group = rank % ep_size is_color_root = rank_within_group == 0 if is_color_root: - try: - from nccl import get_unique_id - - uid_bytes = bytes(get_unique_id())[:UID_SIZE] - except ImportError: - libnccl = ctypes.CDLL("libnccl.so.2", use_errno=True) - uid_arr = (ctypes.c_uint8 * UID_SIZE)() - ret = libnccl.ncclGetUniqueId(ctypes.cast(uid_arr, ctypes.c_void_p)) - assert ret == 0, f"ncclGetUniqueId failed with code {ret}" - uid_bytes = bytes(uid_arr) + libnccl = ctypes.CDLL("libnccl.so.2", use_errno=True) + uid_arr = (ctypes.c_uint8 * UID_SIZE)() + ret = libnccl.ncclGetUniqueId(ctypes.cast(uid_arr, ctypes.c_void_p)) + assert ret == 0, f"ncclGetUniqueId failed with code {ret}" + uid_bytes = bytes(uid_arr) else: uid_bytes = bytes(UID_SIZE) @@ -196,7 +192,7 @@ def _default_out_partition_spec(): raise ValueError( "ep_resource is not set on the active MeshResource; pass out_sharding=... explicitly." ) - outer = gsr.dp_resource or gsr.fsdp_resource + outer = _ep_outer_axis() leading = (outer, gsr.ep_resource) if outer is not None else gsr.ep_resource return (leading,) From e41f588b0b875b68bd9530a0c0ac268da977edd0 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Jun 2026 22:28:44 -0700 Subject: [PATCH 14/22] jax/ep: apply clang-format and silence pylint unused-arg in lowering Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/ep.py | 4 ++-- transformer_engine/jax/csrc/extensions/ep.cpp | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py index 88e7c2abcd..b8a1bdc564 100644 --- a/transformer_engine/jax/cpp_extensions/ep.py +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -514,7 +514,7 @@ def lowering( out_leading_shape, out_partition_spec, ): - del out_partition_spec + del out_leading_shape, out_partition_spec return ffi.ffi_lowering(EpCombinePrimitive.name)( ctx, handle_mem, @@ -650,7 +650,7 @@ def lowering( out_leading_shape, out_partition_spec, ): - del out_partition_spec + del out_leading_shape, out_partition_spec return ffi.ffi_lowering(EpDispatchBwdPrimitive.name)( ctx, handle_mem, diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp index 9cd1422d37..ee204e7594 100644 --- a/transformer_engine/jax/csrc/extensions/ep.cpp +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -446,8 +446,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchBwdHandler, EpDispatchBwdFFI, // ── ep_combine_bwd ──────────────────────────────────────────────────────────── Error_Type EpCombineBwdFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, - Buffer_Type grad, Result_Type grad_expert_out, - EpConfig config) { + Buffer_Type grad, Result_Type grad_expert_out, EpConfig config) { (void)ep_state; auto grad_dims = grad.dimensions(); NVTE_CHECK(grad_dims.size() >= 2, From 572889cc123e43e9e139175f7a3768f4cecbad71 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 10 Jun 2026 16:54:12 -0700 Subject: [PATCH 15/22] qa: wire NCCL EP tests into L1 (multi-process unittest) and L0 (ep_moe example) jax distributed suites Signed-off-by: Phuong Nguyen --- qa/L0_jax_distributed_unittest/test.sh | 4 ++++ qa/L1_jax_distributed_unittest/test.sh | 3 +++ 2 files changed, 7 insertions(+) diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index bf4652c31a..f86cea284e 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -37,6 +37,10 @@ wait TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/collective_gemm/run_test_cgemm.sh || test_fail "run_test_cgemm.sh" wait +# NCCL EP example (ep_moe.py). Self-skips on <4 GPUs or platforms without NVLink. +TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/ep/run_test_ep.sh || test_fail "run_test_ep.sh" +wait + # MoE custom_vjp distributed suite. Runs one Python process per GPU # via tests/jax/run_multiprocess_moe_vjp.sh (mirrors the pattern in # examples/jax/encoder/run_test_multiprocessing_encoder.sh). Requires diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index 4f92d1c783..8e0ef2c267 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -37,6 +37,9 @@ XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_nccl_comm_splitting=false" python3 -m pyt python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_fused_attn.xml $TE_PATH/tests/jax/test_distributed_fused_attn.py || test_fail "test_distributed_fused_attn.py" +# NCCL EP multi-process suite. Self-skips on <4 GPUs. +TE_PATH=$TE_PATH bash $TE_PATH/tests/jax/multi_process_launch_ep.sh || test_fail "test_multi_process_ep.py" + # TODO(Phuong): add this test back after it is verified # SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh || test_fail "test_multi_process_distributed_grouped_gemm.py" From 943a970beea985de015cbe1f6fdb9ca99f01bb27 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 10 Jun 2026 17:14:44 -0700 Subject: [PATCH 16/22] jax/ep: resolve test path from launcher SCRIPT_DIR and export LD_LIBRARY_PATH for libnccl_ep.so Signed-off-by: Phuong Nguyen --- examples/jax/ep/run_test_ep.sh | 9 +++++++++ tests/jax/multi_process_launch_ep.sh | 22 +++++++++++++++++++--- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/examples/jax/ep/run_test_ep.sh b/examples/jax/ep/run_test_ep.sh index 55b958f146..1305ca6fd1 100755 --- a/examples/jax/ep/run_test_ep.sh +++ b/examples/jax/ep/run_test_ep.sh @@ -32,6 +32,15 @@ export PYTHONPATH="${TE_PATH}${PYTHONPATH:+:${PYTHONPATH}}" COORD="${COORD:-127.0.0.1:12345}" TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-300}" +# Editable installs don't embed rpath; libtransformer_engine.so needs +# libnccl_ep.so.0 from the TE editable location at dlopen time. +TE_LIB_PATH=$(pip3 show transformer-engine 2>/dev/null \ + | grep -E "Location:|Editable project location:" \ + | tail -n 1 | awk '{print $NF}') +if [ -n "$TE_LIB_PATH" ]; then + export LD_LIBRARY_PATH="${TE_LIB_PATH}${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" +fi + XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_graph_min_graph_size=1" export XLA_FLAGS="${XLA_BASE_FLAGS}" diff --git a/tests/jax/multi_process_launch_ep.sh b/tests/jax/multi_process_launch_ep.sh index a37ffc2952..d32ce5f5d3 100755 --- a/tests/jax/multi_process_launch_ep.sh +++ b/tests/jax/multi_process_launch_ep.sh @@ -17,6 +17,15 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" +# Editable installs don't embed rpath; libtransformer_engine.so needs +# libnccl_ep.so.0 from the TE editable location at dlopen time. +TE_LIB_PATH=$(pip3 show transformer-engine 2>/dev/null \ + | grep -E "Location:|Editable project location:" \ + | tail -n 1 | awk '{print $NF}') +if [ -n "$TE_LIB_PATH" ]; then + export LD_LIBRARY_PATH="${TE_LIB_PATH}${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" +fi + NUM_RUNS=$(nvidia-smi -L | wc -l) if [ "${NUM_RUNS}" -lt 4 ]; then @@ -29,15 +38,22 @@ NUM_RUNS="${NVTE_TEST_EP_NUM_RANKS:-4}" OVERALL_RET=0 for SCRIPT_NAME in $SCRIPT_NAMES; do - echo "=== Running ${SCRIPT_NAME} ===" + # Allow callers to pass either a bare test name (resolved against this + # script's directory) or an absolute/relative path. + if [ -f "$SCRIPT_NAME" ]; then + SCRIPT_PATH="$SCRIPT_NAME" + else + SCRIPT_PATH="${SCRIPT_DIR}/${SCRIPT_NAME}" + fi + echo "=== Running ${SCRIPT_PATH} ===" for ((i=1; i stdout_rank_${i}.txt 2>&1 & + python "$SCRIPT_PATH" 127.0.0.1:12345 $i $NUM_RUNS > stdout_rank_${i}.txt 2>&1 & done timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ - python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS 2>&1 | tee stdout_multi_process.txt + python "$SCRIPT_PATH" 127.0.0.1:12345 0 $NUM_RUNS 2>&1 | tee stdout_multi_process.txt wait From 8818438bd13b09ef23ceafcb38124e538798b1e3 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 23 Jun 2026 05:56:53 -0700 Subject: [PATCH 17/22] [JAX] Generalize EP MoE example to N stacked layers via --num-layers Signed-off-by: Phuong Nguyen --- examples/jax/ep/ep_moe.py | 235 +++++++++++++++++++++++--------------- 1 file changed, 142 insertions(+), 93 deletions(-) diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py index a23a0b33c9..88bb844d48 100644 --- a/examples/jax/ep/ep_moe.py +++ b/examples/jax/ep/ep_moe.py @@ -30,6 +30,12 @@ def _parse_args(): p.add_argument("--top-k", type=int, default=2) p.add_argument("--hidden", type=int, default=32) p.add_argument("--hidden-out", type=int, default=32) + p.add_argument( + "--num-layers", + type=int, + default=1, + help="Number of stacked MoE layers. >1 requires --hidden-out == --hidden.", + ) p.add_argument( "--num-experts", type=int, @@ -82,6 +88,11 @@ def _build_mesh_and_resource(args): assert args.num_experts % args.ep_size == 0 args.num_local_experts = args.num_experts // args.ep_size args.recv_capacity_per_rank = args.ep_size * args.num_tokens * args.top_k + if args.num_layers > 1 and args.hidden_out != args.hidden: + raise ValueError( + f"--num-layers > 1 needs square layers: --hidden-out ({args.hidden_out})" + f" must equal --hidden ({args.hidden})" + ) devs = np.asarray(jax.devices()).reshape(args.dp_size, args.ep_size) mesh = Mesh(devs, ("dp", "ep")) @@ -89,12 +100,16 @@ def _build_mesh_and_resource(args): return mesh, mr -def _make_routing(dp_color, num_tokens, top_k, num_experts, num_local_experts): - """Deterministic routing: topk_idx[t, k] = (dp_color*NLE + t*K + k) % E.""" +def _make_routing(dp_color, num_tokens, top_k, num_experts, num_local_experts, offset=0): + """Deterministic routing: topk_idx[t, k] = (dp_color*NLE + t*K + k + offset) % E. + + ``offset`` (the layer index) makes stacked layers route differently.""" topk_idx = np.empty((num_tokens, top_k), dtype=np.int32) for t in range(num_tokens): for k in range(top_k): - topk_idx[t, k] = (dp_color * num_local_experts + t * top_k + k) % num_experts + topk_idx[t, k] = ( + dp_color * num_local_experts + t * top_k + k + offset + ) % num_experts return topk_idx @@ -102,20 +117,24 @@ def _make_inputs(args): """Build 3D ``[B, S, H]`` arrays sharded ``(("dp","ep"), None, None)``. B = num_processes (sharded across the compound (dp,ep) axis so each rank - holds one slot); S = args.num_tokens. Global numpy views (rank-0 - reference) are kept 2D for the legacy reference implementation. + holds one slot); S = args.num_tokens. Routing and kernels are produced + per layer (layer ``i`` routes with offset ``i``); the last layer maps + ``H -> H_out``, all others ``H -> H``. Global numpy views (rank-0 + reference) are kept 2D for the reference implementation. """ T, K, H, H_out = args.num_tokens, args.top_k, args.hidden, args.hidden_out E = args.num_experts + L = args.num_layers dp_size = args.dp_size ep_size = args.ep_size num_procs = args.num_processes dp_color = args.process_id // ep_size + NLE = args.num_local_experts rng_dp = np.random.default_rng(seed=42 + dp_color) tokens_np = (rng_dp.standard_normal((T, H), dtype=np.float32) * 0.5).astype(np.float32) - topk_idx_np = _make_routing(dp_color, T, K, E, args.num_local_experts) w_np = np.full((T, K), 1.0 / K, dtype=np.float32) + idx_np_list = [_make_routing(dp_color, T, K, E, NLE, offset=i) for i in range(L)] tokens_global_np = np.concatenate( [ @@ -126,16 +145,23 @@ def _make_inputs(args): ], axis=0, ) - topk_idx_global_np = np.concatenate( - [_make_routing(c, T, K, E, args.num_local_experts) for c in range(dp_size)], axis=0 - ) + idx_global_np_list = [ + np.concatenate([_make_routing(c, T, K, E, NLE, offset=i) for c in range(dp_size)], axis=0) + for i in range(L) + ] w_global_np = np.full((dp_size * T, K), 1.0 / K, dtype=np.float32) - # Same seed on every rank → identical kernel array everywhere. + # Same seed on every rank → identical kernels everywhere. One set per layer; + # only the last layer carries H_out (all others are square H -> H). rng = np.random.default_rng(seed=42) - kernels_np = (rng.standard_normal((E, H, H_out), dtype=np.float32) * (1.0 / np.sqrt(H))).astype( - np.float32 - ) + kernels_np_list = [] + for i in range(L): + out_dim = H_out if i == L - 1 else H + kernels_np_list.append( + (rng.standard_normal((E, H, out_dim), dtype=np.float32) * (1.0 / np.sqrt(H))).astype( + np.float32 + ) + ) # Each rank contributes one [1, T, ...] slab; the global shape is # [num_procs, T, ...] sharded on the first dim across (dp, ep). @@ -144,20 +170,21 @@ def _make_inputs(args): tokens = jax.make_array_from_process_local_data( dpep_spec, tokens_np[None, :, :].astype(np.float32), (num_procs, T, H) ).astype(jnp.bfloat16) - topk_idx = jax.make_array_from_process_local_data( - dpep_spec, topk_idx_np[None, :, :], (num_procs, T, K) - ) + topk_idx_list = [ + jax.make_array_from_process_local_data(dpep_spec, idx_np[None, :, :], (num_procs, T, K)) + for idx_np in idx_np_list + ] topk_w = jax.make_array_from_process_local_data(dpep_spec, w_np[None, :, :], (num_procs, T, K)) - kernels = jnp.asarray(kernels_np, dtype=jnp.bfloat16) + kernels_list = [jnp.asarray(k, dtype=jnp.bfloat16) for k in kernels_np_list] return ( tokens_global_np, - topk_idx_global_np, + idx_global_np_list, w_global_np, - kernels_np, + kernels_np_list, tokens, - topk_idx, + topk_idx_list, topk_w, - kernels, + kernels_list, ) @@ -187,87 +214,107 @@ def _batched_expert_linear(recv_tokens, kernels, num_local_experts, dp_size, ep_ return out.reshape(num_procs, recv_pr, H_out) -def _moe_step(args, topk_idx, tokens, topk_w, kernels): - """Jit'd MoE step: dispatch -> batched per-expert linear -> combine. - - Inputs are 3D ``[B, S, H]`` with the first dim compound-sharded across - ``("dp","ep")``. Combine returns the same 3D shape. - """ - B = args.num_processes - S = args.num_tokens +def _moe_layer(args, cfg, mesh, topk_idx, tokens, topk_w, local_kernels): + """One MoE layer: dispatch -> batched per-expert linear -> combine. ``cfg`` + is layer-private, so each layer gets its own handle_mem. Returns the same + 3D ``[B, S, H_out]`` layout as ``tokens``.""" + B, S = args.num_processes, args.num_tokens NLE = args.num_local_experts dp_size, ep_size = args.dp_size, args.ep_size - mesh = args.mesh in_spec = PartitionSpec(("dp", "ep"), None, None) # [B, S, ...] ep3 = PartitionSpec(("dp", "ep"), None, None) # [num_procs, recv_pr, H] ep2 = PartitionSpec(("dp", "ep"), None) # [num_procs, recv_pr] # Kernels are EP-replicated across dp colors; shard only the ep-rank axis. kernel_spec = PartitionSpec("ep", None, None, None) - kernels = kernels.reshape(ep_size, NLE, *kernels.shape[1:]) - layer_cfg = EpLayerConfig(top_k=args.top_k, dispatch_output_per_expert_alignment=16) + topk_idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(mesh, in_spec)) + tokens = jax.lax.with_sharding_constraint(tokens, NamedSharding(mesh, in_spec)) + topk_w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(mesh, in_spec)) + local_kernels = jax.lax.with_sharding_constraint(local_kernels, NamedSharding(mesh, kernel_spec)) + recv_tokens, recv_topk_w, handle_mem, _tc = ep_dispatch( + cfg, topk_idx, tokens, topk_w, args.recv_capacity_per_rank + ) + recv_tokens = jax.lax.with_sharding_constraint(recv_tokens, NamedSharding(mesh, ep3)) + recv_topk_w = jax.lax.with_sharding_constraint(recv_topk_w, NamedSharding(mesh, ep2)) + expert_out = _batched_expert_linear(recv_tokens, local_kernels, NLE, dp_size, ep_size) + expert_out = jax.lax.with_sharding_constraint(expert_out, NamedSharding(mesh, ep3)) + # ep_combine is unweighted: pre-multiply by recv_topk_w and zero + # padded slots (recv_topk_w == 0) before the scatter-sum. + mask = (recv_topk_w != 0).astype(jnp.float32)[..., None] + weighted = (expert_out.astype(jnp.float32) * recv_topk_w[..., None] * mask).astype( + expert_out.dtype + ) + weighted = jax.lax.with_sharding_constraint(weighted, NamedSharding(mesh, ep3)) + return ep_combine( + cfg, + handle_mem, + _tc, + weighted, + num_local_tokens=(B, S), + out_sharding=(("dp", "ep"), None, None), + ) + + +def _moe_step(args, topk_idx_list, tokens, topk_w, kernels_list): + """Jit'd MoE: ``num_layers`` stacked MoE layers, each with its own + ``EpLayerConfig`` (hence its own handle_mem). Layer output feeds the next + layer's dispatch. Inputs are 3D ``[B, S, H]`` compound-sharded across + ``("dp","ep")``; the result keeps that 3D layout. + """ + NLE = args.num_local_experts + ep_size = args.ep_size + mesh = args.mesh + in_spec = PartitionSpec(("dp", "ep"), None, None) + + kernels_list = [k.reshape(ep_size, NLE, *k.shape[1:]) for k in kernels_list] + cfgs = [ + EpLayerConfig(top_k=args.top_k, dispatch_output_per_expert_alignment=16) + for _ in range(args.num_layers) + ] @jax.jit - def step(topk_idx, tokens, topk_w, local_kernels): - topk_idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(mesh, in_spec)) - tokens = jax.lax.with_sharding_constraint(tokens, NamedSharding(mesh, in_spec)) - topk_w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(mesh, in_spec)) - local_kernels = jax.lax.with_sharding_constraint( - local_kernels, NamedSharding(mesh, kernel_spec) - ) - recv_tokens, recv_topk_w, handle_mem, _tc = ep_dispatch( - layer_cfg, topk_idx, tokens, topk_w, args.recv_capacity_per_rank - ) - recv_tokens = jax.lax.with_sharding_constraint(recv_tokens, NamedSharding(mesh, ep3)) - recv_topk_w = jax.lax.with_sharding_constraint(recv_topk_w, NamedSharding(mesh, ep2)) - expert_out = _batched_expert_linear(recv_tokens, local_kernels, NLE, dp_size, ep_size) - expert_out = jax.lax.with_sharding_constraint(expert_out, NamedSharding(mesh, ep3)) - # ep_combine is unweighted: pre-multiply by recv_topk_w and zero - # padded slots (recv_topk_w == 0) before the scatter-sum. - mask = (recv_topk_w != 0).astype(jnp.float32)[..., None] - weighted = (expert_out.astype(jnp.float32) * recv_topk_w[..., None] * mask).astype( - expert_out.dtype - ) - weighted = jax.lax.with_sharding_constraint(weighted, NamedSharding(mesh, ep3)) - return ep_combine( - layer_cfg, - handle_mem, - _tc, - weighted, - num_local_tokens=(B, S), - out_sharding=(("dp", "ep"), None, None), - ) + def step(topk_idx_list, tokens, topk_w, kernels_list): + h = tokens + for cfg, idx, kern in zip(cfgs, topk_idx_list, kernels_list): + h = _moe_layer(args, cfg, mesh, idx, h, topk_w, kern) + h = jax.lax.with_sharding_constraint(h, NamedSharding(mesh, in_spec)) + return h - return step(topk_idx, tokens, topk_w, kernels) + return step(topk_idx_list, tokens, topk_w, kernels_list) # ── Reference (numerical check) ───────────────────────────────────────────── -def _reference_moe(tokens, topk_idx, topk_w, kernels): - """Single-rank dense MoE reference. tokens [T, H], output [T, H_out].""" +def _token_mixing(topk_idx, topk_w, kernels): + """Per-token effective matrix ``M[t] = sum_k w[t,k] * kernels[idx[t,k]]``. + A MoE layer is per-token linear: ``out[t] = in[t] @ M[t]``.""" T, K = topk_idx.shape - H_out = kernels.shape[-1] - out = np.zeros((T, H_out), dtype=np.float32) + H_in, H_out = kernels.shape[1], kernels.shape[2] + M = np.zeros((T, H_in, H_out), dtype=np.float32) for t in range(T): - tok = tokens[t].astype(np.float32) for k in range(K): - e = int(topk_idx[t, k]) - out[t] += float(topk_w[t, k]) * (tok @ kernels[e].astype(np.float32)) - return out + M[t] += float(topk_w[t, k]) * kernels[int(topk_idx[t, k])].astype(np.float32) + return M -def _reference_grad(tokens, topk_idx, topk_w, kernels): - """d/dtokens of 0.5 * sum(ref_out**2) — used by --check to validate bwd.""" - T, K = topk_idx.shape - H = tokens.shape[-1] - ref_out = _reference_moe(tokens, topk_idx, topk_w, kernels) - grad = np.zeros((T, H), dtype=np.float32) +def _reference_grad(tokens, topk_idx_list, topk_w, kernels_list): + """Single-rank dense reference for ``num_layers`` stacked MoE layers. + + Each layer is per-token linear, so ``out[t] = x[t] @ C[t]`` with + ``C[t] = M_1[t] @ ... @ M_L[t]``. Returns ``(ref_out, grad)`` where ``grad`` + is ``d/dtokens of 0.5 * sum(ref_out**2)`` = ``out[t] @ C[t].T``.""" + T = tokens.shape[0] + Ms = [_token_mixing(idx, topk_w, k) for idx, k in zip(topk_idx_list, kernels_list)] + ref_out = np.zeros((T, Ms[-1].shape[-1]), dtype=np.float32) + grad = np.zeros_like(tokens, dtype=np.float32) for t in range(T): - mixed = np.zeros_like(kernels[0]) - for k in range(K): - mixed = mixed + float(topk_w[t, k]) * kernels[int(topk_idx[t, k])] - grad[t] = ref_out[t] @ mixed.T + C = Ms[0][t] + for M in Ms[1:]: + C = C @ M[t] + y = tokens[t].astype(np.float32) @ C + ref_out[t] = y + grad[t] = y @ C.T return ref_out, grad @@ -300,31 +347,31 @@ def main(): ( tokens_global_np, - topk_idx_global_np, + topk_idx_global_np_list, w_global_np, - kernels_np, + kernels_np_list, tokens, - topk_idx, + topk_idx_list, topk_w, - kernels, + kernels_list, ) = _make_inputs(args) - def loss_fn(toks, idx, w, kern): - out = _moe_step(args, idx, toks, w, kern) + def loss_fn(toks, idx_list, w, kern_list): + out = _moe_step(args, idx_list, toks, w, kern_list) return 0.5 * (out.astype(jnp.float32) ** 2).sum(), out step_jit = jax.jit(jax.value_and_grad(loss_fn, has_aux=True)) # Same jit + same inputs each iter: handle_mem cache must give identical loss/grad. for it in range(args.iters): - (loss, out_fwd), grad_tokens = step_jit(tokens, topk_idx, topk_w, kernels) + (loss, out_fwd), grad_tokens = step_jit(tokens, topk_idx_list, topk_w, kernels_list) grad_tokens.block_until_ready() out_fwd.block_until_ready() if args.process_id == 0: print( f"[ep_moe] iter={it} loss={float(loss):.4f}" f" grad_tokens.shape={grad_tokens.shape}" - f" dp={args.dp_size} ep={args.ep_size}" + f" layers={args.num_layers} dp={args.dp_size} ep={args.ep_size}" f" num_experts={args.num_experts} recv_pr={args.recv_capacity_per_rank}" ) @@ -356,7 +403,7 @@ def _norm(spec, ndim): grad_global.block_until_ready() ref_out, ref_grad = _reference_grad( - tokens_global_np, topk_idx_global_np, w_global_np, kernels_np + tokens_global_np, topk_idx_global_np_list, w_global_np, kernels_np_list ) # 3D global ``[num_procs, S, H]`` with num_procs = dp * ep. Each EP # column in a DP color sees identical inputs (and produces identical @@ -373,18 +420,20 @@ def _norm(spec, ndim): .reshape(dp_size, ep_size, -1, ref_grad.shape[-1])[:, 0] .reshape(-1, ref_grad.shape[-1]) ) + # bf16 error compounds across stacked layers; relax tol slightly for >1 layer. + tol = 5e-2 if args.num_layers == 1 else 6e-2 np.testing.assert_allclose( global_out, ref_out, - rtol=5e-2, - atol=5e-2, + rtol=tol, + atol=tol, err_msg=f"rank {args.process_id}: fwd mismatch", ) np.testing.assert_allclose( global_grad, ref_grad, - rtol=5e-2, - atol=5e-2, + rtol=tol, + atol=tol, err_msg=f"rank {args.process_id}: bwd mismatch", ) if args.process_id == 0: From a3378db7ea705b3a1146f3b5d325dcd1e53c8acd Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 23 Jun 2026 08:35:11 -0700 Subject: [PATCH 18/22] Unify NCCL EP build flag: rename to NVTE_WITH_NCCL_EP, share arch check via nccl_ep_enabled() Signed-off-by: Phuong Nguyen --- build_tools/jax.py | 4 ++-- build_tools/utils.py | 31 +++++++++++++++++++++++++++++++ setup.py | 20 ++------------------ 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index 4ad0442055..1d2578f8f8 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -10,7 +10,7 @@ import setuptools -from .utils import get_cuda_include_dirs, all_files_in_dir, debug_build_enabled, setup_mpi_flags +from .utils import get_cuda_include_dirs, all_files_in_dir, debug_build_enabled, setup_mpi_flags, nccl_ep_enabled from typing import List @@ -120,7 +120,7 @@ def setup_jax_extension( if bool(int(os.getenv("NVTE_WITH_CUBLASMP", 0))): cxx_flags.append("-DNVTE_WITH_CUBLASMP") - if bool(int(os.getenv("NVTE_BUILD_WITH_NCCL_EP", "1"))): + if nccl_ep_enabled(): cxx_flags.append("-DNVTE_WITH_NCCL_EP") # Define TE/JAX as a Pybind11Extension diff --git a/build_tools/utils.py b/build_tools/utils.py index f2548b4de6..82a1dd968c 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -265,6 +265,37 @@ def cuda_archs() -> str: return archs +def nccl_ep_enabled(archs: str = None) -> bool: + """Return True when NCCL EP should be compiled into this build. + + Reads NVTE_WITH_NCCL_EP (default on). Auto-skips with a printed warning + when no arch >= 90 is targeted; raises RuntimeError if the flag was + explicitly set to 1 but no qualifying arch is present. Mirrors the same + logic in both TE/Common (setup.py) and TE/JAX (build_tools/jax.py) so a + single env var controls both sides consistently. + """ + if archs is None: + archs = cuda_archs() + nccl_ep_env = os.getenv("NVTE_WITH_NCCL_EP") + nccl_ep_explicit = nccl_ep_env is not None + build_ep = bool(int(nccl_ep_env if nccl_ep_explicit else "1")) + if build_ep: + arch_tokens = [a.strip() for a in str(archs or "").split(";") if a.strip()] + has_hopper_or_newer = any( + t.lower() == "native" or (t.rstrip("af").isdigit() and int(t.rstrip("af")) >= 90) + for t in arch_tokens + ) + if not has_hopper_or_newer: + if nccl_ep_explicit: + raise RuntimeError( + f"NVTE_WITH_NCCL_EP=1 was set but NVTE_CUDA_ARCHS ('{archs}') " + "contains no arch >= 90. NCCL EP requires Hopper or newer." + ) + print(f"[NCCL EP] No arch >= 90 in NVTE_CUDA_ARCHS ('{archs}'); skipping build.") + build_ep = False + return build_ep + + def cuda_version() -> Tuple[int, ...]: """CUDA Toolkit version as a (major, minor) tuple. diff --git a/setup.py b/setup.py index 64ed120268..2ff305867a 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ get_frameworks, remove_dups, min_python_version_str, + nccl_ep_enabled, ) frameworks = get_frameworks() @@ -86,24 +87,7 @@ def setup_common_extension() -> CMakeExtension: # NCCL EP (Hopper+): on by default; auto-skipped when no arch >= 90 is # targeted. Set NVTE_WITH_NCCL_EP=0 to force off. - nccl_ep_env = os.getenv("NVTE_WITH_NCCL_EP") - nccl_ep_explicit = nccl_ep_env is not None - build_with_nccl_ep = bool(int(nccl_ep_env if nccl_ep_explicit else "1")) - if build_with_nccl_ep: - arch_tokens = [a.strip() for a in str(archs or "").split(";") if a.strip()] - has_hopper_or_newer = any( - t.lower() == "native" or (t.rstrip("af").isdigit() and int(t.rstrip("af")) >= 90) - for t in arch_tokens - ) - if not has_hopper_or_newer: - if nccl_ep_explicit: - raise RuntimeError( - f"NVTE_WITH_NCCL_EP=1 was set but NVTE_CUDA_ARCHS ('{archs}') " - "contains no arch >= 90. NCCL EP requires Hopper or newer." - ) - print(f"[NCCL EP] No arch >= 90 in NVTE_CUDA_ARCHS ('{archs}'); skipping build.") - build_with_nccl_ep = False - if build_with_nccl_ep: + if nccl_ep_enabled(archs): nccl_home = build_nccl_ep_submodule() cmake_flags.append(f"-DNCCL_INCLUDE_DIR={nccl_home}/include") else: From 06440a6f7aface1850bfdbb68dd0673f6e599c56 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 23 Jun 2026 08:42:52 -0700 Subject: [PATCH 19/22] Add arg descriptions to ep_bootstrap docstring Signed-off-by: Phuong Nguyen --- transformer_engine/jax/ep.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py index 666b46f95b..9704e51546 100644 --- a/transformer_engine/jax/ep.py +++ b/transformer_engine/jax/ep.py @@ -85,9 +85,16 @@ def ep_bootstrap( num_ep_groups are read from the mesh axes named by MeshResource.ep_resource and MeshResource.dp_resource/fsdp_resource. - max_token_dtype is the widest jnp dtype the group will dispatch; tensors - passed to ep_dispatch may use any narrower dtype. - max_num_sms caps the SMs allotted to EP kernels (0 = auto). + Args: + world_size: Total number of processes (dp_size * ep_size). + rank: Global rank of the calling process. + num_experts: Total experts across the EP group. + max_tokens_per_rank: Max tokens one rank dispatches per step (sizes send buffers). + recv_capacity_per_rank: Max tokens one rank receives per step; set to + at least ep_size * max_tokens_per_rank * top_k to avoid drops. + hidden_dim: Feature dimension of token tensors passed to ep_dispatch. + max_token_dtype: Widest dtype the group will dispatch (only bfloat16 supported). + max_num_sms: SM budget for EP kernels; 0 = auto. """ if jnp.dtype(max_token_dtype) != jnp.bfloat16: raise NotImplementedError( From ebf5d96fbf00eb3bbe4df9f116182301b6488ebd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Jun 2026 16:41:37 +0000 Subject: [PATCH 20/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- build_tools/jax.py | 8 +++++++- examples/jax/ep/ep_moe.py | 8 ++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index 1d2578f8f8..654f8bfc17 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -10,7 +10,13 @@ import setuptools -from .utils import get_cuda_include_dirs, all_files_in_dir, debug_build_enabled, setup_mpi_flags, nccl_ep_enabled +from .utils import ( + get_cuda_include_dirs, + all_files_in_dir, + debug_build_enabled, + setup_mpi_flags, + nccl_ep_enabled, +) from typing import List diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py index 88bb844d48..671150d655 100644 --- a/examples/jax/ep/ep_moe.py +++ b/examples/jax/ep/ep_moe.py @@ -107,9 +107,7 @@ def _make_routing(dp_color, num_tokens, top_k, num_experts, num_local_experts, o topk_idx = np.empty((num_tokens, top_k), dtype=np.int32) for t in range(num_tokens): for k in range(top_k): - topk_idx[t, k] = ( - dp_color * num_local_experts + t * top_k + k + offset - ) % num_experts + topk_idx[t, k] = (dp_color * num_local_experts + t * top_k + k + offset) % num_experts return topk_idx @@ -230,7 +228,9 @@ def _moe_layer(args, cfg, mesh, topk_idx, tokens, topk_w, local_kernels): topk_idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(mesh, in_spec)) tokens = jax.lax.with_sharding_constraint(tokens, NamedSharding(mesh, in_spec)) topk_w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(mesh, in_spec)) - local_kernels = jax.lax.with_sharding_constraint(local_kernels, NamedSharding(mesh, kernel_spec)) + local_kernels = jax.lax.with_sharding_constraint( + local_kernels, NamedSharding(mesh, kernel_spec) + ) recv_tokens, recv_topk_w, handle_mem, _tc = ep_dispatch( cfg, topk_idx, tokens, topk_w, args.recv_capacity_per_rank ) From 9df769a7a0b525e3c5e44accbd92be9e40b6c8ae Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 24 Jun 2026 00:21:59 -0700 Subject: [PATCH 21/22] L0_jax_unittest: exclude multi-process EP test on single-node runner Signed-off-by: Phuong Nguyen --- qa/L0_jax_unittest/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 9c020ddd33..5d833792cc 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -28,7 +28,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax --ignore=$TE_PATH/tests/jax/test_multi_process_ep.py -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_score_mod.xml $TE_PATH/tests/jax/test_fused_attn_score_mod.py || test_fail "tests/jax/test_fused_attn_score_mod.py" NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_with_determinism.xml $TE_PATH/tests/jax/test_fused_attn.py -k "TestFusedAttnWithDeterminism" || test_fail "tests/jax/test_fused_attn.py" From db8d43b2f42592b7275949908028777df68dfd81 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 25 Jun 2026 06:44:01 -0700 Subject: [PATCH 22/22] Guard against None out_partition_spec in EP combine/dispatch-bwd partition methods Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/ep.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py index b8a1bdc564..2a4e17991a 100644 --- a/transformer_engine/jax/cpp_extensions/ep.py +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -572,9 +572,13 @@ def partition( " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)" f" over [num_procs, recv_pr, H]; got spec={eo_spec}." ) - per_shard_leading = _leading_per_shard(out_leading_shape, out_partition_spec[0], mesh) + if out_partition_spec is not None: + per_shard_leading = _leading_per_shard(out_leading_shape, out_partition_spec[0], mesh) + out_sharding = NamedSharding(mesh, PartitionSpec(*out_partition_spec)) + else: + per_shard_leading = out_leading_shape + out_sharding = NamedSharding(mesh, PartitionSpec()) arg_shardings = tuple(a.sharding for a in arg_infos) - out_sharding = NamedSharding(mesh, PartitionSpec(*out_partition_spec)) def sharded_impl(handle_mem, expert_out): return EpCombinePrimitive.impl( @@ -723,9 +727,13 @@ def partition( "EpDispatchBwd: grad and g_recv_topk_weights must share the leading" f" axis; got grad={g_spec}, g_recv_topk_weights={gw_spec}." ) - per_shard_leading = _leading_per_shard(out_leading_shape, out_partition_spec[0], mesh) + if out_partition_spec is not None: + per_shard_leading = _leading_per_shard(out_leading_shape, out_partition_spec[0], mesh) + out_sharding = NamedSharding(mesh, PartitionSpec(*out_partition_spec)) + else: + per_shard_leading = out_leading_shape + out_sharding = NamedSharding(mesh, PartitionSpec()) arg_shardings = tuple(a.sharding for a in arg_infos) - out_sharding = NamedSharding(mesh, PartitionSpec(*out_partition_spec)) out_shardings = [out_sharding, out_sharding] def sharded_impl(handle_mem, grad, g_recv_topk_weights):