2764 字
14 分钟
MiniMind GRPO 训练源码深度解析

GRPO算法流程图 图 1:GRPO 训练流程示意图

GRPO 算法与 MiniMind 对齐实战导读#

在 LLM 的训练流程中,预训练提供知识,SFT 提供指令对齐,而 RLHF 则让模型真正学会“偏好”的权衡。GRPO 是一种更轻量的偏好优化方案,它不依赖显式 Critic,而是对同一 prompt 的多条回复在组内做相对标准化,用优势信号直接更新策略。

本文将以 train_grpo.py 为主线,重点拆解以下关键模块:

  1. 组内相对优势 (Group-Relative Advantage):同一 prompt 生成多条回复后,在组内标准化得到优势信号。

  2. 无 Critic 的直接策略优化:直接对 token 级对数概率加权,并加入 KL 惩罚,限制策略偏移。

  3. 混合奖励工程:Reward Model 打分 + 规则奖励共同构成最终奖励,适配推理模型的结构化输出。


GRPO 算法核心原理简述#

1. 组内采样与奖励#

对每个 prompt 生成 KK 个候选回复:

{y1,y2,,yK}\{y_1, y_2, \dots, y_K\}

计算每个回复的奖励:

ri=R(x,yi) r_i = R(x, y_i)

在组内做标准化,得到优势:

μ=1Ki=1Kri,σ=1Ki=1K(riμ)2\mu = \frac{1}{K}\sum_{i=1}^K r_i,\quad \sigma = \sqrt{\frac{1}{K}\sum_{i=1}^K (r_i - \mu)^2}Ai=riμσ+ϵA_i = \frac{r_i - \mu}{\sigma + \epsilon}

2. 策略更新目标#

GRPO 直接对策略的 token 级对数概率做加权,并加入 KL 惩罚项:

L=E[Ailogπθ(yix)]+βKL(πθπref)\mathcal{L} = -\mathbb{E}\left[A_i \cdot \log \pi_\theta(y_i|x)\right] + \beta\,\mathrm{KL}(\pi_\theta \parallel \pi_{\text{ref}})

其中:

  1. πθ\pi_\theta 为当前策略模型。
  2. πref\pi_{\text{ref}} 为参考模型(通常是 SFT 权重)。
  3. β\beta 控制 KL 惩罚强度。

实现中常用 token 级 KL 近似:

Δ=logπreflogπθ,KLtokenexp(Δ)Δ1\Delta = \log \pi_{\text{ref}} - \log \pi_\theta,\quad \mathrm{KL}_{\text{token}} \approx \exp(\Delta) - \Delta - 1

3. 训练流程概览#

  1. 取一批 prompts。
  2. 每个 prompt 生成 KK 条回复。
  3. 计算奖励并组内标准化得到优势。
  4. 计算策略与参考模型的 per-token logp。
  5. 构造损失(优势项 + KL 惩罚)。
  6. 反向传播并更新参数。

全局引用与环境初始化 (Imports & Setup)#

这一段完成运行时环境准备:修正包路径、引入训练依赖、设置分布式与告警过滤,为后续函数与主入口打好基础。

代码:全局引用与环境初始化
import os
import sys
__package__ = "trainer"
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import re
import gc
import warnings
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import AutoModel
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from dataset.lm_dataset import RLAIFDataset
from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler, init_model
warnings.filterwarnings('ignore')

奖励计算逻辑 (calculate_rewards)#

该函数负责整合规则奖励与 Reward Model 评分,生成每条回复的总奖励,为后续优势估计与策略更新提供信号。

