Yang's blog Yang's blog
首页
Java
密码学
机器学习
命令手册
关于
友链
  • 分类
  • 标签
  • 归档
GitHub (opens new window)

xiaoyang

编程爱好者
首页
Java
密码学
机器学习
命令手册
关于
友链
  • 分类
  • 标签
  • 归档
GitHub (opens new window)
  • 传统机器学习

    • 机器学习前言
    • 数据预处理
    • 简单线性回归
    • 多元线性回归
    • 逻辑回归(一)
    • 逻辑回归(二)
    • K近邻法(k-NN)
    • k最近邻分类任务代码演示
    • 支持向量机(SVM)
    • 使用SVM进行二分类
    • 决策树
    • 随机森林
    • 什么是K-means聚类算法
    • 使用K-Means算法进行数据聚类:以鸢尾花数据集为例
  • 联邦学习

    • 联邦学习中的基础算法介绍
    • Advances and Open Problems in Federated Learning
    • Vertical Federated Learning Concepts,Advances, and Challenges
    • 机器学习中的并行计算
    • Boosted Trees 简介
    • SecureBoost:一种无损的联邦学习框架
    • FedGen & Data-Free Knowledge Distillation for Heterogeneous Federated Learning
    • Towards Personalized Federated Learning
    • Distilling the Knowledge in a Neural Network
      • 摘要
      • 1 引言
      • 2 蒸馏(Distillation)
        • 2.1举例说明
        • 2.11背景
        • 2.12转移集
        • 2.13目标函数的构建
        • 2.14加权组合
        • 2.2 匹配logits是蒸馏的一个特例
      • 3 MNIST 数据集上的初步实验
      • 4 语音识别实验
        • 4.1 结果
      • 5. 在大数据集上训练专家模型集成
        • 5.1 JFT数据集
        • 5.2 专家模型
        • 5.3 将类别分配给专家模型
        • 5.4 使用专家集成进行推理
        • 5.5 结果
        • 6 软目标作为正则化器
        • 6.1 使用软目标防止专家模型过拟合
        • 7 与专家混合模型的关系
        • 8 讨论
        • 8 致谢
    • FedMD & Heterogenous Federated Learning via Model Distillation
    • FedFTG & Fine-tuning Global Model via Data-Free Knowledge Distillation for Non-IID Federated Learning
    • MOON & Model-Contrastive Federated Learning
    • Knowledge Distillation in Federated Learning:A Practical Guide
    • DKD-pFed & A novel framework for personalized federated learning via decoupling knowledge distillation and feature decorrelation
    • pFedSD & Personalized Edge Intelligence via Federated Self-Knowledge Distillation
    • FedFD&FAug:Communication-Efficient On-Device Machine Learning:Federated Distillation and Augmentation under Non-IID Private Data
  • 机器学习
  • 联邦学习
xiaoyang
2024-10-09
目录

Distilling the Knowledge in a Neural Network

# Distilling the Knowledge in a Neural Network

提示

该论文提出一种知识蒸馏的方法将一个复杂模型中的知识通过软目标(soft targets)提炼并转移到一个较小的模型中,从而在不牺牲性能的情况下大大简化模型的计算复杂度。

原论文地址:https://arxiv.org/abs/1503.02531

# 摘要

提高几乎任何机器学习算法性能的一个非常简单的方法是,在相同的数据上训练多个不同的模型,然后对它们的预测结果进行平均【3】。不幸的是,使用整个模型集(ensemble)进行预测会非常繁琐,并且可能计算成本过高,难以为大量用户部署,尤其是当各个模型是大型神经网络时。Caruana及其合作者【1】已经表明,可以将模型集中的知识压缩到一个更易于部署的单个模型中。我们在此基础上,使用不同的压缩技术进一步发展了这一方法。我们在MNIST数据集上取得了一些惊人的结果,并展示了通过将模型集中的知识提炼到一个单模型中,可以显著改善一个被广泛使用的商业系统的声学模型。此外,我们还引入了一种新的模型集类型,它由一个或多个完整模型和许多专家模型(specialist models)组成,这些专家模型专门学习区分完整模型混淆的细粒度类别。与专家网络(mixture of experts)不同的是,这些专家模型可以快速并行训练。

