摘要:近日,来自博世人工智能中心和蒂宾根大学的研究团队,包括Niclas Popp、Kevin Alexander Laube、Matthias Hein和Lukas Schott,在arXiv平台发表了一篇题为《通过置信引导型数据增强改善未知协变量偏移下的知识蒸馏
为什么我们需要关注这个研究?
想象你有一位经验丰富的烹饪大师(我们称之为"教师模型"),他掌握了无数烹饪秘诀,能够在各种条件下烹制出美味佳肴。现在,你希望将这些技巧传授给一位初学者(我们称之为"学生模型")。正常情况下,初学者通过观察大师的烹饪过程并模仿,逐渐掌握这些技巧。这个过程在人工智能领域被称为"知识蒸馏"。
然而,现实中常常会遇到这样的问题:初学者只能在有限的环境中观察大师(比如只看到大师在高档厨房使用优质食材的烹饪过程),但最终需要在各种不同的环境中施展技艺(如在普通家庭厨房使用普通食材)。当环境发生变化时,初学者往往会因为过度依赖某些特定条件(如高档厨具或特定食材)而无法适应新环境,这就是所谓的"协变量偏移"问题。
在机器学习领域,基础模型(如CLIP)经过大量数据训练后,展现出强大的零样本能力和分布鲁棒性。但这些大模型通常需要海量计算资源,难以在资源受限的环境中部署。知识蒸馏提供了一种将大模型知识转移到小模型的方法,但训练数据的局限性常常限制了蒸馏的效果,特别是当训练数据和测试数据存在协变量偏移时。
什么是协变量偏移?
协变量偏移是指训练数据和测试数据的输入特征分布发生变化,而输出与输入之间的条件分布保持不变。在实际应用中,这常常表现为训练数据中存在"欺骗性特征"(spurious features)——这些特征在训练数据中与目标类别高度相关,但在测试数据中这种相关性不再存在。
举个例子,假设我们在训练一个性别分类模型,训练数据中的女性都是金发、年轻且不戴眼镜的,而男性都是非金发、年长且戴眼镜的。模型很可能会学习到这些表面特征(发色、年龄、是否戴眼镜)与性别的关联,而不是真正学习到性别的本质特征。当测试数据中出现非金发女性或金发男性时,模型就会表现不佳。
研究团队的创新解决方案
研究团队提出了一种名为ConfiG(Confidence-Guided Data Augmentation,置信引导型数据增强)的方法,通过扩充训练数据来解决协变量偏移问题。这种方法的关键在于:利用教师模型和学生模型之间的预测差异,生成针对性的增强样本。
具体来说,ConfiG寻找那些教师模型预测正确但学生模型预测错误的区域,然后生成这些区域的新样本。这些样本保留了类别的本质特征(因为教师模型能正确识别),但改变了欺骗性特征(这些特征导致学生模型做出错误预测)。
这就像教师带着学生特意去练习那些学生容易出错的烹饪技巧一样,有针对性地弥补学生的不足。通过这种方式,即使不知道欺骗性特征具体是什么,也能有效地减少学生对这些特征的依赖。
方法实现细节
ConfiG方法基于扩散模型(Stable Diffusion)实现。首先,研究者使用一个预训练的教师模型和仅在真实训练数据上训练的辅助学生模型。辅助学生模型由于只见过有偏差的训练数据,会过度依赖欺骗性特征。
然后,对于每个训练样本,ConfiG执行以下步骤: 1. 将原始图像编码到扩散模型的潜空间 2. 通过最大化一个特殊的目标函数来优化潜空间表示: 最大化教师模型对正确类别的置信度 最小化学生模型对正确类别的置信度 3. 解码优化后的潜空间表示,得到新的增强图像
这个过程可以理解为在保持图像本质内容(如性别特征)的同时,修改那些导致学生模型出错的特征(如发色或眼镜)。最终,研究者使用原始训练图像和生成的增强图像一起训练最终的学生模型。
实验验证
研究团队在三个数据集上验证了ConfiG方法的有效性:CelebA(名人脸部照片)、SpuCo Birds(鸟类图像)和Spurious ImageNet(带有欺骗性特征的ImageNet子集)。
在CelebA数据集上,训练数据只包含年轻、金发、不戴眼镜的女性和年长、非金发、戴眼镜的男性。测试数据则包含各种组合。实验结果显示,使用ConfiG方法与CutMix和EDRM(经验蒸馏风险最小化)相结合,将最差组性能从原始的7.3%提升到66.1%,组平均准确率从68.0%提升到89.3%。
在SpuCo Birds数据集上,训练数据只包含水鸟在水背景上和陆鸟在陆地背景上的图像,测试数据则包含交叉组合。ConfiG方法将最差组性能从5.6%提升到62.7%,组平均准确率从53.9%提升到83.5%。
在Spurious ImageNet上,ConfiG也实现了最佳的spurious mAUC表现,证明其能有效减轻类别特定的欺骗性特征影响。
研究团队还进行了多项消融研究,包括不同数量的合成增强样本、不同学生模型架构等。结果表明,每个真实图像添加两个合成样本效果最佳,增加更多反而会降低性能,这与理论分析一致。
研究的理论支持
研究团队还提供了严格的理论分析,证明在合理假设下,ConfiG方法能够降低学生模型在测试数据上的泛化误差。这一理论分析直观地解释了为什么找到教师模型和学生模型之间的不一致区域,并在这些区域生成增强样本,能有效改善知识蒸馏过程。
这项研究的意义
归根结底,这项研究提供了一种实用的方法,使小型模型能够从大型基础模型中获取鲁棒性知识,即使训练数据存在明显的偏差。这对于资源受限环境下的AI应用具有重要意义,如移动设备或边缘计算设备上的AI系统。
ConfiG方法的一个重要优势是它不需要预先知道欺骗性特征是什么,也不需要任何组别标注。只要有一个鲁棒的教师模型,就能指导学生模型学习真正有效的特征,而不是依赖数据集中的偶然相关性。
这项研究为解决机器学习中的分布偏移问题提供了新思路,特别是在知识蒸馏这一重要技术中的应用。随着AI系统越来越广泛地部署在各种现实环境中,处理分布偏移的能力将变得至关重要,而ConfiG方法提供了一种有效的解决方案。
来源:至顶网一点号