代码:calculate_rewards
def calculate_rewards(prompts, responses, reward_model, reward_tokenizer):
"""整合所有奖励函数计算总奖励"""
def reasoning_model_reward(rewards):
# 正则化匹配response整体格式
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
pattern2 = r"^<think>\n.*?\n</think>\n\n<answer>\n.*?\n</answer>$"
matches_pattern = [re.match(pattern, response, re.S) for response in responses]
matches_pattern2 = [re.match(pattern2, response, re.S) for response in responses]
format_rewards = []
for match_pattern, match_pattern2 in zip(matches_pattern, matches_pattern2):
# 如果符合pattern1和pattern2人任意格式,就加0.5
if match_pattern or match_pattern2:
format_rewards.append(0.5)
else:
format_rewards.append(0.0)
# 把格式奖励逐元素
rewards += torch.tensor(format_rewards, device=args.device)
def mark_num(text):
reward = 0
# 独立的格式打分,不检查顺序
if text.count("<think>") == 1: reward += 0.25
if text.count("</think>") == 1: reward += 0.25
if text.count("<answer>") == 1: reward += 0.25
if text.count("</answer>") == 1: reward += 0.25
return reward
# 奖励逐个加到
mark_rewards = [mark_num(response) for response in responses]
rewards += torch.tensor(mark_rewards, device=args.device)
return rewards
rewards = torch.zeros(len(responses), device=args.device)
if args.reasoning == 1:
rewards = reasoning_model_reward(rewards)
# 不进行梯度计算
with torch.no_grad():
reward_model_scores = []
# 得到Batch_Size
batch_size = len(prompts)
# 用于分数裁剪
scale = 3.0
for i in range(batch_size):
for j in range(args.num_generations):
# 二维展开到一维索引
response_idx = i * args.num_generations + j
response = responses[response_idx]
prompt = prompts[i]
pattern = r"<\|im_start\|>(system|user|assistant)\s+(.*?)<\|im_end\|>"
matches = re.findall(pattern, prompt, re.DOTALL)
# 将Prompts中形如pattern的段落解析成message列表
messages = [{"role": role, "content": content.strip()} for role, content in matches]
# 进行message拼接
tmp_chat = messages + [{"role": "assistant", "content": response}]
# score是标量, 裁剪到[-3, 3]
score = reward_model.get_score(reward_tokenizer, tmp_chat)
score = max(min(score, scale), -scale)
if args.reasoning == 1:
answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
if answer_match:
# 从<answer></answer>中的内容弄给单独打分, 并与全文评分融合
answer_content = answer_match.group(1).strip()
tmp_chat = messages + [{"role": "assistant", "content": answer_content}]
answer_score = reward_model.get_score(reward_tokenizer, tmp_chat)
answer_score = max(min(answer_score, scale), -scale)
# 40%来自全文, 60%来自答案
score = score * 0.4 + answer_score * 0.6
reward_model_scores.append(score)
reward_model_scores = torch.tensor(reward_model_scores, device=args.device)
rewards += reward_model_scores
return rewards

GRPO 单轮训练 (grpo_train_epoch)#

这一段涵盖生成采样、对数概率计算、优势归一化、KL 约束以及梯度更新,是 GRPO 训练的核心循环。