# 1 引言

许多昆虫拥有一种幼虫形态,该形态针对从环境中获取能量和营养进行了优化,而成虫形态则完全不同,主要优化以适应旅行和繁殖的需求。在大规模机器学习中,尽管训练阶段和部署阶段的要求完全不同,我们通常仍然使用非常相似的模型。对于像语音识别和物体识别这样的任务,训练阶段需要从非常庞大且高度冗余的数据集中提取结构,但它不需要实时运行,并且可以使用大量的计算资源。然而,部署到大量用户时,对延迟和计算资源的要求更加严格。通过昆虫的类比,我们应该愿意训练非常笨重的模型,因为这可以更容易地从数据中提取结构。笨重的模型可以是单独训练的模型集(ensemble),或者是使用强正则化(如dropout【9】)训练的非常大的单个模型。一旦训练好笨重的模型,我们可以使用一种称为“蒸馏”的不同训练方法,将笨重模型中的知识转移到更适合部署的小模型中。Rich Caruana及其合作者【1】已经在这方面开创了先例。在他们的重要论文中,他们有力地证明了可以将由大型模型集获得的知识转移到单个小模型中。

一个阻碍这一非常有前途的方法得到更多研究的概念性障碍是,我们倾向于将已训练模型中的知识与学习到的参数值联系在一起,这使得很难理解如何在改变模型形式的同时保持相同的知识。一个更抽象的知识观念是,它是从输入向量到输出向量的学习映射,与任何特定的实现形式无关(知识不仅仅是模型的参数,而是模型如何从输入映射到输出的这个过程。换句话说,知识是模型学会了如何将输入向量(比如图像的像素数据)转换成输出向量(比如分类结果)的能力,而这与模型的具体结构无关。)。对于那些学习区分大量类别的笨重模型,常规的训练目标是最大化正确答案的平均对数概率,但训练的副产品是模型为所有不正确的答案分配了概率,即使这些概率非常小,一些错误的概率比其他错误的概率仍然要大得多。错误答案的相对概率可以告诉我们很多关于笨重模型如何进行泛化的信息。例如,一张宝马的图片虽然只有很小的几率被误认为是垃圾车,但这种错误还是比误认为胡萝卜的几率大得多。

一般接受的观点是,用于训练的目标函数应该尽可能反映用户的真实目标。尽管如此,模型通常还是被训练来优化在训练数据上的表现,而真实目标是要在新数据上泛化得好。显然,更好的方式是训练模型以便在新数据上表现良好,但这需要正确泛化方式的信息,而这种信息通常是不可得的。然而,当我们将大模型的知识蒸馏到小模型时,我们可以训练小模型以与大模型相同的方式泛化。如果笨重的模型能很好地泛化,例如它是多个不同模型集成的平均值,那么一个以相同方式泛化的小模型在测试数据上的表现通常会比使用相同训练集的常规训练小模型好得多。

将笨重模型的泛化能力转移到小模型的显而易见的方法是,使用笨重模型产生的类别概率作为训练小模型的“软目标”。在这一转移阶段,我们可以使用相同的训练集或一个独立的“转移集”。当笨重模型是一个由简单模型组成的大型集成时,我们可以使用它们各自预测分布的算术或几何平均作为软目标。当软目标具有较高熵时,它们每个训练样例提供的信息比硬目标多得多,并且在训练样例之间的梯度方差要小得多,因此小模型通常可以在比笨重模型少得多的数据上训练,并且可以使用更高的学习率。

对于像MNIST这样的任务,笨重模型几乎总是以非常高的置信度给出正确答案,关于学习函数的大部分信息存在于软目标中的非常小的概率比率中。例如,一个“2”的版本可能被赋予一个10−6​ 的概率被误认为是“3”,另一个“2”的版本可能是反过来。这些信息定义了一个数据上的丰富相似性结构(即说明哪些“2”看起来像“3”,哪些像“7”),但由于这些概率非常接近零,它对转移阶段的交叉熵损失函数影响很小。Caruana及其合作者通过使用logits(即最终softmax的输入)而不是softmax产生的概率作为学习小模型的目标,解决了这个问题,他们最小化笨重模型和小模型产生的logits之间的平方差。我们更通用的解决方案,称为“蒸馏”,是提高最终softmax的温度,直到笨重模型产生适当的软目标。然后,我们在训练小模型以匹配这些软目标时使用相同的高温。稍后我们将展示,匹配笨重模型的logits实际上是蒸馏的一个特例。(蒸馏的解决方案是提高softmax的温度,让大模型产生“较软”的目标。这意味着在softmax中减少输出值的差异,使得不同类别的概率差距变小,从而让小模型能够更好地学习这些软目标。温度控制:通过提高温度,模型输出的概率变得更加平滑,提供了更多的相对信息,特别是关于那些低概率的类别。)

用于训练小模型的转移集可以完全由未标记的数据组成【1】,也可以使用原始训练集。我们发现使用原始训练集效果很好,特别是当我们在损失函数中加入一项额外的项,鼓励小模型既能预测真实目标又能匹配笨重模型提供的软目标。通常,小模型不能完全匹配软目标,而朝着正确答案偏离反而是有益的。

# 2 蒸馏(Distillation)

神经网络通常通过使用一个 “softmax” 输出层来生成类别概率,该输出层将为每个类别计算的 logitzi转换为概率qi,方法是将zi与其他 logits 进行比较:

(1)qi=exp⁡(zi/T)∑jexp⁡(zj/T)

其中,T是通常设置为 1 的温度参数。使用更高的T值会生成更“柔和”的类别概率分布。

在最简单的蒸馏形式中,知识通过将笨重模型在高温T下使用 softmax 生成的软目标分布,作为转移集上的每个样例的目标,训练蒸馏模型来进行转移。训练蒸馏模型时使用相同的高温T,但在训练完成后,蒸馏模型使用的温度为 1。

当转移集中所有或部分样例的正确标签已知时,该方法可以通过同时训练蒸馏模型以生成正确的标签来显著改进。实现这一点的一种方法是使用正确的标签来修改软目标,但我们发现更好的方法是简单地使用两种不同目标函数的加权平均。第一个目标函数是与软目标的交叉熵,使用相同的高温T来计算蒸馏模型的 softmax。第二个目标函数是与正确标签的交叉熵,这在蒸馏模型的 softmax 中使用相同的 logits,但温度为 1。我们发现最佳结果通常是在第二个目标函数上使用显著较低的权重获得的。由于软目标生成的梯度的大小按1/T2缩放,因此在同时使用硬目标和软目标时,必须乘以T2,以确保在实验过程中改变蒸馏温度时,硬目标和软目标的相对贡献保持基本不变。

# 2.1举例说明

# 2.11背景

假设我们有一个复杂的笨重模型,经过训练后可以对图片进行分类,比如区分手写数字(MNIST数据集)。我们希望通过蒸馏,把笨重模型中的知识传递给一个更小的模型。我们有以下步骤:

  1. 笨重模型(大模型):已经训练好了,并且它输出的是手写数字图片的分类概率分布(如给出某张图片是“3”的概率为90%,是“8”的概率为5%)。
  2. 小模型:需要被训练得尽量接近笨重模型的表现,同时对计算资源要求低,适合部署。

# 2.12转移集

我们选择了一组图片作为转移集,用它们来训练小模型。我们知道这些图片的真实标签,也能得到笨重模型对这些图片的输出概率(软目标)。

# 2.13目标函数的构建

