天大团队构建决策模型代码库,支持多种扩散模型和网络架构
近年来,扩散模型一直很火热,在图像生成或视频生成方面,例如,被人们所熟知的 Stable Diffusion、Sora 等应用。
近期,天津大学团队构建了一种面向决策领域的扩散模型代码库 CleanDiffuser。通过重新审视扩散模型在决策领域中的角色,识别出一系列关键的基本子模块,即扩散模型、网络架构、引导采样方法。
此外,CleanDiffuser 让所有部分都相互解耦,让新算法开发与算法组件消融变得更简单。CleanDiffuser 的代码和文档已经在 GitHub 网站开源。
“希望更多开发者能够加入,共同将代码库完善,以促进更多科研成果,甚至是落地的机器人应用。”该论文共同第一作者、天津大学博士生董子斌表示。
图丨董子斌(左)和袁逸夫(来源:董子斌、袁逸夫)
日前,相关论文以《CleanDiffuser: 易于使用的决策扩散模型模块化库》(CleanDiffuser: An Easy-to-use Modularized Library for Diffusion Models in Decision Making)为题发表在预印本网站 arXiv 上 [1]。
天津大学博士生董子斌和袁逸夫是第一作者,天津大学郝建业副教授担任通讯作者。
图丨相关论文(来源:arXiv)
为确保与不同角色的兼容性,在扩散模型(DM,Diffusion Model)库中,DM 骨干和网络架构实现解耦模块。
同时,将决策特定功能(如屏蔽机制、采样机制)纳入模块设计,并开发一个无缝集成模块和机制的算法管道,以满足不同的 DM 使用范例。
图丨CleanDiffuser 的架构(来源:arXiv)
用户通过使用 CleanDiffuser,算法可以通过选择构建块,并将它们集成到流程中来实现。
CleanDiffuser 耦合的模块化架构具有两方面优势:
一方面,方便新手入门理解和比较不同算法差异;另一方面,也为科研和从业人员快速针对特定应用场景或需求提供便利,通过对系统中的各个模块进行独立修改和优化,从而提高整体灵活性和工作效率。
CleanDiffuser 中的 DM 由两个核心部分实现,即随机微分方程(SDE,Stochastic Differential Equation)/常微分方程(ODE,Ordinary Differential Equation)和求解器。
相较于其他 DM 代码库,CleanDiffuser 为决策任务实现了多项特性机制,Diffusion-X 采样,Warm-strating 采样,掩码机制等等。
值得一提的是,CleanDiffuser 支持多种高级扩散模型和网络结构。在 CleanDiffuser 中,所有架构都继承自同一个父类,并共享一个标准的应用程序编程接口(API,Application Programming Interface)调用,这使得研究人员可以轻松地基于基础设计新的架构。
图丨CleanDiffuser 中实现的网络架构(来源:arXiv)
该课题组希望,可以简单地实现将解耦的模块构建为集成的管道。在该研究中,实现 Difusion 算法不需要了解几十个文件夹的很多代码,只需要选择神经网络、DM 和引导。
为了展示 CleanDiffuser 的可靠性和灵活性,研究人员在 37 个强化学习和模仿学习环境中,使用 CleanDiffuser 以单文件方式复现 12 种主流扩散决策算法及其变体进行了广泛的实验。
结果表明,复现算法的性能能够达到甚至超过算法的官方实现,且代码更加轻量简洁。
此外,在不同架构、求解器、采样步骤和模型大小上也进行了广泛的实证分析,发现且指出了目前扩散决策模型仍存在的大量机遇与挑战,为后续的研究工作提供思路。
决策领域目前是缺少统一且易用的代码库,而 CleanDiffuser 填补了当前决策领域的一个空白,通过模块化设计,简化了算法开发和定制,支持多种高级扩散模型和网络架构。
据了解,研究人员会继续关注社区的使用反馈,随时进行更新和补充。
董子斌表示,代码库目前仍在不断更新中,追加了一致性模型(Consistency Models),SfBC/QGPO/DiffuserLite 三种扩散决策算法,ViT/R3M 等视觉表征生成条件网络。
除组件内容以外,整个代码库正在进行 PyTorch Lightning 的重构,借助 PyTorch Lightning 的支持,CleanDiffuser 能够支持多卡并行、混合精度、预训练模型等等先进深度学习技术,感兴趣的读者可以在 GitHub 仓库的 lightning 分支中试用。
接下来,团队成员计划将深入研究解决相关反常现象的方法。
参考资料:
1.https://arxiv.org/pdf/2406.09509
运营/排版:何晨龙
01/
02/
03/
04/