Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions tests/pytorch/test_multi_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,49 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
assert overflow_buf.item() == 0


raw_moment_size_pairs = [
(777, 555),
(2048 * 32 + 1, 555),
]


def _raw_moment_reference(tensor):
values = tensor.float()
values_2 = values * values
return torch.stack(
[
torch.tensor(float(values.numel()), dtype=torch.float32, device=tensor.device),
values.sum(),
values_2.sum(),
(values_2 * values).sum(),
(values_2 * values_2).sum(),
]
)


@pytest.mark.parametrize("input_size_pair", raw_moment_size_pairs)
@pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55])
@pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
def test_multi_tensor_raw_moments(input_size_pair, applier, repeat, in_type):
sizea, sizeb = input_size_pair
device = torch.device("cuda")
overflow_buf = torch.zeros(1, dtype=torch.int32, device=device)

a = (torch.arange(sizea, dtype=torch.float32, device=device) % 17) - 8
b = (torch.arange(sizeb, dtype=torch.float32, device=device) % 11) - 5

in_list = []
for _ in range(repeat):
in_list += [a.clone().to(in_type), b.clone().to(in_type)]

moments = applier(tex.multi_tensor_raw_moments, overflow_buf, [in_list])
references = torch.stack([_raw_moment_reference(tensor) for tensor in in_list])

torch.testing.assert_close(moments, references, rtol=1e-5, atol=1e-2)
assert overflow_buf.item() == 0


@pytest.mark.parametrize("input_size_pair", input_size_pairs)
@pytest.mark.parametrize("applier", appliers)
@pytest.mark.parametrize("repeat", [1, 55])
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ list(APPEND transformer_engine_cuda_sources
common.cu
multi_tensor/adam.cu
multi_tensor/l2norm.cu
multi_tensor/raw_moments.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu
Expand Down Expand Up @@ -569,6 +570,7 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/raw_moments.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
fused_attn/flash_attn.cu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,28 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
int per_tensor, int max_chunks_per_tensor,
cudaStream_t stream);

