Skip to content

继续预训练

Continued pretraining 是指在一个已经预训练过的base模型基础上继续进行预训练的过程,目的是训练出适配下游领域的模型,训练方式与预训练一致。

什么时候需要继续预训练

预训练的语料与下游任务语料的【数据分布/领域差异】大时。假如预训练的数据缺少原神相关的数据(或者没有在原神数据上进行充分训练),我想微调一个原神问答模型,最好先利用大量的原神资料(攻略、更新日志等无监督文本)做继续预训练,再利用原神问答对做微调。

典型场景:

  • 垂直领域适配:医疗、法律、金融等专业领域,通用模型缺乏领域知识。
  • 多语言适配:英文为主的基座模型需要在中文、日文等语言上增强能力。
  • 领域数据增强:已有一定的领域数据,需要进一步注入更多领域知识。
  • 知识时效性更新:将新知识通过继续预训练注入模型,提升对新信息的覆盖。

领域数据比例

领域数据占比过高,可能导致训练损失直接崩掉,占比太低会导致在领域知识方面提升不大。

假设领域数据的比例是r,关于领域数据比例的scaling law的公式为:

L(N,D,r)=E+ANα+BrηDβ+Crγ

其中 r' = r + ε。

当r表示域语料混合比例r_d时,L表示领域语料验证损失L_d。同样,当r表示域语料混合比例r_g时,L表示领域语料验证损失L_g。

随着领域数据的比例增大,会导致领域损失降低而通用损失上升,最终趋于稳定。基于scaling law,就可以估计不同数据配比下损失的预估值。

求解最优混合比例

设模型通用能力下降的阈值不超过T,则优化目标为:

argminrdLd(N=N0,D=D0,rd)s.t.LgLg0Lg0<T

总体而言,领域数据量应控制在15%以内,避免过度干扰模型的泛化能力。

继续预训练 vs 微调

维度继续预训练微调
训练数据特定领域的无标注数据集特定任务的标注数据集
训练规模数据量大,训练时间长数据量少,训练时间短
训练参数更新模型的全部参数可以只更新部分参数(如PEFT)
训练结果将通用模型适配到垂直领域做特定任务的优化

继续预训练的具体实现步骤

1. 基座模型选择

选择合适的预训练基座模型,优先考虑:

  • 模型架构与目标领域数据的匹配度(如领域涉及大量代码,选择代码能力较强的基座)
  • 模型参数量与可用计算资源的匹配
  • 基座模型在目标领域上的当前能力(通过PPL或Benchmark评估)

2. 领域数据准备

数据收集

  • 公开领域语料:如医学领域的PubMed、法律领域的判决文书、学术领域的arXiv论文等。
  • 行业内部数据:如企业技术文档、内部知识库、行业报告等。
  • 合成数据:利用大模型生成领域相关文本,补充数据不足的部分。

数据清洗

  • 使用与基座模型预训练一致的清洗工具链(如DataTrove、Trafilatura)。
  • 对领域数据进行质量过滤:去除乱码、低质量文本、无关内容。
  • 对领域数据进行去重处理:避免重复文本导致的过拟合。
  • 对领域敏感信息进行处理:如医疗数据中的患者隐私信息脱敏。

数据格式化

将领域数据统一转换为预训练时使用的格式,通常为纯文本或JSONL格式:

json
{"text": "领域文本内容..."}
{"text": "领域文本内容..."}

3. 训练配置

  • 学习率:继续预训练的学习率通常为预训练的1/10到1/100,避免过度遗忘。
  • Batch Size:与预训练保持一致或适当减小。
  • 训练步数:取决于领域数据量,通常几万到几十万步。
  • 学习率调度:同样采用WSD策略,但warmup步数可以适当减少。

4. 训练监控与评估

  • 监控领域数据和通用数据的loss变化趋势。
  • 定期在领域评测集上评估模型能力。
  • 使用概率探针监测关键领域知识的记忆情况。
  • 检查通用能力是否出现显著下降。

5. 模型选择

从训练过程中保存的多个checkpoint中选择最优模型,平衡领域能力和通用能力:

Score=αLdomain+(1α)Lgeneral

其中 α 为领域能力权重,根据实际需求调整。

领域数据处理方法

领域数据分类

在继续预训练中,领域数据通常需要进一步细分和配比:

数据类型说明典型占比
核心领域文本与目标领域直接相关的文本60%-80%
相关领域文本与目标领域间接相关的文本10%-20%
通用语料维持通用能力的基础语料10%-20%

数据增强策略

当领域数据量不足时,可以采用以下策略:

  • 回译增强:将领域文本翻译为其他语言再翻译回来,生成语义等价但表述不同的新文本。
  • 同义替换:使用领域专用词典进行同义词替换。
  • 指令改写:利用大模型将领域文本改写为不同风格或格式。

常见问题与解决方案

问题一:灾难性遗忘

表现:继续预训练后,模型在通用任务上的能力显著下降。

解决方案

  • 降低学习率(建议为预训练学习率的1/50到1/100)
  • 在训练数据中混入一定比例的通用语料(10%-20%)
  • 使用EWC(Elastic Weight Consolidation)等正则化方法保护关键参数
  • 采用较短的训练步数,避免过度训练

问题二:训练损失震荡或上升

表现:训练过程中loss不稳定,出现大幅波动或持续上升。

解决方案

  • 检查领域数据质量,可能存在大量噪声数据
  • 降低学习率,增加warmup步数
  • 检查数据格式是否与基座模型预训练一致
  • 减小batch size,降低训练不稳定性

问题三:领域能力提升不明显

表现:继续预训练后,领域任务的表现没有明显改善。

解决方案

  • 增加领域数据量或提高领域数据占比
  • 检查领域数据的质量和相关性
  • 延长训练步数,给予模型更多学习时间
  • 检查基座模型是否在该领域已有较好的基础

问题四:评估集数据泄漏

表现:领域评测集分数虚高,但实际应用效果不佳。

解决方案

  • 确保评估集与训练数据完全隔离
  • 使用动态更新的评估集或私有测试集
  • 在训练数据中搜索是否存在与评估集高度相似的文本

问题五:计算资源不足

表现:无法使用足够大的batch size或训练足够长的时间。

解决方案

  • 使用梯度累积(Gradient Accumulation)模拟大batch size
  • 采用LoRA等参数高效微调方法减少计算开销
  • 优先训练关键领域的数据子集
  • 利用小规模实验(如1/10参数量的模型)预估训练效果

继续预训练经验总结

  • 数据集需要足够大,至少几B的token,指令数据占比高效果更好。

  • 训练开始阶段可能出现loss上升,然后慢慢收敛。

  • 要设置学习率warmup,但warmup步数对于充分训练的模型性能影响不大。

  • 继续预训练要在已经过预训练的模型的基础上做,充分利用其先验知识。

  • 继续预训练阶段学习率低于预训练,对于充分训练的模型,学习率越大在下游任务性能越好,上游任务性能越差(遗忘严重)。

  • 真实训练中可能没有足够的数据和计算资源,这种情况下建议选择较小的学习率和较长的warmup步数。

  • 预训练中遇到训练中断,需要继续训练时,应该把学习率和衰减率都恢复到中断前的状态