u/NoVibeCoding

A hackable compiler to generate efficient fused GPU kernels for AI models [P]

The modern ML (LLM) compiler stack is brutal. TVM is 500K+ lines of C++. PyTorch piles Dynamo, Inductor, and Triton on top of each other. I built a hackable LLM compiler from scratch and am documenting the process. It takes a small model (TinyLlama, Qwen2.5-7B) and lowers it to a sequence of CUDA kernels through six IRs.

Currently, on RTX 5090, the emitted FP32 kernels run at geomean 1.11× vs PyTorch eager and 1.20× vs torch.compile, with full-block parity on TinyLlama-128 and Qwen2.5-7B at seq=128. Wins on small reductions / SDPA / kv-projections (up to 4.7×); losses on dense matmul at seq=512.

Part 1 took an RMSNorm layer end-to-end and walked the upper half of that pipeline in detail. This second part closes the gap and explains Tile IR, Kernel IR, and associated lowering rules in depth.

Full article: A Principled ML Compiler Stack in 5,000 Lines of Python

The article focuses on producing a GPU schedule for an operation written in loop-nest form (Loop IR). Example for RMSNorm:

v0 = reciprocal(2048)
for a0 in 0..32:  # free
    for a1 in 0..2048:  # reduce
        in2 = load x[0, a0, a1]
        v1 = multiply(in2, in2)
        acc0 <- add(acc0, v1)
    v2 = multiply(acc0, v0)
    v3 = add(v2, 1e-06)
    v4 = rsqrt(v3)
    for a2 in 0..2048:  # free
        in3 = load x[0, a0, a2]
        in4 = load p_weight[a2]
        v5 = multiply(in3, v4)
        v6 = multiply(v5, in4)
        merged_n0[0, a0, a2] = v6

The stack mimics a sequence of optimization steps a CUDA engineer would perform when optimizing kernels: stage inputs to smem, reduce bank conflicts, increase occupancy, and so on.

LoopOp
  │
  ▼
[001] tileify                 — lift outer free Loops to thread axes
[002] chunk_matmul_k          — chunk the K reduce into K-outer × K-inner (intra-CTA)
[003] split_matmul_k          — promote the K-outer chunk loop into a grid dimension
[004] cooperative_reduce      — let multiple threads share one reduce; tree-merge with Combine
[005] blockify_launch         — pick block extents; partition free axes into BLOCK and THREAD
[006] chunk_reduce            — chunk non-matmul reduces so their Loads fit in shared memory
[007] stage_inputs            — hoist hot input slabs into Stage nodes
[008] register_tile           — replicate the inner tile so each thread owns a register block
[009] permute_register_tile   — reorder the register strip so bank-conflicting loads land on far columns
[010] double_buffer           — promote K-outer Stages to BufferedStage (ping-pong)
[011] tma_copy                — narrow eligible BufferedStages to TmaBufferedStage (sm_90+)
[012] split_inner_for_swizzle — split the inner cache axis of a TmaBufferedStage for swizzle
[013] async_copy              — narrow the rest to AsyncBufferedStage (cp.async, sm_80+)
[014] pad_smem                — pad shared-memory strides to break bank conflicts
[015] pipeline_k_outer        — rotate the K-outer loop into prologue/steady-state/epilogue (cp.async + TMA)
[016] mark_unroll             — annotate small inner loops for #pragma unroll
  │
  ▼
TileOp (fully scheduled)

Each stage can be reproduced with a CLI command. For example, the stage_inputs pass stages input buffers into smem if possible and if there is a benefit in doing that (inputs are being read multiple times within CTA). To see it, the following command can be used:

deplodock compile \
  -c "torch.nn.RMSNorm(2048)(torch.randn(1,32,2048))" \
  --ir tile -vv \
  | awk '/^>>> t:007/,/^<<< t:007/'
>>> t:007_stage_inputs
@@ matched at rms_norm (in-place) @@
@@ -2,6 +2,7 @@
   v0 = reciprocal(2048)
   Tile(axes=(a0:256=THREAD, a1:32=BLOCK)):
+      x_smem = Stage(x, origin=(0, a1, 0), slab=(a2:2048@2))
       StridedLoop(a2 = a0; < 2048; += 256):  # reduce
-          in2 = load x[0, a1, a2]
+          in2 = load x_smem[a2]
           v1 = multiply(in2, in2)
           acc0 <- add(acc0, v1)
@@ -11,5 +12,5 @@
       v4 = rsqrt(v3)
       StridedLoop(a2 = a0; < 2048; += 256):  # free
-          in3 = load x[0, a1, a2]
+          in3 = load x_smem[a2]
           in4 = load p_weight[a2]
           v5 = multiply(in3, v4)
<<< t:007_stage_inputs

The final CUDA kernel for the RMSNorm layer:

deplodock compile \
  -c "torch.nn.RMSNorm(2048)(torch.randn(1,32,2048))" \
  --target sm_120 --ir cuda
extern "C" __global__
__launch_bounds__(256) void k_rms_norm_reduce(
    const float* x, const float* p_weight, float* rms_norm) {
    float v0 = 1.0f / 2048.0f;
    int a1 = blockIdx.x;
    int a0 = threadIdx.x;
    int lane = threadIdx.x & 31;
    int warp = threadIdx.x >> 5;
    float acc0 = 0.0f;
    __shared__ float x_smem[2048];
    for (int x_smem_flat = a0; x_smem_flat < 2048; x_smem_flat += 256) {
        float x_smem_v = x[a1 * 2048 + x_smem_flat];
        x_smem[x_smem_flat] = x_smem_v;
    }
    __syncthreads();
    for (int a2 = a0; a2 < 2048; a2 += 256) {
        float in2 = x_smem[a2];
        float v1 = in2 * in2;
        acc0 += v1;
    }
    float acc0_w = acc0;
    acc0_w = acc0_w + __shfl_xor_sync(0xffffffff, acc0_w, 16);
    acc0_w = acc0_w + __shfl_xor_sync(0xffffffff, acc0_w, 8);
    acc0_w = acc0_w + __shfl_xor_sync(0xffffffff, acc0_w, 4);
    acc0_w = acc0_w + __shfl_xor_sync(0xffffffff, acc0_w, 2);
    acc0_w = acc0_w + __shfl_xor_sync(0xffffffff, acc0_w, 1);
    __shared__ float acc0_smem[8];
    if (lane == 0) {
        acc0_smem[warp] = acc0_w;
    }
    __syncthreads();
    for (int s = 4; s > 0; s >>= 1) {
        if (warp < s) {
            acc0_smem[warp] = acc0_smem[warp] + acc0_smem[warp + s];
        }
        __syncthreads();
    }
    float acc0_b = acc0_smem[0];
    float v2 = acc0_b * v0;
    float v3 = v2 + 1e-06f;
    float v4 = rsqrtf(v3);
    for (int a2 = a0; a2 < 2048; a2 += 256) {
        float in3 = x_smem[a2];
        float in4 = p_weight[a2];
        float v5 = in3 * v4;
        float v6 = v5 * in4;
        rms_norm[a1 * 2048 + a2] = v6;
    }
}
reddit.com
u/NoVibeCoding — 2 days ago
▲ 62 r/CUDA+1 crossposts

Writing an LLM compiler from scratch [Part 2]: Lowering to a GPU Schedule

The modern ML (LLM) compiler stack is brutal. TVM is 500K+ lines of C++. PyTorch piles Dynamo, Inductor, and Triton on top of each other. I built a hackable LLM compiler from scratch and am documenting the process. It takes a small model (TinyLlama, Qwen2.5-7B) and lowers it to a sequence of CUDA kernels through six IRs.

Currently, on RTX 5090, the emitted FP32 kernels run at geomean 1.11× vs PyTorch eager and 1.20× vs torch.compile, with full-block parity on TinyLlama-128 and Qwen2.5-7B at seq=128. Wins on small reductions / SDPA / kv-projections (up to 4.7×); losses on dense matmul at seq=512.

Part 1 took an RMSNorm layer end-to-end and walked the upper half of that pipeline in detail. This second part closes the gap and explains Tile IR, Kernel IR, and associated lowering rules in depth.

Full article: A Principled ML Compiler Stack in 5,000 Lines of Python Repo: deplodock

The article focuses on producing a GPU schedule for an operation written in loop-nest form (Loop IR). Example for RMSNorm:

v0 = reciprocal(2048)
for a0 in 0..32:  # free
    for a1 in 0..2048:  # reduce
        in2 = load x[0, a0, a1]
        v1 = multiply(in2, in2)
        acc0 <- add(acc0, v1)
    v2 = multiply(acc0, v0)
    v3 = add(v2, 1e-06)
    v4 = rsqrt(v3)
    for a2 in 0..2048:  # free
        in3 = load x[0, a0, a2]
        in4 = load p_weight[a2]
        v5 = multiply(in3, v4)
        v6 = multiply(v5, in4)
        merged_n0[0, a0, a2] = v6

The stack mimics a sequence of optimization steps a CUDA engineer would perform when optimizing kernels: stage inputs to smem, reduce bank conflicts, increase occupancy, and so on.

LoopOp
  │
  ▼
[001] tileify                 — lift outer free Loops to thread axes
[002] chunk_matmul_k          — chunk the K reduce into K-outer × K-inner (intra-CTA)
[003] split_matmul_k          — promote the K-outer chunk loop into a grid dimension
[004] cooperative_reduce      — let multiple threads share one reduce; tree-merge with Combine
[005] blockify_launch         — pick block extents; partition free axes into BLOCK and THREAD
[006] chunk_reduce            — chunk non-matmul reduces so their Loads fit in shared memory
[007] stage_inputs            — hoist hot input slabs into Stage nodes
[008] register_tile           — replicate the inner tile so each thread owns a register block
[009] permute_register_tile   — reorder the register strip so bank-conflicting loads land on far columns
[010] double_buffer           — promote K-outer Stages to BufferedStage (ping-pong)
[011] tma_copy                — narrow eligible BufferedStages to TmaBufferedStage (sm_90+)
[012] split_inner_for_swizzle — split the inner cache axis of a TmaBufferedStage for swizzle
[013] async_copy              — narrow the rest to AsyncBufferedStage (cp.async, sm_80+)
[014] pad_smem                — pad shared-memory strides to break bank conflicts
[015] pipeline_k_outer        — rotate the K-outer loop into prologue/steady-state/epilogue (cp.async + TMA)
[016] mark_unroll             — annotate small inner loops for #pragma unroll
  │
  ▼
TileOp (fully scheduled)

Each stage can be reproduced with a CLI command. For example, the stage_inputs pass stages input buffers into smem if possible and if there is a benefit in doing that (inputs are being read multiple times within CTA). To see it, the following command can be used:

deplodock compile \
  -c "torch.nn.RMSNorm(2048)(torch.randn(1,32,2048))" \
  --ir tile -vv \
  | awk '/^>>> t:007/,/^<<< t:007/'
>>> t:007_stage_inputs
@@ matched at rms_norm (in-place) @@
@@ -2,6 +2,7 @@
   v0 = reciprocal(2048)
   Tile(axes=(a0:256=THREAD, a1:32=BLOCK)):
+      x_smem = Stage(x, origin=(0, a1, 0), slab=(a2:2048@2))
       StridedLoop(a2 = a0; < 2048; += 256):  # reduce
-          in2 = load x[0, a1, a2]
+          in2 = load x_smem[a2]
           v1 = multiply(in2, in2)
           acc0 <- add(acc0, v1)
