Nothing Special   »   [go: up one dir, main page]

CN114092918A - 模型训练方法、装置、设备及存储介质 - Google Patents

模型训练方法、装置、设备及存储介质 Download PDF

Info

Publication number
CN114092918A
CN114092918A CN202210024447.7A CN202210024447A CN114092918A CN 114092918 A CN114092918 A CN 114092918A CN 202210024447 A CN202210024447 A CN 202210024447A CN 114092918 A CN114092918 A CN 114092918A
Authority
CN
China
Prior art keywords
model
target
student
training
data
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210024447.7A
Other languages
English (en)
Inventor
袁振国
刘国清
杨广
王启程
朱爱晨
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen Minieye Innovation Technology Co Ltd
Original Assignee
Shenzhen Minieye Innovation Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen Minieye Innovation Technology Co Ltd filed Critical Shenzhen Minieye Innovation Technology Co Ltd
Priority to CN202210024447.7A priority Critical patent/CN114092918A/zh
Publication of CN114092918A publication Critical patent/CN114092918A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/082Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Computing Systems (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本申请公开了一种模型训练方法、装置、设备及存储介质,通过获取训练数据集,并利用已标注数据,预设的老师模型进行训练,直至老师模型达到预设的第一收敛条件,得到目标老师模型,以使老师模型学习到更多更深层的模型特征;再对目标老师模型和学生模型进行BN层权重共享,以能够利用目标老师模型指导学生模型进行训练,从而使学生模型能够具备目标老师模型的BN层权重进行训练;最后利用已标注数据和未标注数据,对学生模型和目标老师模型进行联合训练,直至学生模型达到预设的第二收敛条件,得到目标学生模型,以在保持较低模型复杂度时,提高学生模型的表达能力,从而能够有效压缩模型,降低计算资源消耗以及降低人工标注的人力成本。

Description

模型训练方法、装置、设备及存储介质
技术领域
本申请涉及人工智能技术领域,尤其涉及一种模型训练方法、装置、设备及存储介质。
背景技术
随着人工智能的快速发展,卷积神经网络广泛应用于车辆驾驶领域,如车辆检测和车道线检测等。其中,训练卷积神经网络需要大量高质量标注数据,以得到高复杂度模型,从而提高模型准确率。但是大量标注数据需要高额存储空间,训练过程也需要耗费巨额计算资源。
目前,由于成本限制,初级智能辅助驾驶系统往往采用计算力相对较低的计算平台,采用高复杂度模型会带来高延时问题。因此,如何将高复杂度模型压缩至端侧计算平台的可接受程度是亟需解决的问题。
发明内容
本申请提供了一种模型训练方法、装置、设备及存储介质,以解决卷积神经网络存在计算资源消耗大的技术问题。
为了解决上述技术问题,第一方面,本申请实施例提供了一种模型训练方法,包括:
获取训练数据集,训练数据集包括已标注数据和未标注数据;
利用已标注数据,对预设的老师模型进行训练,直至老师模型达到预设的第一收敛条件,得到目标老师模型;
对目标老师模型和学生模型进行BN层权重共享,目标老师模型的模型复杂度大于学生模型的模型复杂度;
利用已标注数据和未标注数据,对学生模型和目标老师模型进行联合训练,直至学生模型达到预设的第二收敛条件,得到目标学生模型,目标学生模型能够用于部署到端侧计算平台。
本实施例通过获取训练数据集,并利用已标注数据,预设的老师模型进行训练,直至老师模型达到预设的第一收敛条件,得到目标老师模型,以使老师模型学习到更多更深层的模型特征;再对目标老师模型和学生模型进行BN层权重共享,以能够利用目标老师模型指导学生模型进行训练,从而使学生模型能够具备目标老师模型的BN层权重进行训练;最后利用已标注数据和未标注数据,对学生模型和目标老师模型进行联合训练,直至学生模型达到预设的第二收敛条件,得到目标学生模型,以能够在保持较低模型复杂度的情况下,提高学生模型的表达能力,从而能够有效压缩模型,进而降低计算资源消耗以及降低人工标注的人力成本。
在一实施例中,目标老师模型和学生模型均有多个BN层,对目标老师模型和学生模型进行BN层权重共享,包括:
将目标老师模型的多级BN层权重共享至学生模型。
本实施例通过多级BN层权重共享,以使学生模型在训练阶段能够高效汲取目标老师模型的特征表达能力,从而有效解决学生模型卷积网络层少而导致表达能力差的问题。
在一实施例中,学生模型和目标老师模型在联合训练时,固定目标老师模型的多级BN层权重和学生模型的多级BN层权重。
本实施例通过固定BN层权重,以避免BN层权重更新对学生模型带来不利影响。
在一实施例中,利用已标注数据和未标注数据,对学生模型和目标老师模型进行联合训练,直至学生模型达到预设的第二收敛条件,得到目标学生模型,包括:
将训练数据集作为学生模型和目标老师模型的输入数据,输出学生模型的第一预测结果和目标老师模型的第二预测结果;
根据第一预测结果和第二预测结果,计算目标损失函数的总损失值;
根据总损失值,更新学生模型,直至学生模型收敛,得到目标学生模型。
本实施例通过已标注数据和未标注数据,对目标老师模型和学生模型进行联合训练,能够降低人工标注成本,以及统一学生模型和目标老师模型的输入数据分布,提高学生模型的表达能力。
在一实施例中,根据第一预测结果和第二预测结果,计算目标损失函数的总损失值,包括:
确定输入数据的数据类型,数据类型为已标注数据或未标注数据;
根据数据类型,计算目标损失函数的总损失值。
在一实施例中,目标损失函数为:
Figure 774763DEST_PATH_IMAGE001
其中,
Figure 398424DEST_PATH_IMAGE002
为所述学生模型的预测损失函数,
Figure DEST_PATH_IMAGE003
为所述目标老师模型的预测损失函数,
Figure 766258DEST_PATH_IMAGE004
为所述第一预测结果,
Figure DEST_PATH_IMAGE005
为所述第二预测结果,
Figure 853032DEST_PATH_IMAGE006
为所述第一预测结果与所述第二预测结果之间的均方误差,
Figure DEST_PATH_IMAGE007
为所述数据类型,若所述数据类型为已标注数据,则
Figure 356825DEST_PATH_IMAGE008
,若所述数据类型为未标注数据,则
Figure DEST_PATH_IMAGE009
在一实施例中,根据总损失值,更新学生模型,直至学生模型收敛,得到目标学生模型,包括:
若总损失值不小于预设阈值,则更新学生模型的第一特征层权重和目标老师模型的第二特征层权重,得到新的学生模型和新的目标老师模型;
利用新的学生模型和新的目标老师模型,对训练数据集进行预测,直至总损失值小于预设阈值,得到目标学生模型。
第二方面,本申请实施例提供一种模型训练装置,包括:
获取模块,用于获取训练数据集,训练数据集包括已标注数据和未标注数据;
第一训练模块,用于利用已标注数据,预设的老师模型进行训练,直至老师模型达到预设的第一收敛条件,得到目标老师模型;
共享模块,用于对目标老师模型和学生模型进行BN层权重共享,目标老师模型的模型复杂度大于学生模型的模型复杂度;
第二训练模块,用于利用已标注数据和未标注数据,对学生模型和目标老师模型进行联合训练,直至学生模型达到预设的第二收敛条件,得到目标学生模型,目标学生模型能够用于部署到端侧计算平台。
第三方面,本申请实施例提供一种计算机设备,包括处理器和存储器,存储器用于存储计算机程序,计算机程序被处理器执行时实现如第一方面的模型训练方法。
第四方面,本申请实施例提供一种计算机可读存储介质,其存储有计算机程序,计算机程序被处理器执行时实现如第一方面的模型训练方法。
需要说明的是,上述第二方面至第四方面的有益效果请参见第一方面的相关描述,在此不再赘述。
附图说明
图1为本申请实施例提供的模型训练方法的流程示意图;
图2为本申请实施例提供的多级权重共享的示意图;
图3为本申请实施例提供的模型训练装置的结构示意图;
图4为本申请实施例提供的计算机设备的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
如相关技术记载,由于成本限制,初级智能辅助驾驶系统往往采用计算力相对较低的计算平台,采用高复杂度模型会带来高延时问题。
为此,本申请实施例提供一种模型训练方法、装置、设备及存储介质,通过获取训练数据集,并利用所述已标注数据,对预设的老师模型进行训练,直至所述老师模型达到预设的第一收敛条件,得到目标老师模型,以使老师模型学习到更多更深层的模型特征;再对所述目标老师模型和学生模型进行BN层权重共享,以能够利用目标老师模型指导学生模型进行训练,从而使学生模型能够具备目标老师模型的BN层权重进行训练;最后利用所述已标注数据和所述未标注数据,对所述学生模型和所述目标老师模型进行联合训练,直至所述学生模型达到预设的第二收敛条件,得到目标学生模型,以能够在保持较低模型复杂度的情况下,提高学生模型的表达能力,从而能够有效压缩模型,进而降低计算资源消耗以及降低人工标注的人力成本。
请参照图1,图1为本申请实施例提供的一种模型训练方法的流程示意图。本申请实施例的模型训练方法能够应用于计算机设备,该计算机设备包括但不限于智能手机、平板电脑、笔记本电脑、桌上型计算机、物理服务器和云端服务器等设备。如图1所示,本实施例的模型训练方法包括步骤S101至步骤S104,详述如下:
步骤S101,获取训练数据集,所述训练数据集包括已标注数据和未标注数据。
在本步骤中,构建训练数据集和联合训练模型,联合训练模型包括老师模型和学生模型。可选地,针对待解决的分类问题,采集实际场景数据,对于可持续获得原始数据的场景问题,考虑到人工标注成本,按照一定规则抽取部分数据进行人工标注,比如等间隔抽取。
可选地,根据实际部署的硬件平台和场景需求,优先建立满足存储要求和计算延迟要求的学生模型,再增加学生模型的复杂度以作为老师模型。
步骤S102,利用所述已标注数据,对预设的老师模型进行训练,直至所述老师模型达到预设的第一收敛条件,得到目标老师模型。
在本步骤中,利用已标注数据将老师模型训练至收敛,使得老师模型具有满足甚至超过解决实际问题的精度.可选地,第一收敛条件可以是老师模型的损失函数小于预设值,或老师模型的迭代次数达到预设次数。
步骤S103,对所述目标老师模型和学生模型进行BN层权重共享,所述目标老师模型的模型复杂度大于所述学生模型的模型复杂度。
在本步骤中,目标老师模型和学生模型均有特征层和批量归一化(BatchNormalization,BN)层。本实施例的BN层权重共享为将目标老师模型的BN层权重赋值给学生模型,以使学生模型的BN层权重与目标老师模型相同,从而使学生模型能够具备目标老师模型的特征表达能力。
步骤S104,利用所述已标注数据和所述未标注数据,对所述学生模型和所述目标老师模型进行联合训练,直至所述学生模型达到预设的第二收敛条件,得到目标学生模型,所述目标学生模型能够用于部署到端侧计算平台。
在本步骤中,通过已标注数据和未标注数据共同对学生模型和目标老师模型进行训练,并计算学生模型与目标老师模型之间的总损失值,当该总损失值小于预设阈值时,学生模型达到第二收敛条件。本实施例利用BN层权重共享方式以及联合训练方法,统一学生模型和老师模型的输入数据分布,提高学生模型的表达能力,从而使得端侧部署模型在不增加复杂度的情况下有效提升精度,同时降低计算资源消耗。
在一实施例中,在图1所示实施例的基础上,上述步骤S103,包括:
将所述目标老师模型的多级BN层权重共享至所述学生模型。
在本步骤中,如图2所示的多级BN层权重共享,以常用的四层特征提取神经网络为例,特征提取层用于提取输入数据的特征,包括卷积层、池化层、和活函数层等,以提高模型的非线性特征表达能力,池化层能够降低特征维度,丰富卷积计算后输出的特征信息。以图2左侧学生模型为基础,BN层权重共享关系如下:<BN层1,BN层5>,<BN层2,BN层6>,<BN层3,BN层7>,<BN层4,BN层8>。可以理解的时,老师模型的特征提取层5,特征提取层6,特征提取层7,特征提取层8,相对学生模型的特征提取层1,特征提取层2,特征提取层3,特征提取层4,具有更多的卷积计算和激活函数,从而能够更好的解决实际应用需求。
本实施例通过多级BN层权重共享,以使学生模型在训练阶段能够高效汲取目标老师模型的特征表达能力,从而有效解决学生模型卷积网络层少而导致表达能力差的问题。
可选地,所述学生模型和所述目标老师模型在联合训练时,固定所述目标老师模型的多级BN层权重和所述学生模型的多级BN层权重。
在本可选实施例中,学生模型和老师模型共享BN层权重,共享关系如图2所示:<BN层1,BN层5>,<BN层2,BN层6>,<BN层3,BN层7>,<BN层4,BN层8>,在联合训练阶段,BN层1,BN层2,BN层3,BN层4,BN层5,BN层6,BN层7,BN层8的权重固定,不再随联合训练的损失函数进行更新。本实施例通过固定BN层权重,以避免BN层权重更新对学生模型带来不利影响。
在一实施例中,在图1所示实施例的基础上,上述步骤S104,包括:
将所述训练数据集作为所述学生模型和所述目标老师模型的输入数据,输出所述学生模型的第一预测结果和所述目标老师模型的第二预测结果;
根据所述第一预测结果和所述第二预测结果,计算目标损失函数的总损失值;
根据所述总损失值,更新所述学生模型,直至所述学生模型收敛,得到所述目标学生模型。
在本实施例中,利用已标注数据和未标注数据作为输入数据,联合训练学生模型。可选地,所述根据所述第一预测结果和所述第二预测结果,计算目标损失函数的总损失值,包括:确定所述输入数据的数据类型,所述数据类型为已标注数据或未标注数据;根据所述数据类型,计算所述目标损失函数的总损失值。
对于每个单次训练阶段,分别计算得到学生模型和目标老师模型的分类预测结果Ps和Pt。若输入数据为已标注数据,则分别计算学生模型和目标老师模型的预测损失Ls和Lt.若输入数据为未标注数据,则不计算学生模型和老师模型的预测损失。
本实施例通过已标注数据和未标注数据,对目标老师模型和学生模型进行联合训练,能够降低人工标注成本,以及统一学生模型和目标老师模型的输入数据分布,提高学生模型的表达能力。
可选地,所述目标损失函数为:
Figure 287522DEST_PATH_IMAGE010
其中,
Figure 648096DEST_PATH_IMAGE002
为所述学生模型的预测损失函数,
Figure 222166DEST_PATH_IMAGE003
为所述目标老师模型的预测损失函数,
Figure 778918DEST_PATH_IMAGE004
为所述第一预测结果,
Figure 43415DEST_PATH_IMAGE005
为所述第二预测结果,
Figure 107315DEST_PATH_IMAGE006
为所述第一预测结果与所述第二预测结果之间的均方误差,
Figure 965419DEST_PATH_IMAGE007
为所述数据类型,若所述数据类型为已标注数据,则
Figure 76594DEST_PATH_IMAGE008
,若所述数据类型为未标注数据,则
Figure 87276DEST_PATH_IMAGE011
可选地,所述根据所述总损失值,更新所述学生模型,直至所述学生模型收敛,得到所述目标学生模型,包括:若所述总损失值不小于预设阈值,则更新所述学生模型的第一特征层权重和所述目标老师模型的第二特征层权重,得到新的所述学生模型和新的所述目标老师模型;利用新的所述学生模型和新的所述目标老师模型,对所述训练数据集进行预测,直至所述总损失值小于所述预设阈值,得到所述目标学生模型。
在本实施例中,利用计算得到的损失值,通过反向传播算法对目标老师模型和学生模型的特征提取层对应的第一特征层权重和第二特征层权重进行更新。若总损失值不小于预设阈值,则进入下一个单次训练阶段,否则,训练结束。将训练结束后的目标学生模型作为端侧部署模型。
为了执行上述方法实施例对应的模型训练方法,以实现相应的功能和技术效果。参见图3,图3示出了本申请实施例提供的一种模型训练装置的结构框图。为了便于说明,仅示出了与本实施例相关的部分,本申请实施例提供的模型训练装置,包括:
获取模块301,用于获取训练数据集,所述训练数据集包括已标注数据和未标注数据;
第一训练模块302,用于利用所述已标注数据,对预设的老师模型进行训练,直至所述老师模型达到预设的第一收敛条件,得到目标老师模型;
共享模块303,用于对所述目标老师模型和学生模型进行BN层权重共享,所述目标老师模型的模型复杂度大于所述学生模型的模型复杂度;
第二训练模块304,用于利用所述已标注数据和所述未标注数据,对所述学生模型和所述目标老师模型进行联合训练,直至所述学生模型达到预设的第二收敛条件,得到目标学生模型,所述目标学生模型能够用于部署到端侧计算平台。
在一实施例中,所述共享模块303,具体用于:
将所述目标老师模型的多级BN层权重共享至所述学生模型。
在一实施例中,所述学生模型和所述目标老师模型在联合训练时,固定所述目标老师模型的多级BN层权重和所述学生模型的多级BN层权重。
在一实施例中,所述第二训练模块304,包括:
输出单元,用于将所述训练数据集作为所述学生模型和所述目标老师模型的输入数据,输出所述学生模型的第一预测结果和所述目标老师模型的第二预测结果;
计算单元,用于根据所述第一预测结果和所述第二预测结果,计算目标损失函数的总损失值;
更新单元,用于根据所述总损失值,更新所述学生模型,直至所述学生模型收敛,得到所述目标学生模型。
在一实施例中,所述计算单元,包括:
确定子单元,用于确定所述输入数据的数据类型,所述数据类型为已标注数据或未标注数据;
计算子单元,用于根据所述数据类型,计算所述目标损失函数的总损失值。
在一实施例中,所述目标损失函数为:
Figure 38920DEST_PATH_IMAGE010
其中,
Figure 338314DEST_PATH_IMAGE002
为所述学生模型的预测损失函数,
Figure 220557DEST_PATH_IMAGE003
为所述目标老师模型的预测损失函数,
Figure 85745DEST_PATH_IMAGE004
为所述第一预测结果,
Figure 427865DEST_PATH_IMAGE012
为所述第二预测结果,
Figure 948976DEST_PATH_IMAGE006
为所述第一预测结果与所述第二预测结果之间的均方误差,
Figure 198692DEST_PATH_IMAGE007
为所述数据类型,若所述数据类型为已标注数据,则
Figure 856069DEST_PATH_IMAGE008
,若所述数据类型为未标注数据,则
Figure 873485DEST_PATH_IMAGE011
在一实施例中,所述更新单元,包括:
更新子单元,用于若所述总损失值不小于预设阈值,则更新所述学生模型的第一特征层权重和所述目标老师模型的第二特征层权重,得到新的所述学生模型和新的所述目标老师模型;
迭代子单元,用于利用新的所述学生模型和新的所述目标老师模型,对所述训练数据集进行预测,直至所述总损失值小于所述预设阈值,得到所述目标学生模型。
上述的模型训练装置可实施上述方法实施例的模型训练方法。上述方法实施例中的可选项也适用于本实施例,这里不再详述。本申请实施例的其余内容可参照上述方法实施例的内容,在本实施例中,不再进行赘述。
图4为本申请一实施例提供的计算机设备的结构示意图。如图4所示,该实施例的计算机设备400包括:至少一个处理器401(图4中仅示出一个)处理器、存储器402以及存储在所述存储器402中并可在所述至少一个处理器401上运行的计算机程序403,所述处理器401执行所述计算机程序403时实现上述任意方法实施例中的步骤。
所述计算机设备400可以是智能手机、平板电脑、桌上型计算机和云端服务器等计算设备。该计算机设备可包括但不仅限于处理器401、存储器402。本领域技术人员可以理解,图4仅仅是计算机设备400的举例,并不构成对计算机设备400的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件,例如还可以包括输入输出设备、网络接入设备等。
所称处理器401可以是中央处理单元(Central Processing Unit,CPU),该处理器401还可以是其他通用处理器、数字信号处理器 (Digital Signal Processor,DSP)、专用集成电路 (Application Specific Integrated Circuit,ASIC)、现成可编程门阵列(Field-Programmable Gate Array,FPGA) 或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
所述存储器402在一些实施例中可以是所述计算机设备400的内部存储单元,例如计算机设备400的硬盘或内存。所述存储器402在另一些实施例中也可以是所述计算机设备400的外部存储设备,例如所述计算机设备400上配备的插接式硬盘,智能存储卡(SmartMedia Card, SMC),安全数字(Secure Digital, SD)卡,闪存卡(Flash Card)等。进一步地,所述存储器402还可以既包括所述计算机设备400的内部存储单元也包括外部存储设备。所述存储器402用于存储操作系统、应用程序、引导装载程序(BootLoader)、数据以及其他程序等,例如所述计算机程序的程序代码等。所述存储器402还可以用于暂时地存储已经输出或者将要输出的数据。
另外,本申请实施例还提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现上述任意方法实施例中的步骤。
本申请实施例提供了一种计算机程序产品,当计算机程序产品在计算机设备上运行时,使得计算机设备执行时实现上述各个方法实施例中的步骤。
在本申请所提供的几个实施例中,可以理解的是,流程图或框图中的每个方框可以代表一个模块、程序段或代码的一部分,所述模块、程序段或代码的一部分包含一个或多个用于实现规定的逻辑功能的可执行指令。也应当注意的是,在有些作为替换的实现方式中,方框中所标注的功能也可以以不同于附图中所标注的顺序发生。例如,两个连续的方框实际上可以基本并行地执行,它们有时也可以按相反的顺序执行,这依所涉及的功能而定。
所述功能如果以软件功能模块的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备执行本申请各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述的具体实施例,对本申请的目的、技术方案和有益效果进行了进一步的详细说明,应当理解,以上所述仅为本申请的具体实施例而已,并不用于限定本申请的保护范围。特别指出,对于本领域技术人员来说,凡在本申请的精神和原则之内,所做的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。

Claims (10)

1.一种模型训练方法,其特征在于,包括:
获取训练数据集,所述训练数据集包括已标注数据和未标注数据;
利用所述已标注数据,对预设的老师模型进行训练,直至所述老师模型达到预设的第一收敛条件,得到目标老师模型;
对所述目标老师模型和学生模型进行BN层权重共享,所述目标老师模型的模型复杂度大于所述学生模型的模型复杂度;
利用所述已标注数据和所述未标注数据,对所述学生模型和所述目标老师模型进行联合训练,直至所述学生模型达到预设的第二收敛条件,得到目标学生模型,所述目标学生模型能够用于部署到端侧计算平台。
2.如权利要求1所述的模型训练方法,其特征在于,所述目标老师模型和所述学生模型均有多个BN层,所述对所述目标老师模型和学生模型进行BN层权重共享,包括:
将所述目标老师模型的多级BN层权重共享至所述学生模型。
3.如权利要求2所述的模型训练方法,其特征在于,所述学生模型和所述目标老师模型在联合训练时,固定所述目标老师模型的多级BN层权重和所述学生模型的多级BN层权重。
4.如权利要求1所述的模型训练方法,其特征在于,所述利用所述已标注数据和所述未标注数据,对所述学生模型和所述目标老师模型进行联合训练,直至所述学生模型达到预设的第二收敛条件,得到目标学生模型,包括:
将所述训练数据集作为所述学生模型和所述目标老师模型的输入数据,输出所述学生模型的第一预测结果和所述目标老师模型的第二预测结果;
根据所述第一预测结果和所述第二预测结果,计算目标损失函数的总损失值;
根据所述总损失值,更新所述学生模型,直至所述学生模型收敛,得到所述目标学生模型。
5.如权利要求4所述的模型训练方法,其特征在于,所述根据所述第一预测结果和所述第二预测结果,计算目标损失函数的总损失值,包括:
确定所述输入数据的数据类型,所述数据类型为已标注数据或未标注数据;
根据所述数据类型,计算所述目标损失函数的总损失值。
6.如权利要求5所述的模型训练方法,其特征在于,所述目标损失函数为:
Figure 104157DEST_PATH_IMAGE001
其中,
Figure 587091DEST_PATH_IMAGE002
为所述学生模型的预测损失函数,
Figure 872710DEST_PATH_IMAGE003
为所述目标老师模型的预测损失函数,
Figure 48477DEST_PATH_IMAGE004
为所述第一预测结果,
Figure 285292DEST_PATH_IMAGE005
为所述第二预测结果,
Figure 384966DEST_PATH_IMAGE006
为所述第一预测结果与所述第二预测结果之间的均方误差,
Figure 508780DEST_PATH_IMAGE007
为所述数据类型,若所述数据类型为已标注数据,则
Figure 172891DEST_PATH_IMAGE008
,若所述数据类型为未标注数据,则
Figure 460784DEST_PATH_IMAGE009
7.如权利要求4所述的模型训练方法,其特征在于,所述根据所述总损失值,更新所述学生模型,直至所述学生模型收敛,得到所述目标学生模型,包括:
若所述总损失值不小于预设阈值,则更新所述学生模型的第一特征层权重和所述目标老师模型的第二特征层权重,得到新的所述学生模型和新的所述目标老师模型;
利用新的所述学生模型和新的所述目标老师模型,对所述训练数据集进行预测,直至所述总损失值小于所述预设阈值,得到所述目标学生模型。
8.一种模型训练装置,其特征在于,包括:
获取模块,用于获取训练数据集,所述训练数据集包括已标注数据和未标注数据;
第一训练模块,用于利用所述已标注数据,对预设的老师模型进行训练,直至所述老师模型达到预设的第一收敛条件,得到目标老师模型;
共享模块,用于对所述目标老师模型和学生模型进行BN层权重共享,所述目标老师模型的模型复杂度大于所述学生模型的模型复杂度;
第二训练模块,用于利用所述已标注数据和所述未标注数据,对所述学生模型和所述目标老师模型进行联合训练,直至所述学生模型达到预设的第二收敛条件,得到目标学生模型,所述目标学生模型能够用于部署到端侧计算平台。
9.一种计算机设备,其特征在于,包括处理器和存储器,所述存储器用于存储计算机程序,所述计算机程序被所述处理器执行时实现如权利要求1至7任一项所述的模型训练方法。
10.一种计算机可读存储介质,其特征在于,其存储有计算机程序,所述计算机程序被处理器执行时实现如权利要求1至7任一项所述的模型训练方法。
CN202210024447.7A 2022-01-11 2022-01-11 模型训练方法、装置、设备及存储介质 Pending CN114092918A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210024447.7A CN114092918A (zh) 2022-01-11 2022-01-11 模型训练方法、装置、设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210024447.7A CN114092918A (zh) 2022-01-11 2022-01-11 模型训练方法、装置、设备及存储介质

Publications (1)

Publication Number Publication Date
CN114092918A true CN114092918A (zh) 2022-02-25

Family

ID=80308508

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210024447.7A Pending CN114092918A (zh) 2022-01-11 2022-01-11 模型训练方法、装置、设备及存储介质

Country Status (1)

Country Link
CN (1) CN114092918A (zh)

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180268292A1 (en) * 2017-03-17 2018-09-20 Nec Laboratories America, Inc. Learning efficient object detection models with knowledge distillation
US20190287515A1 (en) * 2018-03-16 2019-09-19 Microsoft Technology Licensing, Llc Adversarial Teacher-Student Learning for Unsupervised Domain Adaptation
US20200134506A1 (en) * 2018-10-29 2020-04-30 Fujitsu Limited Model training method, data identification method and data identification device
CN113111968A (zh) * 2021-04-30 2021-07-13 北京大米科技有限公司 图像识别模型训练方法、装置、电子设备和可读存储介质
CN113205002A (zh) * 2021-04-08 2021-08-03 南京邮电大学 非受限视频监控的低清人脸识别方法、装置、设备及介质
CN113281048A (zh) * 2021-06-25 2021-08-20 华中科技大学 一种基于关系型知识蒸馏的滚动轴承故障诊断方法和系统
CN113724242A (zh) * 2021-09-10 2021-11-30 吉林大学 糖尿病视网膜病变和糖尿病性黄斑水肿的联合分级方法

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180268292A1 (en) * 2017-03-17 2018-09-20 Nec Laboratories America, Inc. Learning efficient object detection models with knowledge distillation
US20190287515A1 (en) * 2018-03-16 2019-09-19 Microsoft Technology Licensing, Llc Adversarial Teacher-Student Learning for Unsupervised Domain Adaptation
US20200134506A1 (en) * 2018-10-29 2020-04-30 Fujitsu Limited Model training method, data identification method and data identification device
CN113205002A (zh) * 2021-04-08 2021-08-03 南京邮电大学 非受限视频监控的低清人脸识别方法、装置、设备及介质
CN113111968A (zh) * 2021-04-30 2021-07-13 北京大米科技有限公司 图像识别模型训练方法、装置、电子设备和可读存储介质
CN113281048A (zh) * 2021-06-25 2021-08-20 华中科技大学 一种基于关系型知识蒸馏的滚动轴承故障诊断方法和系统
CN113724242A (zh) * 2021-09-10 2021-11-30 吉林大学 糖尿病视网膜病变和糖尿病性黄斑水肿的联合分级方法

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
王金甲等: "基于平均教师模型的弱标记半监督声音事件检测", 《复旦学报(自然科学版)》 *

Similar Documents

Publication Publication Date Title
US20190279088A1 (en) Training method, apparatus, chip, and system for neural network model
US9830526B1 (en) Generating image features based on robust feature-learning
US10262272B2 (en) Active machine learning
US20160358070A1 (en) Automatic tuning of artificial neural networks
JP2019528502A (ja) パターン認識に適用可能なモデルを最適化するための方法および装置ならびに端末デバイス
CN111523640B (zh) 神经网络模型的训练方法和装置
WO2021089013A1 (zh) 空间图卷积网络的训练方法、电子设备及存储介质
US20210295168A1 (en) Gradient compression for distributed training
KR102250728B1 (ko) 샘플 처리 방법, 장치, 기기 및 저장 매체
CN111406264A (zh) 神经架构搜索
CN116644804B (zh) 分布式训练系统、神经网络模型训练方法、设备和介质
CN116681127B (zh) 一种神经网络模型训练方法、装置及电子设备和存储介质
CN112966754B (zh) 样本筛选方法、样本筛选装置及终端设备
CN111784699B (zh) 一种对三维点云数据进行目标分割方法、装置及终端设备
CN112561050B (zh) 一种神经网络模型训练方法及装置
CN114241411B (zh) 基于目标检测的计数模型处理方法、装置及计算机设备
CN114998679A (zh) 深度学习模型的在线训练方法、装置、设备及存储介质
CN113409307A (zh) 基于异质噪声特性的图像去噪方法、设备及介质
CN115953651B (zh) 一种基于跨域设备的模型训练方法、装置、设备及介质
CN114092918A (zh) 模型训练方法、装置、设备及存储介质
CN116912923A (zh) 一种图像识别模型训练方法和装置
CN113139490B (zh) 一种图像特征匹配方法、装置、计算机设备及存储介质
CN114065913A (zh) 模型量化方法、装置及终端设备
CN109388784A (zh) 最小熵核密度估计器生成方法、装置和计算机可读存储介质
CN110222693B (zh) 构建字符识别模型与识别字符的方法和装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination