行业报告 AI展会 数据标注 标注供求
数据标注数据集
主页 > 机器学习 正文

模型剪枝,不可忽视的推断效率提升方法

剪枝是常用的模型压缩方法之一,本文对剪枝的原理、效果进行了简单介绍。
 
目前,模型需要大量算力、内存和电量。当我们需要执行实时推断、在设备端运行模型、在计算资源有限的情况下运行浏览器时,这就是瓶颈。能耗是人们对于当前深度学习模型的主要担忧。而解决这一问题的方法之一是提高推断效率。
 
大模型 => 更多内存引用 => 更多能耗
剪枝正是提高推断效率的方法之一,它可以高效生成规模更小、内存利用率更高、能耗更低、推断速度更快、推断准确率损失最小的模型,此类技术还包括权重共享和量化。深度学习从神经科学中汲取过灵感,而剪枝同样受到生物学的启发。
 
随着深度学习的发展,当前最优的模型准确率越来越高,但这一进步伴随的是成本的增加。本文将对此进行讨论。
 
挑战 1:模型规模越来越大
我们很难通过无线更新(over-the-air update)分布大模型。

 

 
来自 Bill Dally 在 NIPS 2016 workshop on Efficient Methods for Deep Neural Networks 的演讲。
 
挑战 2:速度

 

使用 4 块 M40 GPU 训练 ResNet 的时间,所有模型遵循 fb.resnet.torch 训练。
 
训练时间之长限制了研究者的生产效率。
 
挑战 3:能耗
AlphaGo 使用了 1920 块 CPU 和 280 块 GPU,每场棋局光电费就需要 3000 美元。
 

 

这对于移动设备意味着:电池耗尽
 
对于数据中心意味着:总体拥有成本(TCO)上升
 
解决方案:高效推断
剪枝
权重共享
低秩逼近
二值化网络(Binary Net)/三值化网络(Ternary Net)
Winograd 变换
 
剪枝所受到的生物学启发
人工中的剪枝受启发于人脑中的突触修剪(Synaptic Pruning)。突触修剪即轴突和树突完全衰退和死亡,是许多哺乳动物幼年期和青春期间发生的突触消失过程。突触修剪从公出生时就开始了,一直持续到 20 多岁。

 

 
Christopher A Walsh. Peter Huttenlocher (1931–2013). Nature, 502(7470):172–172, 2013.
 
修剪深度神经网络

 

 
[Lecun et al. NIPS 89] [Han et al. NIPS 15]
 
神经网络通常如上图左所示:下层中的每个神经元与上一层有连接,但这意味着我们必须进行大量浮点相乘操作。完美情况下,我们只需将每个神经元与几个其他神经元连接起来,不用进行其他浮点相乘操作,这叫做「稀疏」网络。
 
稀疏网络更容易压缩,我们可以在推断期间跳过 zero,从而改善延迟情况。
 
如果你可以根据网络中神经元但贡献对其进行排序,那么你可以将排序较低的神经元移除,得到规模更小且速度更快的网络。
 
速度更快/规模更小的网络对于在移动设备上运行它们非常重要。
 
如果你根据神经元权重的 L1/L2 范数进行排序,那么剪枝后模型准确率会下降(如果排序做得好的话,可能下降得稍微少一点),网络通常需要经过训练-剪枝-训练-剪枝的迭代才能恢复。如果我们一次性修剪得太多,则网络可能严重受损,无法恢复。因此,在实践中,剪枝是一个迭代的过程,这通常叫做「迭代式剪枝」(Iterative Pruning):修剪-训练-重复(Prune / Train / Repeat)。
 
想更多地了解迭代式剪枝,可参考 TensorFlow 团队的代码:
https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_keras.ipynb
 
权重修剪
将权重矩阵中的多个权重设置为 0,这对应上图中的删除连接。为了使稀疏度达到 k%,我们根据权重大小对权重矩阵 W 中的权重进行排序,然后将排序最末的 k% 设置为 0。
f = h5py.File("model_weights.h5",'r+')
for k in [.25, .50, .60, .70, .80, .90, .95, .97, .99]: 
  ranks = {} 
  for l in list(f[『model_weights』])[:-1]: 
    data = f[『model_weights』][l][l][『kernel:0』] 
    w = np.array(data) 
    ranks[l]=(rankdata(np.abs(w),method= 'dense')—1).astype(int).reshape(w.shape) 
    lower_bound_rank = np.ceil(np.max(ranks[l])*k).astype(int) 
    ranks[l][ranks[l]<=lower_bound_rank] = 0 
    ranks[l][ranks[l]>lower_bound_rank] = 1 
    w = w*ranks[l] 
    data[…] = w
 
单元/神经元修剪
将权重矩阵中的多个整列设置为 0,从而删除对应的输出神经元。
为使稀疏度达到 k%,我们根据 L2 范数对权重矩阵中的列进行排序,并删除排序最末的 k%。
 
f = h5py.File("model_weights.h5",'r+')
for k in [.25, .50, .60, .70, .80, .90, .95, .97, .99]: 
  ranks = {} 
  for l in list(f['model_weights'])[:-1]: 
    data = f['model_weights'][l][l]['kernel:0'] 
    w = np.array(data) 
    norm = LA.norm(w,axis=0) 
    norm = np.tile(norm,(w.shape[0],1)) 
    ranks[l] = (rankdata(norm,method='dense')—1).astype(int).reshape(norm.shape) 
    lower_bound_rank = np.ceil(np.max(ranks[l])*k).astype(int) 
    ranks[l][ranks[l]<=lower_bound_rank] = 0 
    ranks[l][ranks[l]>lower_bound_rank] = 1 
    w = w*ranks[l]
    data[…] = w
 
随着稀疏度的增加、网络删减越来越多,任务性能会逐渐下降。那么你觉得稀疏度 vs. 性能的下降曲线是怎样的呢?
 
我们来看一个例子,使用简单的图像分类神经网络架构在 MNIST 数据集上执行任务,并对该网络进行剪枝操作。
 
下图展示了神经网络的架构:

 

参考代码中使用的模型架构。

 

稀疏度 vs. 准确率。读者可使用代码复现上图(https://drive.google.com/open?id=1GBLFxyFQtTTve_EE5y1Ulo0RwnKk_h6J)。

 

 
总结
很多研究者认为剪枝方法被忽视了,它需要得到更多关注和实践。本文展示了如何在小型数据集上使用非常简单的神经网络架构获取不错的结果。我认为深度学习在实践中用来解决的许多问题与之类似,因此这些问题也可以从剪枝方法中获益。
 
参考资料
本文相关代码:https://drive.google.com/open?id=1GBLFxyFQtTTve_EE5y1Ulo0RwnKk_h6J
To prune, or not to prune: exploring the efficacy of pruning for model compression, Michael H. Zhu, Suyog Gupta, 2017(https://arxiv.org/pdf/1710.01878.pdf)
Learning to Prune Filters in Convolutional Neural Networks, Qiangui Huang et. al, 2018(https://arxiv.org/pdf/1801.07365.pdf)
Pruning deep neural networks to make them fast and small(https://jacobgil.github.io/deeplearning/pruning-deep-learning)
使用 Tensorflow 模型优化工具包优化机器学习模型(https://www.tensorflow.org/model_optimization)
 
声明:本文版权归原作者所有,文章收集于网络,为传播信息而发,如有侵权,请联系小编及时处理,谢谢!
 
 

微信公众号

声明:本站部分作品是由网友自主投稿和发布、编辑整理上传,对此类作品本站仅提供交流平台,转载的目的在于传递更多信息及用于网络分享,并不代表本站赞同其观点和对其真实性负责,不为其版权负责。如果您发现网站上有侵犯您的知识产权的作品,请与我们取得联系,我们会及时修改或删除。

网友评论:

发表评论
请自觉遵守互联网相关的政策法规,严禁发布色情、暴力、反动的言论。
评价:
表情:
用户名: 验证码:点击我更换图片
SEM推广服务

Copyright©2005-2028 Sykv.com 可思数据 版权所有    京ICP备14056871号

关于我们   免责声明   广告合作   版权声明   联系我们   原创投稿   网站地图  

可思数据 数据标注

扫码入群
扫码关注

微信公众号

返回顶部