@@ -11,5 +12,5 @@
       v4 = rsqrt(v3)
       StridedLoop(a2 = a0; < 2048; += 256):  # free
-          in3 = load x[0, a1, a2]
+          in3 = load x_smem[a2]
           in4 = load p_weight[a2]
           v5 = multiply(in3, v4)
<<< t:007_stage_inputs

The final CUDA kernel for the RMSNorm layer:

deplodock compile \
  -c "torch.nn.RMSNorm(2048)(torch.randn(1,32,2048))" \
  --target sm_120 --ir cuda
extern "C" __global__
__launch_bounds__(256) void k_rms_norm_reduce(
    const float* x, const float* p_weight, float* rms_norm) {
    float v0 = 1.0f / 2048.0f;
    int a1 = blockIdx.x;
    int a0 = threadIdx.x;
    int lane = threadIdx.x & 31;
    int warp = threadIdx.x >> 5;
    float acc0 = 0.0f;
    __shared__ float x_smem[2048];
    for (int x_smem_flat = a0; x_smem_flat < 2048; x_smem_flat += 256) {
        float x_smem_v = x[a1 * 2048 + x_smem_flat];
        x_smem[x_smem_flat] = x_smem_v;
    }
    __syncthreads();
    for (int a2 = a0; a2 < 2048; a2 += 256) {
        float in2 = x_smem[a2];
        float v1 = in2 * in2;
        acc0 += v1;
    }
    float acc0_w = acc0;
    acc0_w = acc0_w + __shfl_xor_sync(0xffffffff, acc0_w, 16);
    acc0_w = acc0_w + __shfl_xor_sync(0xffffffff, acc0_w, 8);
    acc0_w = acc0_w + __shfl_xor_sync(0xffffffff, acc0_w, 4);
    acc0_w = acc0_w + __shfl_xor_sync(0xffffffff, acc0_w, 2);
    acc0_w = acc0_w + __shfl_xor_sync(0xffffffff, acc0_w, 1);
    __shared__ float acc0_smem[8];
    if (lane == 0) {
        acc0_smem[warp] = acc0_w;
    }
    __syncthreads();
    for (int s = 4; s > 0; s >>= 1) {
        if (warp < s) {
            acc0_smem[warp] = acc0_smem[warp] + acc0_smem[warp + s];
        }
        __syncthreads();
    }
    float acc0_b = acc0_smem[0];
    float v2 = acc0_b * v0;
    float v3 = v2 + 1e-06f;
    float v4 = rsqrtf(v3);
    for (int a2 = a0; a2 < 2048; a2 += 256) {
        float in3 = x_smem[a2];
        float in4 = p_weight[a2];
        float v5 = in3 * v4;
        float v6 = v5 * in4;
        rms_norm[a1 * 2048 + a2] = v6;
    }
}
cloudrift.ai
u/NoVibeCoding — 3 days ago

Hey r/MachineLearning,

The modern ML (LLM) compiler stack is brutal. TVM is 500K+ lines of C++. PyTorch piles Dynamo, Inductor, and Triton on top of each other. Then there's XLA, MLIR, Halide, Mojo. There is no tutorial that covers the high-level design of an ML compiler without dropping you straight into the guts of one of these frameworks.

I built a reference compiler from scratch in ~5K lines of pure Python that emits raw CUDA. It takes a small model (TinyLlama, Qwen2.5-7B) and lowers it to a sequence of CUDA kernels through six IRs. The goal isn't to beat Triton; it is to build a hackable, easy-to-follow compiler.

Full article: A Principled ML Compiler Stack in 5,000 Lines of Python

Repo: deplodock

The pipeline consists of six IRs, each closer to the hardware than the last. Walking the following PyTorch code through every stage (real reference compiler output with names shortened for brevity and comments added):

