get_accum_type
get_accum_type[dtype: DType, *, preferred_accum_type: DType = DType.float32]() -> DType
Returns the recommended dtype for accumulation operations.
Half precision and float8 types can introduce numerical error if they are used in reduction/accumulation operations. This method returns a higher precision dtype to use for accumulation if a half precision types is provided, otherwise it returns the original dtype.
The rules are as follows: - If the dtype is a float8 type, return a float16 type. - If the dtype is a bfloat16 precision type, return a float32 type. - If the dtype is a float16 precision type, return a float32 dtype if the preferred_accum_type is float32, otherwise return a float16 type. - Otherwise, return the original type.
Parameters:
- dtype (
DType): The dtype of some accumulation operation. - preferred_accum_type (
DType): The preferred dtype for accumulation.
Returns:
DType: The recommended dtype for accumulation operations based on the input
dtype and the preferred accumulation type.