代码:grpo_train_epoch
def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokenizer, start_step=0, wandb=None):
for step, batch in enumerate(loader, start=start_step + 1):
prompts = batch['prompt'] # list[str], length B
# 左侧padding(序列左边补pad), 不自动添加BOS/EOS等特殊Token
prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False,
padding_side="left", add_special_tokens=False).to(args.device) # input_ids: [B, P], attention_mask: [B, P] P是本batch中最长的prompt token数
if args.max_seq_len:
# 截断后, Tensor[B, L], 因为取的是序列末尾的Token, 所以要倒着截断
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -args.max_seq_len:]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -args.max_seq_len:]
with torch.no_grad():
# DDP 模型需要使用 .module 访问 generate 方法
model_for_gen = model.module if isinstance(model, DistributedDataParallel) else model
# 进行采样, num_return_sequences即G, 生成G条回答
outputs = model_for_gen.generate(
**prompt_inputs, max_new_tokens=args.max_gen_len, do_sample=True, temperature=0.8,
num_return_sequences=args.num_generations, pad_token_id=tokenizer.pad_token_id) # [B*num_gen, P+R], 每条序列包含原Prompt+生成Token
# prompt_inputs["input_ids"].size(1)就是P, outputs[:, P:]取每条序列从第P位开始的部分
completion_ids = outputs[:, prompt_inputs["input_ids"].size(1):] # [B*num_gen, R]
def get_per_token_logps(mdl, input_ids, n_keep):
# N=B*G, L=P+R, 形状不变:Tensor[N, L]
input_ids = input_ids.detach().clone() if input_ids.is_inference() else input_ids
# Tensor[N, n_kepp+1, V], V是词表大小, n_keep是生成的response token数R
# 拿到最后n_keep+1个词的分数, 再去掉最后一个
logits = mdl(input_ids, logits_to_keep=n_keep + 1).logits[:, :-1, :]
per_token_logps = []
# logits形状是[N, n_keep, V], input_ids[:, -n_keep:]形状是[N, n_keep]
for logits_row, ids_row in zip(logits, input_ids[:, -n_keep:]):
# logits_row是某一条样本的[n_keep, V], ids_row是同一条样本的[n_keep]
ids_row = ids_row.detach().clone() if ids_row.is_inference() else ids_row
# unsqueeze(1)后变成[n_keepp, 1], 在词表维度上按照真实token id取值, 得到每个位置对应真实token的log概率, 形状为[n_keep, 1], squeeze(1)去掉多余维度, 得到[n_keep]
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
# 所有样本拼接成张量, 最终形状[N, n_keep]
return torch.stack(per_token_logps)
# 进入自动混合精读上下文
with autocast_ctx:
# 计算最后R个token的逐token log概率, 返回形状[B*G, R]
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B*num_gen, R]
# 如果是MoE, 再跑一次拿到MoE的辅助输出
res = model(outputs) if lm_config.use_moe else None
aux_loss = res.aux_loss if res is not None else torch.tensor(0.0, device=args.device)
with torch.no_grad():
# 用参考模型计算同样输出序列的逐token log概率, 形状[B*G, R]
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R]
# 解码成文本
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
# 用奖励模型计算每条response的标量奖励
rewards = calculate_rewards(prompts, completions, reward_model, reward_tokenizer).to(args.device) # [B*num_gen]
# 按照prompt分组, 计算优势, [B, G]
grouped_rewards = rewards.view(-1, args.num_generations) # [B, num_gen]
# 求均值/标准差并广播回[B*G]
mean_r = grouped_rewards.mean(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
std_r = grouped_rewards.std(dim=1).repeat_interleave(args.num_generations) # [B*num_gen]
# 计算标准化优势
advantages = torch.clamp((rewards - mean_r) / (std_r + 1e-4), -10, 10)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # [B*num_gen]
# 计算每条样本第一个EOS的位置, 没有就默认R
is_eos = completion_ids == tokenizer.eos_token_id # [B*num_gen, R]
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=args.device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
completion_mask = (torch.arange(is_eos.size(1), device=args.device).expand(is_eos.size(0), -1) <= eos_idx.unsqueeze(1)).int() # [B*num_gen, R]
# 相当于 log p_ref - log p_pi
kl_div = ref_per_token_logps - per_token_logps
# KL的一种平滑形式, 来自exp(x)-x-1, 数值稳定且对偏离有惩罚
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R]
# 前向值为exp(0)=1, 但反向梯度等价于 ∇log π_θ, 再乘以优势
per_token_loss = -(torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - args.beta * per_token_kl) # [B*num_gen, R]
# 做反向传播
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
loss.backward()
# 参数更新
if (step + 1) % args.accumulation_steps == 0:
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# 日志打印, wandb记录
if step % args.log_interval == 0 or step == iters:
policy_loss_val = loss.item() * args.accumulation_steps
current_aux_loss = aux_loss.item()
avg_reward_val = rewards.mean().item()
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
current_lr = optimizer.param_groups[0]['lr']
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), '
f'Actor Loss: {policy_loss_val:.4f}, Aux Loss: {current_aux_loss:.4f}, Reward: {avg_reward_val:.4f}, '
f'Avg Response Len: {avg_len_val:.2f}, Learning Rate: {current_lr:.8f}')
if wandb and is_main_process():
wandb.log({
"policy_loss": policy_loss_val,
"aux_loss": current_aux_loss,
"reward": avg_reward_val,
"avg_response_len": avg_len_val,
"advantages_mean": advantages.mean().item(),
"learning_rate": current_lr
})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
raw_model = model.module if isinstance(model, DistributedDataParallel) else model
raw_model = getattr(raw_model, '_orig_mod', raw_model)
state_dict = raw_model.state_dict()
torch.save({k: v.half().cpu() for k, v in state_dict.items()}, ckp)
lm_checkpoint(lm_config, weight=args.save_weight, model=model, optimizer=optimizer,
epoch=epoch, step=step, wandb=wandb, save_dir='../checkpoints', scheduler=scheduler)
model.train()
del state_dict
del prompt_inputs, outputs, completion_ids, per_token_logps, ref_per_token_logps
del completions, rewards, grouped_rewards, mean_r, std_r, advantages, completion_mask

主入口与训练流程 (Main Entry)#

主入口负责参数解析、分布式初始化、模型/数据/优化器构建、断点恢复、DDP 包装以及整体训练调度。