torch.relu(torch.matmul(x + bias, w))   # x: (16, 64), bias: (64,), w: (64, 16)

Torch IR. Captured FX graph, 1:1 mirror of PyTorch ops:

bias_bc =  bias[j]                          -> (16, 64) float32
add     =  add(x, bias_bc)                  -> (16, 64) float32
matmul  =  matmul(add, w, has_bias=False)   -> (16, 16) float32
relu    =  relu(matmul)                     -> (16, 16) float32

Tensor IR. Every op is decomposed into Elementwise / Reduction / IndexMap. Minimal unified op surface, so future frontends (ONNX, JAX) plug in without touching downstream passes:

bias_bc  =  bias[j]                 -> (16, 64) float32
w_bc     =  w[j, k]                 -> (16, 64, 16) float32
add      =  add(x, bias_bc)         -> (16, 64) float32
add_bc   =  add[i, j]               -> (16, 64, 16) float32
prod     =  multiply(add_bc, w_bc)  -> (16, 64, 16) float32
red      =  sum(prod, axis=-2)      -> (16, 1, 16) float32
matmul   =  red[i, na, j]           -> (16, 16) float32
relu     =  relu(matmul)            -> (16, 16) float32

The (16, 64, 16) intermediate looks ruinous, but it's never materialized; the next stage fuses it out.

Loop IR. Each kernel has a loop nest fused with adjacent kernels. Prologue, broadcasted multiply, reduction, output layout, and epilogue all collapse into a single loop nest with no intermediate buffers.

=== merged_relu -> relu ===
for a0 in 0..16:  # free (M)
    for a1 in 0..16:  # free (N)
        for a2 in 0..64:  # reduce (K)
            in0 = load bias[a2]
            in1 = load x[a0, a2]
            in2 = load w[a2, a1]
            v0 = add(in1, in0)      # prologue (inside reduce)
            v1 = multiply(v0, in2)
            acc0 <- add(acc0, v1)
        v2 = relu(acc0)             # epilogue (outside reduce)
        merged_relu[a0, a1] = v2

Tile IR. The first GPU-aware IR. Loop axes get scheduled onto threads/blocks, Stage hoists shared inputs into shared memory, and a 2×2 register tile lets each thread accumulate four outputs at once. The K-axis is tiled into two outer iterations of 32-wide reduce. Three-stage annotations below carry the heaviest optimizations:

  • buffers=2@a2 — double-buffer the smem allocation along the a2 K-tile loop, so loads for iteration a2+1 overlap compute for a2.
  • async — emit cp.async.ca.shared.global so the warp doesn't block on global→smem transfers; pairs with commit_group/wait_group fences in Kernel IR.
  • pad=(0, 1, 0) — add 1 element of padding to the middle smem dim so warp-wide loads don't all hit the same bank.kernel k_relu_reduce Tile(axes=(a0:8=THREAD, a1:8=THREAD)): for a2 in 0..2: # K-tile # meta: double-buffered, sync (small, no async needed) bias_smem = Stage(bias, origin=((a2 * 32)), slab=(a3:32@0)) buffers=2@a2

​

