u/zk4x

Yet another tensor graph compiler

Hello, I've spend some last few years learning ML and this is the result - an ML library that spans the whole pytorch stack - from backends (fully complete CUDA, OpenCL, WGPU + partially implemented PTX, HIP and SPIRV) all the way to neural network modules in zyx-nn.

For lovers of python, python bindings are also available and don't differ from rust. Python wheel is 4 MB. Supports all pytorch ops across more hardware than pytorch. The current drawback is speed, but on small models, it should not be that bad. On my 2060, MNIST example takes 0.9ms per step, while compiled torch is 0.7ms.

Zyx was inspired by tinygrad to use minimal opset. Therefore zyx uses only these nodes in the graph: leaf, unary, binary, cast, expand, permute, pad, reshape, reduce.

Most of the design decisions have finally stabilized. I've spend years trying different lowering approaches and finally got to the point where it seems to click.

So what is the secret sauce? Zyx is fully DYNAMIC. The graph is build dynamically and supports all branching, but execution is lazy. Currently the graph is chopped into kernels by heuristics and autotune searches over possible optimizations on each kernel. The novel part here is that all optimization passes are both optional and can be stacked in any order. There is no complex lowering. Optimizations include register tiling, hierarchical local memory reduce, LICM, CSE, DCE, etc. Some like tensor cores and load/store vectorization are partially implemented, while others like local memory tiling are not started yet.

Typical compilers like XLA and TVM and tinygrad instead use multi-stage lowering pipeline where optimizations have to be applied in certain order to be valid.

Zyx has probably the smallest (by number of ops) unified IR you can see in this space. It makes pattern matching more complex, but is fully expressive and complexity of optimizations is very low, while their orthogonality produces interesting results.

Why not use zyx

  1. performance - Currently mainly due to unfinished tensor core support and local memory tiling. The other part is non-ideal graph partitioning.
  2. if there is some other reason, please tell me. I believe zyx can improve faster than others due to stackability of ops and small codebase to be the most ergonomic, support the most hardware and be most correct in terms of numerical stability guarantees across dtypes and kernels.

Next steps

Performance, performance, performance. Other than adding those few obligatory optimization passes and improving autotune with better cost function (btw. with the cost function and fast IR, zyx can search through tens of thousands of kernel variants per second per core - yes, autotune is multithreaded), the main next step is a rewrite of the graph splitting part. There are no user API changes and the rewrite is partially finished.

The new kernelizer will generate a hypergraph of both kernel fusion strategies AND device allocation strategies. Zyx already has automatic parallel pipelining with heuristics, but after the rewrite, it'll have proper search for this.

The part that I am most excited about is since graph in zyx is made up of so simple nodes, I can write pattern matcher that will map parts of the graph to kernels in existing stacks - cuDNN, oneDNN, NPU kernels, etc. With hypergraph search, zyx will be able to select the best path among these and custom zyx kernels will fill the blanks. Expect this to take a few months before ready.

Minimalism

Zyx is not only minimal in it's graph, but also dependencies. All of zyx + dependencies is <50k LOC of pure Rust. I also wrote simple onnx bindings. The goal is to make zyx both the tiny runtime that runs all models on every hardware correctly, albeit with varying performance, as well as making zyx the nicest library to use for the following reasons:

  1. I carefully crafted every user facing function to be as intuitive as possible
  2. Zyx will not wish you good luck like pytorch if you mutate tensors before backprop, instead zyx tensors are immutable
  3. Zyx will not run out of memory if you fill your VRAM, instead zyx will fallback to RAM
  4. Zyx will not take 30s to import in python, instead it takes 10ms
  5. Zyx won't tell you that some ops are unsupported for some dtypes, instead all dtypes are supported with all ops, except for some ops that require float dtypes for mathematical meaningfulness
  6. Zyx won't complain that the graph is too large, too dynamic or too complex, a single node is 16 bytes, so huge graphs run just fine
  7. Zyx will fuse your kernels and won't run out of recompilation passes, recompilation happens at kernel level, not graph level
  8. Zyx will fuse ALL graphs, no matter how complex
  9. And perhaps a bit slowly, but zyx will keep running your code correctly even on the oldest of hardware (e.g. GT 710, RX 480, or even CPUs without AVX)

I wish you pleasant experimentation.

https://github.com/zk4x/zyx

https://crates.io/crates/zyx

https://crates.io/crates/zyx-nn

https://crates.io/crates/zyx-optim

https://pypi.org/project/zyx-py

reddit.com
u/zk4x — 4 days ago
▲ 18 r/rust

