2708 字
14 分钟
MiniMind DPO 训练源码深度解析

DPO算法完整流程图


DPO 算法核心原理简述#

DPO 算法核心原理简述#

DPO (Direct Preference Optimization, 直接偏好优化) 是大模型对齐领域的一次**“极简主义革命”**。

它的核心思想非常犀利:移除“中间商”,用解析解替代强化学习

在传统的 RLHF 中,我们需要先训练一个“裁判”(奖励模型),再用复杂的 PPO 算法去讨好这个裁判。而 DPO 通过数学推导证明:最优策略模型本身,就是奖励模型的隐式表达

1. 核心公式:从“打分”到“比较”#

DPO 将复杂的强化学习问题简化为了一个二分类问题。它的目标非常直观:增大“好回答”与“坏回答”之间的概率差

在 MiniMind 的代码实现中,损失函数 LDPO\mathcal{L}_{\text{DPO}} 被定义为:

LDPO=logσ(β[logπθ(ywx)πref(ywx)logπθ(ylx)πref(ylx)])\mathcal{L}_{\text{DPO}} = -\log \sigma \left( \beta \left[ \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)} \right] \right)
  • πθ\pi_\theta: 当前训练的模型。
  • πref\pi_{\text{ref}}: 冻结的参考模型(防止模型跑偏)。
  • ywy_w / yly_l: 人类偏好(Chosen)与 拒绝(Rejected)的回答。
  • β\beta: 控制约束力度的超参数。

2. 算法逻辑:隐式奖励优化#

代码不再需要维护独立的 Critic 网络,而是直接计算**“相对优势”**:

  1. 计算偏好差:分别计算当前模型和参考模型对“好坏答案”的概率差值(Log Ratios)。
  2. 构建 Logits:如果当前模型比参考模型更倾向于“好答案”,则 Logits 变大。
  3. Sigmoid 损失:通过 -log(sigmoid(logits)) 将优化目标转化为最小化损失。

在 MiniMind 中,采用 DPO 意味着我们彻底告别了 PPO 的显式采样 (Rollout)。训练过程变成了类似于 SFT 的监督学习,极大地降低了显存占用 (VRAM),并获得了如磐石般的训练稳定性

全局引用与辅助函数 (Imports & Helper Functions)#

与 PPO 不同,DPO 不需要复杂的 Actor-Critic 交互,也不需要额外的 Reward Model 进行推理时打分。它更多依赖于数学上的推导,将强化学习问题转化为二分类问题。

首先是必要的库导入和计算 Log Probabilities 的辅助函数。logits_to_log_probs 是 DPO 算法的基础工具,用于从模型输出的 logits 中提取特定 label 对应的对数概率。

👉 点击展开查看完整引用与环境初始化代码
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
import warnings
import torch
import torch.nn.functional as F
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from model.model_minimind import MiniMindConfig
from dataset.lm_dataset import DPODataset
from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler
warnings.filterwarnings('ignore')
def logits_to_log_probs(logits, labels):
# logits shape: (batch_size, seq_len, vocab_size), 模型最后一层的原始打分
# labels shape: (batch_size, seq_len), 真实数据的索引
# log_probs shape: (batch_size, seq_len)
log_probs = F.log_softmax(logits, dim=2)
log_probs_per_token = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
return log_probs_per_token

DPO 核心损失函数 (DPO Loss)#

这是本代码的灵魂所在。dpo_loss 实现了 DPO 论文中的核心公式。它计算了 Policy Model(当前训练模型)和 Reference Model(参考模型)在 “Chosen”(偏好数据)和 “Rejected”(非偏好数据)上的概率差值,并通过 LogSigmoid 计算损失。

def dpo_loss(ref_log_probs, policy_log_probs, mask, beta):
# ref_log_probs 和 policy_log_probs 都是 shape: (batch_size, seq_len)
# [https://github.com/jingyaogong/minimind/issues/298](https://github.com/jingyaogong/minimind/issues/298)
# 统计一行有多少个1, 防止零长度mask导致除零NaN
seq_lengths = mask.sum(dim=1, keepdim=True).clamp_min(1e-8)
# 把padding位置的概率值清零, 只保留有效内容, sum把一句话所有Token的对数概率加起来。/seq_lengths做平均
# [2B, L]->[2B, 1]
ref_log_probs = (ref_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
# 同理
policy_log_probs = (policy_log_probs * mask).sum(dim=1) / seq_lengths.squeeze()
# 将 chosen 和 rejected 数据分开
batch_size = ref_log_probs.shape[0]
# 前一半:参考模型对"好回答"的打分
chosen_ref_log_probs = ref_log_probs[:batch_size // 2]
# 后一半:参考模型对"坏回答"的打分
reject_ref_log_probs = ref_log_probs[batch_size // 2:]
# 前一半:训练模型对"好回答"的打分
chosen_policy_log_probs = policy_log_probs[:batch_size // 2]
# 后一半:训练模型对"坏回答"的打分
reject_policy_log_probs = policy_log_probs[batch_size // 2:]
# 策略模型对优选回答 vs 拒绝回答 的对数概率差
pi_logratios = chosen_policy_log_probs - reject_policy_log_probs
# 参考模型对优选回答 vs 拒绝回答 的对数概率差
ref_logratios = chosen_ref_log_probs - reject_ref_log_probs
logits = pi_logratios - ref_logratios
# 损失函数
loss = -F.logsigmoid(beta * logits)
# 取平均
return loss.mean()

训练循环逻辑 (Training Loop)#

train_epoch 函数定义了 DPO 的核心训练步骤。不同于 PPO 的采样过程,DPO 的训练更像标准的监督学习,但在 Loss 计算前需要同时获取 Policy Model 和 Reference Model 的输出。

1. 数据解包与双模型前向传播 (Forward Pass)#

在这一步,我们从 DataLoader 中提取成对的偏好数据(Chosen/Rejected),并将它们拼接以进行批量处理。随后,分别计算 Reference Model(无梯度)和 Policy Model(有梯度)的 Log Probabilities。

def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=None, beta=0.1):
start_time = time.time()
for step, batch in enumerate(loader, start=start_step + 1):
# chosen与rejected的输入token
x_chosen = batch['x_chosen'].to(args.device)
x_rejected = batch['x_rejected'].to(args.device)
# 对应监督目标
y_chosen = batch['y_chosen'].to(args.device)
y_rejected = batch['y_rejected'].to(args.device)
mask_chosen = batch['mask_chosen'].to(args.device)
mask_rejected = batch['mask_rejected'].to(args.device)
# 拼接成一个大batch
x = torch.cat([x_chosen, x_rejected], dim=0)
y = torch.cat([y_chosen, y_rejected], dim=0)
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 混合精读上下文,减少显存
with autocast_ctx:
# 关闭梯度记录,因为这里的参考模型只用来打分,不需要反向传播
with torch.no_grad():
# x:[2B, L]
ref_outputs = ref_model(x)
# ref_logits:[2B, L, V], 模型输出的Logits
ref_logits = ref_outputs.logits
# y:[2B, L], ref_log_probs:[2B, L], 每个位置只保留一个数,即该位置目标词的logp
ref_log_probs = logits_to_log_probs(ref_logits, y)
# 前向传播计算
outputs = model(x)
# 获得分数[2B, L, V]
logits = outputs.logits
policy_log_probs = logits_to_log_probs(logits, y)

2. 损失反向传播与参数更新 (Backward & Optimization)#

获取概率后,调用 dpo_loss 计算偏好损失,并加上辅助损失(Aux Loss,通常用于 MoE 负载均衡等)。随后进行标准的混合精度反向传播、梯度裁剪和参数更新。

# 计算DPO主损失
dpo_loss_val = dpo_loss(ref_log_probs, policy_log_probs, mask, beta=beta)
# 辅助损失
loss = dpo_loss_val + outputs.aux_loss
# 不除以N会导致N次梯度直接相加,等效于把学习率放大N倍
loss = loss / args.accumulation_steps
# 混合精度训练里的反向传播算法
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
# 优化器更新
scaler.step(optimizer)
# 动态调整下一步的缩放因子
scaler.update()
optimizer.zero_grad(set_to_none=True)

3. 训练监控与检查点保存 (Logging & Checkpointing)#

最后,负责记录训练日志(Loss, LR, Time)、上传 WandB 数据,并根据设定的间隔保存模型权重和 Checkpoint 状态,以便后续恢复或推理使用。

# 记日志信息
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
current_dpo_loss = dpo_loss_val.item()
current_aux_loss = outputs.aux_loss.item()
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, dpo_loss: {current_dpo_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
if wandb: wandb.log({"loss": current_loss, "dpo_loss": current_dpo_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer, scaler=scaler, epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints')
model.train()
del state_dict
del x_chosen, x_rejected, y_chosen, y_rejected, mask_chosen, mask_rejected, x, y, mask
del ref_outputs, ref_logits, ref_log_probs, outputs, logits, policy_log_probs, loss

主程序入口 (Main Entry)#

主程序负责整合所有模块。为了清晰展示,我们将主程序代码进一步拆分为三个逻辑块:参数解析、模型构建、以及状态恢复与循环执行。

1. 参数解析与全局配置 (Arguments & Configuration)#

这里定义了 DPO 特有的超参数,例如 beta(控制对参考模型的偏离程度)以及数据路径配置。同时完成了分布式环境的初始化和随机种子的固定。

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind DPO (Direct Preference Optimization)")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='dpo', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
parser.add_argument("--learning_rate", type=float, default=4e-8, help="初始学习率(建议<=5e-8避免遗忘)")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=100, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=100, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--max_seq_len', default=1024, type=int, help="训练的最大截断长度(中文1token≈1.5~1.7字符)")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
parser.add_argument("--data_path", type=str, default="../dataset/dpo.jsonl", help="DPO训练数据路径")
parser.add_argument('--from_weight', default='full_sft', type=str, help="基于哪个权重训练")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
parser.add_argument('--beta', default=0.1, type=float, help="DPO中的beta参数")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-DPO", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
args = parser.parse_args()
# 初始化环境和随机种子
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))

2. 模型构建与环境初始化 (Model & Data Setup)#

在此阶段,我们初始化混合精度环境、WandB 监控,并加载 Policy Model。最关键的一步是加载并冻结 Reference Model,确保它在训练过程中保持不变,作为 KL 散度的基准。

# 配置目录、模型参数、检查ckpt
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# 设置混合精度
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# 配wandb
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# 5. 定义模型和参考模型
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
Logger(f'策略模型总参数量:{sum(p.numel() for p in model.parameters()) / 1e6:.3f} M')
# 初始化参考模型(ref_model冻结)
ref_model, _ = init_model(lm_config, args.from_weight, device=args.device)
ref_model.eval()
ref_model.requires_grad_(False)
Logger(f'参考模型总参数量:{sum(p.numel() for p in ref_model.parameters()) / 1e6:.3f} M')
train_ds = DPODataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)

3. 状态恢复与训练执行 (Execution)#

最后,如果存在 Checkpoint,则恢复模型和优化器状态。接着使用 DDP 封装模型,利用 SkipBatchSampler 处理断点续训的数据对齐,开始 Epoch 循环,并在结束后清理分布式进程组。

# 6. 从ckpt恢复状态
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scaler.load_state_dict(ckp_data['scaler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# DDP包模型
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# 开始训练
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
train_epoch(epoch, loader, len(loader) + skip, ref_model, lm_config, start_step, wandb, args.beta)
else:
train_epoch(epoch, loader, len(loader), ref_model, lm_config, 0, wandb, args.beta)
# 清理分布进程
if dist.is_initialized(): dist.destroy_process_group()

“千里之行,始于足下。”

MiniMind DPO 训练源码深度解析
https://shenyize.com/posts/minimind_train_dpo/
作者
Shenyize
发布于
2026-02-01
许可协议
CC BY-NC-SA 4.0