TDPO
TDPO算法与KL约束在强化学习中的应用
元数据:
- 分类:算法与优化
- 标签:TDPO, KL约束, 强化学习, PPO, 前向KL
- 日期:2025年4月12日
TDPO与PPO中的KL约束
TDPO(Trust Region Policy Optimization)算法通过引入PPO(Proximal Policy Optimization)中的KL约束来优化策略。不同于PPO使用的backward KL,TDPO采用forward KL来计算KL惩罚。这种选择的原因在于KL距离的非对称性:forward KL旨在尽可能覆盖整个分布的大部分,而backward KL则专注于拟合分布中的某一部分。
TDPO的优势
由于TDPO使用forward KL进行训练,其模型在输出多样性上更为自由。相比之下,PPO训练后的模型输出风格趋于一致,因为输出分布已聚集到一个局部分布上,导致reward方差小于SFT(Softmax Function Transformation)。
💡 启发点:TDPO在多样性输出上的优势使其在需要多种可能性探索的任务中表现更佳。
代码示例与计算步骤
在实现TDPO时,forward KL的计算方式可以通过以下代码实现:
vocab_logps = logits.log_softmax(-1)
reference_vocab_ps = reference_logits.softmax(-1)
reference_vocab_logps = reference_vocab_ps.log()
# Forward KL 计算
per_position_kl = (reference_vocab_ps * (reference_vocab_logps - vocab_logps)).sum(-1)
per_token_logps = torch.gather(vocab_logps, dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_reference_token_logps = torch.gather(reference_vocab_logps, dim=2, index=labels.unsqueeze(2)).squeeze(2)
操作步骤
- ✅ 初始化策略模型与参考模型的logits。
- ⚠ 计算logits的softmax并取log。
- ❗ 计算每个位置的forward KL值。
常见错误
注意:在实现TDPO时,务必确保forward KL计算的准确性,以避免模型输出的多样性不足。
行动清单
- 研究TDPO在不同任务中的表现。
- 比较TDPO和PPO在相同环境下的效果。
- 探索更多应用场景中的KL约束优化。
原始出处:[原文提供者未注明]
通过以上内容,我们总结了TDPO算法的核心概念及其与PPO的区别,特别是在KL约束的应用上。希望这篇笔记能够帮助你更好地理解TDPO算法及其优势。