在训练小模型时,我们使用两个目标函数:

  1. 软目标的交叉熵:

    • 我们将笨重模型对这些图片的预测结果作为“软目标”。例如,笨重模型给图片的预测可能是:“3”的概率为90%,而“8”的概率为5%。
    • 我们将小模型的输出通过softmax,并设置温度T=3来生成更加平滑的概率分布。这让模型输出的概率分布更加软化,不会过于偏向某个类别,从而保留更多样本间的相似性信息。
    • 计算小模型和笨重模型输出概率之间的交叉熵损失。
  2. 硬目标的交叉熵:

    • 知道转移集的真实标签,例如某张图片的真实标签是“3”。
    • 使用正常温度T=1的 softmax 来计算小模型的输出,并根据真实标签计算交叉熵损失。

# 2.14加权组合

现在,我们把这两个目标函数的损失组合在一起进行优化:

Ltotal=αLsoft+(1−α)Lhard

其中,Lsoft是软目标的交叉熵损失,Lhard是硬目标的交叉熵损失,α是控制两个损失权重的超参数。

  • 如果我们希望模型主要学习笨重模型的知识,那么我们可以设置较大的α,例如 0.9。
  • 但如果真实标签非常重要,则可以设置较小的α,如 0.1。

此外,由于软目标的梯度按1/T2缩放,所以我们在计算软目标梯度时,需要乘以T2来调整梯度的大小,使软目标和硬目标的贡献保持平衡。

# 2.2 匹配logits是蒸馏的一个特例

转移集中的每个样例都会对每个蒸馏模型的 logitzi贡献一个交叉熵梯度dCdzi。如果笨重模型有 logitsvi,这些 logits 生成了软目标概率pi,且转移训练在温度T下进行,则梯度为:

(2)∂C∂zi=1T(qi−pi)=1T(ezi/T∑jezj/T−evi/T∑jevj/T)

如果温度相对于 logits 的大小较高,我们可以近似为:

(3)∂C∂zi≈1T(1+zi/TN+∑jzj/T−1+vi/TN+∑jvj/T)

假设 logits 已经对每个转移样例分别做了零均值处理,即∑jzj=∑jvj=0,那么公式(3)可以简化为:

(4)∂C∂zi≈1NT2(zi−vi)

因此,在高温极限下,蒸馏等同于最小化12(zi−vi)2,前提是每个转移样例的 logits 已被分别零均值化。在较低温度下,蒸馏对与平均值相比非常负的 logits 关注较少。这可能是有利的,因为这些 logits 在训练笨重模型的成本函数中几乎没有约束,可能非常噪声化。然而,非常负的 logits 也可能传递笨重模型所获得的有用信息。哪种效果占主导地位是一个经验问题。我们展示了当蒸馏模型过小,无法捕捉笨重模型中的所有知识时,使用中等温度效果最好,这强烈表明忽略那些大的负 logits 可能是有帮助的。

# 3 MNIST 数据集上的初步实验

为了评估蒸馏的效果,我们训练了一个大型神经网络,该网络包含两个隐藏层,每层有1200个ReLU激活的隐藏单元,训练时使用了全部60,000个训练样例。该网络使用了强正则化策略,包括dropout和权重约束,正如文献[5]中所描述的。Dropout可以被视为训练一个共享权重的指数级大规模模型集的方法。此外,输入图像被随机抖动,最多可在任意方向上偏移两个像素。这个网络在测试集中取得了67个错误,而一个较小的网络(同样有两个隐藏层,每层800个ReLU激活的隐藏单元,但没有正则化)在相同测试集中产生了146个错误。但如果我们仅通过添加匹配由大网络在温度为20时生成的软目标任务来正则化小网络,小网络的测试错误数减少到了74个。这表明,软目标能够将大量知识传递给蒸馏模型,包括从偏移训练数据中学到的泛化知识,即使转移集不包含任何偏移的数据。

当蒸馏模型的每个隐藏层包含300个或更多单元时,所有高于8的温度均表现出相似的效果。但当网络规模大幅减少到每层仅包含30个单元时,温度在2.5到4之间的效果显著优于更高或更低的温度。

