Revisiting Variance Reduction

In Policy Gradients for LLM Reinforcement Learning

Yifan Zhang, Quanquan Gu
December 27, 2025
Reinforcement Learning LLM RL Variance Reduction Optimization

The Variance Bottleneck

Reinforcement Learning (RL) has become central to aligning Large Language Models (LLMs) with complex reasoning tasks. While recent advancements have refined KL-regularized policy gradient objectives, the high variance inherent in gradient estimators remains a persistent bottleneck. This often necessitates prohibitively large batch sizes or conservative update steps, impeding sample efficiency.

The Contribution: This work revisits the principles of Stochastic Variance-Reduced Policy Gradient (SVRPG) and adapts them to the large-scale LLM Reinforcement Learning problem. We propose a variance-reduced estimator utilizing periodic policy snapshots to construct a control variate specifically for the KL-regularized objective.

We reconcile standard LLM alignment with SVRPG (Papini et al., 2018), demonstrating that a periodic snapshot mechanism can be efficiently integrated to stabilize learning without incurring prohibitive memory overheads.

Theoretical Framework

To apply stochastic variance reduction effectively, we require a rigorous definition of the gradient for the KL-regularized objective under off-policy sampling: $$J(\theta) = \mathbb{E}[R] - \beta \text{KL}$$

Unnormalized Divergences

In reasoning tasks, we often deal with unnormalized objectives. We derive exact gradients for both Unnormalized Forward KL (UFKL) and Unnormalized Reverse KL (URKL). The Unnormalized Forward KL is defined as:

$$ \text{UKL}(\pi_{\mathrm{old}}\|\pi_\theta) = \int_x \pi_{\mathrm{old}}(x)\log\frac{\pi_{\mathrm{old}}(x)}{\pi_\theta(x)}\,dx + \int_x \bigl(\pi_\theta(x)-\pi_{\mathrm{old}}(x)\bigr)\,dx $$

Differentiable Surrogate Losses

Crucially for implementation in frameworks like PyTorch, we derive differentiable surrogate losses $\mathcal{L}(\theta)$ such that $\nabla_\theta \mathcal{L}(\theta)$ is an unbiased estimator of $-\nabla_\theta J(\theta)$. For the Unnormalized Reverse KL (URKL), theoretically equivalent to the $k_3$ estimator used in methods like GRPO, the surrogate loss is:

$$ \mathcal{L}_{\mathrm{URKL}}(\theta) = Z_{\mathrm{old}} \mathbb{E}_{x\sim\tilde{\pi}_{\mathrm{old}}}\left[ -w(x)R(x) + \beta\bigl(w(x)\log w(x) - w(x)\bigr) \right] $$

Where $w(x) = \frac{\pi_\theta(x)}{\pi_{\mathrm{old}}(x)}$ is the importance weight.

Stochastic Variance Reduction

Standard estimators (like REINFORCE) suffer from high variance in reasoning tasks where rewards are sparse. We mitigate this using a control variate technique.

$$ \mathbf{v}_k = \mathbf{g}(\tau; \theta_k) - \rho_k(\tau) \mathbf{g}(\tau; \tilde{\theta}) + \mu $$

Snapshot Policy ($\tilde{\theta}$)

A "lagging anchor" policy. As $\theta_k$ converges toward $\tilde{\theta}$, the stochastic terms cancel out.

Exact Gradient ($\mu$)

The gradient of the snapshot policy, approximated via a large anchor batch ($B_L$) to serve as the stable baseline.

Importance Sampling ($\rho$)

Weights $\rho_k(\tau)$ correct the distribution mismatch between the current policy and the snapshot.

Algorithm

We propose an interleaved update schedule. To handle the instability of importance weights in high-dimensional token spaces, we apply a Dual-Clip strategy.

1. Snapshot Phase Sample Large Batch $\mathcal{D}_L$ Compute Anchor $\hat{\mu}$ Using $\pi_{\tilde{\theta}}$ 2. Inner Loop (Variance Reduced) Sample Small Batch $\mathcal{D}_S$ from $\pi_{\theta_t}$ Compute Direction $\mathbf{v}_t$ Correction term using $\rho(\tau)$ and $\hat{\mu}$ (Dual-Clip applied to $\rho$) Figure 1: The SVRPG Algorithm Logic. The snapshot phase computes a stable baseline ($\hat{\mu}$) on a large batch, which is used to correct the variance of the updates in the inner loop.

Update Logic

  1. Snapshot Phase (Anchor): Sample a "Large Batch" $\mathcal{D}_L$ from snapshot $\pi_{\tilde{\theta}}$ and compute $\hat{\mu}$.
  2. Inner Loop: Sample a "Small Batch" $\mathcal{D}_S$ from $\pi_{\theta_t}$, compute Importance Sampling weights $\rho(\tau)$ with Dual-Clip truncation, and construct the variance-reduced gradient direction:
$$ \mathbf{v}_t = \frac{1}{B_S} \sum_{\tau \in \mathcal{D}_S} \left[\nabla \mathcal{L}_{\text{Reg}}(\tau; \theta_t) - \rho(\tau) \nabla \mathcal{L}_{\text{Reg}}(\tau; \tilde{\theta}) \right] + \hat{\mu} $$

This approach decouples the variance of the reasoning reward from the model's structural updates, isolating the contribution of the recent parameter shift.

Normalized KL Formulations

For completeness, the repository also supports normalized KL objectives, suitable for standard RLHF workflows.

Regularization Surrogate Loss (sampling $x\sim \pi_{\mathrm{old}}$)
Forward KL $\mathbb{E}\left[ -w(x) R(x) - \beta \log \pi_\theta(x) \right]$
Reverse KL $\mathbb{E}\left[ w(x)\,(-R(x) + \beta \log w(x)) \right]$

Citation

@article{zhang2025revisiting,
  title={Revisiting Variance Reduction in Policy Gradients for LLM Reinforcement Learning},
  author={Zhang, Yifan and Gu, Quanquan},
  year = {2025},
  month = {Dec},
  journal = {Github},
  url = {https://yifanzhang-pro.github.io/Revisiting-SVRPG-LLM-RL}
}