For the complete Mojo documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /docs/manual/basics.md).
cp_async_bulk_tensor_shared_cluster_global_im2col_multicast
cp_async_bulk_tensor_shared_cluster_global_im2col_multicast[dst_type: AnyType, mbr_type: AnyType, tensor_rank: Int, /, *, cta_group: Int = 1](dst_mem: UnsafePointer[dst_type, address_space=AddressSpace.SHARED], tma_descriptor: UnsafePointer[NoneType], mem_bar: UnsafePointer[mbr_type, address_space=AddressSpace.SHARED], coords: IndexList[tensor_rank], filter_offsets: IndexList[(tensor_rank - 2)], multicast_mask: UInt16)
Initiates an asynchronous multicast TMA load with im2col addressing.
This combines im2col addressing with multicast, distributing the loaded data to multiple CTAs in a cluster.
For 2D convolution with 4D NHWC tensor:
- coords: (c, w, h, n) - channel, output spatial, batch
- filter_offsets: (offset_w, offset_h) - position within filter window
PTX instruction formats differ based on cta_group:
- cta_group=1: Uses SM90-style multicast im2col PTX (no cta_group modifier) cp.async.bulk.tensor.4d.shared::cluster.global.im2col...multicast::cluster
- cta_group=2: Uses SM100-style multicast im2col PTX with cta_group::2 (from CUTLASS) cp.async.bulk.tensor.4d.im2col.cta_group::2.shared::cluster.global...multicast::cluster...
Parameters:
- dst_type (
AnyType): The data type of the destination memory. - mbr_type (
AnyType): The data type of the memory barrier. - tensor_rank (
Int): The rank of the tensor (3, 4, or 5). - cta_group (
Int): The CTA group to use for the copy operation. Must be 1 or 2.
Args:
- dst_mem (
UnsafePointer[dst_type, address_space=AddressSpace.SHARED]): Pointer to the destination in shared memory. - tma_descriptor (
UnsafePointer[NoneType]): Pointer to the TMA im2col descriptor. - mem_bar (
UnsafePointer[mbr_type, address_space=AddressSpace.SHARED]): Pointer to the shared memory barrier. - coords (
IndexList[tensor_rank]): Tensor coordinates (c, w, h, n for 4D). - filter_offsets (
IndexList[(tensor_rank - 2)]): Filter window offsets (offset_w, offset_h for 4D). - multicast_mask (
UInt16): Bitmask specifying target CTAs for multicast.