[P] I replaced Dot-Product Attention with distance-based RBF-Attention (so you don't have to...)
I recently asked myself what would happen if we replaced the standard dot-product in self-attention with a different distance metric, e.g. an rbf-kernel?
Standard dot-product attention has this quirk where a key vector can "bully" the softmax simply by having a massive magnitude. A random key that points in roughly the right direction but is huge will easily outscore a perfectly aligned but shorter key. Distance-based (RBF) attention could fix this. To get a high attention score, Q and K actually have to be close to each other in high-dimensional space. You can't cheat by just being large.
I thought this would be a quick 10-minute PyTorch experiment, but it was a reminder on how deeply the dot-product is hardcoded into the entire ML stack. Changing one core operation triggered a massive domino effect. :D
Here is the chain of things that broke, and how I had to fix them just to get a model to train reasonably well:
Instant OOMs: If you naively compute pairwise Euclidean distances using torch.cdist (without the matmul-trick), it materializes the full N x N distance matrix in memory. You will instantly OOM on any decent context length. Luckily with a little high-school algebra, you can expand the squared distance formula and get -||Q||2 - ||K||2 + 2(Q · K). Since the softmax is shift-invariant, the query norm is just a constant to that specific query and we can throw it in the trash. You're left with 2(Q · K) - ||K||2. Now, it turns out that RBF attention is mathematically just standard dot-product attention with a built-in, squared-L2 penalty on the keys.
Custom kernel: Even with that math trick, PyTorch's native scaled dot-product attention (SDPA) doesn't let you arbitrarily subtract a key-norm penalty inside its fused loop. You can hack it by padding your tensors with dummy dimensions, but that's clunky and moves unnecessary memory, so I gave up and wrote a custom Triton kernel. It mirrors the tiling logic of FlashAttention but computes the squared L2 norms of the keys on the fly in SRAM, subtracting them right before the softmax and the thing only uses linear memory.
Attention Sinks: So it turns out, that sometimes Models actually need magnitude bullying to create Attention Sinks. They scale up useless tokens (like <BOS>) so queries have a place to dump their attention mass when they don't care about the context. But in distance math, a massive vector means infinite distance and therefore zero probability and to be a universal sink in Euclidean space, a key must sit exactly at the origin, so I had to resolve that with register tokens. I prepended learnable dummy-vectors to the sequence and initialized them to zero. Whenever a query doesn't find anything useful, it naturally falls back to the register-tokens, safely dumping its attention into the blank registers without corrupting actual tokens.
RoPE makes zero sense anymore: Modern models use RoPE, which explicitly rotates vectors. This is mathematically elegant for dot-products (relative angles), but applying rotations to vectors before measuring their absolute spatial Euclidean distance completely destroys the geometry and makes no sense... So I ripped out RoPE entirely and swapped it for SuSiE (Subspace Sinusoidal Embeddings). It just adds cached unrotated sinusoids directly to the vectors. Because it's additive, positional distance explicitly acts as a penalty in Euclidean space.
Did it actually work? Hmm, kind of... I trained a tiny causal model on the miniscule TinyStories-dataset. It converged slightly faster than a standard SDPA baseline. Potentially that had to do with the distance math and the pre-softmax logits capped at 0, preventing early gradient spikes, but who knows...?
Is it going to replace FlashAttention in big models anytime soon? Nope. GPUs and the whole ML-stack are super optimized for pure dot-products, and the industry solved magnitude bullying with QK-Norm instead. But it was a fun engineering exercise in breaking and rebuilding a part of the ML stack.
I went through all of it so you don't have to. Here is the code:
Blog-Post: https://pisoni.ai/posts/scaled-rbf-attention/
Repo: https://github.com/4rtemi5/rbf_attention
[link] [comments]
Want to read more?
Check out the full article on the original site