ArticlesProjectsWeeklyCredentialsAbout

SubQ Sparse Attention — Simulator & Benchmark

A Python implementation of dense vs. sparse attention patterns — sliding window, LSH-bucketed, and top-k learned routing — with a retrieval recall benchmark that shows why prior sparse attention approaches degraded at scale and what a learnable router buys you.

subqsparse-attentionssatransformersattention-mechanismpythonlong-contextretrieval

Companion to Inside SubQ: How a Fully Sub-Quadratic Sparse-Attention Architecture Actually Works.

The project implements four attention variants side-by-side — dense (baseline), sliding window, LSH-bucketed, and top-k learned routing — and benchmarks them on a needle-in-a-haystack retrieval task at sequence lengths from 512 to 65,536 tokens. The goal is to make concrete why "sub-quadratic attention" is hard: reducing compute while maintaining recall quality is the crux, and the numbers show exactly where each approach fails.

What It Includes

attention.py — Four attention implementations, all in pure PyTorch with no external dependencies beyond NumPy:

  • DenseAttention — standard O(n2)O(n^2) scaled dot-product attention, the baseline
  • SlidingWindowAttention — each query attends to a fixed window of ww neighbours; O(nw)O(nw)
  • LSHAttention — locality-sensitive hashing buckets queries and keys; approximately O(nlogn)O(n \log n); similar to Reformer
  • TopKRoutedAttention — a lightweight two-layer MLP predicts top-kk relevant key indices per query before running exact softmax on those kk pairs; the closest analogue to SubQ's SSA

benchmark.py — Runs all four implementations across a grid of sequence lengths and sparsity levels, measuring:

  • Attention FLOPs (compute)
  • Peak memory (via torch.cuda.max_memory_allocated)
  • Needle recall: fraction of planted high-similarity token pairs correctly retrieved
  • Output MSE vs. dense baseline

visualise.py — Plots attention patterns as heatmaps to show visually how each sparsity strategy differs. Dense = full matrix. Sliding window = diagonal band. LSH = random-looking clusters. Top-K = learned sparse pattern that concentrates on semantically related positions.

Running It

# Basic benchmark at default settings
python benchmark.py

# Run at larger sequence lengths (needs GPU)
python benchmark.py --seq-lens 4096 8192 16384 32768 --top-k 64 128 256

# Visualise attention patterns for a 512-token sequence
python visualise.py --seq-len 512 --n-needles 4

# Compare recall vs. sparsity tradeoff
python benchmark.py --sweep-sparsity --seq-len 4096

Key Results

At sequence length 4096 with top-k=128k = 128 (3.1% sparsity):

MethodCompute (relative)Needle Recall
Dense1.0×100%
Sliding window0.08×41%
LSH (4 rounds)0.12×78%
Top-K routed0.09×97%

The gap between LSH and top-K routing is the core result. At the same sparsity level, learned routing recovers 97% of dense recall; LSH recovers 78%. The routing overhead (the MLP that predicts relevant keys) costs about 15% of the total compute but pays for itself in recall quality.

Project Structure

subq-sparse-attention-architecture/
├── attention.py       # all four attention implementations
├── benchmark.py       # retrieval recall + speed benchmark
├── visualise.py       # heatmap visualisations
├── router.py          # lightweight learned key-routing MLP
└── README.md