随后,我们尝试从转移集中去除所有数字“3”的样例。因此,从蒸馏模型的角度来看,数字“3”是一个从未见过的“神秘数字”。即便如此,蒸馏模型在测试集中仅出现了206个错误,其中有133个错误出现在1010个数字“3”的测试样例中。大多数错误是由于学习到的数字“3”类别的偏差值过低导致的。如果将该类别的偏差值增加3.5(这可以优化测试集上的整体性能),蒸馏模型的错误数降到了109个,其中只有14个错误出现在数字“3”的样例上。因此,在调整了正确的偏差后,尽管蒸馏模型从未在训练过程中见过数字“3”,但它仍能正确识别98.6%的测试集中数字“3”的样例。

如果转移集中仅包含训练集中的数字“7”和“8”样例,蒸馏模型在测试集中产生了47.3%的错误率,但当对数字“7”和“8”的偏差值减少7.6以优化测试性能时,错误率下降到了13.2%。

# 4 语音识别实验

在本节中,我们探讨了在自动语音识别(ASR)中使用深度神经网络(DNN)声学模型进行集成的效果。我们展示了本文提出的蒸馏策略能够成功地将多个模型的集成转化为一个单一的模型,并且该单一模型的性能显著优于从相同训练数据中直接学习的同样大小的模型。

目前,最先进的ASR系统使用DNN将从波形中提取的(短时)特征上下文映射到离散状态隐马尔可夫模型(HMM)的概率分布上【4】。更具体地说,DNN在每个时间点生成三音子状态簇的概率分布,而解码器则在HMM状态中找到最优路径,这个路径在使用高概率状态和生成一个合理的转录结果之间做出最佳折中。

虽然可以(也很理想)在训练DNN时考虑解码器(因此也考虑语言模型),通过对所有可能路径进行边缘化来优化模型,但通常的做法是逐帧进行分类训练,通过最小化网络预测与通过强制对齐获得的真实状态序列标签之间的交叉熵来实现这一目标:

θ=arg⁡maxP(ht|st;θ′)

其中,θ是我们声学模型的参数,该模型将时刻t的声学观测st映射到正确HMM状态ht的概率P(ht|st;θ′),这些状态是通过强制对齐与正确的单词序列确定的。模型通过分布式随机梯度下降方法进行训练。

我们使用一个包含8个隐藏层的架构,每层有2560个ReLU单元,最后一层是包含14,000个标签(HMM目标ht)的softmax层。输入是40个Mel滤波器组系数的26帧数据,帧间隔为10毫秒,我们预测的是第21帧的HMM状态。模型总参数数量大约为8500万。这是Android语音搜索中使用的声学模型的一个稍旧版本,应该被视为一个非常强大的基线模型。为了训练DNN声学模型,我们使用了大约2000小时的英语语音数据,生成了大约7亿个训练样例。该系统在开发集上取得了58.9%的帧分类准确率和10.9%的单词错误率(WER)。

系统 帧分类准确率 单词错误率(WER)
基线 58.9% 10.9%
10x集成模型 61.1% 10.7%
蒸馏后的单一模型 60.8% 10.7%

表1展示了帧分类准确率和WER,表明蒸馏后的单一模型的表现与用于生成软目标的10个模型的平均预测几乎相同。

# 4.1 结果

我们训练了 10 个独立的模型来预测P(ht∣st;θ),使用的架构和训练过程与基线模型完全相同。模型使用不同的初始参数值进行随机初始化,我们发现这能在训练的模型中产生足够的多样性,从而使集成模型的平均预测结果显著优于单个模型的预测结果。我们还探索了通过改变每个模型所看到的数据集来增加模型的多样性,但结果并没有显著变化,因此我们选择了更简单的方法。对于蒸馏训练,我们尝试了温度为 [1, 2, 5, 10] 的不同设置,并在硬标签的交叉熵损失上使用了相对权重为 0.5 的设置,其中加粗的数值为表 1 中使用的最佳值。

