Speed Up PyTorch With Custom Kernels. But It Gets Progressively Darker

Speed Up PyTorch with Custom Kernels

We’ll begin with torch.compile, move on to writing a custom Triton kernel, and finally dive into designing a CUDA kernel

Read for free at alexdremov.me

PyTorch offers remarkable flexibility, allowing you to code complex GPU-accelerated operations in a matter of seconds. However, this convenience comes at a cost. PyTorch executes your code sequentially, resulting in suboptimal performance. This translates into slower model training, which impacts the iteration cycle of your experiments, the robustness of your team, the financial implications, and so on.

In this post, I’ll explore three strategies for accelerating your PyTorch operations. Each method uses softmax as our “Hello World” demonstration, but you can swap it with any function you like, and the discussed methods would still apply.

We’ll begin with torch.compile, move on to writing a custom Triton kernel, and finally dive into designing a CUDA kernel.

So, this post may get complicated, but bear with me.

torch.compile — A Quick Way to Boost Performance

💥 “Wait, you just turn on a single function call and it speeds up your code? That’s it? Sounds too good to be true.”

 — Yes.

The torch.compile is a relatively new API in PyTorch that uses runtime graph capture and kernel fusion under the hood . With one decorator, you can often see speed improvements without significant changes to your code.

Speaking simply, for example, we can speed up calculations by merging operations into one GPU function, which removes overheads of separate GPU calls. Or even better, optimize a chain of operations by replacing them with one equivalent!

Such optimizations are not possible in the regular PyTorch execution mode (eager) as it executes operations just as they are called in the code.

Softmax Implementation with torch.compile

Below is a simple example showing how to implement and compile a softmax function using torch.compile. Replace it in your model’s forward pass, and your code (hopefully) runs faster.

https://medium.com/media/4e71dab5882ab400b9e297a6328911fe/href❗ Note that you’ll have bigger speedups if you compile the whole model and not just one operation

Pros:

One line to enable the compiler.No black magic rituals needed (except for the dynamic shapes maybe).

Cons:

The first pass can be slower while it compiles; afterwards, it picks up speed.Doesn’t always produce dramatic speed-ups for all models and can occasionally break if your code is too creative.Still has problems with handling dynamic shapes.😡 Dynamic shapes compilation mode is needed when input shapes change and we don’t want to recompile the code for each specific size.

The ways to debug this is a whole new article.

Triton Code — Write GPU Kernels With Python Breeze

Why Use Triton?

Triton is a language that compiles to efficient GPU kernels while letting you write Pythonic code. It’s used under the hood of PyTorch’s dynamo/inductor stack, but you can also write your own custom ops! For many matrix/tensor operations — like softmax — you can get huge speed-ups. Because why wait for official PyTorch kernels when you can write your own?

Softmax in Triton

Here’s a minimal snippet that shows how we might do a naive softmax forward in Triton. I’ll keep it short and sweet for demonstration. In a real project, you’d likely do more advanced tiling and block management.

💥 This may look complicated, but you just need to get familiar with Triton, and it will start making sense.

Check out their guideshttps://medium.com/media/507b9cf74681f684d7497415cc4561cc/href

Indeed, it looks complicated. But the core of the algorithm is summarized in a few lines.

https://medium.com/media/e0fa091ef4348e4c94248a7c7dacf1b0/href

Everything else is just data management and side-hustle.

If we’ll conduct benchmarking for different data length, we’ll see that we match torch.nn.functional.softmax performance (which is highly optimized kernel!) and dramatically outperform naive torch implementation.

Benchmarking | Image by the author

You may find the full code for the kernel and benchmark in the following github file.

kernels/src/softmax/kernel.py at main · alexdremov/kernels

Pros:

Potentially huge speed-ups by fusing ops and optimizing memory access patterns.More control than torch.compile.Easy to write efficient code (we matched torch implementation!)Easy to write inefficient code (if you don’t know what you’re doing).

Cons:

You’re now the kernel developer, which means debugging if something goes sideways. Which is tough. Really.If you go further with custom backward passes, you might need a second coffee… or more. That’s because torch cannot use autograd for triton. So you will need to define backward yourself.

Pure CUDA (a.k.a. Going Hardcore)

Sometimes even Triton won’t cut it, or you just enjoy living on the edge. In that case, you can write a custom CUDA kernel in C++, compile it, and tie it into PyTorch via a custom extension. Projects like [this fused CUDA softmax reference] show how people build specialized kernels for maximum speed.

Softmax in Custom CUDA

You’ll typically have a setup.py that compiles a .cu or .cpp file and exposes a Python function as an extension.

GitHub – fattorib/CudaSoftmax: Softmax CUDA kernel 🙂

I will not provide the code for this method in this post, so this fact speaks for itself. This approach is quite complicated, requires good justification, and usually the last thing you should try doing.

It’s very easy to write inefficient, buggy, unsafe code.

Pros:

Maximum control. “If you want something done right, do it yourself.”Potential for the fastest possible kernel if well-optimized.

Cons:

Requires deep CUDA understanding.Memory management, block sizes, shared memory — those are hard!Maintenance overhead can be extremely high.

Conclusion

When it comes to speeding up PyTorch operations, you can choose from progressively more intricate methods:

torch.compile: Minimal code changes needed.Triton Kernel: More control over kernel behaviour, still quite easy coding.Pure CUDA: Maximum optimisation potential, but a lot higher complexity.

If you’re looking for the simplest improvement, start with torch.compile. If that’s insufficient, explore Triton. For advanced users, writing a custom CUDA kernel can yield further gains, though it demands deep GPU programming skills.

Subscribe to not miss posts about other optimisations and useful deep learning techniques!

References

Compiling the optimizer with torch.compile (PyTorch Docs)How should I use torch.compile properly? (PyTorch discussion)Using User-Defined Triton Kernels with torch.compile (PyTorch Docs)Torch.compile with custom Triton kernel (PyTorch discussion)GitHub: fattorib/CudaSoftmax

Choose the path that fits your project’s needs and your comfort level. Good luck optimizing!

The story was originally published at alexdremov.me

Speed Up PyTorch With Custom Kernels. But It Gets Progressively Darker was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.

Author:

Leave a Comment

You must be logged in to post a comment.