代码:主入口
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind GRPO (Group Relative Policy Optimization)")
parser.add_argument("--save_dir", type=str, default="../out", help="模型保存目录")
parser.add_argument('--save_weight', default='grpo', type=str, help="保存权重的前缀名")
parser.add_argument("--epochs", type=int, default=1, help="训练轮数")
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
parser.add_argument("--learning_rate", type=float, default=8e-8, help="初始学习率")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="训练设备")
parser.add_argument("--dtype", type=str, default="bfloat16", help="混合精度类型")
parser.add_argument("--num_workers", type=int, default=8, help="数据加载线程数")
parser.add_argument("--accumulation_steps", type=int, default=1, help="梯度累积步数")
parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")
parser.add_argument("--log_interval", type=int, default=1, help="日志打印间隔")
parser.add_argument("--save_interval", type=int, default=10, help="模型保存间隔")
parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度")
parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量")
parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)")
parser.add_argument('--max_seq_len', default=66, type=int, help="Prompt最大长度")
parser.add_argument("--max_gen_len", type=int, default=1536, help="生成的最大长度")
parser.add_argument("--data_path", type=str, default="../dataset/rlaif-mini.jsonl", help="RLAIF数据路径")
parser.add_argument("--num_generations", type=int, default=8, help="每个prompt生成的样本数")
parser.add_argument("--beta", type=float, default=0.02, help="KL惩罚系数")
parser.add_argument("--reasoning", type=int, default=1, choices=[0, 1], help='推理模型类型(0=普通模型,1=推理模型)')
parser.add_argument("--reward_model_path", type=str, default="../../internlm2-1_8b-reward", help="Reward模型路径")
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-GRPO", help="wandb项目名")
parser.add_argument("--use_compile", default=0, type=int, choices=[0, 1], help="是否使用torch.compile加速(0=否,1=是)")
args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ==========
local_rank = init_distributed_mode()
if dist.is_initialized(): args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))
# ========== 2. 配置目录、模型参数、检查ckp ==========
os.makedirs(args.save_dir, exist_ok=True)
lm_config = MiniMindConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
max_seq_len=args.max_seq_len + args.max_gen_len, use_moe=bool(args.use_moe))
ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None
# ========== 3. 设置混合精度 ==========
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)
# ========== 4. 配wandb ==========
wandb = None
if args.use_wandb and is_main_process():
import swanlab as wandb
wandb_id = ckp_data.get('wandb_id') if ckp_data else None
resume = 'must' if wandb_id else None
wandb_run_name = f"MiniMind-GRPO-Epoch-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 初始化模型和数据 ==========
base_weight = "reason" if args.reasoning == 1 else "full_sft"
# Policy模型
model, tokenizer = init_model(lm_config, base_weight, device=args.device)
if args.use_compile == 1:
model = torch.compile(model)
Logger('torch.compile enabled')
# Reference模型
ref_model, _ = init_model(lm_config, base_weight, device=args.device)
ref_model = ref_model.eval().requires_grad_(False)
# Reward模型
reward_model = AutoModel.from_pretrained(
args.reward_model_path, torch_dtype=torch.float16, trust_remote_code=True
)
reward_model = reward_model.to(args.device).eval().requires_grad_(False)
reward_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path, trust_remote_code=True)
# 数据和优化器
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
loader_for_count = DataLoader(train_ds, batch_size=args.batch_size, sampler=train_sampler)
iters = len(loader_for_count)
total_optimizer_steps = (iters // args.accumulation_steps) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
# ========== 6. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0
if ckp_data:
model.load_state_dict(ckp_data['model'])
optimizer.load_state_dict(ckp_data['optimizer'])
scheduler.load_state_dict(ckp_data['scheduler'])
start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0)
# ========== 7. DDP包模型 ==========
if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 8. 开始训练 ==========
for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch)
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
if skip > 0:
Logger(f'Epoch [{epoch + 1}/{args.epochs}]: 跳过前{start_step}个step,从step {start_step + 1}开始')
grpo_train_epoch(epoch, loader, len(loader) + skip, ref_model, reward_model, reward_tokenizer, start_step, wandb)
else:
grpo_train_epoch(epoch, loader, len(loader), ref_model, reward_model, reward_tokenizer, 0, wandb)
# ========== 9. 清理分布进程 ==========
if dist.is_initialized(): dist.destroy_process_group()
MiniMind GRPO 训练源码深度解析
https://shenyize.com/posts/minimind_train_grpo/
作者
Shenyize
发布于
2026-02-27
许可协议
CC BY-NC-SA 4.0