表 1 显示,我们的蒸馏方法确实能够比仅使用硬标签训练单个模型从训练集中提取更多有用信息。通过使用 10 个模型的集成模型,帧分类准确率的提升超过 80% 被成功转移到蒸馏模型上,这与我们在 MNIST 数据集上的初步实验结果相似。由于目标函数不匹配,集成模型在最终目标(WER,基于 23K 单词的测试集)上的提升较小,但同样,集成模型在 WER 上的改进也被转移到了蒸馏模型上。

我们最近注意到相关的工作,他们通过匹配已训练的大模型的类别概率来学习一个小型声学模型 [8]。然而,他们的蒸馏是在温度为 1 的情况下使用一个大规模未标注的数据集进行的,并且他们的最佳蒸馏模型仅减少了小模型误差率与大模型误差率之间差距的 28%,而两者都是通过硬标签训练的。

# 5. 在大数据集上训练专家模型集成

训练一个模型集成是一种简单的方式来利用并行计算,尽管在测试时通常有人会提出集成需要过多计算的异议,但可以通过使用蒸馏来解决这一问题。不过,对于集成模型还有另一个重要的异议:如果单个模型是大型神经网络,并且数据集非常庞大,那么训练所需的计算量会过于庞大,即使很容易并行化。

在本节中,我们将提供一个这样的数据集的示例,并展示如何通过学习专门的模型,每个模型都专注于一组容易混淆的类别,从而减少训练集成所需的总计算量。专门针对细粒度差异进行区分的专家模型的主要问题是它们容易过拟合,我们会描述如何通过使用软目标来防止这种过拟合。

# 5.1 JFT数据集

JFT是Google的内部数据集,包含1亿张带标签的图像,涵盖15,000个标签。在进行这项工作时,Google的JFT基准模型是一个深度卷积神经网络【7】,它在大量核心上使用异步随机梯度下降训练了大约六个月。这次训练使用了两种并行化方法【2】。首先,神经网络的多个副本在不同的核心集上运行,并处理来自训练集的不同小批次。每个副本计算其当前小批次的平均梯度,并将该梯度发送到分片的参数服务器,服务器返回更新后的参数值。这些新值反映了自服务器上次发送参数到副本以来收到的所有梯度。其次,每个副本分布在多个核心上,通过将不同的神经元子集分配给不同的核心来实现。集成模型的训练是第三种并行化方式,它可以与前两种方式结合使用,但前提是有更多的核心资源可用。由于等待几年来训练模型集成并不可行,因此我们需要一种更快的方法来提升基准模型。

# 5.2 专家模型

当类别数量非常多时,训练一个包含所有数据的通用模型和多个“专家”模型的集成是有意义的。每个专家模型都专注于某个容易混淆的类别子集(例如不同种类的蘑菇)。这种专家模型的softmax可以通过将它不关心的所有类别组合成一个“灰尘”类来大大简化。

为了减少过拟合并共享低级特征检测器的学习工作,每个专家模型的初始化权重都来自通用模型。这些权重稍加修改后,专家模型就会通过其特定子集中的一半样本和从其余训练集中随机抽取的样本进行训练。在训练完成后,我们可以通过将灰尘类别的logit值增加相应的比率来校正有偏的训练集。

# 5.3 将类别分配给专家模型

为了为专家模型推导出对象类别的分组,我们决定专注于我们的完整网络经常混淆的类别。尽管我们可以计算混淆矩阵并将其用作发现这些类别集群的一种方式,但我们选择了一种不需要真实标签来构建集群的更简单方法。

特别地,我们将聚类算法应用于通用模型预测的协方差矩阵,以便将经常一起预测的类别集Sm用作某个专家模型m的目标。我们使用了一种在线版本的K均值算法来对协方差矩阵的列进行聚类,并得到了合理的集群(如表2所示)。我们尝试了几种聚类算法,产生了相似的结果。

# 5.4 使用专家集成进行推理

