Kernel code, of the kind that runs low-level operations like matrix multiplication, convolution, and so on, is usually written in a procedural style. Why is this? The short answer is that this style is closest to how hardware operates, and allows easier reasoning about performance. This is key to writing a good kernel.
To put that into context, let’s first recap: what are the styles of programming? We can divide them into three paradigms: procedural, functional, and object-oriented. You could count many more, but these are a good start.
And what’s the difference between these paradigms? Mostly they relate to how you manage the state of your program. In brief, the differences are as follows:
Functional
In functional programming, data is immutable and all logic is expressed in terms of pure functions that map the given inputs to some new outputs and don’t make any other changes to the state of the program. A program is compose of chains of functions, and the state is defined by the data at each step as it flows through those functions.
Object-oriented
In object-oriented programming, state doesn’t really flow like that. It is often bound together with logic and hidden behind some interface. This is an object, and it will probably hang around for some time while you use it. But when you are done with it, you had better tidy up or you may quickly run out of memory. We mainly think of the program in terms of these objects and how they interact with each other.
Each object may have logic attached to it that allows us to query its state, mutate it, or
perform some other transformation. It’s as if we animated some of the data in our program.
Our data doesn’t just define the state of the program; it has a body too that you can interact
with. For example, think of an object defining a neural network e.g. a torch.nn.Module. Not only does it have data
describing the parameters of the network, but it may also have methods to modify those
parameters, run the model, save a checkpoint, and so on.
Procedural
Next comes procedural programming. Of these three, it is the earliest. In procedural programming we can have many functions, but rather than operating on immutable data and producing brand new outputs, these functions typically operate on some shared mutable state that is passed around the program. For example, that data may be stored in a buffer.
What’s a buffer? A buffer is a region of memory used to temporarily hold data while it is being moved, processed, read, or written. These buffers can be very close in concept to the real registers and memory that the program is running on. In reality, hardware isn’t immutable. The idea of immutability is a concept that was invented, and is certainly useful, but it doesn’t describe what’s really going on in the bits of your device.
As such, there isn’t much abstraction between a buffer and a symbol in your program. That symbol is a reference pointing to a physical place in an electronic device, and your program is manipulating the data in that place. This is just the reality of what you’re doing, rather than thinking in terms of abstract objects or immutable data.
That said, just because this is the reality doesn’t mean you always want to work at that level. After all, you could operate in machine code directly if all you valued was realism. The reason for preferring procedural programming is if it helps you achieve a goal, such as performance.
What limits performance?
One of the primary issues that typically limits performance is how quickly your processor can access the data on which it operates, and we will see that many of the approaches to dealing with this fit into a procedural style.
A modern processor can do arithmetic very fast, much faster than it can pull data from main memory. A multiply-add can be issued roughly once per cycle; fetching a value from DRAM can cost hundreds. So if your kernel is forever reading all the way out to main memory for its next number, the arithmetic units spend most of their time sitting idle.
This is why memory is often the limiting factor, especially for memory-heavy calculations such as running an LLM. Often this problem has been tackled using a hierarchy of storage. A few registers sit right next to the arithmetic units of your processor, then small, fast caches of increasing size (L1, L2, L3), then large, slow memory further out. On a GPU you might have HBM and shared memory in similar roles. Each level closer to the processor is faster, but smaller. The closer your data sits to the processor when needed, the less time you waste. So writing a fast kernel is, to a large degree, the task of keeping the data the processor needs as close as possible, and reusing it as much as you can.
Three concepts are especially impactful here.
Data layout and packing
Before you can keep data close to the processor, you have to think about how it sits in memory in the first place. The fact is that memory doesn’t move one value at a time. When the processor reads from memory it pulls a cache line, typically 64 bytes, enough for sixteen 32-bit floats, not just the single number you asked for. So whenever you touch one value, its neighbours in memory are brought in as well.
This makes the order of your data matter significantly. If the next values your kernel needs are the ones sitting right beside the current one, they’ve already arrived and the fetch is paid for. If instead each value lives far from the last, a large stride apart, then every access pulls a fresh line and you use only one number out of the sixteen, wasting most of the bandwidth you just spent. This is an area in which too much abstraction can be a serious problem, as it can hide the details that most critically affect performance. By using a procedural style in which you clearly control the way the data in your buffers is accessed, you can more clearly reason about performance issues and, if necessary, rearrange the data.
Tiling
For matrix multiplication specifically, one of the most useful locality techniques is tiled matrix multiplication. Where data layout is about spatial locality (putting the values you need next to each other in memory), tiling is about temporal locality (reusing each value many times while it is still close to the processor). When you write this on paper:
C[i, j] = sum over k of A[i, k] * B[k, j]
the straightforward way to compute it is to walk across a whole row of one matrix and a whole column of the other for every single output element. For anything but a tiny matrix, the data you need doesn’t all fit in cache, so as you go through the computation the same values get fetched from main memory again and again, each trip costing hundreds of cycles.
Tiling helps to fix this. Instead of working element by element, you break the matrices into small sub-blocks chosen to be small enough to fit in fast cache. You load a tile from each matrix once, do every multiply-and-accumulate those tiles contribute to while they’re sitting in cache, and only then move on. The slow fetch is paid for once and amortised over a whole block of arithmetic. The accumulating output tile, meanwhile, lives in registers or cache and is updated in place as the partial results come in.
SIMD
Whereas tiling is about getting data close to the processor, there’s also a second lever that can be very useful once the data is ready: how much work can you do per instruction?
An ordinary instruction operates on one number at a time. Add this to that, multiply these two together. But the arithmetic hardware on a modern processor is wide. A single unit can work through a whole row of numbers in one go, if you feed it the right way. This is single instruction, multiple data (SIMD): one multiply applied across eight or sixteen values at once, in a single step.
To use this you have to think in terms of the vector registers the hardware uses. You load a contiguous chunk of a buffer into a wide register, issue one instruction across all its lanes, and write the result back. This can give up to an 8 to 16 times boost in operations per instruction, depending on the data type: a 512-bit AVX-512 register holds sixteen 32-bit floats, but only eight 64-bit doubles. So a single fused multiply-add works through all sixteen float multiply-adds in one step, where scalar code would need sixteen separate instructions. That factor is a ceiling on the work done per instruction, not an end-to-end speedup; a kernel that cannot feed the lanes fast enough from memory will not reach it.
Procedural Inner Loop
Looking back at the three levers: Data layout is about controlling exactly where your bytes sit so that the right ones travel together. Tiling is about controlling when each block of data is loaded and how long it stays resident. SIMD is about controlling how many values a single instruction acts on. Each one is a statement of manual control over data and the machine that moves it. But having that control means minimising the abstraction that comes with object-oriented programming, and also accepting that data on device is mutable.
That said, it doesn’t really mean that object-oriented or functional ideas aren’t useful when organising kernels. For example, having defined a clear interface and implementation, kernels can be treated as functions from the outside, or we may dispatch at runtime to different kernel implementations, often generated via templating or overloading, based on some configuration of the program. JAX is a good example of the functional approach: you write pure functions over immutable arrays, and its jit compiler traces them into an XLA graph that is lowered into mutable, buffer-reusing procedural kernels described above. But that lies at or above the interface of the kernel, and the implementation below is largely in the procedural domain.