reduce_launch
reduce_launch[num_reductions: Int, input_fn: def[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: def[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None, reduce_fn: def[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width], rank: Int, dtype: DType](shape: IndexList[rank], axis: Int, init: StaticTuple[Scalar[dtype], num_reductions], ctx: DeviceContext)
Selects and launches the appropriate GPU reduction kernel based on the tensor shape, axis, and device saturation level.
Three-tier dispatch:
- Thread-saturated (many rows, non-contiguous axis): one row per thread
via
saturated_reduce_kernel. - Block-saturated (enough rows to fill SMs at one block per row):
reduce_kernelorsmall_reduce_kernel. - Under-saturated (too few rows to fill the device): multiple blocks per
row via
twophase_reduce_kernelwith a two-phase atomic finish.
Parameters:
- num_reductions (
Int): The number of fused reductions to perform. - input_fn (
def[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]): The lambda to load input elements. - output_fn (
def[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None): The lambda to store output elements. - reduce_fn (
def[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width]): The binary reduction function. - rank (
Int): The tensor rank. - dtype (
DType): The data type of the elements.
Args:
- shape (
IndexList): The shape of the input tensor. - axis (
Int): The axis along which to reduce. - init (
StaticTuple): The identity values for each reduction. - ctx (
DeviceContext): The device context for GPU execution.
Raises:
If the GPU kernel launch fails.