What's 'order' in Triton make_block_ptr
I’ve been trying to figure out what order
in tl.make_block_ptr
means. The docs doesn’t help:
order – The order of the original data format
Isn’t stride enough?
Update:
I asked around, and found out order
is just the order of strides
:
order = np.argsort(-1 * strides)
This parameter is necessary because in some cases, strides
may not be known at compile time. To best utilize Hopper TMA, the compiler needs a hint as to the order of strides. Its value doesn’t matter at all for anything below Hopper, i.e. anything earlier than sm_90
.
Original answer below:
Anyway, I wrote a small test. It launches two kernels, the only difference is the order
parameter to the make_block_ptr
function that loads the input. One is 0, 1, 2
, other is 2, 1, 0
.
import shutil
import modal
triton_image = modal.Image.from_registry(
"pytorch/pytorch:2.2.2-cuda12.1-cudnn8-devel")
with triton_image.imports():
import torch
import triton
import triton.language as tl
vol = modal.Volume.from_name("triton-load-block-test", create_if_missing=True)
stub = modal.Stub(
"triton-load-block-test",
image=triton_image
)
@triton.jit
def load_block_kernel(a_ptr: tl.tensor,
output_ptr: tl.tensor,
a_x: tl.constexpr,
a_y: tl.constexpr,
a_z: tl.constexpr,
load_block_order: tl.constexpr):
a_block_ptr = tl.make_block_ptr(
base=a_ptr,
shape=(a_x, a_y, a_z),
strides=(a_y * a_z, a_z, 1),
offsets=(0, 0, 0),
block_shape=(a_x, a_y, a_z),
order=(0, 1, 2) if load_block_order else (2, 1, 0),
)
a = tl.load(a_block_ptr)
out = (a + 1.0).to(a.dtype)
o_block_ptr = tl.make_block_ptr(
base=output_ptr,
shape=(a_x, a_y, a_z),
strides=(a_y * a_z, a_z, 1),
offsets=(0, 0, 0),
block_shape=(a_x, a_y, a_z),
order=(0, 1, 2)
)
tl.store(o_block_ptr, out)
GPU_TYPE = 'A100'
@stub.function(gpu=GPU_TYPE, image=triton_image, volumes={"/root/triton": vol})
def run_load_block(load_block_order: bool):
vol.reload()
m, n, k = 4096, 16, 2
a = torch.randn((m, n, k), dtype=torch.float16, device='cuda')
triton_output = torch.empty((m, n, k), dtype=a.dtype, device="cuda")
def grid(_):
return (1, )
load_block_kernel[grid](a, triton_output, a_x=m, a_y=n,
a_z=k, load_block_order=load_block_order, num_warps=8, num_stages=3)
dir_name = '012' if load_block_order else '210'
try:
shutil.rmtree(f'/root/triton/{GPU_TYPE}/{dir_name}')
except FileNotFoundError:
pass
shutil.copytree('/root/.triton', f'/root/triton/{GPU_TYPE}/{dir_name}/')
vol.commit()
@stub.local_entrypoint()
def main():
run_load_block.remote(True)
run_load_block.remote(False)
I used modal because I want to test on both A100 and H100 GPUs. Install:
pip install modal
pip3 install torch==2.2.2 --index-url https://download.pytorch.org/whl/cu121 # Actually not necessary if you only want to run remotely.
triton is automatically install along with torch. To run:
modal run test_load_block.py
# Download the `.triton` folder
modal volume get triton-load-block-test /A100/**
The results are perplexing. First, if I set GPU_TYPE
to H100, then the ttgir optimization pass fails:
loc("/root/test_load_block.py":37:16): error: 'tt.expand_dims' op inferred type(s) 'tensor<4096x1x1xi64, #triton_gpu.blocked<{sizePerThread = [8, 1, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [8, 1, 1], order = [0, 1, 2], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [2, 1, 0]}>>' are incompatible with return type(s) of operation 'tensor<4096x1x1xi64>'
loc("/root/test_load_block.py":37:16): error: 'tt.expand_dims' op failed to infer returned types
A100 succeeds. But it turns out the ttgir
and ptx
files are byte for byte identical. And I don’t understand the first line:
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [8, 1, 1], order = [2, 0, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [2, 1, 0]}>
What’s order
, and what’s CTAOrder
? There’re no CTAs on A100s.
Digging into the source code doesn’t help either. This line is as far as I’m able to get. The table-gen definition doesn’t say anything about order
.