在研究蒸馏专家模型之前,我们首先想看看包含专家的集成模型表现如何。除了专家模型之外,我们始终保留一个通用模型,以便处理没有专家模型的类别,并帮助我们决定使用哪些专家模型。给定输入图像x,我们通过以下两个步骤进行top-1分类:

  • 第一步:对于每个测试样本,我们根据通用模型找到n个最可能的类别。这个类别集合称为k。在我们的实验中,取n=1。
  • 第二步:然后,我们将所有专家模型m的特殊混淆类别子集Sm与集合k有非空交集的模型称为活动专家集合Ak(该集合可能为空)。然后我们找到一个覆盖所有类别的完整概率分布q,使其最小化以下公式:
(5)KL(pg,q)+∑m∈AkKL(pm,q)

其中,KL表示KL散度,pm和pg分别表示某个专家模型或完整通用模型的概率分布。专家模型的分布pm是针对所有类别加上灰尘类的一个分布,因此在计算其与完整分布q的KL散度时,我们将q分布分配给专家模型中灰尘类的所有概率相加。

公式5没有一般的闭式解,尽管当所有模型为每个类别产生一个单一的概率时,解要么是算术平均数,要么是几何平均数,这取决于我们是使用KL(p, q)还是KL(q, p)。我们将q参数化为softmax(z)(T=1),并使用梯度下降来优化对公式5的logitsz。请注意,这个优化必须针对每个图像进行。

表3:JFT开发集上的分类准确率(Top 1)

系统类型 条件测试准确率 测试准确率
基线模型 43.1% 25.0%
+ 61个专家模型 45.9% 26.1%
  • 基线模型是最初的模型,没有任何专家模型支持。
  • 当加入61个专家模型后,条件测试准确率从43.1%提高到45.9%,测试准确率从25.0%提高到26.1%。这表明通过引入专家模型,模型在特定任务上的性能有所改善。

表4: JFT测试集上按覆盖正确类别的专家模型数量的Top 1准确率提升

覆盖专家数量 测试样本数量 Top 1准确率变化 相对准确率变化
0 350037 0 0.0%
1 141993 +1421 +3.4%
2 67161 +1572 +7.4%
3 38801 +1124 +8.8%
4 26298 +835 +10.5%
5 16474 +561 +11.1%
6 10682 +362 +11.3%
7 7376 +232 +12.8%
8 4703 +182 +13.6%
9 4706 +208 +16.6%
10个或更多 9082 +324 +14.1%

整体来看,数据表明通过引入专家模型,模型在处理特定类别时的性能有了显著的提升,尤其是专家模型数量增加时的效果更加明显。这是因为每个专家模型能够专注于处理某些特定的、容易混淆的类别,从而提高了模型的整体准确率。

# 5.5 结果

从训练好的基准全模型开始,专家模型的训练速度非常快(仅需几天,而JFT的全模型训练需时数周)。此外,所有专家模型都是完全独立训练的。表3显示了基准系统和结合专家模型的基准系统的绝对测试准确率。使用61个专家模型,总体测试准确率相对提升了4.4%。我们还报告了条件测试准确率,即仅考虑属于专家类别的样本,并将预测限制在该类别子集中的准确率。

在我们的JFT专家实验中,我们训练了61个专家模型,每个模型有300个类别(加上灰尘类)。由于这些专家模型的类别集并不互斥,因此我们经常有多个专家模型覆盖同一图像类别。表4显示了测试集中示例数量、使用专家模型后在top-1位置上预测正确的示例变化量,以及按覆盖该类别的专家模型数量划分的相对百分比top-1准确率提升。我们对有更多专家模型覆盖某个类别时准确率提升更大的趋势感到鼓舞,因为独立训练专家模型非常容易并行化。

# 6 软目标作为正则化器

我们关于使用软目标而不是硬目标的主要论点之一是,软目标可以承载很多在单一硬目标中无法编码的有用信息。在本节中,我们通过使用远少于的数据来拟合先前描述的基线语音模型的8500万个参数,展示了这一点是一个非常大的影响。表5显示,仅使用3%的数据(约2000万个样本),使用硬目标训练基线模型导致严重的过拟合(我们进行了提前停止,因为准确率在达到44.5%后急剧下降),而使用软目标训练的同一模型能够恢复几乎全部完整训练集中的信息(仅少了约2%)。更值得注意的是,我们不需要进行提前停止:使用软目标的系统简单地“收敛”到57%。这表明,软目标是将训练在所有数据上的模型发现的规律有效地传递给另一个模型的一种方式。