kernel k_relu_reduce
    Tile(axes=(a0:8=THREAD, a1:8=THREAD)):
        for a2 in 0..2:  # K-tile
            bias_smem = Stage(bias,
                              origin=((a2 * 32)),
                              slab=(a3:32@0))
                          buffers=2@a2

            x_smem = Stage(x,
                           origin=(0, (a2 * 32)),
                           slab=(a0:8@0, a3:32@1, cell:2@0)) 
                       pad=(0, 1, 0) buffers=2@a2 async

            w_smem = Stage(w,
                           origin=((a2 * 32), 0),
                           slab=(a3:32@0, a1:8@1, cell:2@1))
                       buffers=2@a2 async
  
            # reduce
            for a3 in 0..32:  
                in0 = load bias_smem[a2, a3]
                in1 = load x_smem[a2, a0, a3, 0];
                in2 = load x_smem[a2, a0, a3, 1]
                in3 = load w_smem[a2, a3, a1, 0];
                in4 = load w_smem[a2, a3, a1, 1]
                
                # prologue, reused 2× across N
                v0 = add(in1, in0); v1 = add(in2, in0)
                
                # 2×2 register tile   
                acc0 <- add(acc0, multiply(v0, in3))          
                acc1 <- add(acc1, multiply(v0, in4))
                acc2 <- add(acc2, multiply(v1, in3))
                acc3 <- add(acc3, multiply(v1, in4))

        # epilogue
        relu[a0*2,     a1*2    ] = relu(acc0)                 
        relu[a0*2,     a1*2 + 1] = relu(acc1)
        relu[a0*2 + 1, a1*2    ] = relu(acc2)
        relu[a0*2 + 1, a1*2 + 1] = relu(acc3)

Kernel IR. Schedule materialized into hardware primitives. THREAD/BLOCK become threadIdx/blockIdx, async Stage becomes Smem + cp.async fill with commit/wait fences, sync Stage becomes a strided fill loop. Framework-agnostic: same IR could lower to Metal or HIP:

kernel k_relu_reduce
    Tile(axes=(a0:8=THREAD, a1:8=THREAD)):
        Init(acc0..acc3, op=add)
        for a2 in 0..2:  # K-tile
            Smem bias_smem[2, 32] (float)
            StridedLoop(flat = a0*8 + a1; < 32; += 64):
                bias_smem[a2, flat] = load bias[a2*32 + flat]
            Sync
            
            # pad row to 33 to kill bank conflicts
            Smem x_smem[2, 8, 33, 2] (float)
            StridedLoop(flat = a0*8 + a1; < 512; += 64):
                cp.async x_smem[a2, flat/64, (flat/2)%32, flat%2]
                    <- x[flat/64*2 + flat%2, a2*32 + (flat/2)%32]
            cp.async.commit_group;  cp.async.wait_group(0);  Sync
            
            Smem w_smem[2, 32, 8, 2] (float)
            StridedLoop(flat = a0*8 + a1; < 512; += 64):
                cp.async w_smem[a2, flat/16, (flat/2)%8, flat%2]
                    <- w[a2*32 + flat/16, (flat/2)%8*2 + flat%2]
            cp.async.commit_group;  cp.async.wait_group(0);  Sync
            
            for a3 in 0..32:  # reduce
                ...

CUDA. One-to-one tree walk over Kernel IR, ready for nvcc. Bias-add, the K-axis reduction, the 2×2 register tile, and the relu activation all in one kernel. One HBM read each of x, bias, w, one HBM write of relu, no intermediates between ops.

