基于梯度权值追踪的域自适应分类研究
doi: 10.13878/j.cnki.jnuist.20230927001
崔绍君1,2,3 , 季繁繁2,3 , 王婷2,3,4 , 袁晓彤2,3,4
1. 南京信息工程大学 自动化学院,南京, 210044
2. 南京信息工程大学 数字取证教育部工程研究中心,南京, 210044
3. 南京信息工程大学 江苏省大数据分析技术重点实验室,南京, 210044
4. 南京信息工程大学 计算机学院,南京, 210044
基金项目: 科技创新2030—“新一代人工智能”重大项目(2018AAA0100400) ; 国家自然科学基金(U21B2049,61936005)
Domain adaptive classification based on gradient weight pursuit
CUI Shaojun1,2,3 , JI Fanfan2,3 , WANG Ting2,3,4 , YUAN Xiaotong2,3,4
1. School of Automation,Nanjing University of Information Science & Technology,Nanjing 210044 ,China
2. Engineering Research Center of Digital Forensics Ministry of Education, Nanjing University of Information Science & Technology,Nanjing 210044 ,China
3. Jiangsu Key Laboratory of Big Data Analysis Technology,Nanjing University of Information Science & Technology,Nanjing 210044 ,China
4. School of Computer Science,Nanjing University of Information Science & Technology,Nanjing 210044 ,China
摘要
本文提出一种基于梯度权值追踪的剪枝与优化算法(GWP),旨在解决无监督领域中存在的过拟合问题,即在下游任务上的精度远低于在训练集上的精度.针对无监督领域自适应中基于差异与基于对抗的方法,将稠密-稀疏-稠密策略应用于解决过拟合问题.先对网络进行密集预训练,并学出哪些连接是重要的;在剪枝阶段,与原有的稠密-稀疏-稠密策略中的剪枝过程不同,本文的优化算法同时将权值和梯度联合考虑,既考虑到了权值信息(即零阶信息),也考虑到了梯度信息(即一阶信息)对网络剪枝过程的影响;在重密集阶段,恢复被修剪的连接,并以较小的学习率重新训练密集网络.最终,得到的网络在下游任务上取得了理想的效果.实验结果表明,与原有的基于差异和基于对抗的领域自适应方法相比,本文提出的GWP可以有效提升下游任务精度,且具有即插即用的效果.
Abstract
Here,we propose a pruning and optimization approach based on Gradient Weight Pursuit (GWP) to address the overfitting in unsupervised domain,which manifests as significantly lower accuracy on downstream tasks compared to that on training sets.To tackle the overfitting challenge in unsupervised domain,we employ the dense-sparse-dense strategy,focusing on both difference-based and adversarial adaptive methods.First,the network is pretrained intensively to identify crucial connections.Second,during the pruning stage,the optimization algorithm in this paper distinguishes itself from original dense-sparse-dense strategy by jointly considering both weight and gradient information.Specifically,it leverages both weight (i.e.zero-order information) and gradient (i.e.first-order information) to influence pruning process.In the final dense phase,the pruned connections are restored and the dense network is retrained with a reduced learning rate.Finally,the obtained network achieves desirable outcomes in downstream tasks.The experimental results show that the proposed GWP approach can effectively improve the accuracy of downstream tasks,offering a plug-and-play capability compared with original difference-based and adversarial domain adaptation methods.
0 引言
深度神经网络依靠其海量的训练数据和强大的计算能力在计算机视觉领域取得了巨大的成功,但随着其发展日趋成熟,问题也随之而来.目前机器学习中均假设训练数据和测试数据来自同一分布[1],但在实际应用中,这种假设往往不满足[2].在相同的应用领域中,数据会随着时间、空间等变化而发生变化,虽然在许多领域产生并提供了大量未标记的数据,但大多数现有的机器学习模型通常依赖于大量标记的训练数据来实现高性能,而实际应用程序中无法满足.因为标签的数量有限,而手动标注数据既昂贵又耗时,通常需要将知识从已有的领域转移到新的领域.由于域之间的差异,模型性能会下降.为了克服标注负担,缓解知识从一个领域转移到另一个相似但不同的领域时的领域转移问题,无监督领域自适应(Unsupervised Domain Adaptation,UDA)[3]应运而生.在UDA问题中,主要是处理标记的源域数据和未标注的目标域数据,通过减少两个域之间的差异,同时在训练过程中学习两个域之间的域不变表示,最终使得在有标签的源域上训练得到的模型迁移到无标签的目标域上时仍表现出良好的性能.
无监督领域自适应问题的关键在于减小两个领域之间的差异,提升模型的泛化性能.对于每个领域,都有其特有的知识,这些特有的知识对其他领域反而是一种干扰.无监督领域自适应学习的核心在于学习共有的特征,即领域不变特征.学习领域不变特征的两种方式:一是利用能够衡量两个领域特征差异的距离度量方法来对齐两个领域的边缘分布; 二是引入对抗式训练.针对这两种方式,典型的有基于差异和基于对抗的方法.本文的工作主要围绕基于差异与基于对抗的方法展开.基于差异的方法特点是通过最小化两个领域分布差异,来提取领域不变特征,代表性的工作有DAN[4](Deep Adaptation Network,深层自适应网络); 基于对抗的方法的核心思想是通过特征提取器和领域分类器二者之间的对抗模式来提取领域不变信息,代表性的工作有DANN[5]、CDAN[6]、DALN[7]等.基于差异与基于对抗的方法存在过拟合问题.训练数据和测试数据类别上的差异使得在源域上获得较高准确率的模型在目标域上最终的测试准确率较低,达不到训练集中的精度,对图像的分类效果不佳,这时模型被认为是过拟合的,而解决过拟合问题的方法包括增加数据量、运用正则化、简化模型、多种模型组合等.
稠密-稀疏-稠密(Dense-Sparse-Dense,DSD)[8]方法最初是为了解决过拟合问题而提出的.本文将稠密-稀疏-稠密的结构应用于UDA问题中,针对其稀疏的过程,原方法只利用网络中的权值信息,即零阶信息,而忽略了其一阶信息,即网络中的梯度信息.在神经网络的传播过程中,前向传播传递输入的信号,反向传播根据损失值计算其相关参数的梯度值.神经网络利用前向传播和反向传播的交替迭代来更新网络参数,当参数收敛或习得较好的结果时,就停止迭代训练.所以,参数的梯度信息对于神经网络的影响也是不容忽视的.文献[9-10]通过计算网络中相关权值的重要性,将其按重要性结果进行排序,删除重要性较小的权值,得到稀疏后的神经网络.这种剪枝准则仅考虑到了模型本身的权值信息(即零阶信息),却未考虑其梯度信息(即一阶信息)对网络的影响,使得有可能剪去一些比较重要的权值.有些权值本身很小,但是其变化却不可忽略,最终可能会对剪枝结果产生重大影响,这些梯度大的权值是不应该被剪去的.对此,本文引入基于梯度权值的非结构化剪枝算法,将权值和梯度信息联合考虑.
图1所示,现有的基于差异与基于对抗的无监督领域自适应方法存在两个领域之间的样本数目与领域差异问题,使在源域上获得较高准确率的模型在目标域上的准确率较低,不能将下游任务上的样本进行正确的分类.本文提出一种基于梯度权值追踪的剪枝与优化算法,首先对样本进行训练识别出哪些连接是重要的,用它们的绝对值来量化权重的重要性; 其次将梯度和权值联合考虑作为剪枝准则对模型进行剪枝操作;然后训练稀疏的网络,所有层的稀疏度相同; 最后对网络进行再稠密训练,恢复被剪枝的连接,即将之前剪枝的连接重新初始化为零,最终得到的是一个完整的模型.本文模型可以对目标域上的样本进行正确的分类,可以有效解决基于差异与基于对抗的无监督领域自适应方法中存在的过拟合问题.
本文的主要贡献如下:
1)将DSD框架运用于UDA方法中来解决其存在的过拟合问题;
2)设计了一种基于梯度权值追踪的非结构化剪枝算法,与传统的非结构化剪枝算法相比,本文的方法将权值和梯度结合作为剪枝准则,创新性地考虑到梯度对网络剪枝的影响,让剪枝准则更加优化,并将其应用于DSD框架的稀疏阶段;
3)将本文的算法运用到ResNet、AlexNet模型上,并在基准数据集上进行实验.结果表明,本文方法能够提升网络最终的精度,且具有即插即用的效果.
1 相关工作
1.1 无监督领域自适应研究
领域自适应(Domain Adaptation,DA)作为迁移学习[11]的一个特例,是机器学习的重要分支,旨在将在源域上训练得到的模型迁移到目标域上时能够获得良好的性能.根据目标域标签的有无可以将领域自适应问题分为三大类:有监督领域自适应、半监督领域自适应以及无监督领域自适应(UDA)[5].其中,UDA中用于训练的目标域数据是没有标签的,所以UDA的研究更具有挑战性,也越来越受到研究者的关注.
随着深度学习的快速发展,深度UDA方法表现出更大的发展前景.现有的UDA主流方法是学习域不变特征表示,主要可分为基于差异度量的UDA方法和基于对抗的UDA方法.基于差异度量的UDA方法先定义一个衡量两个领域之间特征差异的度量准则,再将提取出的源域特征和目标域特征映射到同一个再生核希尔伯特空间中,然后通过最小化上述准则进而减小源域和目标域之间的差异.代表性工作有基于最大均值差异(Maximum Mean Discrepancy,MMD)的深度域混淆(Deep Domain Confusion,DDC)[12],它通过最小化MMD来显式地对齐跨域的学习特征分布.DAN[6]和JAN[13]在DDC的基础上进行了改进,分别通过最小化多核最大均值差异(MK-MMD)以及联合最大均值差异(JMMD)来进行学习.
1本文动机
Fig.1Schematic of the gradient weight pursuit approach to address overfitting in UDA
本文基于差异(DAN,Deep Adaptation Network,深层自适应网络)和基于对抗(DANN、CDAN、DALN)的方法,将梯度权值追踪算法应用到UDA任务中,最终提升了算法的性能.其中,DANN从整体上将两个领域进行对齐,CDAN在DANN的基础上引入了预测得到的类别信息,DALN将类别分类器重用为判别器,通过统一的目标实现明确的领域对齐和类别区分,利用预测的判别信息进行充分的特征对齐.
1.2 网络剪枝
网络剪枝是一项重要的模型压缩技术,它可以减小模型的大小,同时使精度或性能损失最小化.网络剪枝的主要目的是删除冗余元素以减小网络大小,一般可分为结构化剪枝和非结构化剪枝两大类.其中,结构化剪枝能够对整个滤波器进行裁剪,以适度的性能下降为代价实现网络的加速,而非结构化剪枝具有保持剪枝模型性能的优势.本文方法运用在UDA任务中,主要是为了提升下游任务上的准确率,因此重点研究非结构化剪枝.Han等[9]将权值较小的权重做剪枝,即将低于阈值的所有权值从网络中移除.这种方式缩小了神经网络中模型的尺寸,也达到减小计算量的效果.但其不足之处是只考虑到网络中的权值信息(即零阶信息),根据权值对其进行裁剪,而未考虑到梯度信息(即一阶信息)对剪枝效果的影响.因此,本文的方法是将权值和梯度信息联合考虑,更全面地对网络参数信息进行研究.
剪枝算法在无监督领域自适应方面也有重大进展.Han等[14]提出TransPar来减少特定领域信息在学习过程中带来的副作用,从而增强对领域不变信息的记忆.具体做法是根据分布差异程度,在训练迭代中将参数分为可转移和不可转移参数,对两类参数采取不同的更新规则,最终在深度UDA网络上取得了很好的迁移效果.Liu等[15]提出TransTailor方法,旨在修剪预训练模型以改进迁移学习,与传统剪枝方法不同,他们根据目标感知的权重重要性对预训练模型进行剪枝和微调,生成了针对特定目标的最优子模型,在多个数据集上的实验结果表明,TransTailor方法优于传统的修剪方法.
1.3 DSD训练
本文的方法受到经典的DSD[8]训练流的启发,用于训练具有稀疏性规则的深度神经网络.该方法包含密集预训练(Dense pre-training,D)、稀疏剪枝(Sparse pruning,S)和密集再训练(Dense retraining,D)三个步骤.在预训练阶段,该方法训练一个密集网络,与传统的训练不同,这一步的主要目的不只是学习相关权重的连接,还需要学出哪些连接是重要的,用它们的绝对值来量化权重的重要性; 在剪枝阶段,根据权值的幅值对不重要的连接进行剪枝,然后训练稀疏的网络,所有层的稀疏度相同; 在最后的再密集阶段,恢复被剪枝的连接,将之前剪枝的连接重新初始化为零,整个网络以原始学习率的1/10重新训练,其他超参数保持不变,恢复修剪后的连接,增加了网络模型的容量,与之前剪枝后的稀疏模型相比,能达到较优的局部最小值.DSD训练改变了优化过程,通过修剪和重新训练对网络进行微调,显著提高了优化性能.DSD最初是为了解决训练样本和测试样本来自同一个域的过拟合问题而提出的,本文将机制扩展到UDA任务中,即应用于解决跨领域问题.
2 本文算法
本节主要介绍梯度权值追踪算法(Gradient Weight Pursuit,GWP),并将其应用到无监督领域自适应算法中.
2.1 领域自适应相关符号与定义
首先介绍相关的领域、任务、迁移学习以及领域自适应的相关符号和定义.
在领域自适应中,领域是学习的主体部分,包含两个域:源域(S)和目标域(T).一个领域 D主要可用两个部分概括:特征空间X以及边缘概率分布PX).X表示k个示例的集合,每个示例都与d维特征空间X中的特征向量一一对应,即X=x1x2xkX.而任务T是学习的目标.任务T用两部分表示为:标签空间和类别预测函数f(·).已知样本的特征向量表示,可用类别预测函数f(·)来预测其所属的类别标签fx).对源域DS和目标域DT进行相关的定义,源域样本可记为DS=xS1yS1xS2yS2xSkySk,其中,xSkXS表示源域的数据样本,ySk是对应所属类别的标签.同理,目标域可记为DT=xT1yT1xT2yT2xTkyTk,其中,xTkXT表示目标域的输入样本,yTk为其相应类别预测函数的输出.大多数设定下,源域样本充足,且存在大量有标签的数据,而目标域中数据量较少,且很难获取其真正的标签类别.即满足0≤nTnS.
2.2 基于梯度权值追踪的领域自适应方法
本文将提出的基于梯度权值追踪的剪枝与优化算法应用到无监督领域自适应方法中,用于提升下游任务的准确率.与无监督领域自适应方法中基于差异和基于对抗的方法相比,本文的策略能够解决过拟合问题,提高模型在测试集上的最终准确率.
在领域自适应模型中,每个领域都有自身特有的知识,而这些特有的知识对其他领域可能会是一种干扰.在UDA任务中,寻找一种共有特征,即领域不变特征[16],有利于两个领域之间的迁移.下文主要针对基于差异和基于对抗的方法,将梯度权值追踪的剪枝与优化算法应用到UDA任务中,提升下游任务的性能.
图2中:首先将源域和目标域的数据送入到特征提取器Gf中;然后对特征提取器中的卷积层进行稀疏操作,采用的是本文所提出的基于梯度权值追踪的剪枝算法(具体算法将在2.3节中介绍);接着得到经过参数更新后的特征提取器,提取出相应的特征ZSZT.其后的操作分成两路,具体如下:
1)基于差异(即DAN)的领域自适应方法.将提取出的特征ZSZT二者映射到同一个再生核希尔伯特空间中,计算其相应的MK-MMD(Multi-Kerncl Maximum Mean Discrepancy)损失dk2DslDtl,并将特征ZS送入到分类器Gy中,去训练一个分类器,即计算预测值与实际所属类别之间的分类损失Ly,总的优化表达式为
L=minΘ 1nbi=1nb Lyθxib,yib+λl=lalb dk2Dsl,Dtl.
(1)
式(1)中:nb表示源域和目标域中的标注数据; Ly是交叉熵损失函数; θxib表示将xib分配给标签yib的条件概率; λ>0是惩罚参数; 右边第一项是指输入网络中的参数xib所计算的输出值与实际标签yib二者之间计算得出的交叉熵损失值; 右边第二项dk2DslDtl是指源域和目标域在第l层上计算的MK-MMD损失值; lalb分别表示开始和终止的层索引.
2)基于对抗的领域自适应方法流程.相比于上一路中所介绍的DAN结构,这一路增加了一个领域分类器模块,主要用来判别特征提取器所提取到的特征属于哪一个域.在得到相应的特征ZSZT后,先将特征ZS中带有标签的部分送入到分类器Gy中,计算其交叉熵损失,将特征ZS中不带标签的部分和特征ZT同时送入到领域分类器Dφ中,用以分辨数据来自源域还是目标域,生成领域分类损失Ld.目标函数最终由两部分构成:
L=Ly+Ld.
(2)
其中:Ly表示标签预测的分类损失,即上文中所提到的交叉熵损失; Ld指对领域进行分类的损失.针对DANN方法,其目标函数具体表达式为
Lθf,θy,θd=i=1,,Ndi=0 LyGyGfxi;θf;θy,yi+i=1,,N LdGdRλGfxi;θf;θd,yi.
(3)
2基于梯度权值追踪的领域自适应算法框架
Fig.2Framework of the proposed domain adaptation approach based on gradient weight pursuit
式(3)中:θfθyθd分别表示为特征提取器Gf、标签预测器Gy以及领域分类器Gd的参数; xi表示输入的样本; yi表示数据所属的实际标签; di=0表示源域; Rλ表示梯度反转层函数; i表示样本数量; Ly表示源域数据的分类损失; Ld表示领域分类损失.
针对基于差异和基于对抗的方法,加入本文的优化算法后,可以有效地减少过拟合,在最大程度上提升模型的泛化能力,提升模型最终在测试集上的准确率.
2.3 梯度权值追踪的剪枝与优化算法
本节主要介绍梯度权值追踪的剪枝与优化算法,具体分为三步:密集预训练、剪枝及稀疏训练、再密集阶段.详细过程如图3所示.
1)密集预训练阶段
本阶段对原始网络进行训练,主要目的不只是学习相关权重的连接,还需要学出哪些连接是重要的,用它们的绝对值来量化权重的重要性.将源域数据和目标域数据同时送入特征提取器中去训练,训练n1轮:W=argminW {LxyW},其中,L表示损失函数,训练样本(xy),L层深度神经网络随机初始化权重为W=W1W2WL.
2)剪枝及稀疏训练阶段
本阶段包含分层网络剪枝和子网络稀疏训练.具体剪枝过程如图3所示,第一个矩阵表示网络参数相对应的梯度值,其下方的矩阵表示网络中的权值,将两者做乘积运算所得到的结果取绝对值,剪枝准则的表达式为
Pr=|L(W)×W|.
(4)
其中:Pr表示剪枝准则; W为训练中网络的权值; LW为损失对权值的梯度.
将上述计算出的结果从大到小进行排序,根据所设定的剪枝率0.5,计算出相应的阈值Sk,若结果大于阈值Sk,则二值掩码矩阵中相应位置上的值记为“1”,否则记为“0”.其中,“1”代表与其做相乘操作时原位置的参数保持不变,“0”所代表相应位置的参数会被置为零.然后将网络中的权值矩阵与掩码矩阵做相乘运算,得到新的权值矩阵.
经过剪枝操作后,得到的是一个稀疏的网络模型,上述剪枝操作将其中的部分冗余参数剪枝掉.相比于常规的随机丢弃(dropout)[17]方法,虽然都是在训练过程中有剪枝操作,但本文的剪枝方法是有一定的依据去选择去掉哪些连接,而dropout只是随机去掉,没有依据,可能将网络中的相关重要信息删除,对网络性能造成一定影响,且不可恢复.简单地降低模型容量会导致机器学习系统错过特征和目标输出之间的相关逻辑关系,所以对剪枝后的模型进行再次训练的操作是有必要的,即对修剪后的网络做再训练来提高模型容量.
3)再密集阶段
本阶段的具体操作是将剪枝操作中剪掉的连接重新连接并初始化为零,并使用较低的学习率(原始学习率的1/10)去重新训练所有权重,最终得到的是一个稠密的网络模型.与稀疏模型相比,网络模型的容量得以恢复,最终的下游分类任务的准确率也得到提升.
整个过程将稠密-稀疏-稠密[8]的策略运用到了其中,一方面,稀疏网络可以有效地减少过拟合,并且减小网络的容量,但会使得精度下降,另一方面,为了缓解网络模型的精度因容量的减少而下降,对网络进行再训练,将被剪枝的参数重新激活并继续训练,最终得到的是一个稠密的网络且具备良好的泛化性能.
本文的剪枝准则不仅考虑到了网络模型的零阶信息(即参数本身的信息),还考虑到了一阶信息(即梯度信息),将模型的权值信息和梯度信息联合考虑,在剪枝后有效地保留模型的有用信息,为后续进行再密集操作提供先验信息,也为网络的性能提升作出了贡献.
3梯度权值追踪剪枝优化
Fig.3Gradient weight pursuit pruning optimization
3 实验结果与分析
本节介绍将GWP算法运用到DAN、DANN、CDAN以及DALN[7]方法上取得的结果.主干网络分别采用AlexNet[18]、ResNet[19]网络,在4个标准数据集(Office-31、Office-Home、DomainNet、VisDA-2017)上进行实验,同时还与一些新的无监督领域自适应方法(MCD[20]、BSP[21]、GVB-GD[22]、DADA[23]、TSA[24]、SCDA[25])进行比较,所有的实验均是基于PyTorch框架[26]运行的.
3.1 基准数据集
Office-31[27]是迁移学习中的主流基准数据集,它由3个不同的域组成:亚马逊网站的Amazon(A)、网络摄像头的Webcam(W)及数码单反相机的DSLR(D).该数据集包含4 652张图像,总共31个不平衡类.基于此数据集,进行6个跨领域分类任务:A→W、D→W、W→D、A→D、D→A、W→A,同时进行跨领域分类任务评估(Avg).
Office-Home[28]也是一个域适应数据集,由15 599张图片和65个不平衡的类别组成.它由4个不同的域构成:艺术图片(A)、剪贴画艺术(C)、产品图像(P)和真实世界图像(R).基于此数据集,进行12个跨领域分类任务:A→C、A→P、A→R、C→A、C→P、C→R、P→A、P→C、P→R、R→A、R→C、R→P,同时进行跨领域分类任务评估(Avg).
DomainNet[29]数据集由6个不同的域组成:从剪贴画图像中收集的剪贴画(c)、从写实或现实世界图像中收集到的真实的照片(r)、从特定物体的草图中收集的素描(s)、特定物体的信息图形图像(i)、用绘画和快画的方式对物体做相应描述(p),以及从游戏图纸中收集的快画(q).该数据集包含345个类别,大约60万张图像.在这个数据集上构建12个跨领域分类任务,分别为c→p、c→r、c→s、p→c、p→r、p→s、r→c、r→p、r→s、s→c、s→p、s→r,同时进行跨领域分类任务评估(Avg).
VisDA-2017[30]数据集是一个由虚拟到真实的数据集,它由2个不同的领域所构成:从不同角度和不同光照条件下生成的3D模型的合成2D效果图(S),以及从真实或真实图像数据集中收集的真实图片(R).它一共包含12个类共28万张图像.
在处理Office-31[27]和Office-Home[28]数据集分类任务时,本文使用了ResNet50[19]以及AlexNet[18]模型,在这2种模型上,设置训练轮数均为20,在第10次训练时对模型进行剪枝,最终所得的精度为测试集的Top1精度.本文采用了SGD优化器[31].在ResNet50模型上优化器的超参数设置如下:动量系数为0.9; 权重衰减系数在用DAN作为框架时为5×10-4,用DANN、CDAN以及DALN作为框架时为0.001; 学习率在用DAN作为框架时为0.003,用DANN和CDAN作为框架时为0.01,用DALN作为框架时为0.005; 在DALN作为框架时训练的batchsize设置为36,其余情况下设置为32; 剪枝率设定为0.5.在DomainNet[29]数据集上的实验中,本文使用的模型为ResNet101,在此模型上,设置的训练轮数为30,其余超参数和在ResNet50模型上一致.在VisDA-2017数据集上,相关参数的设置和在ResNet50模型上一致.在表格所示的结果中,把每个数据集的所有迁移任务中最好的结果用粗体表示,便于观察与比较.
3.2 Office-31和Office-Home数据集上实验结果
将稀疏度设定为50%,再利用基于梯度权值追踪的剪枝与优化算法对网络进行一系列操作,相关结果如下.
表1是AlexNet[18]以及ResNet50[19]模型在Office-31数据集上的实验结果.本文的方法采用AlexNet以及ResNet50作为基本架构,在6个域的迁移任务上都取得了较高的分类精度,且均优于原有的方法.例如,对于ResNet50模型,在DANN方法上加入本文算法(GWP)后,它的平均准确率提升将近5个百分点,在CDAN以及DALN方法上加入本文算法平均准确率也分别提升了1.2以及1.1个百分点.表2显示的是在Office-Home数据集上的实验结果.可以观察到在该数据集上的12个迁移任务下,本文的方法也取得了较好的提升.对于ResNet50模型,在DANN方法上加入本文算法后,平均准确率提升将近8个百分点,在DAN以及DALN方法上,加入本文算法的平均准确率也分别提升了5.3和2.3个百分点.上述结果验证了本文算法的有效性,表明本文算法在一定程度上能有效地减少过拟合,提升模型在目标域上的最终分类精度,更好地实现从源域到目标域上的迁移效果.
1Office-31上无监督领域自适应的准确率
Table1Accuracy of unsupervised domain adaptation on Office-31
2Office-Home上无监督领域自适应的准确率
Table2Accuracy of unsupervised domain adaptation on Office-Home
3.3 DomainNet 数据集的实验结果
DomainNet数据集相较于Office-31和Office-Home这两个数据集,其数据量较大,最终测试集上的准确率会低一些.对于DomainNet数据集,使用本文算法在ResNet101模型上做了实验.当所设稀疏度为50%时,表3可得其实验结果.本文的方法在4个域的12种迁移任务上准确率均有了一定的提升.在DAN方法上,平均准确率提升了3.4个百分点,在DANN方法上,平均准确率提升了3.8个百分点,其提升的效果相对明显.实验结果说明本文的方法针对基于差异(DAN)以及基于对抗(DANN、CDAN)的方法在DomainNet数据集上可以有效地提升模型的精度,具有普适性.
3DomainNet上无监督领域自适应的准确率(ResNet101)
Table3Accuracy of unsupervised domain adaptation on DomainNet (ResNet101)
3.4 VisDA-2017数据集的实验结果
在VisDA-2017数据集上,使用本文的算法在ResNet101模型上做了相关实验.实验结果如表4所示,其中,S→R表示从合成域迁移到真实域,ResNet101是指不采用其他方法所得到的基准实验结果.在DALN上加入本文的算法后,最终准确率提升了1.2个百分点,相较于之前的SCDA以及TSA等方法,准确率分别提升了2.1和3.2个百分点.
4VisDA-2017上无监督领域自适应的准确率(ResNet101)
Table4Accuracy of unsupervised domain adaptation on VisDA-2017 (ResNet101)
3.5 测量分布距离
图4表示利用DANN和DANN+GWP的深层特征计算得出的PAD(Proxy A-Distance)[32].PAD表示H-散度的近似值,可用来衡量2个领域之间的差距.对于假设空间H的任意判别器A,假设在其源域和目标域之间存在误差ϵst,其定义为
d^A=21-ϵst.
(5)
可以观察到在Office-31的6种迁移任务上,DANN+GWP都成功地缩小了2个领域之间的分布距离,在A→D,D→A,W→A这3个迁移任务上相对明显,说明加入本文的方法后,源域和目标域的分布在进一步靠拢,分布差异的减小也有利于迁移任务的进行,最终也证明了本文方法的有效性.
4利用DANN和DANN+GWP的特征计算PAD
Fig.4PAD calculated by using characteristics of DANN and DANN+GWP
3.6 可视化结果
图5表示DANN和DANN+GWP在Office-31数据集上的A→W任务下学习到的特征可视化以后的结果.其中,红色和蓝色分别表示的是源域和目标域中的特征.图5a中DANN的特征分散比较混乱,而图5b在加入本文的方法后,特征在每一类样本示例中都具有比较强的鉴别性,且让类内距离[33]缩小,类间距离增大,同一类的样本实现更好地对齐.经过计算得出, DANN的轮廓系数值(sihouette value)为0.521 7(图5a),加入本文的算法后,所得到的轮廓系数值为0.843 9(图5b),由此可以看出,加入本文的方法使得类间特征更分散、类内特征更接近.
3.7 消融实验
为了研究本文方法中不同的参数设置对最终结果的影响,本文对Office-31数据集在ResNet50模型上做了相应的消融实验.
3.7.1 剪枝准则的选择
为了研究不同的剪枝准则对最终精度的影响,本文采用另外两种剪枝准则与本文的方法进行对比:一种是取幅值大的权值,即只利用网络中的权值信息作为剪枝准则; 另一种是取梯度大的参数,即只将梯度信息作为剪枝准则.
5特征可视化结果
Fig.5Visualized feature results
表5是不同的剪枝策略对最终结果的影响.其中,Weight only和Gradient only分别表示只使用权值和只使用梯度作为剪枝准则的结果.由表5可知,在Office-31的6种迁移任务下,用权值和梯度分别作为剪枝准则相比原始的方法,其结果均有提升.对于DANN框架,本文的方法比基于权值的方法提升了1.2个百分点,比基于梯度的方法提升了1.1个百分点,结果表明,将权值和梯度联合考虑,是一种较优的剪枝策略.它对后续的再训练操作以及最终的精度产生积极的影响,有更好的推广性.
5Office-31上使用不同剪枝准则的实验结果
Table5Results of experiment using different pruning criteria on Office-31
3.7.2 剪枝率的选择
为了验证不同稀疏度下网络性能的变化,对不同的稀疏度做消融实验.图6图7分别表示DAN和DANN模型在Office-31数据集上对不同稀疏度的精度变化影响,可以看出当剪枝率在30%~50%之间时,模型最终的分类精度相对其他稀疏度而言是较高的,而当剪枝率为10%或90%时,模型的分类精度相对较低,说明过高或过低的剪枝率对最终目标域上的分类性能都会产生消极影响.
6DAN框架在Office-31上不同稀疏度下精度变化曲线
Fig.6Accuracy variation curves of DAN framework under different sparsity on Office-31
7DANN框架在Office-31上不同稀疏度下精度变化曲线
Fig.7Accuracy variation curves of DANN framework under different sparsity on Office-31
3.7.3 稠密-稀疏-稠密策略的选择
为了研究密集-稀疏-密集策略在整个过程中所起的作用,本文算法还选择对网络进行剪枝且剪枝后不进行再密集操作.表6是DAN和DANN模型在Office-31数据集上的实验结果.其中,w/表示使用稠密-稀疏-稠密策略,w/o表示只使用剪枝算法.从实验结果可以看出,在Office-31数据集上运用稠密-稀疏-稠密策略后,在DAN和DANN方法上的6个迁移任务下以及平均准确率均有提升,DAN上平均准确率提升了0.7个百分点,DANN上平均准确率提升了1.7个百分点.结果表明,稠密-稀疏-稠密策略可以有效缓解过拟合问题,提升算法的迁移效果.
3.7.4 被裁剪参数的更新规则
在剪枝后对网络进行再密集操作,恢复网络模型的原始容量,针对被裁剪掉的参数,其更新规则既可以从零开始更新,也可以不从零开始更新,即随机进行初始化更新.具体实验结果如表7所示.其中:RandInit表示使用随机初始化的方法,即使用“均值为0、标准差为0.01、偏差为1的高斯(正常)噪声进行初始化”,这种正常的随机初始化方法不适用于训练非常深的网络,对于使用ReLU激活函数的网络,会产生梯度消失和爆炸问题; ZeroInit表示从零开始进行初始化,然后继续进行网络的更新.实验结果表明,对被裁剪的参数使用从零开始初始化比随机初始化最终得到的网络性能要好,在Office-31上的6个迁移任务以及最终的平均准确率上都有提升,且使用从零初始化比使用随机初始化提升了1.8个百分点.
6Office-31上密集-稀疏-密集策略的使用对无监督领域自适应的影响(ResNet50)
Table6Impact of dense-sparse-dense strategy on unsupervised domain adaptation on Office-31 (ResNet50)
7Office-31上被裁剪参数更新对无监督领域自适应的影响(ResNet50)
Table7Impact of cropped parameter updates on unsupervised domain adaptation on Office-31 (ResNet50)
4 结束语
本文提出了一种基于梯度权值追踪的非结构化剪枝与优化算法(GWP),该方法将DSD框架应用于UDA任务中,并把其中的剪枝过程做了改进,将权值信息和梯度信息联合考虑.既考虑到了网络模型的零阶信息(即权值信息),也考虑到了一阶信息(即梯度信息),从而能够在剪枝后有效地保留模型的相关重要信息,且为后续进行再密集操作提供了先验信息.本文的方法在几个领域自适应的基准数据集上均取得了较好的结果.
接下来可以考虑无源域领域自适应的应用,因为相关数据的保密性或小型设备的存储空间有限等原因,存在有些源域数据并不总是可以访问的,而无源域领域自适应[34-35]问题只提供经过良好训练的源域模型而不是经过注释的源域数据来实现对目标数据的自适应,其应用范围更广泛,也有望进一步提升网络在下游任务中的泛化性能.
1本文动机
Fig.1Schematic of the gradient weight pursuit approach to address overfitting in UDA
2基于梯度权值追踪的领域自适应算法框架
Fig.2Framework of the proposed domain adaptation approach based on gradient weight pursuit
3梯度权值追踪剪枝优化
Fig.3Gradient weight pursuit pruning optimization
4利用DANN和DANN+GWP的特征计算PAD
Fig.4PAD calculated by using characteristics of DANN and DANN+GWP
5特征可视化结果
Fig.5Visualized feature results
6DAN框架在Office-31上不同稀疏度下精度变化曲线
Fig.6Accuracy variation curves of DAN framework under different sparsity on Office-31
7DANN框架在Office-31上不同稀疏度下精度变化曲线
Fig.7Accuracy variation curves of DANN framework under different sparsity on Office-31
1Office-31上无监督领域自适应的准确率
Table1Accuracy of unsupervised domain adaptation on Office-31
2Office-Home上无监督领域自适应的准确率
Table2Accuracy of unsupervised domain adaptation on Office-Home
3DomainNet上无监督领域自适应的准确率(ResNet101)
Table3Accuracy of unsupervised domain adaptation on DomainNet (ResNet101)
4VisDA-2017上无监督领域自适应的准确率(ResNet101)
Table4Accuracy of unsupervised domain adaptation on VisDA-2017 (ResNet101)
5Office-31上使用不同剪枝准则的实验结果
Table5Results of experiment using different pruning criteria on Office-31
6Office-31上密集-稀疏-密集策略的使用对无监督领域自适应的影响(ResNet50)
Table6Impact of dense-sparse-dense strategy on unsupervised domain adaptation on Office-31 (ResNet50)
7Office-31上被裁剪参数更新对无监督领域自适应的影响(ResNet50)
Table7Impact of cropped parameter updates on unsupervised domain adaptation on Office-31 (ResNet50)
Huang K Z, Zheng D N, Sun J,et al. Sparse learning for support vector classification[J]. Pattern Recognition Letters,2010,31(13):1944-1951
Cortes C, Mohri M. Domain adaptation in regression[M]//Lecture Notes in Computer Science. Berlin, Heidelberg: Springer Berlin Heidelberg,2011:308-323
Wang M, Deng W H. Deep visual domain adaptation:a survey[J]. Neurocomputing,2018,312(C):135-153
Long M S, Cao Y, Wang J M,et al. Learning transferable features with deep adaptation networks[C]//Proceedings of the 32nd International Conference on Machine Learnin, Volume 37. July 6-11,2015, Lille, France. ACM,2015:97-105
Ganin Y, Lempitsky V. Unsupervised domain adaptation by backpropagation[C]//Proceedings of the 32nd International Conference on Machine Learnin, Volume 37. July 6-11,2015, Lille, France. ACM,2015:1180-1189
Long M S, Cao Z J, Wang J M,et al. Conditional adversarial domain adaptation[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. December 3-8,2018, Montreal, Canada. ACM,2018:1647-1657
Chen L, Chen H A, Wei Z X,et al. Reusing the task-specific classifier as a discriminator:discriminator-free adversarial domain adaptation[C]//2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition(CVPR). June 18-24,2022, New Orleans, LA, USA. IEEE,2022. DOI:10.1109/cvpr52688.2022.00704
Han S, Pool J, Narang S R,et al. DSD:dense-sparse-dense training for deep neural networks[J].arXiv e-Print,2016,arXiv:1607.04381
Han S, Pool J, Tran J,et al. Learning both weights and connections for efficient neural networks[C]//Proceedings of the 28th International Conference on Neural Information Processing Systems, Volume 1. December 7-12,2015, Montreal, Canada. ACM,2015:1135-1143
LeCun Y, Denke J S, Solla S A. Optimal brain damage[M]//Touretzky D S. Advances in Neural Information Processing Systems 2. San Francisco, CA, USA: Morgan Kaufmann Publishers Inc.,1990:598-605
庄福振, 罗平, 何清, 等. 迁移学习研究进展[J]. 软件学报,2015,26(1):26-39. ZHUANG Fuzhen, LUO Ping, HE Qing,et al. Survey on transfer learning research[J]. Journal of Software,2015,26(1):26-39
Tzeng E, Hoffman J, Zhang N,et al. Deep domain confusion:maximizing for domain invariance[J].arXiv e-Print,2014,arXiv:1412.3474
Long M S, Zhu H, Wang J M,et al. Deep transfer learning with joint adaptation networks[J].arXiv e-Print,2016,arXiv:1605.06636
Han Z Y, Sun H L, Yin Y L. Learning transferable parameters for unsupervised domain adaptation[J]. IEEE Transactions on Image Processing,2022,31:6424-6439
Liu B Y, Cai Y F, Guo Y,et al. TransTailor:pruning the pre-trained model for improved transfer learning[J]. Proceedings of the AAAI Conference on Artificial Intelligence,2021,35(10):8627-8634
Fan B, Yang Y Z, Feng W S,et al. Seeing through darkness:visual localization at night via weakly supervised learning of domain invariant features[J]. IEEE Transactions on Multimedia,2023,25:1713-1726
Hinton G, Srivastava N, Krizhevsky A,et al. Improving neural networks by preventing co-adaptation of feature detectors[J].arXiv e-Print,2012,arXiv:1207.0580
Krizhevsky A, Sutskever I, Hinton G E. ImageNet classification with deep convolutional neural networks[J]. Communications of the ACM,2017,60(6):84-90
He K M, Zhang X Y, Ren S Q,et al. Deep residual learning for image recognition[C]//2016 IEEE Conference on Computer Vision and Pattern Recognition. June 27-30,2016, Las Vegas, NV, USA. IEEE,2016:770-778
Saito K, Watanabe K, Ushiku Y,et al. Maximum classifier discrepancy for unsupervised domain adaptation[C]//2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition. June 18-23,2018, Salt Lake City, UT, USA. IEEE,2018:3723-3732
Chen X Y, Wang S N, Long M S,et al. Transferability vs.discriminability:batch spectral penalization for adversarial domain adaptation[C]//International Conference on Machine Learning. June 9-15,2019, Long Beach, CA, USA. IMLS,2019:1081-1090
Cui S H, Wang S H, Zhuo J B,et al. Gradually vanishing bridge for adversarial domain adaptation[C]//2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition(CVPR). June 13-19,2020, Seattle, WA, USA. IEEE,2020:12455-12464
Tang H, Jia K. Discriminative adversarial domain adaptation[J]. Proceedings of the AAAI Conference on Artificial Intelligence,2020,34(4):5940-5947
Li S, Xie M X, Gong K X,et al. Transferable semantic augmentation for domain adaptation[C]//2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition(CVPR). June 20-25,2021, Nashville, TN, USA. IEEE,2021:11516-11525
Li S, Xie M, Lv F,et al. Semantic concentration for domain adaptation[C]//2021 IEEE/CVF International Conference on Computer Vision. October 10-17,2021, Montreal, QC, Canada. IEEE,2021:9102-9111
Paszke A, Gross S, Massa F,et al. PyTorch:an imperative style,high-performance deep learning library[C]//Proceedings of the 33rd International Conference on Neural Information Processing Systems,2019:8026-8037
Saenko K, Kulis B, Fritz M,et al. Adapting visual category models to new domains[C]//Proceedings of the 11th European Conference on Computer Vision: Part IV. September 5-11,2010, Heraklion, Crete, Greece. ACM,2010:213-226
Venkateswara H, Eusebio J, Chakraborty S,et al. Deep hashing network for unsupervised domain adaptation[C]//2017 IEEE Conference on Computer Vision andPattern Recognition(CVPR). July 21-26,2017, Honolulu, HI. IEEE,2017:5018-5027
Peng X C, Bai Q X, Xia X D,et al. Moment matching for multi-source domain adaptation[C]//2019 IEEE/CVF International Conference on Computer Vision(ICCV). October 27-November 2,2019, Seoul, Korea(South). IEEE,2019. DOI:10.1109/iccv.2019.00149
Peng X, Usman B, Kaushik N,et al. Visda:the visual domain adaptation challenge[J].arXiv e-Print,2017,arXiv:1710.06924
Loshchilov I, Hutter F. SGDR:stochastic gradient descent with warm restarts[J].arXiv e-Print,2016,arXiv:1608.03983
Ben-David S, Blitzer J, Crammer K,et al. A theory of learning from different domains[J]. Machine Learning,2010,79(1):151-175
Yan L, Fan B, Liu HM,et al. Triplet adversarial domain adaptation for pixel-level classification of VHR remote sensing images[J]. IEEE Transactions on Geoscience and Remote Sensing,2019, PP(99):1-16
Karim N, Mithun N C, Rajvanshi A,et al. C-SFDA:a curriculum learning aided self-training framework for efficient source free domain adaptation[C]//2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition(CVPR). June 17-24,2023, Vancouver, BC, Canada. IEEE,2023:24120-24131
Litrico M, Del Bue A, Morerio P. Guiding pseudo-labels with uncertainty estimation for source-free unsupervised domain adaptation[C]//2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition(CVPR). June 17-24,2023, Vancouver, BC, Canada. IEEE,2023:7640-7650

地址:江苏省南京市宁六路219号    邮编:210044

联系电话:025-58731025    E-mail:nxdxb@nuist.edu.cn

南京信息工程大学学报 ® 2025 版权所有  技术支持:北京勤云科技发展有限公司