Skip to main content
Version: Nightly

reduce_launch

reduce_launch[num_reductions: Int, input_fn: def[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: def[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None, reduce_fn: def[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width], rank: Int, dtype: DType](shape: IndexList[rank], axis: Int, init: StaticTuple[Scalar[dtype], num_reductions], ctx: DeviceContext)

Selects and launches the appropriate GPU reduction kernel based on the tensor shape, axis, and device saturation level.

Three-tier dispatch:

  1. Thread-saturated (many rows, non-contiguous axis): one row per thread via saturated_reduce_kernel.
  2. Block-saturated (enough rows to fill SMs at one block per row): reduce_kernel or small_reduce_kernel.
  3. Under-saturated (too few rows to fill the device): multiple blocks per row via twophase_reduce_kernel with a two-phase atomic finish.

Parameters:

  • num_reductions (Int): The number of fused reductions to perform.
  • input_fn (def[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]): The lambda to load input elements.
  • output_fn (def[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None): The lambda to store output elements.
  • reduce_fn (def[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width]): The binary reduction function.
  • rank (Int): The tensor rank.
  • dtype (DType): The data type of the elements.

Args:

  • shape (IndexList): The shape of the input tensor.
  • axis (Int): The axis along which to reduce.
  • init (StaticTuple): The identity values for each reduction.
  • ctx (DeviceContext): The device context for GPU execution.

Raises:

If the GPU kernel launch fails.