Authors: Jiacai Liu Zhuo Jiang Yuqian Fu**
***Co-First Authors. **
<aside> 📇
Table of Contents
</aside>
OPD minimizes the sequence-level reverse-KL objective
\mathbb{E}{x\sim D}\left[ D{\mathrm{KL}}\left(\pi_\theta(\cdot \mid x)\,\|\,q(\cdot \mid x)\right) \right], $$
where $\pi_\theta$ and $q$ denote the student and teacher policies, respectively. For a student-generated response $y \sim \pi_\theta(\cdot \mid x)$, let $c_t := (x, y_{<t})$ denote the current context, let $y_t$ denote the sampled token at step $t$, and define
$$ r_t := \log \frac{\pi_\theta(y_t \mid c_t)}{q(y_t \mid c_t)}. $$
This objective can also be viewed as a special entropy-regularized finite-horizon RL problem, so OPD admits a policy-gradient formulation. We therefore present several equivalent policy-gradient expressions below, together with their proofs.
Theorem 1 (Policy Gradient of OPD). For arbitrary parameter $\theta$, the gradient of the OPD objective can be written as
\mathbb{E}{x\sim D} \mathbb{E}{y\sim \pi_\theta(\cdot \mid x)} \left[ \sum_{t=1}^{T} \hat A_\theta(c_t,y_t)\, \nabla_\theta \log \pi_\theta(y_t \mid c_t) \right], \end{align} $$
where the per-token coefficient $\hat A_\theta(c_t,y_t)$ can be one of the following:
(i) $\sum_{t^{\prime}=1}^{T}{\log \frac{\pi_\theta\left( y_{t^{\prime}}|c_{t^{\prime}} \right)}{q\left( y_{t^{\prime}}|c_{t^{\prime}} \right)}}=\log \pi_\theta\left( y|x \right)-\log q\left( y|x \right)$: the log ratio of the full trajectory.
(ii) $\sum_{t^{\prime}=t}^{T}{\log \frac{\pi_\theta\left( y_{t^{\prime}}|c_{t^{\prime}} \right)}{q\left( y_{t^{\prime}}|c_{t^{\prime}} \right)}}$: log-ratio-to-go.
(iii) $\sum_{t^{\prime}=t}^{T}{\log \frac{\pi_\theta\left( y_{t^{\prime}}|c_{t^{\prime}} \right)}{q\left( y_{t^{\prime}}|c_{t^{\prime}} \right)}} - b(c_t)$: log-ratio-to-go with baseline.
(iv) $\log \frac{\pi_\theta\left( y_t|c_t \right)}{q\left( y_t|c_t \right)} + \sum_{t^{\prime}=t+1}^{T}{\mathrm{KL}\left( \pi_\theta\left( \cdot |c_{t^{\prime}} \right) ,q\left( \cdot |c_{t^{\prime}} \right) \right)}$.
(v) $\log \frac{\pi_\theta\left( y_t|c_t \right)}{q\left( y_t|c_t \right)} + \sum_{t^{\prime}=t+1}^{T}{\mathrm{KL}\left( \pi_\theta\left( \cdot |c_{t^{\prime}} \right) ,q\left( \cdot |c_{t^{\prime}} \right) \right)} - b(c_t)$: baselined version of (iv).
Proof for (i)
A direct computation yields that
$$ \small{ \begin{align*} \nabla_\theta J_{\text{OPD}}(\theta) &= \nabla_\theta \mathbb{E}{x\sim D} \mathbb{E}{y\sim \pi_\theta(\cdot \mid x)} \left[ \log \pi_\theta(y \mid x) - \log q(y \mid x) \right] \\ &= \mathbb{E}{x\sim D} \mathbb{E}{y\sim \pi_\theta(\cdot \mid x)} \left[ \big(\log \pi_\theta(y \mid x)-\log q(y \mid x)\big)\, \nabla_\theta \log \pi_\theta(y \mid x) + \nabla_\theta \log \pi_\theta(y \mid x) \right] \\ &= \mathbb{E}{x\sim D} \mathbb{E}{y\sim \pi_\theta(\cdot \mid x)} \left[ \big(\log \pi_\theta(y \mid x)-\log q(y \mid x)\big)\, \nabla_\theta \log \pi_\theta(y \mid x) \right] \\ &= \mathbb{E}{x\sim D} \mathbb{E}{y\sim \pi_\theta(\cdot \mid x)} \left[ \left( \sum_{t^{\prime}=1}^{T} \log \frac{\pi_\theta(y_{t^{\prime}} \mid c_{t^{\prime}})}{q(y_{t^{\prime}} \mid c_{t^{\prime}})} \right) \sum_{t=1}^{T} \nabla_\theta \log \pi_\theta(y_t \mid c_t) \right] \\ &= \mathbb{E}{x\sim D} \mathbb{E}{y\sim \pi_\theta(\cdot \mid x)} \left[ \sum_{t=1}^{T} \left( \sum_{t^{\prime}=1}^{T} \log \frac{\pi_\theta(y_{t^{\prime}} \mid c_{t^{\prime}})}{q(y_{t^{\prime}} \mid c_{t^{\prime}})} \right) \cdot \nabla_\theta \log \pi_\theta(y_t \mid c_t) \right]. \end{align*}} $$
Proof for (ii)
Consider two fixed time indexes $1 \le t^{\prime} < t \le T$. One can show that
$$ \small{ \begin{align*} \mathbb{E} \left[ \log \frac{\pi_\theta(y_{t^{\prime}} \mid c_{t^{\prime}})}{q(y_{t^{\prime}} \mid c_{t^{\prime}})} \cdot \nabla_\theta \log \pi_\theta(y_t \mid c_t) \right] &= \mathbb{E} \left[ \mathbb{E} \left[ \log \frac{\pi_\theta(y_{t^{\prime}} \mid c_{t^{\prime}})}{q(y_{t^{\prime}} \mid c_{t^{\prime}})} \cdot \nabla_\theta \log \pi_\theta(y_t \mid c_t) \,\middle|\, (c_{t^{\prime}}, y_{t^{\prime}}) \right] \right] \\ &= \mathbb{E} \left[ \log \frac{\pi_\theta(y_{t^{\prime}} \mid c_{t^{\prime}})}{q(y_{t^{\prime}} \mid c_{t^{\prime}})} \cdot \mathbb{E} \left[ \nabla_\theta \log \pi_\theta(y_t \mid c_t) \,\middle|\, (c_{t^{\prime}}, y_{t^{\prime}}) \right] \right] \\ &= \mathbb{E} \left[ \log \frac{\pi_\theta(y_{t^{\prime}} \mid c_{t^{\prime}})}{q(y_{t^{\prime}} \mid c_{t^{\prime}})} \cdot 0 \right] \\ &= 0, \end{align*}} $$
where the expectation is taken over $x\sim D, y\sim \pi_\theta$. This causality means that the policy gradient of OPD at each visited token $y_t$, i.e., $\nabla_\theta \log \pi_\theta(y_t \mid c_t)$, does not need to consider the impact on the KL divergence of previously encountered tokens. It only needs to consider the impact on the KL divergence of tokens that may be encountered subsequently. Thus, combining it with the proof of (i), one has
$$ \small{ \begin{align*} \nabla_\theta J_{\text{OPD}}(\theta) &= \mathbb{E}{x\sim D} \mathbb{E}{y\sim \pi_\theta(\cdot \mid x)} \left[ \sum_{t=1}^{T} \left( \sum_{t^{\prime}=1}^{T} \log \frac{\pi_\theta(y_{t^{\prime}} \mid c_{t^{\prime}})}{q(y_{t^{\prime}} \mid c_{t^{\prime}})} \right) \cdot \nabla_\theta \log \pi_\theta(y_t \mid c_t) \right] \\ &= \mathbb{E}{x\sim D} \mathbb{E}{y\sim \pi_\theta(\cdot \mid x)} \left[ \sum_{t=1}^{T} \left( \sum_{t^{\prime}=t}^{T} \log \frac{\pi_\theta(y_{t^{\prime}} \mid c_{t^{\prime}})}{q(y_{t^{\prime}} \mid c_{t^{\prime}})} \right) \cdot \nabla_\theta \log \pi_\theta(y_t \mid c_t) \right]. \end{align*}} $$