extern "C" __global__
__launch_bounds__(256)
void k_relu_reduce(const float* bias,
                   const float* x,
                   const float* w,
                   float* relu) {
    long long tid = blockIdx.x * blockDim.x + threadIdx.x;
    if (tid < 64) {
        int a0 = tid / 8;
        int a1 = tid % 8;
        float acc0 = 0.0f, acc1 = 0.0f, acc2 = 0.0f, acc3 = 0.0f;
        #pragma unroll
        for (int a2 = 0; a2 < 2; a2++) {
            __shared__ float bias_smem[64];
            for (int f = a0*8 + a1; f < 32; f += 64)
                bias_smem[a2*32 + f] = bias[a2*32 + f];
            __syncthreads();
            
            // padded to avoid bank conflicts
            __shared__ float x_smem[1056];
            for (int f = a0*8 + a1; f < 512; f += 64) {
                unsigned int addr = __cvta_generic_to_shared(
                    &x_smem[a2*528 + f/64*66 + f/2%32*2 + f%2]
                );
                asm volatile(
                    "cp.async.ca.shared.global [%0], [%1], 4;\n"
                    :: "r"(addr),
                       "l"(&x[(f/64*2 + f%2)*64 + (a2*32 + f/2%32)])
                    : "memory");
            }
            asm volatile("cp.async.commit_group;\n"
                         ::: "memory");
            asm volatile("cp.async.wait_group 0;\n"
                         ::: "memory");
            __syncthreads();
            
            __shared__ float w_smem[1024];
            for (int f = a0*8 + a1; f < 512; f += 64) {
                unsigned int addr = __cvta_generic_to_shared(
                    &w_smem[a2*512 + f/16*16 + f/2%8*2 + f%2]
                );
                asm volatile(
                    "cp.async.ca.shared.global [%0], [%1], 4;\n"
                    :: "r"(addr),
                       "l"(&w[(a2*32 + f/16)*16 + (f/2%8*2 + f%2)])
                    : "memory");
            }
            asm volatile("cp.async.commit_group;\n"
                         ::: "memory");
            asm volatile("cp.async.wait_group 0;\n"
                         ::: "memory");
            __syncthreads();

            #pragma unroll
            for (int a3 = 0; a3 < 32; a3++) {
                float in0 = bias_smem[a2*32 + a3];
                float in1 = x_smem[a2*528 + a0*66 + a3*2    ];
                float in2 = x_smem[a2*528 + a0*66 + a3*2 + 1];
                float in3 = w_smem[a2*512 + a3*16 + a1*2    ];
                float in4 = w_smem[a2*512 + a3*16 + a1*2 + 1];
                float v0 = in1 + in0;  float v1 = in2 + in0;
                acc0 += v0 * in3;  acc1 += v0 * in4;
                acc2 += v1 * in3;  acc3 += v1 * in4;
            }
        }
        relu[a0*2*16     + a1*2    ] = fmaxf(0.0f, acc0);
        relu[a0*2*16     + a1*2 + 1] = fmaxf(0.0f, acc1);
        relu[(a0*2+1)*16 + a1*2    ] = fmaxf(0.0f, acc2);
        relu[(a0*2+1)*16 + a1*2 + 1] = fmaxf(0.0f, acc3);
    }
}

Every stage is printable on demand. No GPU needed.

deplodock compile -c "torch.relu(torch.matmul(torch.randn(16,64) + torch.randn(64), torch.randn(64,16)))" --ir tensor|loop|tile|kernel|cuda

Benchmarking against eager PyTorch and torch.compile (attention scores at Qwen-block size, where the compiler ties torch.compile):

deplodock run --bench -c "torch.nn.Softmax(dim=-1)(torch.randn(1,28,2048,2048))"

End-to-end compilation of a real model:

deplodock compile Qwen/Qwen2.5-7B

The linked article goes through the design in detail (RMSNorm walked through every IR, the σ-based fusion algorithm with blowup guard, validation against torch.compile on TinyLlama and Qwen2.5-7B blocks). The forthcoming second part will go through the codegen internals.

u/NoVibeCoding — 13 days ago
▲ 22 r/Compilers+1 crossposts

Hey r/LocalLLaMA,

I wanted to come up with a simple overview of the modern ML compiler stack, essentially what happens between model.generate()and the GPU executing a kernel. However, the stack is brutal to read. TVM is 500K+ lines of C++. PyTorch piles Dynamo, Inductor, and Triton on top of each other. Then there's XLA, MLIR, Halide, and Mojo.

Instead, I decided to take a different approach and just build one from scratch. Just pure Python and raw CUDA. Take a small model (Qwen2.5-7B, TinyLlama) and compile it into a sequence of CUDA kernels. The goal isn't to beat Triton today, but to create a hackable compiler that doesn't require a PhD in compilers to modify, or at least make it easier to follow.

The final performance is about 50-90% of the production stack (as compared to PyTorch Eager and torch.compile).

