Nothing Special   »   [go: up one dir, main page]

CN112598082B - 基于非校验集预测图像识别模型泛化误差的方法及系统 - Google Patents

基于非校验集预测图像识别模型泛化误差的方法及系统 Download PDF

Info

Publication number
CN112598082B
CN112598082B CN202110017334.XA CN202110017334A CN112598082B CN 112598082 B CN112598082 B CN 112598082B CN 202110017334 A CN202110017334 A CN 202110017334A CN 112598082 B CN112598082 B CN 112598082B
Authority
CN
China
Prior art keywords
training
image recognition
recognition model
output
model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110017334.XA
Other languages
English (en)
Other versions
CN112598082A (zh
Inventor
伍冬睿
张潇
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Huazhong University of Science and Technology
Original Assignee
Huazhong University of Science and Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Huazhong University of Science and Technology filed Critical Huazhong University of Science and Technology
Priority to CN202110017334.XA priority Critical patent/CN112598082B/zh
Publication of CN112598082A publication Critical patent/CN112598082A/zh
Application granted granted Critical
Publication of CN112598082B publication Critical patent/CN112598082B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Computing Systems (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于非校验集预测图像识别模型泛化误差的方法及系统,属于深度学习优化与泛化领域,方法包括:在每一个训练回合结束后,随机采样K组训练图片,使用模型优化器计算K组训练图片对应图像识别模型的参数更新量;利用参数更新量,得到对应的K个更新后的模型,并记录K个更新后的模型对各张训练图片的输出;计算各张训练图片的输出的方差值,使用输出模长对方差值进行归一化,得到输出相对方差;以输出相对方差预测图像识别模型的泛化误差在训练过程中的变化趋势。如此,本发明不需要使用校验集故能够将所有的训练样本投入训练,从而获得更好的泛化性能;另外该过程只需要训练一轮神经网络,减少了多次训练带来的能量与硬件的损耗。

Description

基于非校验集预测图像识别模型泛化误差的方法及系统
技术领域
本发明属于深度学习优化与泛化领域,更具体地,涉及一种基于非校验集预测图像识别模型泛化误差的方法及系统。
背景技术
机器学习作为目前人工智能的研究热点,常用于挖掘数据之间的潜在关系。近几年,基于数据驱动的机器学习算法在生物、医疗、金融、军事等各个领域都取得了卓越的成绩。随着数据与算力的提升,深度学习作为一种能很好处理图像的机器学习算法,成为了目前的研究热点并被广泛应用于各行各业。
虽然深度学习在图像识别的任务上具有良好的表现,但是其尚存在着诸多问题亟待解决与研究。用于图像识别的神经网络模型在训练过程中存在着复杂的泛化现象,如现有技术中提到的训练过程中的测试误差二次下降现象:随着训练回合数的增加,神经网络在图像测试集上的误差先下降,然后由于过拟合开始上升,最后在某个时候又会再次下降。这些复杂的泛化现象使得在训练过程中预测模型泛化误差的变化趋势尤为重要。目前最常用的预测手段为将图像训练集划分出一部分作为校验集,然后图像识别模型在剩下的训练集上进行训练而在校验集上计算误差从而来预测测试误差的变化趋势,最后通过预测的测试误差变化趋势来进行如早停等其他下游处理。
虽然使用校验集信息来预测图像识别模型训练过程中泛化误差曲线的方法简单实用,但是校验集划去了部分的训练图片,往往使得预测的泛化误差曲线跟实际中使用所有训练样本训练时的泛化误差曲线不太一致,从而影响到后续的早停等处理;除此之外,由于划分校验集而导致的训练图片数量的减少也常常会带来泛化性能的下降。后者可以通过两轮的训练来减轻,即先将训练集划分出一部分作为校验集,然后通过检验校验集上的结果来确定训练的回合数,最后将校验集并入训练集整体从而在所有图片上训练同样的回合数;但是由此增加的训练代价又会使得硬件与能源的损耗成为新的问题,同时这种流程依然没有办法保证训练图片数量不同的情况下泛化误差曲线一致变化。
发明内容
针对现有技术的以上缺陷或改进需求,本发明提供了一种基于非校验集预测图像识别模型泛化误差的方法及系统,由此解决现有图像识别模型训练过程中使用校验集预测泛化性能时存在的多次训练代价大、预测不准的技术问题。
为实现上述目的,一方面,本发明提供了一种基于非校验集预测图像识别模型泛化误差的方法,包括以下步骤:
(1)在每一个训练回合结束后,随机采样K组训练图片,使用模型优化器计算所述K组训练图片对应的图像识别模型的参数更新量;
(2)利用所述参数更新量,得到对应的K个更新后的模型,并记录所述K个更新后的模型对各张训练图片的输出;
(3)计算所述各张训练图片的输出的方差值,使用输出模长对所述方差值进行归一化,得到输出相对方差;以所述输出相对方差预测所述图像识别模型的泛化误差在训练过程中的变化趋势。
进一步地,所述图像识别模型的参数更新量为参数更新梯度。
进一步地,所述模型优化器包括ADAM优化器、SGD优化器。
进一步地,所述输出相对方差RV表示为:
Figure BDA0002887424320000031
其中,n为图片样本数,i=1,2,……,n,j=1,2,……,K,f表示图像识别模型。
另一方面,本发明提供了一种基于非校验集预测图像识别模型泛化误差的系统,包括:
第一计算模块,用于在每一个训练回合结束后,随机采样K组训练图片,使用模型优化器计算所述K组训练图片对应的图像识别模型的参数更新量;
更新模块,用于利用所述参数更新量,得到对应的K个更新后的模型,并记录所述K个更新后的模型对各张训练图片的输出;
第二计算模块,用于计算所述各张训练图片的输出的方差值,使用输出模长对所述方差值进行归一化,得到输出相对方差;以所述输出相对方差预测所述图像识别模型的泛化误差在训练过程中的变化趋势。
进一步地,所述图像识别模型的参数更新量为参数更新梯度。
进一步地,所述模型优化器包括ADAM优化器、SGD优化器。
进一步地,所述输出相对方差RV表示为:
Figure BDA0002887424320000032
其中,n为图片样本数,i=1,2,……,n,j=1,2,……,K,f表示图像识别模型。
总体而言,通过本发明所构思的以上技术方案与现有技术相比,能够取得下列有益效果:
本发明以输出相对方差来预测图像识别模型的泛化误差在训练过程中的变化趋势,可以直接在训练集上进行估计,并且能够较为准确的判定图像识别模型训练过程中泛化误差曲线的变化趋势。同时,该过程由于并不需要使用校验集故能够将所有的训练图片投入训练,从而获得更好的泛化性能;另外该过程只需要训练一轮神经网络,减少了多次训练带来的能量与硬件的损耗。
附图说明
图1是本发明提出的模型输出相对方差的计算流程简图;
图2是不同标签噪声下(即随机扰乱不同比例的标签)神经网络模型VGG16在数据集CIFAR100上训练时的测试误差曲线以及使用训练集计算出来的RV曲线;
图3是CIFAR10数据集上不同宽度的ResNet18其对应的RV曲线以及测试准确率曲线。
具体实施方式
为了使本发明的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本发明进行进一步详细说明。应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。此外,下面所描述的本发明各个实施方式中所涉及到的技术特征只要彼此之间未构成冲突就可以相互组合。
参阅图1,本发明提供了一种基于非校验集预测图像识别模型泛化误差的方法,包括以下步骤:
(1)在每一个训练回合结束后,随机采样K组训练图片,使用模型优化器计算所述K组训练图片对应的图像识别模型的参数更新量;
(2)利用所述参数更新量,得到对应的K个更新后的模型,并记录所述K个更新后的模型对各张训练图片的输出;
(3)计算所述各张训练图片的输出的方差值,使用输出模长对所述方差值进行归一化,得到输出相对方差;以所述输出相对方差预测所述图像识别模型的泛化误差在训练过程中的变化趋势。
具体的,以使用包含n个样本的训练数据集
Figure BDA0002887424320000051
来训练模型f为例(如CIFAR10中n=50000),在每个训练回合之后,从训练数据集D中随机采样K组包含B个训练样本的训练批次(例K=100~150,B=128或256),然后使用训练模型时的如ADAM(学习率为1e-3或1e-4)、SGD(学习率为1e-2或1e-3,动量为0.9)等优化器根据各个训练批次求取对应的模型参数更新量,从而得到对应的K个更新后的模型
Figure BDA0002887424320000052
计算这K个模型在训练样本上的模型相对方差值:
Figure BDA0002887424320000053
实验表明,RV的值与模型的泛化性能在训练过程中具有相同的变化趋势,故可以不需要划分验证集而直接通过使用RV的值来预测模型的泛化性能。
上述求取RV指标的过程需要多次计算模型参数更新量,使得计算变得相对繁琐复杂。一种简化的方案是使用直接采样的随机噪声(如神经网络每层使用均值为0且方差为0.001倍该层参数模长的高斯噪声)替代需要计算的模型参数更新量,从而大大降低了计算量。需要指出的是,虽然这种方案具有更加简单的计算方式,但是其在部分数据集上(如CIFAR100)并不有效。该简化方案通常只对类别数较少(一般小于20类)的简单数据集有效。
RV除了可以用来预测训练过程中单一模型的泛化性能曲线,也可以用来预测架构逐渐变化时泛化性能的改变。例如不同宽度的ResNet18在CIFAR10上训练相同的回合数后,使用不带动量的SGD优化器(学习率为1e-3)分别计算其对应的RV,便可以预测相应的测试准确率变化。实验结果表明RV与准确率具有极高的相关性,可一定程度上预测ResNet18随着宽度变化其泛化性能的变化趋势。
图1展示了模型相对方差的计算流程简图。在训练数据集中采样不同的训练批次来计算其对应的模型参数更新量,然后估计分别使用这些参数更新后的各个模型对同一个训练样本点输出的方差,使用输出模长进行归一化后求取该值在训练样本点上的期望,从而得到输出相对方差指标。通过在不同训练阶段估计该指标并记录其在训练过程中的变化趋势,便可以获得泛化误差的变化趋势。
图2展示了不同标签噪声下(即随机扰乱不同比例的标签)神经网络模型VGG16在数据集CIFAR100上训练时的测试误差曲线以及使用训练集计算出来的RV曲线。可以发现这两条曲线在竖直方向上对称,实验结果说明训练过程中RV能够很好的预测模型的泛化性能的变化曲线。
图3展示了CIFAR10数据集上不同宽度的ResNet18其对应的RV以及测试准确率。宽度分别为原始模型宽度的0.25倍-2.0倍,并且使用ADAM优化器(学习率为1e-4)训练100个回合。经过计算RV与测试准确率相关度为-0.94,显著性检验p值为0.0006,该结果表明RV对不同宽度的模型测试准确率也具有较好的预测效果。
另一方面,本发明提供了一种基于非校验集预测图像识别模型泛化误差的系统,包括:
第一计算模块,用于在每一个训练回合结束后,随机采样K组训练图片,使用模型优化器计算所述K组训练图片对应的图像识别模型的参数更新量;
更新模块,用于利用所述参数更新量,得到对应的K个更新后的模型,并记录所述K个更新后的模型对各张训练图片的输出;
第二计算模块,用于计算所述各张训练图片的输出的方差值,使用输出模长对所述方差值进行归一化,得到输出相对方差;以所述输出相对方差预测所述图像识别模型的泛化误差在训练过程中的变化趋势。
上述基于非校验集预测图像识别模型泛化误差的系统中各个模块的划分仅用于举例说明,在其他实施例中,可将基于非校验集预测图像识别模型泛化误差的系统按照需要划分为不同的模块,以完成上述系统的全部或部分功能。
本领域的技术人员容易理解,以上所述仅为本发明的较佳实施例而已,并不用以限制本发明,凡在本发明的精神和原则之内所作的任何修改、等同替换和改进等,均应包含在本发明的保护范围之内。

Claims (6)

1.一种基于非校验集预测图像识别模型泛化误差的方法,其特征在于,包括以下步骤:
(1)在每一个训练回合结束后,随机采样K组训练图片,使用模型优化器计算所述K组训练图片对应的图像识别模型的参数更新量;
(2)利用所述参数更新量,得到对应的K个更新后的模型,并记录所述K个更新后的模型对各张训练图片的输出;
(3)计算所述各张训练图片的输出的方差值,使用输出模长对所述方差值进行归一化,得到输出相对方差;以所述输出相对方差预测所述图像识别模型的泛化误差在训练过程中的变化趋势;
所述输出相对方差RV表示为:
Figure FDA0003648364390000011
其中,n为图片样本数,i=1,2,……,n,j=1,2,……,K,f表示图像识别模型。
2.如权利要求1所述的基于非校验集预测图像识别模型泛化误差的方法,其特征在于,所述图像识别模型的参数更新量为参数更新梯度。
3.如权利要求1所述的基于非校验集预测图像识别模型泛化误差的方法,其特征在于,所述模型优化器包括ADAM优化器、SGD优化器。
4.一种基于非校验集预测图像识别模型泛化误差的系统,其特征在于,包括:
第一计算模块,用于在每一个训练回合结束后,随机采样K组训练图片,使用模型优化器计算所述K组训练图片对应的图像识别模型的参数更新量;
更新模块,用于利用所述参数更新量,得到对应的K个更新后的模型,并记录所述K个更新后的模型对各张训练图片的输出;
第二计算模块,用于计算所述各张训练图片的输出的方差值,使用输出模长对所述方差值进行归一化,得到输出相对方差;以所述输出相对方差预测所述图像识别模型的泛化误差在训练过程中的变化趋势;
所述输出相对方差RV表示为:
Figure FDA0003648364390000021
其中,n为图片样本数,i=1,2,……,n,j=1,2,……,K,f表示图像识别模型。
5.如权利要求4所述的基于非校验集预测图像识别模型泛化误差的系统,其特征在于,所述图像识别模型的参数更新量为参数更新梯度。
6.如权利要求4所述的基于非校验集预测图像识别模型泛化误差的系统,其特征在于,所述模型优化器包括ADAM优化器、SGD优化器。
CN202110017334.XA 2021-01-07 2021-01-07 基于非校验集预测图像识别模型泛化误差的方法及系统 Active CN112598082B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110017334.XA CN112598082B (zh) 2021-01-07 2021-01-07 基于非校验集预测图像识别模型泛化误差的方法及系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110017334.XA CN112598082B (zh) 2021-01-07 2021-01-07 基于非校验集预测图像识别模型泛化误差的方法及系统

Publications (2)

Publication Number Publication Date
CN112598082A CN112598082A (zh) 2021-04-02
CN112598082B true CN112598082B (zh) 2022-07-12

Family

ID=75207068

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110017334.XA Active CN112598082B (zh) 2021-01-07 2021-01-07 基于非校验集预测图像识别模型泛化误差的方法及系统

Country Status (1)

Country Link
CN (1) CN112598082B (zh)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113361575B (zh) * 2021-05-28 2023-10-20 北京百度网讯科技有限公司 模型训练方法、装置和电子设备

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN103954914A (zh) * 2014-05-16 2014-07-30 哈尔滨工业大学 基于概率集成的锂离子电池剩余寿命直接预测方法
CN106169096A (zh) * 2016-06-24 2016-11-30 山西大学 一种机器学习系统学习性能的评估方法
CN106951959A (zh) * 2017-01-24 2017-07-14 上海交通大学 基于学习自动机的深度神经网络优化方法
CN112115973A (zh) * 2020-08-18 2020-12-22 吉林建筑大学 一种基于卷积神经网络图像识别方法

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20200327450A1 (en) * 2019-04-15 2020-10-15 Apple Inc. Addressing a loss-metric mismatch with adaptive loss alignment

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN103954914A (zh) * 2014-05-16 2014-07-30 哈尔滨工业大学 基于概率集成的锂离子电池剩余寿命直接预测方法
CN106169096A (zh) * 2016-06-24 2016-11-30 山西大学 一种机器学习系统学习性能的评估方法
CN106951959A (zh) * 2017-01-24 2017-07-14 上海交通大学 基于学习自动机的深度神经网络优化方法
CN112115973A (zh) * 2020-08-18 2020-12-22 吉林建筑大学 一种基于卷积神经网络图像识别方法

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
Analysis of Variance of Cross-Validation Estimators of the Generalization Error;Marianthi Markatou等;《Journal of Machine Learning Research》;20051231;第6卷;第1127-1168页 *
基于样本抽样和权重调整的SWA-Adaboost算法;高敬阳 等;《计算机工程》;20140930;第40卷(第9期);第248-251、256页 *

Also Published As

Publication number Publication date
CN112598082A (zh) 2021-04-02

Similar Documents

Publication Publication Date Title
WO2018196760A1 (en) Ensemble transfer learning
Ardia Financial risk management with Bayesian estimation of GARCH models theory and applications
CN113196314B (zh) 适配预测模型
CN109993236B (zh) 基于one-shot Siamese卷积神经网络的少样本满文匹配方法
Zheng et al. Resolving the bias in electronic medical records
CN111144542A (zh) 油井产能预测方法、装置和设备
US20210271980A1 (en) Deterministic decoder variational autoencoder
Zheng et al. Capturing feature-level irregularity in disease progression modeling
CN111047078B (zh) 交通特征预测方法、系统及存储介质
CN112000808B (zh) 一种数据处理方法及装置、可读存储介质
Zhou et al. Disentangled network alignment with matching explainability
CN112084330A (zh) 一种基于课程规划元学习的增量关系抽取方法
CN112884570A (zh) 一种模型安全性的确定方法、装置和设备
CN112041880B (zh) 用于评估信用风险的深度学习方法
CN112598082B (zh) 基于非校验集预测图像识别模型泛化误差的方法及系统
CN116525117A (zh) 一种面向数据分布漂移检测与自适应的临床风险预测系统
Lipkovich et al. Modern approaches for evaluating treatment effect heterogeneity from clinical trials and observational data
CN115659978A (zh) 多注意力机制跨度级小样本命名实体识别方法
CN114495114B (zh) 基于ctc解码器的文本序列识别模型校准方法
CN117371493A (zh) 基于多尺度tcn的数据预测方法、系统、终端及存储介质
Pálsson et al. Semi-supervised variational autoencoder for survival prediction
CN115905545A (zh) 一种基于变分自编码器的无监督读者书评情感分析方法
CN111581469B (zh) 基于多子空间表示的偏多标记学习方法
CN114971059A (zh) 基于时间演化图建模动态交互的行为预测系统及方法
CN115131816A (zh) 一种基于掩膜遮蔽的深度学习模型去偏方法及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant