拒绝采样(Rejection Sampling)


一句话总结

拒绝采样通过生成多个候选回答并选择最优的一个,是提升训练数据质量和推理输出质量的简单有效且广泛使用的方法。

核心概念

拒绝采样(也称Best-of-N)的流程:对每个输入生成N个候选回答(N通常为4-64),用奖励模型或评分函数对所有候选打分,选择得分最高的作为最终输出或训练数据。在训练中用于构造高质量SFT数据(Rejection Sampling Fine-tuning, RFT);在推理中用于提升单次输出质量。采样温度通常设较高值(0.7-1.0)以增加候选多样性,确保覆盖更多可能的回答空间。

为什么重要

拒绝采样是连接SFT和RLHF的桥梁——可以用奖励模型筛选数据后做SFT,效果接近RLHF但实现简单得多。Llama-2的对齐训练大量使用了这一技术来迭代提升数据质量。它也是迭代式自我改进的基础方法。

实践要点

N值越大质量越高但计算成本线性增长,实践中N=16-32是性价比最高的选择;奖励模型的质量是关键瓶颈,低质量奖励模型会选出错误答案;可以使用多个评分维度的加权组合作为选择标准。注意防范奖励模型的奖励黑客问题。

常见误区

误区一:N越大越好——存在明显的边际收益递减,且计算成本高昂,需要权衡成本效益。误区二:拒绝采样等于RLHF——它只是近似,缺少策略梯度优化的在线探索能力。