CN116681945A - 一种基于强化学习的小样本类增量识别方法 - Google Patents
一种基于强化学习的小样本类增量识别方法 Download PDFInfo
- Publication number
- CN116681945A CN116681945A CN202310688597.2A CN202310688597A CN116681945A CN 116681945 A CN116681945 A CN 116681945A CN 202310688597 A CN202310688597 A CN 202310688597A CN 116681945 A CN116681945 A CN 116681945A
- Authority
- CN
- China
- Prior art keywords
- classifier
- class
- learning
- small sample
- increment
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 53
- 230000002787 reinforcement Effects 0.000 title claims abstract description 34
- 238000012549 training Methods 0.000 claims abstract description 45
- 238000005457 optimization Methods 0.000 claims description 21
- 230000008569 process Effects 0.000 claims description 11
- 230000006870 function Effects 0.000 claims description 9
- 238000012360 testing method Methods 0.000 claims description 7
- 230000008901 benefit Effects 0.000 claims description 5
- 239000006185 dispersion Substances 0.000 claims description 5
- 238000005070 sampling Methods 0.000 claims description 5
- 230000008859 change Effects 0.000 claims description 4
- 238000011156 evaluation Methods 0.000 claims description 4
- 230000007774 longterm Effects 0.000 claims description 4
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 claims description 3
- 238000005259 measurement Methods 0.000 claims description 2
- 238000013461 design Methods 0.000 description 6
- 238000013473 artificial intelligence Methods 0.000 description 2
- 230000009471 action Effects 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000015556 catabolic process Effects 0.000 description 1
- 230000007423 decrease Effects 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000006731 degradation reaction Methods 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 230000002708 enhancing effect Effects 0.000 description 1
- 238000007689 inspection Methods 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 206010027175 memory impairment Diseases 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000002360 preparation method Methods 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 238000003860 storage Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
- G06V10/765—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects using rules for classification or partitioning the feature space
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
- G06N3/0455—Auto-encoder networks; Encoder-decoder networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/092—Reinforcement learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- General Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Biomedical Technology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Multimedia (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Image Analysis (AREA)
Abstract
本发明提供的一种基于强化学习的小样本类增量识别方法,所述识别方法包括:下载小样本类增量学习数据集;设计基于强化学习的小样本增量分类识别网络;利用训练数据对所述网络进行训练,获得训练网络;根据所述训练网络生成小样本类增量学习模型;采用所述小样本类增量学习模型进行小样本增量分类识别。解决上述现有技术未自适应探索新类分类器学习策略造成不能缓和小样本增量学习灾难遗忘的问题。
Description
技术领域
本发明涉及强化学习、小样本分类处理和计算机视觉技术领域,尤其涉及一种基于强化学习的小样本类增量识别方法。
背景技术
在某些场景下,例如工业视觉检测,新概念通常只有相当少的样本,例如新的缺陷类别,这就产生了一种极具有挑战性的场景——小样本学习场景,这时我们希望人工智能系统能够快速整合这些新概念,同时保持住旧知识,这一过程被称为小样本增量学习。小样本类增量学习的标准流程是分阶段地输入不同类别的训练数据,而每一个阶段,一组特定类别的样本被输入到模型中,这些样本往往数量很少,同时其他类别的数据也不可用。增量学习赋予模型增量迭代、动态学习的能力,在智能辅导系统和风控领域具有天然优势和用武之地。然而,当人工智能模型只用新类别的有限样本更新时,往往会出现“灾难性遗忘”的情况,即从以前的数据中获得的知识急剧退化。另外由于训练数据不足,模型也可能会新类上严重过拟合。因此如何缓和灾难遗忘和过拟合成为亟需解决的问题。目前的小样本增量识别方法可分为“基于回放的”、“基于正则化的”和“基于网络体系结构的”。具体有:
1)基于回放的增量学习方法。主要技术手段:通过隐式地保留相关的旧知识来缓解在增量学习各阶段中旧知识的遗忘,在此基础上利用相应的分类器进行增量识别。问题和缺点:这些方法需要额外的计算资源和存储空间用于回忆旧知识,当任务种类不断增多时,要么训练成本会变高,要么代表样本的代表性会减弱,同时在实际生产环境中,这种方法还可能存在数据隐私泄露的问题。
2)基于正则化的增量学习方法。模型输出结果多采用交叉熵/基于feature和中间结果多采用L2范数或者余弦相似度),使新模型与旧模型相似。问题和缺点:这种方法高度依赖于新旧任务之间的相关性,当任务差异太大时会出现任务混淆的现象,并且一个任务的训练时间会随着学习任务的数量线性增长,同时引入的正则项常常不能有效地约束模型在新任务上的优化过程。
基于网络体系结构的增量学习方法。主要技术手段:通过扩张模型结构或者利用拓扑结构进行增量学习,不断地对网络结构进行修改。这样旧任务通过旧的模型权重进行保持,同时新增的模型结构适应了增量任务,从而达到了既适应旧任务,也适应新任务的目的。
现有技术存在的技术问题和缺点:模型结构增加必然造成任务存储负担和运算负担的加剧,如何通过更优的模型增扩方式。
发明内容
鉴于上述问题,提出了本发明以便提供克服上述问题或者至少部分地解决上述问题的一种基于强化学习的小样本类增量识别方法。
根据本发明的一个方面,提供了一种基于强化学习的小样本类增量识别方法,所述识别方法包括:
下载小样本类增量学习数据集;
设计基于强化学习的小样本增量分类识别网络;
利用训练数据对所述网络进行训练,获得训练网络;
根据所述训练网络生成小样本类增量学习模型;
采用所述小样本类增量学习模型进行小样本增量分类识别。
可选的,所述下载小样本类增量学习数据集具体包括:
搜集主流的小样本增量学习数据集CIFAR-100,包括100个类,每个类包含600个图像;
每类各有500个训练图像和100个测试图像,所述小样本增量学习数据集CIFAR-100中的100个类被分成20个超类;
每个图像都带有一个精细标签和一个粗糙标签;
遵守小样本增量N-way K-shot设置,将所述小样本增量学习数据集CIFAR-100数据集划分为60个基类和40个新类;
在基础阶段的基类训练结束后,剩下的40个类进一步划分为8个增量阶段,每个阶段是一个5-way 5-shot的分类任务。
可选的,所述设计基于强化学习的小样本增量分类识别网络具体包括:分类器更新和探索性优化两部分。
可选的,所述分类器更新具体包括:
模型输入为数据集CIFAR-100,编码器φ采用ResNet20作为骨干,设Wt是某个阶段t时刻新类分类器的权重,如果每次增量阶段新类包含M个类别,则Wt=[w1,…,wk,…,wM]T,其中wk表示第k类的分类器权重,d表示特征维度;
对于来自批次数据X中的一张输入样本xi,使用余弦相似度来度量提取出的特征嵌入和对应的分类权重Wt来描述分类性能,并定义为分类器状态/>在获取当前批次数据预测状态St后,计算出当前时刻分类器的交叉熵损失:
其中,是独热编码下标签向量yi的第k个元素,/>表示分类器状态/>的第k个元素,对应新类中第k个类上样本xi的余弦距离;/>越大,表示越像第k类;
得到当前时刻分类器权重参数Wt的梯度并结合探索性优化模块得到的分类器学习策略at一起更新权重参数:
下一个分类器状态通过该次探索更新后的权重Wt+1,用样本xi计算得出;同时获得学习策略at的奖励/>其中R为奖励函数。
根据权利要求1所述的一种基于强化学习的小样本类增量识别方法,其特征在于,所述探索性优化具体包括:
利用批判者ψ对分类器状态St进行评估,以获得长期回报vt,与分类器一步更新的奖励rt不同;
vt从状态St开始评估分类器的性能,并使用at持续更新分类器;
下一个分类器状态St+1通过探索性批判者ψ得到返回值vt,描述了下一个状态St+1在长时间使用策略at后,对分类器性能的影响;
使用强化学习中单步差分的思想指导批判者ψ的学习:
其中,γ是一个比例系数,用来缩放未来的收益;
当前差异和当前分类器策略at的log值同时自适应地监督高斯行动者探索优化:
作为监督信号行动者损失La引导高斯行动者更新到/>
下一个分类器状态St+1通过更新后的行动者学习分类器策略均值μ和方差σ,并通过高斯采样N(μ,σ)获得新的分类器学习策略at+1;
新策略at+1用于推动分类器参数Wt+1的下一次更新,交替推进分类器更新和分类器策略的探索性优化,并实现将新类分类器的更新和新类分类器的探索性优化集成到一个端到端闭环学习过程中。
根据权利要求4所述的一种基于强化学习的小样本类增量识别方法,其特征在于,所述奖励函数R具体包括:
评估新类中类内紧致性的奖励rintra:
其中表示取对应于标签yi的操作;
评估新类数据的类间分散的奖励rinter:
当新的分类器参数Wt+1使输入样本xi越接近其他类时,奖励分数rinter变得越低,取变化最大的一项作为惩罚;
评估新旧类之间分散的奖励rcross:
其中,Yold表示当前阶段的旧类标签。样本xi的特征嵌入越接近旧类的分类器,奖励rcross值就越小,表示错误分类的情况越多;
rintra表示相比于前一个分类器,更新后的分类器正确分类概率与前一个分类器相比的增益,其中正确分类概率与余弦度量成正比;
rinter表示更新前后最大误分类概率的变化;
rcross表示新类样本被错误分类到旧类的概率值之和,rcross越大,表示旧类的知识能被保留的越多;
强化学习的总奖励rt(xi)形式化:rt(xi)=rintra(xi)+rintra(xi)+rcross(xi)。
可选的,所述利用训练数据对所述网络进行训练,获得训练网络具体包括:
在基础训练阶段,使用大量的基类数据对基础模型进行训练;
在增量阶段,冻结模型的骨干参数,只优化分类器参数;
使用Adam优化器对模型进行了2000个回合训练,学习率为0.0003;
在整个训练过程中,进行数据增强。
可选的,所述进行数据增强的方法具体包括:采用随机裁剪、随机缩放和随机水平翻转方法进行数据增强。
可选的,所述识别方法还包括:
每个增量阶段结束后,将训练好的模型在相应的测试集上进行评估,并报告分类准确率。
本发明提供的一种基于强化学习的小样本类增量识别方法,所述识别方法包括:下载小样本类增量学习数据集;设计基于强化学习的小样本增量分类识别网络;利用训练数据对所述网络进行训练,获得训练网络;根据所述训练网络生成小样本类增量学习模型;采用所述小样本类增量学习模型进行小样本增量分类识别。解决上述现有技术未自适应探索新类分类器学习策略造成不能缓和小样本增量学习灾难遗忘的问题。
上述说明仅是本发明技术方案的概述,为了能够更清楚了解本发明的技术手段,而可依照说明书的内容予以实施,并且为了让本发明的上述和其它目的、特征和优点能够更明显易懂,以下特举本发明的具体实施方式。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。
图1为本发明实施例提供的一种基于强化学习的小样本类增量识别方法的流程图;
图2为本发明实施例提供的基于强化方法的小样本增量学习模型示意图。
具体实施方式
下面将参照附图更详细地描述本公开的示例性实施例。虽然附图中显示了本公开的示例性实施例,然而应当理解,可以以各种形式实现本公开而不应被这里阐述的实施例所限制。相反,提供这些实施例是为了能够更透彻地理解本公开,并且能够将本公开的范围完整的传达给本领域的技术人员。
本发明的说明书实施例和权利要求书及附图中的术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元。
下面结合附图和实施例,对本发明的技术方案做进一步的详细描述。
如图1所示,数据准备阶段
1.1搜集主流的小样本增量学习数据集CIFAR-100,它有100个类,每个类包含600个图像。每类各有500个训练图像和100个测试图像。CIFAR-100中的100个类被分成20个超类。每个图像都带有一个“精细”标签(它所属的类)和一个“粗糙”标签(它所属的超类)。遵守小样本增量N-way K-shot设置,将CIFAR-100数据集划分为60个基类和40个新类。在基础阶段的基类训练结束后,剩下的40个类进一步划分为8个增量阶段,每个阶段是一个5-way5-shot的分类任务。
模型设计阶段,端到端的模型设计如下:
2.1整体模型如图2所示,包含分类器更新和探索性优化两部分。在分类器更新部分,模型输入为数据集CIFAR-100,编码器φ采用ResNet20作为骨干,设Wt是某个阶段t时刻新类分类器的权重,如果每次增量阶段新类包含M个类别,则Wt=[w1,…,wk,…,wM]T,其中wk表示第k类的分类器权重,d表示特征维度。对于来自批次数据X中的一张输入样本xi,我们使用余弦相似度来度量提取出的特征嵌入/>和对应的分类权重Wt来描述分类性能,并将其定义为分类器状态/>在获取当前批次数据预测状态St后,计算出当前时刻分类器的交叉熵损失:
其中,是独热编码下标签向量yi的第k个元素,/>表示分类器状态/>的第k个元素,对应新类中第k个类上样本xi的余弦距离。/>越大,表示越像第k类。然后得到当前时刻分类器权重参数Wt的梯度/>并结合探索性优化模块得到的分类器学习策略at一起更新权重参数:
随后,下一个分类器状态通过该次探索更新后的权重Wt+1,用样本xi计算得出。同时获得学习策略at的奖励/>另一方面在探索优化模块,首先利用批判者ψ对分类器状态St进行评估,以获得长期回报vt(与分类器一步更新的奖励rt不同)。然后vt从状态St开始评估分类器的性能,并使用at持续更新分类器。类似地,下一个分类器状态St+1通过探索性批判者ψ得到返回值vt,它描述了下一个状态St+1在长时间使用策略at后,对分类器性能的影响。接下来使用强化学习中单步差分的思想指导批判者ψ的学习:
其中,γ是一个比例系数,用来缩放未来的收益。另一方面,当前差异和当前分类器策略at的log值同时自适应地监督高斯行动者探索优化:
接下来,作为监督信号行动者损失La引导高斯行动者更新到/>相应地,下一个分类器状态St+1通过更新后的行动者/>学习分类器策略均值μ和方差σ,并通过高斯采样N(μ,σ)获得新的分类器学习策略at+1。新策略at+1用于推动分类器参数Wt+1的下一次更新,从而交替推进分类器更新和分类器策略的探索性优化,并实现将新类分类器的更新和新类分类器的探索性优化集成到一个端到端闭环学习过程中。
2.2奖励函数R的设计。
(1)评估新类中类内紧致性的奖励rintra:
其中,表示取对应于标签yi的操作。
(2)评估新类数据的类间分散的奖励rinter:
当新的分类器参数Wt+1使输入样本xi越接近其他类时,奖励分数rinter变得越低,这里取变化最大的一项作为惩罚。
(3)评估新旧类之间分散的奖励rcross:
其中,Yold表示当前阶段的旧类标签。样本xi的特征嵌入越接近旧类的分类器,奖励rcross值就越小,表示错误分类的情况越多。
综上,rintra表示相比于前一个分类器,更新后的分类器正确分类概率与前一个分类器相比的增益,其中正确分类概率与余弦度量成正比。rinter表示更新前后最大误分类概率的变化。rcross表示新类样本被错误分类到旧类的概率值之和,rcross越大,表示旧类的知识能被保留的越多。最终强化学习的总奖励rt(xi)形式化如下:
rt(xi)=rintra(xi)+rintra(xi)+rcross(xi)。
3.模型训练阶段
在基础训练阶段,使用大量的基类数据对基础模型进行训练。在增量阶段,冻结模型的骨干参数,只优化分类器参数。使用Adam优化器对模型进行了2000个回合训练,学习率为0.0003。在整个训练过程中,采用随机裁剪、随机缩放和随机水平翻转等方法进行数据增强。
4.模型测试阶段
每个增量阶段结束后,我们将训练好的模型在相应的测试集上进行评估,并报告分类准确率。
小样本探索性增量学习方法:如图2所示,该方法包括两部分:新分类器更新和探索性优化。
图像数据首先通过编码器提取特征,然后通过分类器获得每个类的度量值,接下来将与新类分类器权重Wt相关的度量值作为分类器状态St,然后,通过分类器学习策略at和计算出的梯度将新类分类器权重由Wt更新为Wt+1。当分类器权重更新后,就通过Wt+1得到新的分类器度量,从而得到下一个状态St+1。其中,分类器更新策略at由策略探索优化部分来完成,首先分类器状态St被送到高斯行动者,获取策略的均值μ和方差σ。在此基础上,可以通过高斯采样选择一个分类器学习策略。另一方面,状态St通过探索性批判者获得当前分类器状态返回值vt,vt从长期角度评估分类器状态。然后利用优化奖励值rt和分类器状态返回值vt计算批判者损失Lc(rt;vt),以监督批判者的更新。同时,使用分类器行为损失La(rt;vt;at)监督高斯行动者的更新。该方法在训练过程中,通过多个共同作用的损失确保了模型能够在增量阶段探索到较优的新类分类器学习策略,从而提升最终的小样本类增量识别性能。
利用新旧类间距离设计奖励函数:在分类器的优化中,需要每次评估更新后分类器的性能。另一方面探索性批判者的更新学习,更需要对分类器进行长期的评估,因此奖励函数R的设计是影响优化方向的关键。综合从三个方面设计:(1)新类中的类内紧致性。给定样本xi和得到的分类器学习策略at,在分类器权重从Wt更新到Wt+1后,样本xi在Wt+1上的性能应该优于Wt,所以样本xi在Wt+1上与对应标签yi的余弦度量应该大于分类器Wt上的余弦度量,从而可以定义一个距离差来表示两个余弦度量之间的差异。2)新类数据的类间分散。通过分析新类的类间距离,通过约束不同类的距离变大,从而样本才不容易分错类。(3)新旧类之间的分散。增量过程中,所有的类别应尽可能分散,以减少特征空间中的误分类。为此,考虑约束新类样本与旧类权重之间的距离。最后通过上述综合考虑类内类间距离设计奖励函数,减少了增量阶段的误分概率,小样本增量学习中的灾难遗忘被有效地缓和。
有益效果:本发明提供了一个基于强化学习的小样本探索性增量学习方法,将新类分类器的更新和新类分类器的探索性优化集成到一个端到端闭环学习过程中,显著提升小样本类增量学习的精度。利用强化学习中试错学习的特点,提出了一种基于强化学习的小样本探索策略来优化新的分类器,并采用高斯推理的策略采样方法自适应调整增量模型参数以适应新知识数据;设计分类器奖励函数来平衡新旧类别的影响,并对分类器进行长远评估,即时调整优化方向,在提高新类的可区别性的同时减轻旧类的性能下降,进一步缓和灾难遗忘问题,从而提升了小样本增量学习的识别性能。
以上的具体实施方式,对本发明的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上仅为本发明的具体实施方式而已,并不用于限定本发明的保护范围,凡在本发明的精神和原则之内,所做的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。
Claims (9)
1.一种基于强化学习的小样本类增量识别方法,其特征在于,所述识别方法包括:
下载小样本类增量学习数据集;
设计基于强化学习的小样本增量分类识别网络;
利用训练数据对所述网络进行训练,获得训练网络;
根据所述训练网络生成小样本类增量学习模型;
采用所述小样本类增量学习模型进行小样本增量分类识别。
2.根据权利要求1所述的一种基于强化学习的小样本类增量识别方法,其特征在于,所述下载小样本类增量学习数据集具体包括:
搜集主流的小样本增量学习数据集CIFAR-100,包括100个类,每个类包含600个图像;
每类各有500个训练图像和100个测试图像,所述小样本增量学习数据集CIFAR-100中的100个类被分成20个超类;
每个图像都带有一个精细标签和一个粗糙标签;
遵守小样本增量N-way K-shot设置,将所述小样本增量学习数据集CIFAR-100数据集划分为60个基类和40个新类;
在基础阶段的基类训练结束后,剩下的40个类进一步划分为8个增量阶段,每个阶段是一个5-way 5-shot的分类任务。
3.根据权利要求1所述的一种基于强化学习的小样本类增量识别方法,其特征在于,所述设计基于强化学习的小样本增量分类识别网络具体包括:分类器更新和探索性优化两部分。
4.根据权利要求1所述的一种基于强化学习的小样本类增量识别方法,其特征在于,所述分类器更新具体包括:
模型输入为数据集CIFAR-100,编码器φ采用ResNet20作为骨干,设Wt是某个阶段t时刻新类分类器的权重,如果每次增量阶段新类包含M个类别,则Wt=[w1,…,wk,…,wM]T,其中wk表示第k类的分类器权重,d表示特征维度;
对于来自批次数据X中的一张输入样本xi,使用余弦相似度来度量提取出的特征嵌入和对应的分类权重Wt来描述分类性能,并定义为分类器状态/>在获取当前批次数据预测状态St后,计算出当前时刻分类器的交叉熵损失:
其中,是独热编码下标签向量yi的第k个元素,/>表示分类器状态/>的第k个元素,对应新类中第k个类上样本xi的余弦距离;/>越大,表示越像第k类;
得到当前时刻分类器权重参数Wt的梯度并结合探索性优化模块得到的分类器学习策略at一起更新权重参数:
下一个分类器状态通过该次探索更新后的权重Wt+1,用样本xi计算得出;同时获得学习策略at的奖励/>其中R为奖励函数。
5.根据权利要求1所述的一种基于强化学习的小样本类增量识别方法,其特征在于,所述探索性优化具体包括:
利用批判者ψ对分类器状态St进行评估,以获得长期回报vt,与分类器一步更新的奖励rt不同;
vt从状态St开始评估分类器的性能,并使用at持续更新分类器;
下一个分类器状态St+1通过探索性批判者ψ得到返回值vt,描述了下一个状态St+1在长时间使用策略at后,对分类器性能的影响;
使用强化学习中单步差分的思想指导批判者ψ的学习:
其中,γ是一个比例系数,用来缩放未来的收益;
当前差异和当前分类器策略at的log值同时自适应地监督高斯行动者探索优化:
作为监督信号行动者损失La引导高斯行动者更新到/>
下一个分类器状态St+1通过更新后的行动者学习分类器策略均值μ和方差σ,并通过高斯采样N(μ,σ)获得新的分类器学习策略at+1;
新策略at+1用于推动分类器参数Wt+1的下一次更新,交替推进分类器更新和分类器策略的探索性优化,并实现将新类分类器的更新和新类分类器的探索性优化集成到一个端到端闭环学习过程中。
6.根据权利要求4所述的一种基于强化学习的小样本类增量识别方法,其特征在于,所述奖励函数R具体包括:
评估新类中类内紧致性的奖励rintra:
其中表示取对应于标签yi的操作;
评估新类数据的类间分散的奖励rinter:
当新的分类器参数Wt+1使输入样本xi越接近其他类时,奖励分数rinter变得越低,取变化最大的一项作为惩罚;
评估新旧类之间分散的奖励rcross:
其中,Yold表示当前阶段的旧类标签。样本xi的特征嵌入越接近旧类的分类器,奖励rcross值就越小,表示错误分类的情况越多;
rintra表示相比于前一个分类器,更新后的分类器正确分类概率与前一个分类器相比的增益,其中正确分类概率与余弦度量成正比;
rinter表示更新前后最大误分类概率的变化;
rcross表示新类样本被错误分类到旧类的概率值之和,rcross越大,表示旧类的知识能被保留的越多;
强化学习的总奖励rt(xi)形式化:rt(xi)=rintra(xi)+rintra(xi)+rcross(xi)。
7.根据权利要求4所述的一种基于强化学习的小样本类增量识别方法,其特征在于,所述利用训练数据对所述网络进行训练,获得训练网络具体包括:
在基础训练阶段,使用大量的基类数据对基础模型进行训练;
在增量阶段,冻结模型的骨干参数,只优化分类器参数;
使用Adam优化器对模型进行了2000个回合训练,学习率为0.0003;
在整个训练过程中,进行数据增强。
8.根据权利要求7所述的一种基于强化学习的小样本类增量识别方法,其特征在于,所述进行数据增强的方法具体包括:采用随机裁剪、随机缩放和随机水平翻转方法进行数据增强。
9.根据权利要求1所述的一种基于强化学习的小样本类增量识别方法,其特征在于,所述识别方法还包括:
每个增量阶段结束后,将训练好的模型在相应的测试集上进行评估,并报告分类准确率。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310688597.2A CN116681945A (zh) | 2023-06-12 | 2023-06-12 | 一种基于强化学习的小样本类增量识别方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310688597.2A CN116681945A (zh) | 2023-06-12 | 2023-06-12 | 一种基于强化学习的小样本类增量识别方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116681945A true CN116681945A (zh) | 2023-09-01 |
Family
ID=87790551
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310688597.2A Pending CN116681945A (zh) | 2023-06-12 | 2023-06-12 | 一种基于强化学习的小样本类增量识别方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116681945A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117011672A (zh) * | 2023-09-27 | 2023-11-07 | 之江实验室 | 基于类特定元提示学习的小样本类增对象识别方法和装置 |
CN117975203A (zh) * | 2024-04-02 | 2024-05-03 | 山东大学 | 基于数据增强的小样本图像类增量学习方法及系统 |
-
2023
- 2023-06-12 CN CN202310688597.2A patent/CN116681945A/zh active Pending
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117011672A (zh) * | 2023-09-27 | 2023-11-07 | 之江实验室 | 基于类特定元提示学习的小样本类增对象识别方法和装置 |
CN117011672B (zh) * | 2023-09-27 | 2024-01-09 | 之江实验室 | 基于类特定元提示学习的小样本类增对象识别方法和装置 |
CN117975203A (zh) * | 2024-04-02 | 2024-05-03 | 山东大学 | 基于数据增强的小样本图像类增量学习方法及系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US11341424B2 (en) | Method, apparatus and system for estimating causality among observed variables | |
CN110427654B (zh) | 一种基于敏感状态的滑坡预测模型构建方法及系统 | |
CN116681945A (zh) | 一种基于强化学习的小样本类增量识别方法 | |
Ayodeji et al. | Causal augmented ConvNet: A temporal memory dilated convolution model for long-sequence time series prediction | |
CN114548591B (zh) | 一种基于混合深度学习模型和Stacking的时序数据预测方法及系统 | |
CN108399434B (zh) | 基于特征提取的高维时间序列数据的分析预测方法 | |
CN111832228A (zh) | 基于cnn-lstm的振动传递系统 | |
CN118070682B (zh) | 基于人工智能的螺旋栓吊环受损评估方法及装置 | |
CN112766603A (zh) | 一种交通流量预测方法、系统、计算机设备及存储介质 | |
CN116258877A (zh) | 土地利用场景相似度变化检测方法、装置、介质及设备 | |
CN112001115A (zh) | 一种半监督动态软测量网络的软测量建模方法 | |
CN115761868A (zh) | 不确定环境下人脸表情分类的鲁棒自适应更新方法 | |
CN112101482B (zh) | 一种对有缺失卫星数据进行参数异常模式检测的方法 | |
CN117635718B (zh) | 一种基于图像增强的弱光环境下矿车定位方法及系统 | |
CN117494573B (zh) | 一种风速预测方法、系统及电子设备 | |
CN107437112A (zh) | 一种基于改进多尺度核函数的混合rvm模型预测方法 | |
CN116701875A (zh) | 一种特高压交流输电线路可听噪声概率预测方法及系统 | |
CN116340384A (zh) | 基于规则演化的核递归最大相关熵时间序列在线预测方法 | |
CN116579408A (zh) | 一种基于模型结构冗余度的模型剪枝方法及系统 | |
CN116306292A (zh) | 一种基于卷积神经网络的水电站物理场级数字孪生模型构建方法 | |
CN115048856A (zh) | 一种基于ms-alstm的滚动轴承剩余寿命预测方法 | |
CN114332594B (zh) | 一种基于ddqn的触觉材料不平衡数据的分类方法 | |
CN118378178B (zh) | 基于残差图卷积神经网络的变压器故障识别方法及系统 | |
CN118552907B (zh) | 一种基于优选下采样尺度的周界入侵人员检测方法及系统 | |
CN118393368B (zh) | 一种电池续航能力评估方法、装置、存储介质及设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |