Like What You Like: Knowledge Distill via Neuron Selectivity Transfer

Author Avatar
Xin Wei 7月 10, 2017
  • 在其它设备中阅读本文章

Like What You Like: Knowledge Distill via Neuron Selectivity Transfer

Zehao Huang, Naiyan Wang by TuSimple

arXiv:1707.01219v1

What?

模型蒸馏技术是一种knowledge transfer方法, 它的思想是用一个已经训练好的更大、更深的大网络指导小网络的训练,将大网络中的”知识”提取到小网络中。

传统的模型蒸馏方法使用大网络和小网络的softmax输出之间的差异作为蒸馏损失,即期望在训练的过程中小网络尽可能“学习”到大网络的输出类别分布,虽然这种蒸馏策略简单直观并且确实能有效提高小网络的效果,然而其缺点也显而易见:只适用于具有softmax的分类任务。

为了使得模型具有更好的适用范围,Zehao Huang, Naiyan Wang等提出了一种称为Neuron Selectivity Transfer(NST)的模型蒸馏方法,

网络结构图

大体来说,这个方法将teacher和student网络在FC之前的特征图按层计算MMD Matching Loss

这种方法具有以下优点:

  1. 能够神经网络加速和压缩
  2. 在多种数据集上证明了NST能够显著提高知识蒸馏的效果
  3. 能够很好的和其他knowledge transfer方法结合

Why?

以VGG16为例,作者将conv5_3层的特征值映射回原图,

http://ww1.sinaimg.cn/large/6425ef91ly1fhcg2jhzmnj20qr09zwqm.jpg

可以看到,conv5_3输出在左图猴子的脸部、右图的路牌处有较大的激活值,这也暗示着神经元具有对某些特定区域敏感而对其他区域不过多关注。受到这个现象的启发,作者认为student模型如果想要获得接近teacher模型的效果,尽量学习到teacher模型的感兴趣部位也是至关重要的。

之前的文章中(Adriana et al. Hints for thin deep nets; Zagoruyko Sergey

)

How?

Maximum Mean Discrepancy(MMD) 最大平均差异

为了使teacher网络的输出分布和student网络的输出分布尽可能相一致,作者使用了MMD的平方作为distill loss,最大平均差异最早用于检测两个分布是否相同,它的基本思想是如果两个分布p和q生成的足够多的样本对应f的映射均值相等,则可以认为两个分布是相等的。

它的具体形式如式(1)所示

http://ww1.sinaimg.cn/large/6425ef91ly1fhciqwyzrnj20nq03074h.jpg

其中,${x^i}{i=1}^N,{y^j}{i=1}^N$分别由两个分布p和q生成, $\phi(\cdot)$为特征图输出的映射函数。将上式展开,并使用类似SVM中的”核技巧“,获得公式(2)

上式中$k(\cdot, \cdot)$是核函数,它的作用是将低维特征映射到高维空间,SVM中核函数的使用是为了在高维空间中将线性不可分样本转换成线性可分的样本分布,而本文的作用是同时映射student网络和teacher网络的中间结果到相同的空间分布从而使teacher网络的知识能够更准确的提取到student网络中。因为当且仅当p和q分布完全相同时MMD 损失才为0,所以最小化MMD损失等同于使p分布和q分布尽可能接近。

Neuron Selectivity Transfer(NST)

具体到模型的训练上,作者采用了一种”神经元选择性转移”的方法来训练student网络,以$f^k$表示中间特征,定义下面的损失函数

NST中使用的损失函数分为两部分,第一部分为真实标签与student预测结果的交叉熵损失,第二部分为MMD损失,将第二项展开

LMM

为了确保每个样本取值范围一样,上式对所有特征都进行了L2归一化,最小化MMD损失等效于将teacher网络中的知识迁移到student。

核函数的选择

论文中,作者验证了三种常用的核函数,他们分别是

  1. 线性核:$k(x,y)=x^Ty$
  2. 多项式核:$k(x,y)=(x^Ty+c)^d$, 文中采用$d=2$, $c=0$
  3. 高斯核:$k(x, y)=exp(-\frac{|x-y|^2}{2\sigma^2})$:高斯核中$\sigma$设为两网络特征的平均距离

实验

作者分别在CIFAR-10,CIFAR-100和ImageNet LSVRC 2012数据集上验证了NST的效果,对于CIFAR数据集采用ResNet-10011网络作为teacher网络,ImageNet数据集采用ResNet-1012网络作为teacher网络,Inception BN3网络作为student网络.

具体结果如下图所示:

上表中,对于CIFAR-100数据集,KD方法取得了比NST更好的结果,作者认为CIFAR-100数据集具有更多的类别而更多的类别能够在softmax输出提供更多的的类内方差信息造成的。即NST方法更适用于类别数较少的分类任务。

上表中,作者实验了将多种方法结合起来的效果,可以看到,KD+NST的组合具有更好的效果。

在ImageNet数据集下的实验结果图表3所示,当使用单个distill方法时,FitNet4取得了最好的效果,但是当将多种方法组合在一起时,KD+NST具有最好的效果。

下图中作者使用t-SNE5可视化了在使用NST前后student网络中间输出的分布,可以看到在经过NST之后teacher网络和student网络的特征分布明显更加一致了。

Conclusion

总得来说,作者提出了一种将student和teacher网络输出映射到相同特征空间作为distill loss的方法,虽然在类别数目多时这个方法弱于KD,但是KD只能用于带softmax的分类任务;虽然FitNet不需要softmax输出,且大数据集下更牛,但是FitNet强迫student网络完全和teacher输出一致,无形中和KD等价了。

最近会尝试使用NST做几个实验,到时候再补充自己的实验结果。


  1. 1.He K, Zhang X, Ren S, et al. Identity mappings in deep residual networks[C]//European Conference on Computer Vision. Springer International Publishing, 2016: 630-645.
  2. 2.He K, Zhang X, Ren S, et al. Deep residual learning for image recognition[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778.
  3. 3.Ioffe S, Szegedy C. Batch normalization: Accelerating deep network training by reducing internal covariate shift[C]//International Conference on Machine Learning. 2015: 448-456.
  4. 4.Romero A, Ballas N, Kahou S E, et al. Fitnets: Hints for thin deep nets[J]. arXiv preprint arXiv:1412.6550, 2014.
  5. 5.Maaten L, Hinton G. Visualizing data using t-SNE[J]. Journal of Machine Learning Research, 2008, 9(Nov): 2579-2605.