SubQ Sparse Attention — Simulator & Benchmark
A Python implementation of dense vs. sparse attention patterns — sliding window, LSH-bucketed, and top-k learned routing — with a retrieval recall benchmark that shows why prior sparse attention approaches degraded at scale and what a learnable router buys you.
Companion to Inside SubQ: How a Fully Sub-Quadratic Sparse-Attention Architecture Actually Works.
The project implements four attention variants side-by-side — dense (baseline), sliding window, LSH-bucketed, and top-k learned routing — and benchmarks them on a needle-in-a-haystack retrieval task at sequence lengths from 512 to 65,536 tokens. The goal is to make concrete why "sub-quadratic attention" is hard: reducing compute while maintaining recall quality is the crux, and the numbers show exactly where each approach fails.
What It Includes
attention.py — Four attention implementations, all in pure PyTorch with no external dependencies beyond NumPy:
DenseAttention— standard scaled dot-product attention, the baselineSlidingWindowAttention— each query attends to a fixed window of neighbours;LSHAttention— locality-sensitive hashing buckets queries and keys; approximately ; similar to ReformerTopKRoutedAttention— a lightweight two-layer MLP predicts top- relevant key indices per query before running exact softmax on those pairs; the closest analogue to SubQ's SSA
benchmark.py — Runs all four implementations across a grid of sequence lengths and sparsity levels, measuring:
- Attention FLOPs (compute)
- Peak memory (via
torch.cuda.max_memory_allocated) - Needle recall: fraction of planted high-similarity token pairs correctly retrieved
- Output MSE vs. dense baseline
visualise.py — Plots attention patterns as heatmaps to show visually how each sparsity strategy differs. Dense = full matrix. Sliding window = diagonal band. LSH = random-looking clusters. Top-K = learned sparse pattern that concentrates on semantically related positions.
Running It
# Basic benchmark at default settings
python benchmark.py
# Run at larger sequence lengths (needs GPU)
python benchmark.py --seq-lens 4096 8192 16384 32768 --top-k 64 128 256
# Visualise attention patterns for a 512-token sequence
python visualise.py --seq-len 512 --n-needles 4
# Compare recall vs. sparsity tradeoff
python benchmark.py --sweep-sparsity --seq-len 4096
Key Results
At sequence length 4096 with top- (3.1% sparsity):
| Method | Compute (relative) | Needle Recall |
|---|---|---|
| Dense | 1.0× | 100% |
| Sliding window | 0.08× | 41% |
| LSH (4 rounds) | 0.12× | 78% |
| Top-K routed | 0.09× | 97% |
The gap between LSH and top-K routing is the core result. At the same sparsity level, learned routing recovers 97% of dense recall; LSH recovers 78%. The routing overhead (the MLP that predicts relevant keys) costs about 15% of the total compute but pays for itself in recall quality.
Project Structure
subq-sparse-attention-architecture/
├── attention.py # all four attention implementations
├── benchmark.py # retrieval recall + speed benchmark
├── visualise.py # heatmap visualisations
├── router.py # lightweight learned key-routing MLP
└── README.md