
DPO 算法核心原理简述
DPO 算法核心原理简述
DPO (Direct Preference Optimization, 直接偏好优化) 是大模型对齐领域的一次**“极简主义革命”**。
它的核心思想非常犀利:移除“中间商”,用解析解替代强化学习。
在传统的 RLHF 中,我们需要先训练一个“裁判”(奖励模型),再用复杂的 PPO 算法去讨好这个裁判。而 DPO 通过数学推导证明:最优策略模型本身,就是奖励模型的隐式表达。
1. 核心公式:从“打分”到“比较”
DPO 将复杂的强化学习问题简化为了一个二分类问题。它的目标非常直观:增大“好回答”与“坏回答”之间的概率差。
在 MiniMind 的代码实现中,损失函数 被定义为:
- : 当前训练的模型。
- : 冻结的参考模型(防止模型跑偏)。
- / : 人类偏好(Chosen)与 拒绝(Rejected)的回答。
- : 控制约束力度的超参数。
2. 算法逻辑:隐式奖励优化
代码不再需要维护独立的 Critic 网络,而是直接计算**“相对优势”**:
- 计算偏好差:分别计算当前模型和参考模型对“好坏答案”的概率差值(Log Ratios)。
- 构建 Logits:如果当前模型比参考模型更倾向于“好答案”,则 Logits 变大。
- Sigmoid 损失:通过
-log(sigmoid(logits))将优化目标转化为最小化损失。
在 MiniMind 中,采用 DPO 意味着我们彻底告别了 PPO 的显式采样 (Rollout)。训练过程变成了类似于 SFT 的监督学习,极大地降低了显存占用 (VRAM),并获得了如磐石般的训练稳定性。
全局引用与辅助函数 (Imports & Helper Functions)
与 PPO 不同,DPO 不需要复杂的 Actor-Critic 交互,也不需要额外的 Reward Model 进行推理时打分。它更多依赖于数学上的推导,将强化学习问题转化为二分类问题。
首先是必要的库导入和计算 Log Probabilities 的辅助函数。logits_to_log_probs 是 DPO 算法的基础工具,用于从模型输出的 logits 中提取特定 label 对应的对数概率。
👉 点击展开查看完整引用与环境初始化代码
import osimport sys
__package__ = "trainer"sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparseimport timeimport warningsimport torchimport torch.nn.functional as Fimport torch.distributed as distfrom contextlib import nullcontextfrom torch import optimfrom torch.nn.parallel import DistributedDataParallelfrom torch.utils.data import DataLoader, DistributedSamplerfrom model.model_minimind import MiniMindConfigfrom dataset.lm_dataset import DPODatasetfrom trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
warnings.filterwarnings('ignore')
def logits_to_log_probs(logits, labels): # logits shape: (batch_size, seq_len, vocab_size), 模型最后一层的原始打分 # labels shape: (batch_size, seq_len), 真实数据的索引 # log_probs shape: (batch_size, seq_len) log_probs = F.log_softmax(logits, dim=2) log_probs_per_token = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1) return log_probs_per_tokenDPO 核心损失函数 (DPO Loss)
这是本代码的灵魂所在。dpo_loss 实现了 DPO 论文中的核心公式。它计算了 Policy Model(当前训练模型)和 Reference Model(参考模型)在 “Chosen”(偏好数据)和 “Rejected”(非偏好数据)上的概率差值,并通过 LogSigmoid 计算损失。
def dpo_loss(ref_log_probs, policy_log_probs, mask, beta): # ref_log_probs 和 policy_log_probs 都是 shape: (batch_size, seq_len) # [https://github.com/jingyaogong/minimind/issues/298](https://github.com/jingyaogong/minimind/issues/298) # 统计一行有多少个1, 防止零长度mask导致除零NaN seq_lengths = mask.sum(dim=1, keepdim=True).clamp_min(1e-8) # 把padding位置的概率值清零, 只保留有效内容, sum把一句话所有Token的对数概率加起来。/seq_lengths做平均 # [2B, L]->[2B, 1] ref_log_probs = (ref_log_probs * mask).sum(dim=1) / seq_lengths.squeeze() # 同理 policy_log_probs = (policy_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
# 将 chosen 和 rejected 数据分开 batch_size = ref_log_probs.shape[0] # 前一半:参考模型对"好回答"的打分 chosen_ref_log_probs = ref_log_probs[:batch_size // 2] # 后一半:参考模型对"坏回答"的打分 reject_ref_log_probs = ref_log_probs[batch_size // 2:] # 前一半:训练模型对"好回答"的打分 chosen_policy_log_probs = policy_log_probs[:batch_size // 2] # 后一半:训练模型对"坏回答"的打分 reject_policy_log_probs = policy_log_probs[batch_size // 2:] # 策略模型对优选回答 vs 拒绝回答 的对数概率差 pi_logratios = chosen_policy_log_probs - reject_policy_log_probs # 参考模型对优选回答 vs 拒绝回答 的对数概率差 ref_logratios = chosen_ref_log_probs - reject_ref_log_probs logits = pi_logratios - ref_logratios # 损失函数 loss = -F.logsigmoid(beta * logits) # 取平均 return loss.mean()训练循环逻辑 (Training Loop)
train_epoch 函数定义了 DPO 的核心训练步骤。不同于 PPO 的采样过程,DPO 的训练更像标准的监督学习,但在 Loss 计算前需要同时获取 Policy Model 和 Reference Model 的输出。
1. 数据解包与双模型前向传播 (Forward Pass)
在这一步,我们从 DataLoader 中提取成对的偏好数据(Chosen/Rejected),并将它们拼接以进行批量处理。随后,分别计算 Reference Model(无梯度)和 Policy Model(有梯度)的 Log Probabilities。
def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=None, beta=0.1): start_time = time.time()
for step, batch in enumerate(loader, start=start_step + 1): # chosen与rejected的输入token x_chosen = batch['x_chosen'].to(args.device) x_rejected = batch['x_rejected'].to(args.device) # 对应监督目标 y_chosen = batch['y_chosen'].to(args.device) y_rejected = batch['y_rejected'].to(args.device) mask_chosen = batch['mask_chosen'].to(args.device) mask_rejected = batch['mask_rejected'].to(args.device) # 拼接成一个大batch x = torch.cat([x_chosen, x_rejected], dim=0) y = torch.cat([y_chosen, y_rejected], dim=0) mask = torch.cat([mask_chosen, mask_rejected], dim=0)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr # 混合精读上下文,减少显存 with autocast_ctx: # 关闭梯度记录,因为这里的参考模型只用来打分,不需要反向传播 with torch.no_grad(): # x:[2B, L] ref_outputs = ref_model(x) # ref_logits:[2B, L, V], 模型输出的Logits ref_logits = ref_outputs.logits # y:[2B, L], ref_log_probs:[2B, L], 每个位置只保留一个数,即该位置目标词的logp ref_log_probs = logits_to_log_probs(ref_logits, y) # 前向传播计算 outputs = model(x) # 获得分数[2B, L, V] logits = outputs.logits policy_log_probs = logits_to_log_probs(logits, y)2. 损失反向传播与参数更新 (Backward & Optimization)
获取概率后,调用 dpo_loss 计算偏好损失,并加上辅助损失(Aux Loss,通常用于 MoE 负载均衡等)。随后进行标准的混合精度反向传播、梯度裁剪和参数更新。
# 计算DPO主损失 dpo_loss_val = dpo_loss(ref_log_probs, policy_log_probs, mask, beta=beta) # 辅助损失 loss = dpo_loss_val + outputs.aux_loss # 不除以N会导致N次梯度直接相加,等效于把学习率放大N倍 loss = loss / args.accumulation_steps # 混合精度训练里的反向传播算法 scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0: scaler.unscale_(optimizer) # 梯度裁剪 torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) # 优化器更新 scaler.step(optimizer) # 动态调整下一步的缩放因子 scaler.update() optimizer.zero_grad(set_to_none=True)3. 训练监控与检查点保存 (Logging & Checkpointing)
最后,负责记录训练日志(Loss, LR, Time)、上传 WandB 数据,并根据设定的间隔保存模型权重和 Checkpoint 状态,以便后续恢复或推理使用。
# 记日志信息 if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time current_loss = loss.item() * args.accumulation_steps current_dpo_loss = dpo_loss_val.item() current_aux_loss = outputs.aux_loss.item() current_lr = optimizer.param_groups[-1]['lr'] eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, dpo_loss: {current_dpo_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
if wandb: wandb.log({"loss": current_loss, "dpo_loss": current_dpo_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process(): model.eval() moe_suffix = '_moe' if lm_config.use_moe else '' ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth' raw_model = model.module if isinstance(model, DistributedDataParallel) else model raw_model = getattr(raw_model, '_orig_mod', raw_model) state_dict = raw_model.state_dict() torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp) lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints') model.train() del state_dict
del x_chosen, x_rejected, y_chosen, y_rejected, mask_chosen, mask_rejected, x, y, mask del ref_outputs, ref_logits, ref_log_probs, outputs, logits, policy_log_probs, loss主程序入口 (Main Entry)
主程序负责整合所有模块。为了清晰展示,我们将主程序代码进一步拆分为三个逻辑块:参数解析、模型构建、以及状态恢复与循环执行。
1. 参数解析与全局配置 (Arguments & Configuration)
这里定义了 DPO 特有的超参数,例如 beta(控制对参考模型的偏离程度)以及数据路径配置。同时完成了分布式环境的初始化和随机种子的固定。
if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)") parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录") parser.add_argument('--save_weight', default='dpo', type=str, help="保存权重的前缀名") parser.add_argument("--epochs", type=int, default=1, help="训练轮数") parser.add_argument("--batch_size", type=int, default=4, help="batch size") parser.add_argument("--learning_rate", type=float, default=4e-8, help="初始学习率(建议<=5e-8避免遗忘)") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备") parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型") parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数") parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数") parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值") parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔") parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔") parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度") parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量") parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)") parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl", help="DPO训练数据路径") parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练") parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)") parser.add_argument('--beta', default=0.1, type=float, help="DPO中的beta参数") parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") parser.add_argument("--wandb_project", type=str, default="MiniMind-DPO", help="wandb项目名") parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)") args = parser.parse_args()
# 初始化环境和随机种子 local_rank = init_distributed_mode() if dist.is_initialized(): args.device = f"cuda:{local_rank}" setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))2. 模型构建与环境初始化 (Model & Data Setup)
在此阶段,我们初始化混合精度环境、WandB 监控,并加载 Policy Model。最关键的一步是加载并冻结 Reference Model,确保它在训练过程中保持不变,作为 KL 散度的基准。
# 配置目录、模型参数、检查ckpt os.makedirs(args.save_dir, exist_ok=True) lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe)) ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# 设置混合精度 device_type = "cuda" if "cuda" in args.device else "cpu" dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# 配wandb wandb = None if args.use_wandb and is_main_process(): import swanlab as wandb wandb_id = ckp_data.get('wandb_id') if ckp_data else None resume = 'must' if wandb_id else None wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}" wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# 5. 定义模型和参考模型 model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) if args.use_compile == 1: model = torch.compile(model) Logger('torch.compile enabled') Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M') # 初始化参考模型(ref_model冻结) ref_model, _ = init_model(lm_config, args.from_weight, device=args.device) ref_model.eval() ref_model.requires_grad_(False) Logger(f'参考模型总参数量:{sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M')
train_ds = DPODataset(args.data_path, tokenizer, max_length=args.max_seq_len) train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)3. 状态恢复与训练执行 (Execution)
最后,如果存在 Checkpoint,则恢复模型和优化器状态。接着使用 DDP 封装模型,利用 SkipBatchSampler 处理断点续训的数据对齐,开始 Epoch 循环,并在结束后清理分布式进程组。
# 6. 从ckpt恢复状态 start_epoch, start_step = 0, 0 if ckp_data: model.load_state_dict(ckp_data['model']) optimizer.load_state_dict(ckp_data['optimizer']) scaler.load_state_dict(ckp_data['scaler']) start_epoch = ckp_data['epoch'] start_step = ckp_data.get('step', 0)
# DDP包模型 if dist.is_initialized(): model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[local_rank])
# 开始训练 for epoch in range(start_epoch, args.epochs): train_sampler and train_sampler.set_epoch(epoch) setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist() skip = start_step if (epoch == start_epoch and start_step > 0) else 0 batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip) loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True) if skip > 0: Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始') train_epoch(epoch, loader, len(loader) + skip, ref_model, lm_config, start_step, wandb, args.beta) else: train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
# 清理分布进程 if dist.is_initialized(): dist.destroy_process_group()“千里之行,始于足下。”