Higher-order Linear Attention

A causal, streaming attention mechanism that realizes higher‑order interactions via compact prefix statistics, with exact masked identities and associative scans enabling parallel training that matches recurrent computations.

Yifan Zhang  ·  Zhen Qin  ·  Quanquan Gu

Princeton University  •  UCLA  •  October 30, 2025

Streaming Causal Mask Second‑order Third‑order Associative Scans

Abstract

The quadratic cost of scaled dot‑product attention limits long‑context inference. Linear‑time attentions and state space models provide scalable alternatives but are often restricted to first‑order or kernel‑based approximations. We introduce Higher-order Linear Attention (HLA), a causal, streaming mechanism that realizes higher‑order interactions via compact prefix sufficient statistics. In the second‑order case, HLA maintains an $O(d^2)$ state per head and computes per‑token outputs in $O(d^2)$ time without materializing any $n{\times}n$ matrices. We derive exact masked streaming identities (strict autoregressive causality), and a chunk‑parallel training scheme based on associative scans that exactly reproduces serial activations. We further outline masked third‑order HLA and decay‑aware variants. HLA combines attention‑style data‑dependent mixing with the efficiency of modern recurrent architectures.

Project Repository Read the Paper

Second‑Order HLA: Streaming Mechanism

HLA: prefix summaries drive higher‑order attention without n×n matrices. Figure: Prefix summaries $S^K$, $C^{QV}$, and $m^Q$ (plus masked cross‑summaries) enable streaming updates.

Prefix summaries (per head)

S_t^K   = \sum_{i \le t} \bm{k}_i \bm{k}_i^\top     \in \mathbb{R}^{d \times d}
C_t^{QV}= \sum_{i \le t} \bm{q}_i \bm{v}_i^\top     \in \mathbb{R}^{d \times d_v}
M_t^Q   = \sum_{i \le t} \bm{q}_i                   \in \mathbb{R}^d

Default (unnormalized) output

$$\mathbf{o}_t \;=\; \bm{q}_t^\top S_t^K C_t^{QV}.$$

Auxiliary normalized variant

$$\mathbf{o}_t \;=\; \frac{\bm{q}_t^\top S_t^K C_t^{QV}}{\bm{q}_t^\top S_t^K m_t^Q + \varepsilon}.$$

When $S_t^K=\mathbf{I}$ the normalized form reduces to a linear‑attention kernel with $K(\bm{q}_t,\bm{q}_i)=\bm{q}_t^\top\bm{q}_i$. In general, $S_t^K=\sum_{i\le t}\bm{k}_i\bm{k}_i^\top$ induces a data‑adaptive second‑order metric on query space, strictly enriching first‑order mechanisms while retaining streaming updates.

Strict Causality via Masked Streaming Identities

Let $L$ be the binary causal mask (ones on and below the diagonal). To impose strict autoregressive causality, augment with cross‑summaries $$G_t=\sum_{i\le t}\left(\bm{k}_i\bm{k}_i^\top\right) C_{i-1}^{QV},\qquad H_t=\sum_{i\le t}\left(\bm{k}_i\bm{k}_i^\top\right) m_{i-1}^{Q}.$$

Masked identities (second‑order)

$$\mathbf{o}_t \;=\; \bm{q}_t^\top\left(S_t^K C_t^{QV} - G_t\right) \quad\text{(default, unnormalized)}$$ $$\mathbf{o}_t \;=\; \frac{\bm{q}_t^\top\left(S_t^K C_t^{QV}-G_t\right)} {\bm{q}_t^\top\left(S_t^K M_t^{Q}-H_t\right) + \varepsilon} \quad\text{(auxiliary normalized)}.$$

Online updates (per token)

S_t^K = S_{t-1}^K + \bm{k}_t \bm{k}_t^\top
C_t^{QV} = C_{t-1}^{QV} + \bm{q}_t \bm{v}_t^\top
M_t^Q = M_{t-1}^Q + \bm{q}_t
G_t = G_{t-1} + \bm{k}_t \big(\bm{k}_t^\top C_{t-1}^{QV}\big)
H_t = H_{t-1} + \bm{k}_t \big(\bm{k}_t^\top M_{t-1}^{Q}\big)

Decay (optional)

S_t^K = \gamma S_{t-1}^K + \bm{k}_t \bm{k}_t^\top     (similarly for C^{QV}, M^Q)
G_t   = \gamma G_{t-1} + \bm{k}_t(\bm{k}_t^\top C_{t-1}^{QV})
H_t   = \gamma H_{t-1} + \bm{k}_t(\bm{k}_t^\top M_{t-1}^{Q})
Decay controls spectral growth and encourages recency bias while preserving scan‑friendly associativity.

Associative Scans: Chunk‑Parallel Training

We define an associative operator over segment summaries so that a standard exclusive Blelloch scan produces per‑token prefix states; local inclusions recover the exact serial activations. For the masked second‑order case, use state $\mathcal{S}=(S,C,M,G,H)$ with concatenation

