编辑 | ScienceAI
OpenAI o1、DeepSeek R1 等模型成功实现了在数学、编程等领域的智能慢思考推理,通过自我反思和修正实现了运行时的性能外推。
然而,在医疗领域,仍然很少有模型可以实现具有长链慢思考的推理。目前医疗领域的推理模型大多是通过在医疗考试题上对 OpenAI 系列的模型进行蒸馏,并没有考虑推理过程的可验证性,以及医疗任务的覆盖度。
为了解决这些问题,上海交通大学人工智能学院、复旦大学和上海人工智能实验室的团队开发了一种新型医学推理系统——MedS3。
该系统采用自我进化的「慢思考」范式,无需预训练和模型蒸馏,能够对推理流程的每一步进行细粒度验证。
论文链接:https://arxiv.org/pdf/2501.12051
项目主页:https://pixas.github.io/MedS3-pages/
论文标题:MedS3: Towards Medical Small Language Models with Self-Evolved Slow Thinking
MedS3 由策略模型(Policy Model)和过程奖励模型(Process Reward Model; PRM)组成,通过在 16 种不同数据集上的学习,包括医疗诊断、生物医学和知识性问答等。
仅使用 7465 条种子数据,结合细粒度的蒙特卡洛树搜索和规则验证的过程监督信号,MedS3 迭代优化策略模型和过程奖励模型。
评估结果显示,MedS3 在医疗知识问答、生物医学问答、长上下文问答和医疗诊断任务上的推理能力显著超越现有医疗大模型和通用域推理模型,成为首个在医疗诊断任务上实现长链推理「R1」的大语言模型框架。
研究动机
以往的医疗模型训练面临医疗语料匮乏的问题,通常有两种解决方案:
(1)在大规模人工收集筛选的医疗语料上进行预训练;
(2)在少量特定任务数据集上进行有监督微调。然而,第一种方法消耗大量计算资源,但下游任务性能提升有限;第二种方法虽计算高效,但微调数据多为闭源模型生成的蒸馏数据或人工标注的短回复数据,限制了模型的优化空间和跨任务泛化能力。
系统框架
为了解决医疗模型的数据困境,MedS3 转向运行时缩放(test-time scaling),以一种数据高效的后训练方法进行提升,从而突破数据集标注的约束,在平衡计算资源与性能之间的矛盾下,高效利用现有的医疗数据。
MedS3 的核心在于其独特的自我进化框架。研究者首先利用蒙特卡洛树搜索(MCTS)技术,基于基础策略模型生成可验证的推理链。在推理链的每一步,都会基于这一步的正确性赋予一个展开值,通过这些经过验证的轨迹来训练策略模型和过程奖励模型(PRM)。
这种搜索对计算资源的依赖极小,通过策略模型演化得到的正负样本均可以作为 MedS3 的监督信号,大大增加了数据利用率,并且按步采样也能提升模型的探索空间。
在推理过程中,策略模型会生成多个回答,而奖励模型则通过新提出的 PRM 引导的投票求和(P-VS)策略来选择最终答案。
这种策略不仅考虑了 PRM 对每个回复的评判结果,也考虑了不同回复之间的语义一致性。这种自我进化的方式,不仅提高了模型的数据效率,还使其在多种临床任务中展现出了卓越的推理能力。
图 1 MedS3 框架的构建过程。
图 2 PRM 引导的投票求和计算示例。
MedS3 的优势在于:
- 数据利用率高:使用自我启发式搜索,扩大了数据的表征范围和利用率。
- 支持单步监督:搜索中的进化展开值可以为单步推理提供监督,从而规划正确推理轨迹。
- 高效支持多任务学习:对每个数据集采样约 500 条数据即可实现多任务同时学习。
图 3 种子数据集中各任务及设计的数据集。
实验结论
MedS3 的任务同样涵盖了来自不同任务的 11 个数据集,涵盖了知识问答、生物医学问答、长上下文问答、医疗语义推理以及医疗诊断式问答。
实验 1:同时领先医疗开源模型、通用推理模型
医疗模型的性能提升普遍较小,最大提升仅为 6.48(MMedS-Ins vs Llama3 8B),且多数模型因任务覆盖不全,难以超越通用的 Llama3 8B。
通用域强化推理能力的开源模型因缺乏医疗知识,效果受限,最佳表现 R1-Distill-Qwen32B 仅提升 4.36。
相比之下,MedS3 通过融合推理能力和医疗知识,并采用创新的 PRM 引导投票求和方法,相比 Llama3 8B 提升了 13.07,显著超越所有同等规模的开源模型,并在综合性能上领先更大规模的模型。
实验 2:P-VS 选择策略领先以往的 PRM 使用方法
传统 Best-of-N(BoN)方法依赖 PRM 筛选最优回复,但 PRM 存在训练不稳定性和训练/真实标签偏差问题。P-VS 创新性融合语义一致性校验与 PRM 评分,突破单一依赖瓶颈,实现 3.46 的性能跃升。
实验 3:几乎无界的性能外推
推理模型通过增加词元消耗,几乎无损提升性能,MedS3 也具备此特性。
通过在 5 个诊断数据集上分别设置采样条数为 2、4、8、12、16,采用 P-VS 策略评估其性能,结果显示,除了 Pubhealth 数据集性能收敛外,其余数据集呈现出几乎无界的性能提升,验证了 MedS3 强大的性能外推潜力。
图 4 MedS3 词元消耗量和性能的关系。
实验 4:MCTS+PRM 仍是实现医疗推理的最有效方案
DeepSeek R1 证实强化学习(RL)可提升模型推理能力,但研究者们的实验表示,医疗领域中,基于蒸馏的小模型泛化性不足,传统 RL 方法性能仍弱于 MCTS+PRM 范式(如 MedS3 所示)。
医疗场景的特殊性使 MCTS+PRM 优势显著:医学指南的强结构化特征天然适配过程监督需求——策略模型可精准划分诊疗步骤,PRM 能有效完成分步评估,规避其常规场景下的步骤划分难题。
值得注意的是,MCTS+PRM 与 RL 具备互补性,联合应用可进一步提升模型泛化能力。
结论和展望
这篇工作发布了涵盖多个医疗任务的推理系统 MedS3,通过蒙特卡洛树搜索训练了一个策略模型和一个过程监督模型,是医疗推理模型进展上的一个重要工作。
随着 DeepSeek R1 的发布,强化学习以其高泛化性和高数据利用率成为通用域广泛使用的方案。如何将强化学习的思想融入到医疗推理,使其能有效和过程监督结合,仍然是值得思考的一个问题。