系统与训练集 训练帧准确率 测试帧准确率
基线(100%的训练集) 63.4% 58.9%
基线(3%的训练集) 67.3% 44.5%
软目标(3%的训练集) 65.4% 57.0%

表5:软目标使新模型能够仅从3%的训练集中良好地泛化。软目标是通过在完整训练集上训练获得的。

# 6.1 使用软目标防止专家模型过拟合

我们在JFT数据集上的实验中使用的专家模型将所有非专家类别合并为一个单一的“垃圾桶”类别。如果我们允许专家对所有类别进行完整的softmax,那么可能有更好的方法来防止它们过拟合,而不是使用提前停止。专家是在其特定类别的数据上进行训练的,这意味着它的有效训练集大小要小得多,并且有很强的倾向在其特定类别上过拟合。通过将专家模型缩小得很多来解决这个问题是不可行的,因为那样我们会失去从建模所有非专家类别中获得的非常有帮助的迁移效果。

我们使用3%的语音数据进行的实验强烈表明,如果专家模型用通用模型的权重初始化,我们可以通过对非特定类别的软目标进行训练,使其几乎保留所有关于非特定类别的知识。此外,专家模型还可以用硬目标进行训练。软目标可以由通用模型提供。我们目前正在探索这种方法。

# 7 与专家混合模型的关系

使用在数据子集上训练的专家与专家混合模型有一些相似之处,后者使用门控网络来计算将每个示例分配给每个专家的概率。专家在学习处理分配给它们的示例的同时,门控网络正在学习根据专家对该示例的相对判别性能来选择将每个示例分配给哪些专家。使用专家的判别性能来确定学习到的分配比仅仅对输入向量进行聚类并将专家分配给每个聚类要好得多,但这使得训练很难并行化:首先,每个专家的加权训练集以依赖于所有其他专家的方式不断变化,其次,门控网络需要比较不同专家在同一示例上的性能,以知道如何修正其分配概率。这些困难使得专家混合模型很少在它们可能最有益的领域被使用:具有显著不同子集的大型数据集任务。

并行化训练多个专家要容易得多。我们首先训练一个通用模型,然后使用混淆矩阵定义专家模型训练的子集。一旦这些子集被定义,专家可以完全独立地训练。在测试时,我们可以使用通用模型的预测来决定哪些专家是相关的,只有这些专家需要运行。

# 8 讨论

我们已经展示了,蒸馏在将知识从一个集成模型或一个大型高度正则化的模型转移到一个更小的蒸馏模型中效果非常好。在MNIST上,即使用于训练蒸馏模型的转移集缺少一个或多个类别的示例,蒸馏的效果仍然非常显著。对于一种深层声学模型(与Android语音搜索使用的模型版本相同),我们已经证明,通过训练一个深度神经网络的集成所获得的几乎所有改进都可以蒸馏到一个相同大小的单一神经网络中,这样更容易进行部署。

对于真正大的神经网络,训练一个完整的集成可能不可行,但我们已经展示,通过学习大量专家网络的知识,可以显著改善一个经过长时间训练的单个大型网络的性能,每个专家网络学习区分在高度混淆的簇中的类别。我们尚未证明我们可以将专家中的知识蒸馏回单个大型网络中。

# 8 致谢

我们感谢Yangqing Jia在ImageNet上训练模型时的帮助,Ilya Sutskever和Yoram Singer的有益讨论。

编辑 (opens new window)
#知识蒸馏
上次更新: 2025/04/01, 01:48:12

← Towards Personalized Federated Learning FedMD & Heterogenous Federated Learning via Model Distillation→

最近更新
01
操作系统
03-18
02
Nginx
03-17
03
后端服务端主动推送消息的常见方式
03-11
更多文章>
Theme by Vdoing | Copyright © 2023-2025 xiaoyang | MIT License
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式