About Blog Contact

What's 'order' in Triton make_block_ptr

I’ve been trying to figure out what order in tl.make_block_ptr means. The docs doesn’t help:

order – The order of the original data format

Isn’t stride enough?

Update:

I asked around, and found out order is just the order of strides:

order = np.argsort(-1 * strides)

This parameter is necessary because in some cases, strides may not be known at compile time. To best utilize Hopper TMA, the compiler needs a hint as to the order of strides. Its value doesn’t matter at all for anything below Hopper, i.e. anything earlier than sm_90.

Original answer below:

Anyway, I wrote a small test. It launches two kernels, the only difference is the order parameter to the make_block_ptr function that loads the input. One is 0, 1, 2, other is 2, 1, 0.

import shutil

import modal


triton_image = modal.Image.from_registry(
    "pytorch/pytorch:2.2.2-cuda12.1-cudnn8-devel")

with triton_image.imports():
    import torch
    import triton
    import triton.language as tl

vol = modal.Volume.from_name("triton-load-block-test", create_if_missing=True)
stub = modal.Stub(
    "triton-load-block-test",
    image=triton_image
)


@triton.jit
def load_block_kernel(a_ptr: tl.tensor,
                      output_ptr: tl.tensor,
                      a_x: tl.constexpr,
                      a_y: tl.constexpr,
                      a_z: tl.constexpr,
                      load_block_order: tl.constexpr):
    a_block_ptr = tl.make_block_ptr(
        base=a_ptr,
        shape=(a_x, a_y, a_z),
        strides=(a_y * a_z,  a_z,  1),
        offsets=(0, 0, 0),
        block_shape=(a_x, a_y, a_z),
        order=(0, 1, 2) if load_block_order else (2, 1, 0),
    )

    a = tl.load(a_block_ptr)
    out = (a + 1.0).to(a.dtype)
    o_block_ptr = tl.make_block_ptr(
        base=output_ptr,
        shape=(a_x, a_y, a_z),
        strides=(a_y * a_z, a_z, 1),
        offsets=(0, 0, 0),
        block_shape=(a_x, a_y, a_z),
        order=(0, 1, 2)
    )

    tl.store(o_block_ptr, out)


GPU_TYPE = 'A100'

@stub.function(gpu=GPU_TYPE, image=triton_image, volumes={"/root/triton": vol})
def run_load_block(load_block_order: bool):
    vol.reload()
    m, n, k = 4096, 16, 2
    a = torch.randn((m, n, k), dtype=torch.float16, device='cuda')
    triton_output = torch.empty((m, n, k), dtype=a.dtype, device="cuda")

    def grid(_):
        return (1, )
    load_block_kernel[grid](a, triton_output, a_x=m, a_y=n,
                            a_z=k, load_block_order=load_block_order, num_warps=8, num_stages=3)

    dir_name = '012' if load_block_order else '210'
    try:
        shutil.rmtree(f'/root/triton/{GPU_TYPE}/{dir_name}')
    except FileNotFoundError:
        pass
    shutil.copytree('/root/.triton', f'/root/triton/{GPU_TYPE}/{dir_name}/')
    vol.commit()


@stub.local_entrypoint()
def main():
    run_load_block.remote(True)
    run_load_block.remote(False)

I used modal because I want to test on both A100 and H100 GPUs. Install:

pip install modal
pip3 install torch==2.2.2 --index-url https://download.pytorch.org/whl/cu121 # Actually not necessary if you only want to run remotely.

triton is automatically install along with torch. To run:

modal run test_load_block.py
# Download the `.triton` folder
modal volume get triton-load-block-test /A100/** 

The results are perplexing. First, if I set GPU_TYPE to H100, then the ttgir optimization pass fails:

loc("/root/test_load_block.py":37:16): error: 'tt.expand_dims' op inferred type(s) 'tensor<4096x1x1xi64, #triton_gpu.blocked<{sizePerThread = [8, 1, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [8, 1, 1], order = [0, 1, 2], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [2, 1, 0]}>>' are incompatible with return type(s) of operation 'tensor<4096x1x1xi64>'
loc("/root/test_load_block.py":37:16): error: 'tt.expand_dims' op failed to infer returned types

A100 succeeds. But it turns out the ttgir and ptx files are byte for byte identical. And I don’t understand the first line:

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [8, 1, 1], order = [2, 0, 1], CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [2, 1, 0]}>

What’s order, and what’s CTAOrder? There’re no CTAs on A100s.

Digging into the source code doesn’t help either. This line is as far as I’m able to get. The table-gen definition doesn’t say anything about order.