I built it in a principled way, with a layered pipeline and concerns clearly separated:

  1. Torch IR — captured FX graph (rmsnorm, linear, softmax, ...)
  2. Tensor IR — every op decomposed into Elementwise / Reduction / IndexMap
  3. Loop IR — a kernel written as a loop nest fused with other kernels
  4. Tile IR — a kernel scheduled onto the GPU (threads, blocks, shared memory)
  5. Kernel IR — schedule materialized into hardware primitives
  6. CUDA — emitted source ready for nvcc

Tensor IR is introduced to support future frontends, such as ONNX and Jax. Loop fusion handles the fusion of long pointwise and reduction chains. Lowering stages introduce optimizations such as tiled matmul, smem staging, and double-buffering.

Each stage can be inspected and debugged independently (repository link). No GPU needed:

deplodock compile -c "nn.RMSNorm(2048)(torch.randn(1,32,2048))" --ir tensor|loop|tile|kernel|cuda

Benchmarking:

deplodock run --bench --profile -c "torch.nn.Softmax(dim=-1)(torch.randn(1,28,2048,2048))"

End-to-end compilation:

deplodock compile Qwen/Qwen2.5-7B

The generated CUDA kernel for RMSNorm looks like this:

extern "C" __global__
__launch_bounds__(256) void k_rms_norm_reduce(const float* x, const float* p_weight, float* rms_norm) {
    float in0 = 2048.0f;
    float in1 = 1e-06f;
    {
        int a1 = blockIdx.x;
        int a0 = threadIdx.x;
        float acc0 = 0.0f;
        __syncthreads();
        __shared__ float x_smem[2048];
        for (int x_smem_flat = a0; x_smem_flat < 2048; x_smem_flat += 256) {
            {
                unsigned int _smem_addr = __cvta_generic_to_shared(&x_smem[x_smem_flat]);
                asm volatile("cp.async.ca.shared.global [%0], [%1], 4;\n"
                             :: "r"(_smem_addr), "l"(&x[a1 * 2048 + x_smem_flat])
                             : "memory");
            }
        }
        asm volatile("cp.async.commit_group;\n" ::: "memory");
        asm volatile("cp.async.wait_group 0;\n" ::: "memory");
        __syncthreads();
        __shared__ float p_weight_smem[2048];
        for (int p_weight_smem_flat = a0; p_weight_smem_flat < 2048; p_weight_smem_flat += 256) {
            {
                unsigned int _smem_addr = __cvta_generic_to_shared(&p_weight_smem[p_weight_smem_flat]);
                asm volatile("cp.async.ca.shared.global [%0], [%1], 4;\n"
                             :: "r"(_smem_addr), "l"(&p_weight[p_weight_smem_flat])
                             : "memory");
            }
        }
        asm volatile("cp.async.commit_group;\n" ::: "memory");
        asm volatile("cp.async.wait_group 0;\n" ::: "memory");
        __syncthreads();
        for (int a2 = a0; a2 < 2048; a2 += 256) {
            float in2 = x_smem[a2];
            float v0 = in2 * in2;
            acc0 += v0;
        }
        __shared__ float acc0_smem[256];
        acc0_smem[a0] = acc0;
        __syncthreads();
        for (int s = 128; s > 0; s >>= 1) {
            if (a0 < s) {
                acc0_smem[a0] = acc0_smem[a0] + acc0_smem[a0 + s];
            }
            __syncthreads();
        }
        __syncthreads();
        float acc0_b = acc0_smem[0];
        float v1 = acc0_b / in0;
        float v2 = v1 + in1;
        float v3 = rsqrtf(v2);
        for (int a3 = a0; a3 < 2048; a3 += 256) {
            float in3 = x_smem[a3];
            float in4 = p_weight_smem[a3];
            float v4 = in3 * v3;
            float v5 = v4 * in4;
            rms_norm[a1 * 2048 + a3] = v5;
        }
    }
}
medium.com
u/NoVibeCoding — 14 days ago