zyx v0.15.3 released - ML from scratch in Rust

https://github.com/zk4x/zyx

https://crates.io/crates/zyx

https://crates.io/crates/zyx-nn

https://crates.io/crates/zyx-optim

https://pypi.org/project/zyx-py

Hello, I've spend some last few years learning Rust and ML and this is the result - an ML library, fully written in Rust, that spans the whole pytorch stack - from backends (fully complete CUDA, OpenCL, WGPU + partially implemented PTX, HIP and SPIRV) all the way to neural network modules in zyx-nn.

For lovers of python, python bindings are also available and don't differ from rust. I accomplished this by removing all generics from user facing API. Python wheel is 4 MB. Supports all pytorch ops across more hardware than pytorch. The current drawback is speed, but on small models, it should not be that bad. On my 2060, MNIST example takes 0.9ms per step, while compiled torch is 0.7ms.

Zyx was inspired by tinygrad to use minimal opset. Therefore zyx uses only these nodes in the graph: leaf, unary, binary, cast, expand, permute, pad, reshape, reduce. By stacking these ops, zyx supports ALL of pytorch.

I am writing this post, because as of now, zyx' API has become mature. I haven't had the need to change it in about a year. All the remaining work will be on getting zyx up to speed. But on this internal front, most of the design decisions have also been stable. I've spend years trying different approaches and finally got to the point where it seems to click.

The secret sauce

So what is the secret sauce? Zyx is fully DYNAMIC. The graph is build dynamically and supports all branching, but execution is lazy. Currently the graph is chopped into kernels by heuristics and autotune searches over possible optimizations on each kernel. The novel part here is that all optimization passes are both optional and can be stacked in any order. There is no complex lowering. Optimizations include register tiling, hierarchical local memory reduce, LICM, CSE, DCE, etc. Some like tensor cores and load/store vectorization are partially implemented, while others like local memory tiling are not started yet.

Zyx has probably the smallest (by number of ops) unified IR you can see in this space. It makes pattern matching more complex, but is fully expressive and complexity of optimizations is very low, while their orthogonality produces interesting results.

Zyx has autograd and the standard machinery you can expect from any other ML library. Especially the op support is possibly the greatest among rust libraries.

Why not use zyx

  1. performance - Currently mainly due to unfinished tensor core support and local memory tiling. The other part is non-ideal graph partitioning.
  2. if there is some other reason, please tell me. I believe zyx can improve faster than others due to stackability of ops and small codebase to be the most ergonomic, support the most hardware and be most correct in terms of numerical stability guarantees across dtypes and kernels.

Next steps

Performance, performance, performance. Other than adding those few obligatory optimization passes and improving autotune with better cost function (btw. with the cost function and fast IR, zyx can search through tens of thousands of kernel variants per second per core - yes, autotune is multithreaded), the main next step is a rewrite of the graph splitting part. There are no user API changes and the rewrite is partially finished.

The new kernelizer will generate a hypergraph of both kernel fusion strategies AND device allocation strategies. Zyx already has automatic parallel pipelining with heuristics, but after the rewrite, it'll have proper search for this.

The part that I am most excited about is since graph in zyx is made up of so simple nodes, I can write pattern matcher that will map parts of the graph to kernels in existing stacks - cuDNN, oneDNN, NPU kernels, etc. With hypergraph search, zyx will be able to select the best path among these and custom zyx kernel will fill the blanks. Expect this to take a few months before ready.

Minimalism

Zyx is not only minimal in it's graph, but also dependencies. All of zyx + dependencies is <50k LOC of pure Rust. I also wrote simple onnx bindings. The goal is to make zyx the tiny runtime that runs all models on every hardware correctly, albeit with varying performance.

Conclusion

Whether you are a hobbyist or a compiler enthusiast, I wish you give zyx a chance. Try rust API, try python bindings. I carefully crafted every user facing function to be as intuitive as possible. Zyx will not wish you good luck like pytorch if you mutate tensors before backprop. Zyx will not run out of memory if you fill your VRAM. Zyx will not take 30s to import in python (it takes 10ms). Zyx won't tell you that some ops are unsupported for some dtypes. Zyx won't complain that the graph is too large, too dynamic or too complex. Zyx will fuse your kernels and won't run out of recompilation passes. Zyx will fuse ALL graphs, no matter how complex. And perhaps a bit slowly, but zyx will keep running your code correctly even on the oldest of hardware.

With that, I am leaving you and wishing you pleasant experimentation.

u/zk4x — 6 days ago