/*! \brief Computes raw moments for a list of tensors.
*
* The returned rows contain count and raw sums of powers 1-4 for each tensor.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] output_per_tensor Fixed size auxilliary scratch space.
* \param[out] ret Raw-moment rows for each tensor.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_raw_moments_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list,
NVTETensor output_per_tensor, NVTETensor ret,
int max_chunks_per_tensor, cudaStream_t stream);

/*! \brief Compute and apply gradient update to parameters for Adam optimizer.
*
* \warning This API is **experimental** and subject to change.
Expand Down
194 changes: 194 additions & 0 deletions transformer_engine/common/multi_tensor/raw_moments.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>

#include "../utils.cuh"
#include "multi_tensor_apply.cuh"

namespace transformer_engine {
namespace multi_tensor_raw_moments {

#define BLOCK_SIZE 512
#define ILP 4
#define RAW_MOMENT_FIELDS 5

template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
}

template <typename T>
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int src_offset) {
typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; // NOLINT(*)
}

__device__ __forceinline__ float reduce_block_sum(float *x, float val) {
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y;

if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}

#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}

float final = 0.f;
if (tid < 32) {
if (blockSize >= 64) {
final = x[tid] + x[tid + 32];
} else {
final = val;
}

#pragma unroll
for (int i = 16; i >= 1; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i);
}

__syncthreads();
return final;
}

template <typename x_t>
struct RawMomentsFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem,
TensorListMetadata<1> &tl, // NOLINT(*)
float *output_per_tensor, int max_chunks_per_tensor) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_idx = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];

x_t *x = reinterpret_cast<x_t *>(tl.addresses[0][tensor_loc]);
x += chunk_idx * chunk_size;

n -= chunk_idx * chunk_size;
int elements_this_chunk = n < chunk_size ? n : chunk_size;

__shared__ float s_vals[RAW_MOMENT_FIELDS - 1][BLOCK_SIZE];

float sum_1 = 0.f;
float sum_2 = 0.f;
float sum_3 = 0.f;
float sum_4 = 0.f;

x_t r_x[ILP];
for (int i = 0; i < ILP; i++) r_x[i] = 0;

if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
load_store(r_x, x, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
float val = static_cast<float>(r_x[ii]);
float val_2 = val * val;
sum_1 += val;
sum_2 += val_2;
sum_3 += val_2 * val;
sum_4 += val_2 * val_2;
}
}
} else {
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
float val = static_cast<float>(x[i]);
float val_2 = val * val;
sum_1 += val;
sum_2 += val_2;
sum_3 += val_2 * val;
sum_4 += val_2 * val_2;
}
}
}
}

float final_sum_1 = reduce_block_sum(s_vals[0], sum_1);
float final_sum_2 = reduce_block_sum(s_vals[1], sum_2);
float final_sum_3 = reduce_block_sum(s_vals[2], sum_3);
float final_sum_4 = reduce_block_sum(s_vals[3], sum_4);

if (threadIdx.x == 0) {
if (!isfinite(final_sum_1) || !isfinite(final_sum_2) || !isfinite(final_sum_3) ||
!isfinite(final_sum_4)) {
*noop_gmem = 1;
}
float *row =
output_per_tensor + (tensor_idx * max_chunks_per_tensor + chunk_idx) * RAW_MOMENT_FIELDS;
row[0] = static_cast<float>(elements_this_chunk);
row[1] = final_sum_1;
row[2] = final_sum_2;
row[3] = final_sum_3;
row[4] = final_sum_4;
}
}
};

__global__ void cleanup(float *output_per_tensor, float *ret, int max_chunks_per_tensor) {
int tensor_idx = blockIdx.x;
int field_idx = blockIdx.y;
__shared__ float vals[BLOCK_SIZE];

float *chunks =
output_per_tensor + tensor_idx * max_chunks_per_tensor * RAW_MOMENT_FIELDS + field_idx;

float val = 0.f;
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) {
val += chunks[i * RAW_MOMENT_FIELDS];
}

float final = reduce_block_sum(vals, val);
if (threadIdx.x == 0) ret[tensor_idx * RAW_MOMENT_FIELDS + field_idx] = final;
}

void multi_tensor_raw_moments_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
Tensor output_per_tensor, Tensor ret, int max_chunks_per_tensor,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<1>(
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, RawMomentsFunctor<dtype>(), stream,
reinterpret_cast<float *>(output_per_tensor.data.dptr), max_chunks_per_tensor);)

NVTE_CHECK_CUDA(cudaGetLastError());

dim3 grid(tensor_lists[0].size(), RAW_MOMENT_FIELDS);
cleanup<<<grid, BLOCK_SIZE, 0, stream>>>(reinterpret_cast<float *>(output_per_tensor.data.dptr),
reinterpret_cast<float *>(ret.data.dptr),
max_chunks_per_tensor);
NVTE_CHECK_CUDA(cudaGetLastError());
}

} // namespace multi_tensor_raw_moments
} // namespace transformer_engine

void nvte_multi_tensor_raw_moments_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list,
NVTETensor output_per_tensor, NVTETensor ret,
int max_chunks_per_tensor, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_raw_moments_cuda);
using namespace transformer_engine;

multi_tensor_raw_moments::multi_tensor_raw_moments_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(output_per_tensor), *convertNVTETensorCheck(ret),
max_chunks_per_tensor, stream);
}
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,9 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python);

at::Tensor multi_tensor_raw_moments_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists);

void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "../../extensions.h"

namespace transformer_engine::pytorch {

at::Tensor multi_tensor_raw_moments_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists) {
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);

int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = 0;
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor) {
max_chunks_per_tensor = max_chunks_this_tensor;
}
}

auto ret = at::empty({ntensors, 5}, float_options);
if (max_chunks_per_tensor == 0) {
ret.zero_();
return ret;
}

auto output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor * 5}, float_options);

auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor);
auto ret_cu = makeTransformerEngineTensor(ret);

nvte_multi_tensor_raw_moments_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(),
num_lists, num_tensors, output_per_tensor_cu.data(),
ret_cu.data(), max_chunks_per_tensor,
at::cuda::getCurrentCUDAStream());

return ret;
}

} // namespace transformer_engine::pytorch
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only "
"performed for L2 norm computation, and tensors are not updated)",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_raw_moments", &transformer_engine::pytorch::multi_tensor_raw_moments_cuda,
"Computes count and raw sums of powers 1-4 for a list of contiguous tensors",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam", &transformer_engine::pytorch::multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
multi_tensor_scale_tensor,
multi_tensor_l2norm,
multi_tensor_unscale_l2norm,
multi_tensor_raw_moments,
multi_tensor_adam,
multi_tensor_adam_fp8,
multi_tensor_adam_capturable,
Expand Down
Loading