RaggedTMA3DTile
struct RaggedTMA3DTile[dtype: DType, swizzle_mode: TensorMapSwizzle, *, BM: Int, BN: Int, group: Int = 1]
Creates a TMA descriptor for loading/storing from ragged 3D arrays with a ragged leading dimension. This loads 2D tiles, indexing into the middle dim. When using this loads, it is essential that at least BM_seq * stride space has been allocated in front of the gmem pointer, otherwise CUDA_ERROR_ILLEGAL_ADDRESS may result.
When group > 1, the gmem is treated as 4D (rows, middle_dim, group, depth)
and a 5D TMA descriptor is created. The smem tile has BM_seq * group = BM
rows, where BM_seq = BM // group is the number of distinct sequence positions.
The dynamic_dim parameter in copy methods represents valid sequence positions.
Parameters
- dtype (
DType): The data type of the tensor. - swizzle_mode (
TensorMapSwizzle): The swizzling mode to use for memory access. - BM (
Int): The number of rows of the corresponding 2D shared memory tile. - BN (
Int): The number of columns of the corresponding 2D shared memory tile. - group (
Int): The number of heads fused into each sequence position (default 1).