(S_A,C_A,M_A,G_A,H_A) ⊕ (S_B,C_B,M_B,G_B,H_B)
= (S_A+S_B, C_A+C_B, M_A+M_B,
   G_A+G_B + S_B C_A,
   H_A+H_B + S_B M_A)

This semidirect product is associative; decay multiplies left segments by an attenuation $\rho_B=\gamma^{\ell(B)}$ before forming cross‑terms. Reverse‑mode uses the adjoint operator $\oplus^\ast$ with checkpointing at tile boundaries, yielding gradients identical to those of the serial recurrence.

Reference pseudocode (masked, second‑order)

// Within-chunk exclusive scan (masked)
prefixes = scan_exclusive(segments, ⊕)   // O(log w) span
for each token t in parallel:
  // Inclusive state from local delta + prefix
  S_t = γ S_{t-1} + ΔS_t
  C_t = γ C_{t-1} + ΔC_t
  M_t = γ M_{t-1} + ΔM_t
  G_t = γ G_{t-1} + ΔS_t C_{t-1}
  H_t = γ H_{t-1} + ΔS_t M_{t-1}
  // Output (unnormalized by default)
  o_t = q_t^T (S_t C_t - G_t)
  // Optional normalization:
  // o_t /= q_t^T (S_t M_t - H_t) + ε

Third‑Order HLA

Unmasked third‑order uses $S_t^K=\sum \bm{k}\bm{k}^\top$, $S_t^Q=\sum \bm{q}\bm{q}^\top$, $P_t^{KV}=\sum \bm{k}\bm{v}^\top$, $m_t^{K}=\sum \bm{k}$ with default (unnormalized) output

$$\mathbf{o}_t^{(3)}=\bm{q}_t^\top S_t^K S_t^Q P_t^{KV}.$$

Strict causality adds cross‑summaries $G_t^{(1:3)}$, $H_t^{(1:3)}$; the masked numerator is

$$\text{num}_t^{(3)\mathrm{mask}}=\bm{q}_t^\top\Big(S_t^K S_t^Q P_t^{KV}-G^{(1)}_t-G^{(2)}_t-G^{(3)}_t\Big),$$

with an analogous denominator replacing $P^{KV}$ by $m^K$. Online updates keep $O(d^2{+}d\,d_v)$‑style costs via mat‑vec/outer‑product forms. An associative scan operator carries the same summaries and cross‑terms with decay‑aware scaling.

Streaming kernel sketch (masked, third‑order)

// Core step (per token t; decay γ)
S^K ← γ S^K + k_t k_t^T
S^Q ← γ S^Q + q_t q_t^T
P   ← γ P   + k_t v_t^T
M^K ← γ M^K + k_t

// Cross-summaries via matvecs
u1 = S^Q_prev k_t
G^(1) ← γ G^(1) + k_t (u1^T P_prev)
H^(1) ← γ H^(1) + k_t (u1^T M_prev)

a2 = S^K_prev q_t
G^(2) ← γ G^(2) + a2 (q_t^T P_prev)
H^(2) ← γ H^(2) + a2 (q_t^T M_prev)

a3 = S^K_prev (S^Q_prev k_t)
G^(3) ← γ G^(3) + a3 v_t^T
H^(3) ← γ H^(3) + a3

// Output (default)
y = S^K q_t;  z = S^Q y
o_t = z^T P - q_t^T(G^(1)+G^(2)+G^(3))
// Optional normalization by masked denominator

Implementation & Complexity

  • Drop‑in mixer: replace the standard attention sublayer; keep positional encodings and masking unchanged.
  • Per‑head state (2nd‑order): $S^K\in\mathbb{R}^{d\times d}$, $C^{QV}\in\mathbb{R}^{d\times d_v}$, $m^Q\in\mathbb{R}^d$ (plus masked $G\in\mathbb{R}^{d\times d_v}$, $H\in\mathbb{R}^d$).
  • Per‑token cost (2nd‑order): $O(d^2{+}d\,d_v)$ dominated by $S^K C^{QV}$ and $S^K M^Q$ (masked cross‑terms via $\bm{k}^\top X$ avoid cubic cost).
  • Multi‑query keys/values: share $\bm{K},\bm{V}$ across heads to store $S^K$ once per layer; memory $O(d^2{+}h\,d\,d_v)$.
  • Packed symmetric $S^K$: store the upper triangle to reduce bandwidth without changing the algebra.
  • Training throughput: within‑chunk exclusive Blelloch scans (span $O(\log w)$); inter‑chunk scans across $B_c$ chunks reuse the same operator.

Citation

If you find this work useful, please cite:

@article{zhang2025higher,
   title   = {Higher-order Linear Attention},
   author  = {Zhang, Yifan and Qin, Zhen and Gu, Quanquan},
   journal = {arXiv preprint arXiv:2510.27258},
   year    = {2025}
}