diff --git a/docs/envvars.rst b/docs/envvars.rst
index 044a7f6a0d..e8f90a5412 100644
--- a/docs/envvars.rst
+++ b/docs/envvars.rst
@@ -122,6 +122,19 @@ These environment variables control the behavior of Transformer Engine during ex
Attention Backend Selection
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Transformer Engine attention selects a backend in two stages. First, it filters the available
+backends by environment variables, GPU architecture, installed ``flash-attn`` and cuDNN versions,
+data type and FP8 recipe, training or inference mode, and the provided attention configuration.
+Then it applies a performance-based preference order among the remaining eligible backends.
+
+In PyTorch, the broad preference order is ``FlashAttention > FusedAttention >
+UnfusedDotProductAttention`` on supported pre-Hopper GPUs such as Ampere/Ada, and
+``FusedAttention > FlashAttention > UnfusedDotProductAttention`` on Hopper and newer GPUs,
+including Blackwell. In JAX, Transformer Engine uses cuDNN fused attention when
+``NVTE_FUSED_ATTN=1`` and an eligible cuDNN kernel is available; otherwise it falls back to the
+JAX-native implementation. See :doc:`examples/attention/attention` for a longer
+backend-selection overview.
+
.. envvar:: NVTE_FLASH_ATTN
:Type: ``int`` (0 or 1)
@@ -144,7 +157,7 @@ Attention Backend Selection
:Type: ``int`` (1 or 2)
:Default: Auto-selected
- :Description: Force a specific FusedAttention backend. ``1`` = F16_arbitrary_seqlen (cuDNN, any seq len), ``2`` = FP8 backend. If not set, the backend is automatically selected based on the input configuration.
+ :Description: Request a cuDNN FusedAttention backend when that request is supported by the active fused-attention path. ``1`` = F16_arbitrary_seqlen (cuDNN, any seq len), ``2`` = FP8 backend. If not set, the backend is automatically selected based on the input configuration. BF16/FP16 attention uses sub-backend ``1`` when eligible. FP8 attention uses sub-backend ``2`` when FP8 DPA is enabled and supported by the architecture, cuDNN version, and input configuration.
.. envvar:: NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT
diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb
index e7253415d2..c1c8ff38bf 100644
--- a/docs/examples/attention/attention.ipynb
+++ b/docs/examples/attention/attention.ipynb
@@ -110,14 +110,6 @@
"
Additional info | \n",
" \n",
" \n",
- " | 0 | \n",
- " Non-Flash | \n",
- " BF16/FP16 | \n",
- " ≤512 | \n",
- " sm80, 90 | \n",
- " [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-attention-fprop) | \n",
- "
\n",
- " \n",
" | 1 | \n",
" Flash | \n",
" BF16/FP16 | \n",
@@ -208,11 +200,11 @@
"source": [
"## 2. Backend Selection\n",
"\n",
- "Given the various attention backends, Transformer Engine has a selection logic in place to choose the most appropriate backend for a particular set of user inputs and runtime environment. The selection logic is based on both backend availability and backend performance.\n",
+ "Given the various attention backends, Transformer Engine first determines which backends are eligible for the provided inputs and runtime environment, then applies a preference order among the eligible backends. Eligibility is affected by user environment variables, GPU architecture, installed `flash-attn` and cuDNN versions, data type and FP8 recipe, QKV layout, training or inference mode, dropout, and other attention features.\n",
"\n",
- "Backend availability is determined by factors such as model configuration, training hyper-parameters, software versions, and the GPU architecture in question. For example, some considerations are the sequence length, number of attention heads, head size, attention mask type, attention bias type, training or inference mode, self or cross attention, MHA or MQA/GQA, `flash-attn`/cuDNN library versions, and the compute capability of the GPU.\n",
+ "In PyTorch, the candidates are FlashAttention (`flash-attn` v2, v3, or v4), FusedAttention (cuDNN sub-backends), and UnfusedDotProductAttention. Users can disable whole backend families with `NVTE_FLASH_ATTN`, `NVTE_FUSED_ATTN`, or `NVTE_UNFUSED_ATTN`. In JAX, Transformer Engine checks whether a cuDNN fused-attention kernel is available when `NVTE_FUSED_ATTN=1`; otherwise it falls back to the JAX-native implementation.\n",
"\n",
- "When there are multiple backends available, Transformer Engine makes backend selection based on performance. In general, there are a few rules being followed in our selection logic (see table below). As we monitor the performance of different backends, the selection logic may change.\n",
+ "At a high level, the architecture-specific PyTorch selection order is:\n",
"\n",
"\n",
" \n",
@@ -220,22 +212,29 @@
" | Selection Order | \n",
"
\n",
" \n",
- " | PyTorch | \n",
- " sm90: cuDNN attention > flash-attention > PyTorch-native attention | \n",
+ " PyTorch | \n",
+ " sm8x (Ampere/Ada): flash-attention > cuDNN attention > PyTorch-native attention | \n",
"
\n",
" \n",
- " | sm80: flash-attention > cuDNN attention > PyTorch-native attention | \n",
+ " sm90 (Hopper): cuDNN attention > flash-attention > PyTorch-native attention | \n",
"
\n",
" \n",
- " | \n",
- " cuDNN attention: sub-backend 1 > sub-backend 0\n",
- " | \n",
+ " sm100/sm120 (Blackwell): cuDNN attention > flash-attention > PyTorch-native attention | \n",
+ "
\n",
+ " \n",
+ " | cuDNN attention: BF16/FP16 uses sub-backend 1 when eligible; FP8 uses sub-backend 2 when enabled and eligible | \n",
"
\n",
" \n",
" | JAX | \n",
" cuDNN attention > JAX-native attention | \n",
"
\n",
- "
"
+ "\n",
+ "\n",
+ "Within FlashAttention, TE uses the installed implementation that is supported for the architecture and input. FlashAttention 3 is Hopper-only (`sm90`). FlashAttention 4 supports `sm80`, `sm90`, `sm100`, and `sm120`; on Hopper, TE prefers FlashAttention 3 over FlashAttention 4 when both are installed and eligible. On Blackwell, FlashAttention 4 is the Blackwell-specific flash-attention path when installed and eligible, while FlashAttention 2 can still be eligible depending on the installed version and input configuration.\n",
+ "\n",
+ "Within cuDNN FusedAttention, TE asks the fused-attention helper which sub-backend is eligible. Sub-backend 1 is the BF16/FP16 flash-based path when available; sub-backend 2 is the FP8 path when FP8 DPA is enabled and the architecture, cuDNN version, and input configuration support it. Hopper supports eligible FP8 DPA through cuDNN sub-backend 2. In the current PyTorch selector, eligible FP8 DPA on Blackwell is an `sm100` path and is disabled on `sm120`.\n",
+ "\n",
+ "When all optimized backends are disabled or ineligible, TE falls back to UnfusedDotProductAttention if it is enabled. If no backend is eligible, backend selection returns no backend and the caller raises an error. As we monitor the performance of different backends, the selection logic may change."
]
},
{
@@ -350,7 +349,7 @@
"**cuDNN attention sub-backends:**\n",
"This environment variable allows users to express their preference of cuDNN attention sub-backends. However, the elected sub-backend will only be used *if* it is eligible, i.e. if it has support for the provided inputs and runtime environment.\n",
"```\n",
- "NVTE_FUSED_ATTN_BACKEND = 0/1/2 # user preference of cuDNN sub-backend\n",
+ "NVTE_FUSED_ATTN_BACKEND = 1/2 # user preference of cuDNN sub-backend\n",
"```\n",
"\n",
"**Execution paths of cuDNN sub-backend 1:**\n",
@@ -369,7 +368,7 @@
"\n",
"Note\n",
" \n",
- "Environment variables NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT and NVTE_ALLOW_NONDETERMINISTIC_ALGO are only supported in PyTorch, and will be added to JAX in the future.\n",
+ "Environment variables NVTE_FLASH_ATTN, NVTE_UNFUSED_ATTN, NVTE_FUSED_ATTN_BACKEND, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT, and NVTE_FUSED_ATTN_USE_FAv2_BWD are supported in PyTorch. NVTE_FUSED_ATTN and NVTE_ALLOW_NONDETERMINISTIC_ALGO are supported in both PyTorch and JAX.\n",
"
\n",
"\n",
"### 2.3 Example Tests\n",
diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
index 03008bb2d7..30f74e1614 100644
--- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
+++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
@@ -1059,8 +1059,13 @@ def forward(
Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
- and FusedAttention backend if applicable, to use. Transformer Engine prioritizes
- FlashAttention over FusedAttention and over UnfusedDotProductAttention.
+ and FusedAttention backend if applicable, to use. Transformer Engine first filters
+ backends by support for the runtime environment and input configuration, then applies
+ a performance-based preference order. On supported pre-Hopper GPUs, FlashAttention is
+ preferred over FusedAttention and UnfusedDotProductAttention when both optimized
+ backends are eligible. On Hopper and newer GPUs, including Blackwell, FusedAttention is
+ preferred over FlashAttention and UnfusedDotProductAttention when both optimized
+ backends are eligible.
If FusedAttention is being used, users can also choose to switch to flash-attn's
implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
(default: 0), because of the performance differences between various versions of