继续预训练
Continued pretraining 是指在一个已经预训练过的base模型基础上继续进行预训练的过程,目的是训练出适配下游领域的模型,训练方式与预训练一致。
什么时候需要继续预训练
预训练的语料与下游任务语料的【数据分布/领域差异】大时。假如预训练的数据缺少原神相关的数据(或者没有在原神数据上进行充分训练),我想微调一个原神问答模型,最好先利用大量的原神资料(攻略、更新日志等无监督文本)做继续预训练,再利用原神问答对做微调。
典型场景:
- 垂直领域适配:医疗、法律、金融等专业领域,通用模型缺乏领域知识。
- 多语言适配:英文为主的基座模型需要在中文、日文等语言上增强能力。
- 领域数据增强:已有一定的领域数据,需要进一步注入更多领域知识。
- 知识时效性更新:将新知识通过继续预训练注入模型,提升对新信息的覆盖。
领域数据比例
领域数据占比过高,可能导致训练损失直接崩掉,占比太低会导致在领域知识方面提升不大。
假设领域数据的比例是r,关于领域数据比例的scaling law的公式为:
其中 r' = r + ε。
当r表示域语料混合比例r_d时,L表示领域语料验证损失L_d。同样,当r表示域语料混合比例r_g时,L表示领域语料验证损失L_g。
随着领域数据的比例增大,会导致领域损失降低而通用损失上升,最终趋于稳定。基于scaling law,就可以估计不同数据配比下损失的预估值。
求解最优混合比例
设模型通用能力下降的阈值不超过T,则优化目标为:
总体而言,领域数据量应控制在15%以内,避免过度干扰模型的泛化能力。
继续预训练 vs 微调
| 维度 | 继续预训练 | 微调 |
|---|---|---|
| 训练数据 | 特定领域的无标注数据集 | 特定任务的标注数据集 |
| 训练规模 | 数据量大,训练时间长 | 数据量少,训练时间短 |
| 训练参数 | 更新模型的全部参数 | 可以只更新部分参数(如PEFT) |
| 训练结果 | 将通用模型适配到垂直领域 | 做特定任务的优化 |
继续预训练的具体实现步骤
1. 基座模型选择
选择合适的预训练基座模型,优先考虑:
- 模型架构与目标领域数据的匹配度(如领域涉及大量代码,选择代码能力较强的基座)
- 模型参数量与可用计算资源的匹配
- 基座模型在目标领域上的当前能力(通过PPL或Benchmark评估)
2. 领域数据准备
数据收集
- 公开领域语料:如医学领域的PubMed、法律领域的判决文书、学术领域的arXiv论文等。
- 行业内部数据:如企业技术文档、内部知识库、行业报告等。
- 合成数据:利用大模型生成领域相关文本,补充数据不足的部分。
数据清洗
- 使用与基座模型预训练一致的清洗工具链(如DataTrove、Trafilatura)。
- 对领域数据进行质量过滤:去除乱码、低质量文本、无关内容。
- 对领域数据进行去重处理:避免重复文本导致的过拟合。
- 对领域敏感信息进行处理:如医疗数据中的患者隐私信息脱敏。
数据格式化
将领域数据统一转换为预训练时使用的格式,通常为纯文本或JSONL格式:
{"text": "领域文本内容..."}
{"text": "领域文本内容..."}3. 训练配置
- 学习率:继续预训练的学习率通常为预训练的1/10到1/100,避免过度遗忘。
- Batch Size:与预训练保持一致或适当减小。
- 训练步数:取决于领域数据量,通常几万到几十万步。
- 学习率调度:同样采用WSD策略,但warmup步数可以适当减少。
4. 训练监控与评估
- 监控领域数据和通用数据的loss变化趋势。
- 定期在领域评测集上评估模型能力。
- 使用概率探针监测关键领域知识的记忆情况。
- 检查通用能力是否出现显著下降。
5. 模型选择
从训练过程中保存的多个checkpoint中选择最优模型,平衡领域能力和通用能力:
其中
领域数据处理方法
领域数据分类
在继续预训练中,领域数据通常需要进一步细分和配比:
| 数据类型 | 说明 | 典型占比 |
|---|---|---|
| 核心领域文本 | 与目标领域直接相关的文本 | 60%-80% |
| 相关领域文本 | 与目标领域间接相关的文本 | 10%-20% |
| 通用语料 | 维持通用能力的基础语料 | 10%-20% |
数据增强策略
当领域数据量不足时,可以采用以下策略:
- 回译增强:将领域文本翻译为其他语言再翻译回来,生成语义等价但表述不同的新文本。
- 同义替换:使用领域专用词典进行同义词替换。
- 指令改写:利用大模型将领域文本改写为不同风格或格式。
常见问题与解决方案
问题一:灾难性遗忘
表现:继续预训练后,模型在通用任务上的能力显著下降。
解决方案:
- 降低学习率(建议为预训练学习率的1/50到1/100)
- 在训练数据中混入一定比例的通用语料(10%-20%)
- 使用EWC(Elastic Weight Consolidation)等正则化方法保护关键参数
- 采用较短的训练步数,避免过度训练
问题二:训练损失震荡或上升
表现:训练过程中loss不稳定,出现大幅波动或持续上升。
解决方案:
- 检查领域数据质量,可能存在大量噪声数据
- 降低学习率,增加warmup步数
- 检查数据格式是否与基座模型预训练一致
- 减小batch size,降低训练不稳定性
问题三:领域能力提升不明显
表现:继续预训练后,领域任务的表现没有明显改善。
解决方案:
- 增加领域数据量或提高领域数据占比
- 检查领域数据的质量和相关性
- 延长训练步数,给予模型更多学习时间
- 检查基座模型是否在该领域已有较好的基础
问题四:评估集数据泄漏
表现:领域评测集分数虚高,但实际应用效果不佳。
解决方案:
- 确保评估集与训练数据完全隔离
- 使用动态更新的评估集或私有测试集
- 在训练数据中搜索是否存在与评估集高度相似的文本
问题五:计算资源不足
表现:无法使用足够大的batch size或训练足够长的时间。
解决方案:
- 使用梯度累积(Gradient Accumulation)模拟大batch size
- 采用LoRA等参数高效微调方法减少计算开销
- 优先训练关键领域的数据子集
- 利用小规模实验(如1/10参数量的模型)预估训练效果
继续预训练经验总结
数据集需要足够大,至少几B的token,指令数据占比高效果更好。
训练开始阶段可能出现loss上升,然后慢慢收敛。
要设置学习率warmup,但warmup步数对于充分训练的模型性能影响不大。
继续预训练要在已经过预训练的模型的基础上做,充分利用其先验知识。
继续预训练阶段学习率低于预训练,对于充分训练的模型,学习率越大在下游任务性能越好,上游任务性能越差(遗忘严重)。
真实训练中可能没有足够的数据和计算资源,这种情况下建议选择较小的学习率和较长的warmup步数。
预训练中遇到训练中断,需要继续训练时,应该把学习率和衰减率都恢复到中断前的状态。