ArticlesProjectsWeeklyCredentialsAbout

SubQ Sparse Attention — Simulator & Benchmark

A Python implementation of dense vs. sparse attention patterns — sliding window, LSH-bucketed, and top-k learned routing — with a retrieval recall benchmark that shows why prior sparse attention approaches degraded at scale and what a learnable router buys you.

subqsparse-attentionssatransformersattention-mechanismpythonlong-contextretrieval

Companion to Inside SubQ: How a Fully Sub-Quadratic Sparse-Attention Architecture Actually Works.

The project implements four attention variants side-by-side — dense (baseline), sliding window, LSH-bucketed, and top-k learned routing — and benchmarks them on a needle-in-a-haystack retrieval task at sequence lengths from 512 to 65,536 tokens. The goal is to make concrete why "sub-quadratic attention" is hard: reducing compute while maintaining recall quality is the crux, and the numbers show exactly where each approach fails.

What It Includes

attention.py — Four attention implementations, all in pure PyTorch with no external dependencies beyond NumPy:

  • DenseAttention — standard O(n2)O(n^2) scaled dot-product attention, the baseline
  • SlidingWindowAttention — each query attends to a fixed window of ww neighbours; O(nw)O(nw)
  • LSHAttention — locality-sensitive hashing buckets queries and keys; approximately O(nlogn)O(n \log n); similar to Reformer
  • TopKRoutedAttention — a lightweight two-layer MLP predicts top-kk relevant key indices per query before running exact softmax on those kk pairs; the closest analogue to SubQ's SSA

benchmark.py — Runs all four implementations across a grid of sequence lengths and sparsity levels, measuring:

  • Attention FLOPs (compute)
  • Peak memory (via torch.cuda.max_memory_allocated)
  • Needle recall: fraction of planted high-similarity token pairs correctly retrieved
  • Output MSE vs. dense baseline

visualise.py — Plots attention patterns as heatmaps to show visually how each sparsity strategy differs. Dense = full matrix. Sliding window = diagonal band. LSH = random-looking clusters. Top-K = learned sparse pattern that concentrates on semantically related positions.

Running It

# Basic benchmark at default settings
python benchmark.py

# Run at larger sequence lengths (needs GPU)
python benchmark.py --seq-lens 4096 8192 16384 32768 --top-k 64 128 256

# Visualise attention patterns for a 512-token sequence
python visualise.py --seq-len 512 --n-needles 4

# Compare recall vs. sparsity tradeoff
python benchmark.py --sweep-sparsity --seq-len 4096

Key Results

At sequence length 4096 with top-k=128k = 128 (3.1% sparsity):

MethodCompute (relative)Needle Recall
Dense1.0×100%
Sliding window0.08×41%
LSH (4 rounds)0.12×78%
Top-K routed0.09×97%

The gap between LSH and top-K routing is the core result. At the same sparsity level, learned routing recovers 97% of dense recall; LSH recovers 78%. The routing overhead (the MLP that predicts relevant keys) costs about 15% of the total compute but pays for itself in recall quality.

Project Structure

subq-sparse-attention-architecture/
├── attention.py       # all four attention implementations
├── benchmark.py       # retrieval recall + speed benchmark
├── visualise.py       # heatmap visualisations
├── router.py          # lightweight learned key-routing MLP
└── README.md
Source code
# SubQ Sparse Attention — Simulator & Benchmark

Companion code for the article [Inside SubQ: How a Fully Sub-Quadratic Sparse-Attention Architecture Actually Works](https://rishisharma.in/articles/subq-sparse-attention-architecture).

Implements four attention variants and benchmarks them on a needle-in-a-haystack retrieval task at sequence lengths up to 65,536 tokens.

## Files

| File | Description |
|------|-------------|
| `attention.py` | Four attention implementations: `DenseAttention`, `SlidingWindowAttention`, `LSHAttention`, `TopKRoutedAttention` |
| `benchmark.py` | Retrieval recall + compute benchmark across sequence lengths and sparsity levels |
| `visualise.py` | 2×2 heatmap of attention patterns — shows visually what each strategy attends to |

## Requirements

```
torch>=2.0
numpy
matplotlib   # only for visualise.py
```

No GPU required; all implementations run on CPU. GPU accelerates the benchmark significantly.

## Quick Start

```bash
# Run default benchmark (seq lengths 512–4096)
python benchmark.py

# Sparsity vs. recall sweep at a fixed sequence length
python benchmark.py --sweep-sparsity --seq-lens 4096

# Attention pattern heatmaps
python visualise.py --seq-len 128 --top-k 16

# Save heatmap to file
python visualise.py --seq-len 256 --save patterns.png

# Larger benchmark (GPU recommended)
python benchmark.py --seq-lens 4096 8192 16384 --top-k 64 128 256
```

## What You'll See

Running `python benchmark.py` prints a table like:

```
Method                 SeqLen  Sparsity  NeedleRecall   OutputMSE      ms
----------------------------------------------------------------------
Dense                     512    100.0%        100.0%     0.00000     2.1
SlidingWindow-64          512     12.5%         38.2%     0.21450     0.8
SlidingWindow-128         512     25.0%         62.4%     0.11200     1.1
LSH-4rounds               512     ~6.0%         74.3%     0.08910     3.4
TopK-32                   512      6.3%         96.1%     0.00820     1.6
TopK-64                   512     12.5%         99.4%     0.00210     2.0
TopK-128                  512     25.0%        100.0%     0.00004     2.4
```

The key result: at the same 6% sparsity level, top-k routing recalls 96% of needles vs. 74% for LSH. The routing overhead (the linear projection) recovers most of the quality gap between sparse and dense attention.

## The Core Result

The quality difference between `LSHAttention` and `TopKRoutedAttention` at matched sparsity is the central observation. Both attend to roughly the same number of tokens per query. LSH distributes them by hash collision (random-ish). Top-K routing concentrates them on the tokens that actually score high for this query.

This is what Subquadratic is claiming to have scaled to frontier quality: learned routing that reliably identifies the $k \ll n$ tokens that matter for each query, making the architecture genuinely sub-quadratic without fixed pattern constraints.

## Caveats

The implementations here are for understanding, not production:

- `LSHAttention` uses Python loops over buckets — correct but slow. A real implementation sorts by bucket and uses blocked GEMM.
- `TopKRoutedAttention` still runs the routing scores in $O(n^2)$ to show the structure clearly. A production system uses a cheaper proxy (product quantisation, hierarchical routing) for the initial selection pass.
- The router is randomly initialised, not trained. A trained router achieves significantly higher recall at lower $k$.

The purpose of this code is to make the mechanics concrete and measurable, not to replicate SubQ's production system.