CN110097130A - 分类任务模型的训练方法、装置、设备及存储介质 - Google Patents
分类任务模型的训练方法、装置、设备及存储介质 Download PDFInfo
- Publication number
- CN110097130A CN110097130A CN201910377510.3A CN201910377510A CN110097130A CN 110097130 A CN110097130 A CN 110097130A CN 201910377510 A CN201910377510 A CN 201910377510A CN 110097130 A CN110097130 A CN 110097130A
- Authority
- CN
- China
- Prior art keywords
- feature
- training
- task model
- classification task
- generator
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000012549 training Methods 0.000 title claims abstract description 144
- 238000000034 method Methods 0.000 title claims abstract description 52
- 238000003860 storage Methods 0.000 title claims abstract description 24
- 239000013598 vector Substances 0.000 claims description 70
- 230000006870 function Effects 0.000 claims description 20
- 230000008569 process Effects 0.000 claims description 13
- 230000015654 memory Effects 0.000 claims description 9
- 238000004140 cleaning Methods 0.000 claims description 4
- 238000001914 filtration Methods 0.000 claims description 3
- 238000010801 machine learning Methods 0.000 abstract description 6
- 238000010586 diagram Methods 0.000 description 12
- 238000005070 sampling Methods 0.000 description 11
- 238000012545 processing Methods 0.000 description 10
- 238000013461 design Methods 0.000 description 6
- 230000000670 limiting effect Effects 0.000 description 6
- ZOKXTWBITQBERF-UHFFFAOYSA-N Molybdenum Chemical compound [Mo] ZOKXTWBITQBERF-UHFFFAOYSA-N 0.000 description 5
- 238000000605 extraction Methods 0.000 description 5
- 229910052750 molybdenum Inorganic materials 0.000 description 5
- 239000011733 molybdenum Substances 0.000 description 5
- 238000012360 testing method Methods 0.000 description 5
- 238000004590 computer program Methods 0.000 description 4
- 238000005516 engineering process Methods 0.000 description 4
- 230000001575 pathological effect Effects 0.000 description 4
- 230000008901 benefit Effects 0.000 description 3
- 238000004422 calculation algorithm Methods 0.000 description 3
- 238000002059 diagnostic imaging Methods 0.000 description 3
- 230000003902 lesion Effects 0.000 description 3
- 238000003032 molecular docking Methods 0.000 description 3
- 238000004458 analytical method Methods 0.000 description 2
- 238000013473 artificial intelligence Methods 0.000 description 2
- 238000013527 convolutional neural network Methods 0.000 description 2
- 239000000284 extract Substances 0.000 description 2
- 230000002401 inhibitory effect Effects 0.000 description 2
- 230000014759 maintenance of location Effects 0.000 description 2
- 238000013507 mapping Methods 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 241000208340 Araliaceae Species 0.000 description 1
- 235000005035 Panax pseudoginseng ssp. pseudoginseng Nutrition 0.000 description 1
- 235000003140 Panax quinquefolius Nutrition 0.000 description 1
- 230000004913 activation Effects 0.000 description 1
- 230000003321 amplification Effects 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 210000000481 breast Anatomy 0.000 description 1
- 238000004364 calculation method Methods 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 239000012141 concentrate Substances 0.000 description 1
- 238000013480 data collection Methods 0.000 description 1
- 238000013500 data storage Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000004069 differentiation Effects 0.000 description 1
- 238000009826 distribution Methods 0.000 description 1
- 235000013399 edible fruits Nutrition 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 238000000105 evaporative light scattering detection Methods 0.000 description 1
- 235000008434 ginseng Nutrition 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000009434 installation Methods 0.000 description 1
- 239000000203 mixture Substances 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 238000003199 nucleic acid amplification method Methods 0.000 description 1
- 230000002829 reductive effect Effects 0.000 description 1
- 230000001105 regulatory effect Effects 0.000 description 1
- 230000000452 restraining effect Effects 0.000 description 1
- 238000012163 sequencing technique Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
- 238000005406 washing Methods 0.000 description 1
- 238000005303 weighing Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/22—Matching criteria, e.g. proximity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/243—Classification techniques relating to the number of classes
- G06F18/2431—Multiple classes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/74—Image or video pattern matching; Proximity measures in feature spaces
- G06V10/761—Proximity, similarity or dissimilarity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/7715—Feature extraction, e.g. by transforming the feature space, e.g. multi-dimensional scaling [MDS]; Mappings, e.g. subspace methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G16—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR SPECIFIC APPLICATION FIELDS
- G16H—HEALTHCARE INFORMATICS, i.e. INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR THE HANDLING OR PROCESSING OF MEDICAL OR HEALTHCARE DATA
- G16H30/00—ICT specially adapted for the handling or processing of medical images
- G16H30/40—ICT specially adapted for the handling or processing of medical images for processing medical images, e.g. editing
-
- G—PHYSICS
- G16—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR SPECIFIC APPLICATION FIELDS
- G16H—HEALTHCARE INFORMATICS, i.e. INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR THE HANDLING OR PROCESSING OF MEDICAL OR HEALTHCARE DATA
- G16H50/00—ICT specially adapted for medical diagnosis, medical simulation or medical data mining; ICT specially adapted for detecting, monitoring or modelling epidemics or pandemics
- G16H50/20—ICT specially adapted for medical diagnosis, medical simulation or medical data mining; ICT specially adapted for detecting, monitoring or modelling epidemics or pandemics for computer-aided diagnosis, e.g. based on medical expert systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Health & Medical Sciences (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Medical Informatics (AREA)
- Life Sciences & Earth Sciences (AREA)
- General Engineering & Computer Science (AREA)
- Databases & Information Systems (AREA)
- Multimedia (AREA)
- Biomedical Technology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Evolutionary Biology (AREA)
- Mathematical Physics (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Molecular Biology (AREA)
- Public Health (AREA)
- Epidemiology (AREA)
- Primary Health Care (AREA)
- Probability & Statistics with Applications (AREA)
- Nuclear Medicine, Radiotherapy & Molecular Imaging (AREA)
- Radiology & Medical Imaging (AREA)
- Pathology (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
- Image Analysis (AREA)
Abstract
本申请公开了一种分类任务模型的训练方法、装置、设备及存储介质,涉及机器学习技术领域,所述方法包括:采用第一数据集训练初始的特征提取器,该第一数据集是类别不均衡数据集;构建生成对抗网络,该生成对抗网络包括特征提取器、特征生成器和域分类器;采用第二类别样本对生成对抗网络进行训练,得到完成训练的特征生成器;构建分类任务模型,该分类任务模型包括完成训练的特征生成器、特征提取器和分类器;采用第一数据集对分类任务模型进行训练;其中,完成训练的特征生成器用于对第二类别样本在特征空间进行扩增。本申请通过特征生成器对少数类别样本在特征空间进行扩增,提高最终训练得到的分类任务模型的精度。
Description
技术领域
本申请实施例涉及机器学习技术领域,特别涉及一种分类任务模型的训练方法、装置、设备及存储介质。
背景技术
机器学习对于处理分类任务具有较好的性能表现,例如基于深度神经网络构建分类任务模型,并通过适当的训练样本对该模型进行训练,完成训练的分类任务模型即可用于处理分类任务,如图像识别、语音识别等分类任务。
在训练分类任务模型时,训练数据集中包含的训练样本的类别可能并不均衡,例如正样本的数量远少于负样本的数量,这样的训练数据集可以称为类别不均衡数据集。如果采用类别不均衡数据集对分类任务模型进行训练,会导致最终得到的分类任务模型的性能表现不佳。
在相关技术中,提出了通过样本上采样来使得类别不均衡数据集中不同类别的训练样本数量保持均衡。所谓样本上采样,就是以数量多的一方的样本数量为基准,把数量少的一方的样本数量进行扩增,生成和数量多的一方相同数量的样本。例如,当正样本的数量小于负样本的数量时,可以复制一些正样本,使得正样本的数量和负样本的数量相同。
经样本上采样得到的训练样本训练出的分类任务模型,存在过拟合的情况,即该分类任务模型的训练误差远小于其在测试数据集上的误差,无法训练得到高精度的分类任务模型。
发明内容
本申请实施例提供了一种分类任务模型的训练方法、装置、设备及存储介质,可用于解决相关技术提供的样本上采样手段,无法训练得到高精度的分类任务模型的技术问题。所述技术方案如下:
一方面,本申请实施例提供一种分类任务模型的训练方法,所述方法包括:
采用第一数据集训练初始的特征提取器;其中,所述第一数据集是包括第一类别样本和第二类别样本的类别不均衡数据集,所述第一类别样本的数量大于所述第二类别样本的数量;
构建生成对抗网络,所述生成对抗网络包括所述特征提取器、特征生成器和域分类器;其中,所述特征生成器用于生成与所述特征提取器相同维度的特征向量,所述域分类器用于对所述特征提取器输出的特征向量和所述特征生成器输出的特征向量进行区分;
采用所述第二类别样本对所述生成对抗网络进行训练,得到完成训练的特征生成器;
构建分类任务模型,所述分类任务模型包括所述完成训练的特征生成器、所述特征提取器和分类器;
采用所述第一数据集对所述分类任务模型进行训练;其中,所述完成训练的特征生成器用于对所述第二类别样本在特征空间进行扩增。
另一方面,本申请实施例提供一种分类任务模型的训练装置,所述装置包括:
第一训练模块,用于采用第一数据集训练初始的特征提取器;其中,所述第一数据集是包括第一类别样本和第二类别样本的类别不均衡数据集,所述第一类别样本的数量大于所述第二类别样本的数量;
第一构建模块,用于构建生成对抗网络,所述生成对抗网络包括所述特征提取器、特征生成器和域分类器;其中,所述特征生成器用于生成与所述特征提取器相同维度的特征向量,所述域分类器用于对所述特征提取器输出的特征向量和所述特征生成器输出的特征向量进行区分;
第二训练模块,用于采用所述第二类别样本对所述生成对抗网络进行训练,得到完成训练的特征生成器;
第二构建模块,用于构建分类任务模型,所述分类任务模型包括所述完成训练的特征生成器、所述特征提取器和分类器;
第三训练模块,用于采用所述第一数据集对所述分类任务模型进行训练;其中,所述完成训练的特征生成器用于对所述第二类别样本在特征空间进行扩增。
再一方面,本申请实施例提供一种计算机设备,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由所述处理器加载并执行以实现上述分类任务模型的训练方法。
又一方面,本申请实施例提供一种计算机可读存储介质,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由处理器加载并执行以实现上述分类任务模型的训练方法。
又一方面,本申请实施例提供一种计算机程序产品,当该计算机程序产品被执行时,其用于执行上述分类任务模型的训练方法。
本申请实施例提供的技术方案至少包括如下有益效果:
本申请实施例提供的技术方案中,基于生成对抗网络训练得到特征生成器,通过该特征生成器对少数类别样本(即类别不均衡数据集中数量偏少的一类训练样本)在特征空间进行扩增,从特征层面进行扩增,而非采用样本上采样手段对少数类别样本进行简单复制,使得最终训练得到的分类任务模型避免出现过拟合的情况,提高最终训练得到的分类任务模型的精度。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请一个实施例提供的分类任务模型的训练方法的流程图;
图2示例性示出了初始的分类任务模型的结构示意图;
图3示例性示出了生成对抗网络的结构示意图;
图4示例性示出了分类任务模型的结构示意图;
图5示例性示出了本申请技术方案的整体架构图;
图6和图7示例性示出了两组实验结果的示意图;
图8是本申请一个实施例提供的分类任务模型的训练装置的框图;
图9是本申请一个实施例提供的计算机设备的结构示意图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请实施方式作进一步地详细描述。
本申请实施例提供的技术方案中,基于生成对抗网络训练得到特征生成器,通过该特征生成器对少数类别样本(即类别不均衡数据集中数量偏少的一类训练样本)在特征空间进行扩增,从特征层面进行扩增,而非采用样本上采样手段对少数类别样本进行简单复制,使得最终训练得到的分类任务模型避免出现过拟合的情况,提高最终训练得到的分类任务模型的精度。
本申请实施例中涉及的分类任务模型,是指通过机器学习训练得到的、用于处理分类任务的机器学习模型。该分类任务模型可以是深度学习分类任务模型,即基于深度神经网络构建的分类任务模型,如基于深度卷积神经网络构建的分类任务模型。该分类任务模型可用于处理图像识别、语音识别等分类任务,本申请实施例对该分类任务模型的具体应用场景不作限定。
本申请实施例提供的方法,各步骤的执行主体可以是计算机设备,该计算机设备是指具备数据计算、处理和存储能力的电子设备,如PC(Personal Computer,个人计算机)或服务器。
请参考图1,其示出了本申请一个实施例提供的分类任务模型的训练方法的流程图。该方法可以包括以下几个步骤(101~105):
步骤101,采用第一数据集训练初始的特征提取器,第一数据集是包括第一类别样本和第二类别样本的类别不均衡数据集,第一类别样本的数量大于第二类别样本的数量。
第一类别样本和第二类别样本是第一数据集中两种不同类别的样本。例如,第一类别样本为正样本,第二类别样本为负样本;或者,第一类别样本为负样本,第二类别样本为正样本。第一类别样本的数量大于第二类别样本的数量,即第一类别样本可以称为多数类别样本,第二类别样本可以称为少数类别样本。在大多数场景下,负样本的数量大于甚至远大于正样本的数量,因此,第一类别样本可以是负样本,相应地第二类别样本则为正样本。
特征提取器是分类任务模型中用于提取样本特征的部分,特征提取器也称为编码器(encoder)。分类任务模型包括特征提取器和分类器,特征提取器的输出端和分类器的输入端对接,特征提取器从模型的输入样本中提取特征向量,分类器用于根据该特征向量确定输入样本所属的类别。以分类任务模型用于图像识别为例,特征提取器用于对输入图像进行映射编码,输出维度远低于输入图像像素的特征向量,特征提取器获得了一种非线性的、局部到全局的特征映射,融合了低层的视觉特征和高层的语义信息。
在示例性实施例中,分类任务模型基于深度卷积神经网络构建,特征提取器可以包括多个卷积层。例如,分类任务模型为Inception-v3模型,Inception-v3模型是一种深度神经网络模型,其对图像分类任务具有较好的性能表现。另外,Inception-v3模型的另一优点是可以将预训练好的Inception-v3模型作为初始化的分类任务模型来使用,而不必对分类任务模型中的参数进行随机初始化,这有助于提高模型的训练效率。分类器可以采用Softmax分类器或其它分类器,本申请实施例对此不作限定。
在示例性实施例中,步骤101包括如下几个子步骤:
1、构建初始的分类任务模型,初始的分类任务模型包括初始的特征提取器和初始的分类器;
如上文介绍,初始的分类任务模型可以是预训练好的Inception-v3模型。
2、采用第一数据集对初始的分类任务模型进行训练,得到初始训练后的特征提取器。
第一数据集中包括第一类别样本和第二类别样本,每一个训练样本根据其所属类别设定有相应的标签。例如,第一类别样本的标签为1,第二类别样本的标签为0;或者,第一类别样本的标签为0,第二类别样本的标签为1。将第一数据集中的训练样本(包括第一类别样本和第二类别样本)输入至初始的分类任务模型,将模型输出的分类结果和标签进行比对,计算该模型对应的损失函数值;然后,根据损失函数值使用反向传播算法计算模型中各个参数的梯度;最后,使用梯度更新模型中的各个参数,更新的步调由学习率控制。其中,损失函数可以采用交叉熵(Cross Entropy,CE)损失函数。
在初始的分类任务模型满足停止训练条件时,停止对该模型的训练,得到初始训练后的分类任务模型。该初始训练后的分类任务模型中包含初始训练后的特征提取器,该初始训练后的特征提取器被用于下述的生成对抗网络中。其中,初始的分类任务模型的停止训练条件可以预先进行设定,如模型精度达到预设要求、训练轮数达到预设轮数或训练时长达到预设时长等,本申请实施例对此不作限定。
如图2所示,其示例性示出了初始的分类任务模型的结构示意图。该初始的分类任务模型包括特征提取器EI和分类器CI,特征提取器EI的输入端即为模型的输入端,特征提取器EI的输出端和分类器CI的输入端对接,分类器CI的输出端即为模型的输出端。采用第一数据集(包括多数类别样本和少数类别样本)对该初始的分类任务模型进行训练,得到初始训练后的分类任务模型。该初始训练后的分类任务模型包括初始训练后的特征提取器EI和初始训练后的分类器CI。
步骤102,构建生成对抗网络,该生成对抗网络包括特征提取器、特征生成器和域分类器。
在生成对抗网络中,特征生成器的输出端和特征提取器的输出端,分别和域分类器的输入端对接。
特征提取器即为上述步骤101得到的初始训练后的特征提取器。
特征生成器用于生成与特征提取器相同维度的特征向量。例如,特征提取器输出的特征向量的维度为20,则特征生成器生成的特征向量的维度也为20。特征生成器也可以采用多个卷积层构建,如包括6个卷积层,前5个卷积层的卷积核尺寸为3*3,最后一个卷积层的卷积核尺寸为1*1,对应的每个卷积层的输出特征图数量分别为64、128、256、512、1024和2048,每个卷积层后都可以跟随一个批量归一化(batch norm)层和一个激活函数ReLU层。
域分类器用于对特征提取器输出的特征向量和特征生成器输出的特征向量进行区分。域分类器利用对抗学习来调整特征生成器,使其输出的特征向量尽可能地接近特征提取器输出的特征向量,通过这样一个对抗学习的过程找到最大-最小化博弈均衡的模型参数。
步骤103,采用第二类别样本对生成对抗网络进行训练,得到完成训练的特征生成器。
在对生成对抗网络进行训练的过程中,特征提取器的参数固定,也即不对特征提取器的参数进行更新。特征提取器的输入是第二类别样本,也即少数类别样本,输出是从上述第二类别样本中提取到的特征向量。
特征生成器的输入包括先验数据与噪声数据的叠加,输出是与特征提取器同维度的特征向量。先验数据可以从第一数据集的第二类别样本中提取,也可以从第二数据集中与第二类别样本同类别的样本中提取。其中,第二数据集可以是同类任务中不同于第一数据集的另一数据集。噪声数据可以是随机噪声数据。以先验数据为64*64的图像为例,噪声数据也可以是64*64的图像,但噪声数据的图像中各个像素的像素值是随机生成的。将先验数据与噪声数据叠加,即为将先验数据与噪声数据中相同位置像素的像素值进行加权求和,最终得到一张叠加后的图像。特征生成器从该叠加后的图像中提取得到特征向量。另外,考虑到特征生成器的网络层数可能较少,因此其输入不能过大,所以先验数据可以是对样本图像进行缩小后得到的小尺寸的样本图像,如64*64的样本图像。在本申请实施例中,特征生成器的输入并非完全是噪声数据,完全从噪声数据中生成与真实样本类似的特征向量的话,缺乏有效的约束,特征生成器的输入是先验数据与噪声数据的叠加,这样可以抑制生成对抗网络训练过程中不收敛和容易崩溃的问题,增加生成对抗网络的鲁棒性。
在示例性实施例中,步骤103包括如下几个子步骤:
1、在生成对抗网络的每一轮训练过程中,为特征提取器的输入赋予第一标签,为特征生成器的输入赋予第二标签;
2、计算域分类器的第一损失函数值;
3、根据第一损失函数值对域分类器的参数进行更新;
4、屏蔽特征提取器的输入,为特征生成器的输入赋予第一标签;
5、计算域分类器的第二损失函数值;
6、根据第二损失函数值对特征生成器的参数进行更新。
在生成对抗网络的训练过程中,特征生成器和域分类器互相进行对抗,即在每一轮训练过程中进行两次反向传播计算,第一次固定特征生成器的参数,更新域分类器的参数,第二次固定域分类器的参数,更新特征生成器的参数。上述第一标签和第二标签是两个不同的标签,例如第一标签为1且第二标签为0,或第一标签为0且第二标签为1。
在一个示例中,首先,为特征提取器的输入赋予标签1,为特征生成器的输入赋予标签0,计算域分类器的第一损失函数值,根据该第一损失函数值反向传播调整域分类器的参数;然后,屏蔽特征提取器的输入,为特征生成器的输入赋予标签1,计算域分类器的第二损失函数值,根据该第二损失函数值反向传播调整特征生成器的参数。
如图3所示,其示例性示出了生成对抗网络的结构示意图。该生成对抗网络包括特征提取器EI、特征生成器G和域分类器D。特征提取器EI的输出端和特征生成器G的输出端分别与域分类器D的输入端对接。特征生成器G的输入为先验数据和噪声数据的叠加,特征提取器EI的输入为第一数据集中的少数类别样本。完成训练的特征生成器G被用于下面的分类任务模型中。
步骤104,构建分类任务模型,该分类任务模型包括完成训练的特征生成器、特征提取器和分类器。
在分类任务模型中,完成训练的特征生成器的输出端和特征提取器的输出端,分别和分类器的输入端对接。
完成训练的特征生成器即为上述步骤103中利用生成对抗网络训练得到的特征生成器。本步骤中的特征提取器和分类器采用与步骤101中初始的分类任务模型相同的结构和配置。可选地,本步骤中的特征提取器采用步骤101中训练得到的特征提取器的参数进行初始化。
步骤105,采用第一数据集对分类任务模型进行训练;其中,完成训练的特征生成器用于对第二类别样本在特征空间进行扩增。
在对分类任务模型进行训练的过程中,配合原有的类别不均衡的第一数据集,利用生成对抗网络训练得到的特征生成器对少数类别样本在特征空间进行扩增,将类别不均衡的学习任务转化为类别均衡的学习任务,重新训练得到分类任务模型。
在示例性实施例中,分类任务模型还包括数据清洗单元,该数据清洗单元用于对特征生成器和特征提取器输出的异常特征向量进行过滤。数据清洗单元可以是一个通过软件、硬件或者软硬件结合实现的功能单元,通过采用合适的数据清洗技术(如Tomek Link算法)来抑制特征生成器生成的一些异常特征向量,从而进一步提高最终训练得到的分类任务模型的精度。
在示例性实施例中,数据清洗单元从特征生成器和特征提取器输出的特征向量中,筛选出符合预设条件的特征向量对,该符合预设条件的特征向量对是指标签不同且相似度最大的两个特征向量,然后将上述符合预设条件的特征向量对作为异常特征向量进行过滤。两个特征向量之间的相似度可以通过欧式距离算法或其它相似度算法进行计算得到,本申请实施例对此不作限定。示例性地,对于特征生成器和特征提取器输出的所有特征向量,遍历该所有特征向量,对于每一个特征向量,找到与该特征向量最相似的另一特征向量,比对这两个特征向量的标签是否相同,如果这两个特征向量的标签不相同,如一个特征向量的标签为1且另一个特征向量的标签为0,则这两个特征向量即为符合预设条件的特征向量对,将这两个特征向量作为异常特征向量进行过滤。
如图4所示,其示例性示出了分类任务模型的结构示意图。该分类任务模型包括完成训练的特征生成器G、特征提取器EF、分类器CF以及数据清洗单元。特征生成器G的输出端和特征提取器EF的输出端,分别与数据清洗单元的输入端对接,数据清洗单元的输出端与分类器CF的输入端对接。特征提取器EF与图2所示的分类任务模型中的特征提取器EI具有相同的结构和配置,分类器CF与图2所示的分类任务模型中的分类器C1具有相同的结构和配置。采用第一数据集(包括多数类别样本和少数类别样本)对该分类任务模型进行训练,当满足预设的停止训练条件时,停止对该分类任务模型的训练,得到完成训练的分类任务模型。其中,预设的停止训练条件可以是模型精度达到预设要求、训练轮数达到预设轮数或训练时长达到预设时长等,本申请实施例对此不作限定。
综上所述,本申请实施例提供的技术方案中,基于生成对抗网络训练得到特征生成器,通过该特征生成器对少数类别样本(即类别不均衡数据集中数量偏少的一类训练样本)在特征空间进行扩增,从特征层面进行扩增,而非采用样本上采样手段对少数类别样本进行简单复制,使得最终训练得到的分类任务模型避免出现过拟合的情况,提高最终训练得到的分类任务模型的精度。
另外,本申请实施例提供的技术方案中,在训练分类任务模型的过程中,还通过数据清洗单元对特征生成器和特征提取器输出的异常特征向量进行过滤,实现抑制特征生成器生成的一些异常特征向量,从而进一步提高最终训练得到的分类任务模型的精度。
另外,在本申请实施例中,特征生成器的输入并非完全是噪声数据,完全从噪声数据中生成与真实样本类似的特征向量的话,缺乏有效的约束,特征生成器的输入是先验数据与噪声数据的叠加,这样可以抑制生成对抗网络训练过程中不收敛和容易崩溃的问题,增加生成对抗网络的鲁棒性。
下面,结合图5,对本申请实施例提供的技术方案进行整体说明。本申请实施例提供的分类任务模型的训练过程可以包括如下3个步骤:
第一步:训练初始的特征提取器;
在本步骤中,构建初始的分类任务模型,包括特征提取器EI和分类器CI,采用类别不均衡数据集对该初始的分类任务模型进行训练,得到初始训练后的特征提取器EI。
第二步:训练特征生成器;
在本步骤中,构建生成对抗网络,包括初始训练后的特征提取器EI、特征生成器G和域分类器D,在训练过程中,固定初始训练后的特征提取器EI的参数不变,利用生成对抗网络训练特征生成器G。
第三步:训练最终的分类任务模型。
在本步骤中,构建分类任务模型,包括特征生成器G、特征提取器EF、数据清洗单元和分类器EF,在训练过程中,固定特征生成器G的参数不变,配合原有的类别不均衡数据集,利用特征生成器G对少数类别样本在特征空间进行扩增,将类别不均衡的学习任务转化为类别均衡的学习任务,训练得到最终的分类任务模型。
本申请实施例提供的技术方案,可应用于AI(Artificial Intelligence,人工智能)领域的机器学习分类任务的模型训练过程中,特别适用于训练数据集为类别不均衡数据集的分类任务模型的训练过程中。以对类别不均衡的医疗影像的分类任务为例,训练数据集可以包括多张从医疗影像中提取的子图,这些子图有的是正样本(也即病灶区域的图像),有的是负样本(也即非病灶区域的图像),负样本的数量往往远大于正样本的数量。在这种应用场景下,分类任务模型可以称为影像学病灶判别模型,其输入是一张从医疗影像中提取的子图,输出是该子图是否为病灶区域的判别结果。通过生成对抗网络训练得到特征生成器,利用该特征生成器对少数类别样本在特征空间进行扩增,最终训练出更准确的影像学病灶判别模型,辅助医生做出病灶诊断分析,例如乳腺钼靶图像中的肿块检测分析。
本方案分别在一个包含2194张钼靶影像的数据集和一个camelyon2016病理图像数据集上测试,对图像进行ROI(region of interest,感兴趣区域)提取得到子图集合,分别使用了1:10和1:20的类别不均衡比例。测试的结果如下表-1和表-2所示。
表-1
表-2
上述表-1是在钼靶影像的数据集上的测试结果,表-2是在camelyon2016病理图像数据集上的测试结果。
在上述表-1和表-2中,方案1代表不对数据集做任何处理,方案2代表对数据集进行样本下采样处理,方案3代表对数据集进行样本上采样处理,方案4代表对数据集从样本空间进行扩增,方案5代表采用本申请技术方案对数据集从特征空间进行扩增,且不包含数据清洗步骤,方案6代表采用本申请技术方案对数据集从特征空间进行扩增,且包含数据清洗步骤。
在上述表-1和表-2中,Acc和AUC均为模型评价参数。其中,Acc(Accuracy)代表最终训练得到的分类任务模型的准确率,Acc越大,代表模型的性能越优,Acc越小,代表模型的性能越差。AUC(Area under the ROC curve)表示ROC(receiver operatingcharacteristic curve,受试者工作特征曲线)曲线下的面积,AUC直观反映了ROC曲线表达的分类能力,AUC越大,代表模型的性能越优,AUC越小,代表模型的性能越差。
图6中(a)部分示出了上述6种方案在钼靶影像的数据集、1:10的类别不均衡比例下的ROC曲线及相应的AUC值。图6中(b)部分示出了上述6种方案在钼靶影像的数据集、1:20的类别不均衡比例下的ROC曲线及相应的AUC值。
图7中(a)部分示出了上述6种方案在camelyon2016病理图像数据集、1:10的类别不均衡比例下的ROC曲线及相应的AUC值。图7中(b)部分示出了上述6种方案在camelyon2016病理图像数据集、1:20的类别不均衡比例下的ROC曲线及相应的AUC值。
从上述测试结果的图表中可以看出,本申请技术方案大多优于样本上采样、样本下采样、样本空间扩增技术等其它方案,且增加数据清洗步骤后的方案能够进一步提升最终训练得到的分类任务模型的性能。
下述为本申请装置实施例,可以用于执行本申请方法实施例。对于本申请装置实施例中未披露的细节,请参照本申请方法实施例。
请参考图8,其示出了本申请一个实施例提供的分类任务模型的训练装置的框图。该装置具有实现上述方法示例的功能,所述功能可以由硬件实现,也可以由硬件执行相应的软件实现。该装置可以是计算机设备,也可以设置在计算机设备中。该装置800可以包括:第一训练模块810、第一构建模块820、第二训练模块830、第二构建模块840和第三训练模块850。
第一训练模块810,用于采用第一数据集训练初始的特征提取器;其中,所述第一数据集是包括第一类别样本和第二类别样本的类别不均衡数据集,所述第一类别样本的数量大于所述第二类别样本的数量。
第一构建模块820,用于构建生成对抗网络,所述生成对抗网络包括所述特征提取器、特征生成器和域分类器;其中,所述特征生成器用于生成与所述特征提取器相同维度的特征向量,所述域分类器用于对所述特征提取器输出的特征向量和所述特征生成器输出的特征向量进行区分。
第二训练模块830,用于采用所述第二类别样本对所述生成对抗网络进行训练,得到完成训练的特征生成器。
第二构建模块840,用于构建分类任务模型,所述分类任务模型包括所述完成训练的特征生成器、所述特征提取器和分类器。
第三训练模块850,用于采用所述第一数据集对所述分类任务模型进行训练;其中,所述完成训练的特征生成器用于对所述第二类别样本在特征空间进行扩增。
综上所述,本申请实施例提供的技术方案中,基于生成对抗网络训练得到特征生成器,通过该特征生成器对少数类别样本(即类别不均衡数据集中数量偏少的一类训练样本)在特征空间进行扩增,从特征层面进行扩增,而非采用样本上采样手段对少数类别样本进行简单复制,使得最终训练得到的分类任务模型避免出现过拟合的情况,提高最终训练得到的分类任务模型的精度。
在一些可能的设计中,所述第二训练模块830,用于:在所述生成对抗网络的每一轮训练过程中,为所述特征提取器的输入赋予第一标签,为所述特征生成器的输入赋予第二标签;计算所述域分类器的第一损失函数值;根据所述第一损失函数值对所述域分类器的参数进行更新;屏蔽所述特征提取器的输入,为所述特征生成器的输入赋予所述第一标签;计算所述域分类器的第二损失函数值;根据所述第二损失函数值对所述特征生成器的参数进行更新。
在一些可能的设计中,所述特征生成器的输入包括先验数据与噪声数据的叠加;其中,所述先验数据从所述第一数据集的所述第二类别样本中提取,或者,所述先验数据从第二数据集中与所述第二类别样本同类别的样本中提取。
在一些可能的设计中,所述分类任务模型还包括:数据清洗单元;所述数据清洗单元,用于对所述特征生成器和所述特征提取器输出的异常特征向量进行过滤。
在一些可能的设计中,所述数据清洗单元,用于:从所述特征生成器和所述特征提取器输出的特征向量中,筛选出符合预设条件的特征向量对,所述符合预设条件的特征向量对是指标签不同且相似度最大的两个特征向量;将所述符合预设条件的特征向量对作为所述异常特征向量进行过滤。
在一些可能的设计中,所述第一训练模块810,用于:构建初始的分类任务模型,所述初始的分类任务模型包括所述初始的特征提取器和初始的分类器;采用所述第一数据集对所述初始的分类任务模型进行训练,得到初始训练后的特征提取器,所述初始训练后的特征提取器被用于所述生成对抗网络中。
需要说明的是,上述实施例提供的装置,在实现其功能时,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将设备的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。另外,上述实施例提供的装置与方法实施例属于同一构思,其具体实现过程详见方法实施例,这里不再赘述。
请参考图9,其示出了本申请一个实施例提供的计算机设备的结构示意图。该计算机设备可以是任何具备数据处理和存储功能的电子设备,如PC(PersonalComputer,个人计算机)或服务器。该计算机设备用于实施上述实施例中提供的分类任务模型的训练方法。具体来讲:
所述计算机设备900包括中央处理单元(CPU)901、包括随机存取存储器(RAM)902和只读存储器(ROM)903的系统存储器904,以及连接系统存储器904和中央处理单元901的系统总线905。所述计算机设备900还包括帮助计算机内的各个器件之间传输信息的基本输入/输出系统(I/O系统)906,和用于存储操作系统913、应用程序914和其他程序模块915的大容量存储设备907。
所述基本输入/输出系统906包括有用于显示信息的显示器908和用于用户输入信息的诸如鼠标、键盘之类的输入设备909。其中所述显示器908和输入设备909都通过连接到系统总线905的输入输出控制器910连接到中央处理单元901。所述基本输入/输出系统906还可以包括输入输出控制器910以用于接收和处理来自键盘、鼠标、或电子触控笔等多个其他设备的输入。类似地,输入输出控制器910还提供输出到显示屏、打印机或其他类型的输出设备。
所述大容量存储设备907通过连接到系统总线905的大容量存储控制器(未示出)连接到中央处理单元901。所述大容量存储设备907及其相关联的计算机可读介质为计算机设备900提供非易失性存储。也就是说,所述大容量存储设备907可以包括诸如硬盘或者CD-ROM驱动器之类的计算机可读介质(未示出)。
不失一般性,所述计算机可读介质可以包括计算机存储介质和通信介质。计算机存储介质包括以用于存储诸如计算机可读指令、数据结构、程序模块或其他数据等信息的任何方法或技术实现的易失性和非易失性、可移动和不可移动介质。计算机存储介质包括RAM、ROM、EPROM、EEPROM、闪存或其他固态存储其技术,CD-ROM、DVD或其他光学存储、磁带盒、磁带、磁盘存储或其他磁性存储设备。当然,本领域技术人员可知所述计算机存储介质不局限于上述几种。上述的系统存储器904和大容量存储设备907可以统称为存储器。
根据本申请的各种实施例,所述计算机设备900还可以通过诸如因特网等网络连接到网络上的远程计算机运行。也即计算机设备900可以通过连接在所述系统总线905上的网络接口单元911连接到网络912,或者说,也可以使用网络接口单元911来连接到其他类型的网络或远程计算机系统(未示出)。
所述存储器中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、至少一段程序、代码集或指令集经配置以由一个或者一个以上处理器执行,以实现上述实施例提供的分类任务模型的训练方法。
在示例性实施例中,还提供了一种计算机可读存储介质,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或所述指令集在被计算机设备的处理器执行时实现上述实施例提供的分类任务模型的训练方法。在示例性实施例中,上述计算机可读存储介质可以是ROM、RAM、CD-ROM、磁带、软盘和光数据存储设备等。
在示例性实施例中,还提供了一种计算机程序产品,当该计算机程序产品被执行时,其用于实现上述实施例提供的分类任务模型的训练方法。
应当理解的是,在本文中提及的“多个”是指两个或两个以上。“和/或”,描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。字符“/”一般表示前后关联对象是一种“或”的关系。另外,本文中描述的步骤编号,仅示例性示出了步骤间的一种可能的执行先后顺序,在一些其它实施例中,上述步骤也可以不按照编号顺序来执行,如两个不同编号的步骤同时执行,或者两个不同编号的步骤按照与图示相反的顺序执行,本申请实施例对此不作限定。
以上所述仅为本申请的示例性实施例,并不用以限制本申请,凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。
Claims (14)
1.一种分类任务模型的训练方法,其特征在于,所述方法包括:
采用第一数据集训练初始的特征提取器;其中,所述第一数据集是包括第一类别样本和第二类别样本的类别不均衡数据集,所述第一类别样本的数量大于所述第二类别样本的数量;
构建生成对抗网络,所述生成对抗网络包括所述特征提取器、特征生成器和域分类器;其中,所述特征生成器用于生成与所述特征提取器相同维度的特征向量,所述域分类器用于对所述特征提取器输出的特征向量和所述特征生成器输出的特征向量进行区分;
采用所述第二类别样本对所述生成对抗网络进行训练,得到完成训练的特征生成器;
构建分类任务模型,所述分类任务模型包括所述完成训练的特征生成器、所述特征提取器和分类器;
采用所述第一数据集对所述分类任务模型进行训练;其中,所述完成训练的特征生成器用于对所述第二类别样本在特征空间进行扩增。
2.根据权利要求1所述的方法,其特征在于,所述采用所述第二类别样本对所述生成对抗网络进行训练,得到完成训练的特征生成器,包括:
在所述生成对抗网络的每一轮训练过程中,为所述特征提取器的输入赋予第一标签,为所述特征生成器的输入赋予第二标签;
计算所述域分类器的第一损失函数值;
根据所述第一损失函数值对所述域分类器的参数进行更新;
屏蔽所述特征提取器的输入,为所述特征生成器的输入赋予所述第一标签;
计算所述域分类器的第二损失函数值;
根据所述第二损失函数值对所述特征生成器的参数进行更新。
3.根据权利要求1所述的方法,其特征在于,所述特征生成器的输入包括先验数据与噪声数据的叠加;
其中,所述先验数据从所述第一数据集的所述第二类别样本中提取,或者,所述先验数据从第二数据集中与所述第二类别样本同类别的样本中提取。
4.根据权利要求1至3任一项所述的方法,其特征在于,所述分类任务模型还包括:数据清洗单元;
所述数据清洗单元,用于对所述特征生成器和所述特征提取器输出的异常特征向量进行过滤。
5.根据权利要求4所述的方法,其特征在于,所述数据清洗单元,用于:
从所述特征生成器和所述特征提取器输出的特征向量中,筛选出符合预设条件的特征向量对,所述符合预设条件的特征向量对是指标签不同且相似度最大的两个特征向量;
将所述符合预设条件的特征向量对作为所述异常特征向量进行过滤。
6.根据权利要求1至3任一项所述的方法,其特征在于,所述采用第一数据集训练初始的特征提取器,包括:
构建初始的分类任务模型,所述初始的分类任务模型包括所述初始的特征提取器和初始的分类器;
采用所述第一数据集对所述初始的分类任务模型进行训练,得到初始训练后的特征提取器,所述初始训练后的特征提取器被用于所述生成对抗网络中。
7.一种分类任务模型的训练装置,其特征在于,所述装置包括:
第一训练模块,用于采用第一数据集训练初始的特征提取器;其中,所述第一数据集是包括第一类别样本和第二类别样本的类别不均衡数据集,所述第一类别样本的数量大于所述第二类别样本的数量;
第一构建模块,用于构建生成对抗网络,所述生成对抗网络包括所述特征提取器、特征生成器和域分类器;其中,所述特征生成器用于生成与所述特征提取器相同维度的特征向量,所述域分类器用于对所述特征提取器输出的特征向量和所述特征生成器输出的特征向量进行区分;
第二训练模块,用于采用所述第二类别样本对所述生成对抗网络进行训练,得到完成训练的特征生成器;
第二构建模块,用于构建分类任务模型,所述分类任务模型包括所述完成训练的特征生成器、所述特征提取器和分类器;
第三训练模块,用于采用所述第一数据集对所述分类任务模型进行训练;其中,所述完成训练的特征生成器用于对所述第二类别样本在特征空间进行扩增。
8.根据权利要求7所述的装置,其特征在于,所述第二训练模块,用于:
在所述生成对抗网络的每一轮训练过程中,为所述特征提取器的输入赋予第一标签,为所述特征生成器的输入赋予第二标签;
计算所述域分类器的第一损失函数值;
根据所述第一损失函数值对所述域分类器的参数进行更新;
屏蔽所述特征提取器的输入,为所述特征生成器的输入赋予所述第一标签;
计算所述域分类器的第二损失函数值;
根据所述第二损失函数值对所述特征生成器的参数进行更新。
9.根据权利要求7所述的装置,其特征在于,所述特征生成器的输入包括先验数据与噪声数据的叠加;
其中,所述先验数据从所述第一数据集的所述第二类别样本中提取,或者,所述先验数据从第二数据集中与所述第二类别样本同类别的样本中提取。
10.根据权利要求7至9任一项所述的装置,其特征在于,所述分类任务模型还包括:数据清洗单元;
所述数据清洗单元,用于对所述特征生成器和所述特征提取器输出的异常特征向量进行过滤。
11.根据权利要求10所述的装置,其特征在于,所述数据清洗单元,用于:
从所述特征生成器和所述特征提取器输出的特征向量中,筛选出符合预设条件的特征向量对,所述符合预设条件的特征向量对是指标签不同且相似度最大的两个特征向量;
将所述符合预设条件的特征向量对作为所述异常特征向量进行过滤。
12.根据权利要求7至9任一项所述的装置,其特征在于,所述第一训练模块,用于:
构建初始的分类任务模型,所述初始的分类任务模型包括所述初始的特征提取器和初始的分类器;
采用所述第一数据集对所述初始的分类任务模型进行训练,得到初始训练后的特征提取器,所述初始训练后的特征提取器被用于所述生成对抗网络中。
13.一种计算机设备,其特征在于,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由所述处理器加载并执行以实现如权利要求1至6任一项所述的方法。
14.一种计算机可读存储介质,其特征在于,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由处理器加载并执行以实现如权利要求1至6任一项所述的方法。
Priority Applications (4)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910377510.3A CN110097130B (zh) | 2019-05-07 | 2019-05-07 | 分类任务模型的训练方法、装置、设备及存储介质 |
PCT/CN2020/085006 WO2020224403A1 (zh) | 2019-05-07 | 2020-04-16 | 分类任务模型的训练方法、装置、设备及存储介质 |
EP20802264.0A EP3968222B1 (en) | 2019-05-07 | 2020-04-16 | Classification task model training method, apparatus and device and storage medium |
US17/355,310 US20210319258A1 (en) | 2019-05-07 | 2021-06-23 | Method and apparatus for training classification task model, device, and storage medium |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910377510.3A CN110097130B (zh) | 2019-05-07 | 2019-05-07 | 分类任务模型的训练方法、装置、设备及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN110097130A true CN110097130A (zh) | 2019-08-06 |
CN110097130B CN110097130B (zh) | 2022-12-13 |
Family
ID=67447198
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201910377510.3A Active CN110097130B (zh) | 2019-05-07 | 2019-05-07 | 分类任务模型的训练方法、装置、设备及存储介质 |
Country Status (4)
Country | Link |
---|---|
US (1) | US20210319258A1 (zh) |
EP (1) | EP3968222B1 (zh) |
CN (1) | CN110097130B (zh) |
WO (1) | WO2020224403A1 (zh) |
Cited By (18)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110570492A (zh) * | 2019-09-11 | 2019-12-13 | 清华大学 | 神经网络训练方法和设备、图像处理方法和设备以及介质 |
CN110706738A (zh) * | 2019-10-30 | 2020-01-17 | 腾讯科技(深圳)有限公司 | 蛋白质的结构信息预测方法、装置、设备及存储介质 |
CN110732139A (zh) * | 2019-10-25 | 2020-01-31 | 腾讯科技(深圳)有限公司 | 检测模型的训练方法和用户数据的检测方法、装置 |
CN110888911A (zh) * | 2019-10-11 | 2020-03-17 | 平安科技(深圳)有限公司 | 样本数据处理方法、装置、计算机设备及存储介质 |
CN111126503A (zh) * | 2019-12-27 | 2020-05-08 | 北京同邦卓益科技有限公司 | 一种训练样本的生成方法和装置 |
CN111241969A (zh) * | 2020-01-06 | 2020-06-05 | 北京三快在线科技有限公司 | 目标检测方法、装置及相应模型训练方法、装置 |
CN111291841A (zh) * | 2020-05-13 | 2020-06-16 | 腾讯科技(深圳)有限公司 | 图像识别模型训练方法、装置、计算机设备和存储介质 |
CN111444967A (zh) * | 2020-03-30 | 2020-07-24 | 腾讯科技(深圳)有限公司 | 生成对抗网络的训练方法、生成方法、装置、设备及介质 |
CN111582647A (zh) * | 2020-04-09 | 2020-08-25 | 上海淇毓信息科技有限公司 | 用户数据处理方法、装置及电子设备 |
CN111832404A (zh) * | 2020-06-04 | 2020-10-27 | 中国科学院空天信息创新研究院 | 一种基于特征生成网络的小样本遥感地物分类方法及系统 |
WO2020224403A1 (zh) * | 2019-05-07 | 2020-11-12 | 腾讯科技(深圳)有限公司 | 分类任务模型的训练方法、装置、设备及存储介质 |
WO2021082786A1 (zh) * | 2019-10-30 | 2021-05-06 | 腾讯科技(深圳)有限公司 | 语义理解模型的训练方法、装置、电子设备及存储介质 |
CN113610191A (zh) * | 2021-09-07 | 2021-11-05 | 中原动力智能机器人有限公司 | 垃圾分类模型建模方法、垃圾分类方法及装置 |
CN113723519A (zh) * | 2021-08-31 | 2021-11-30 | 平安科技(深圳)有限公司 | 基于对比学习的心电数据处理方法、装置及存储介质 |
CN113869398A (zh) * | 2021-09-26 | 2021-12-31 | 平安科技(深圳)有限公司 | 一种不平衡文本分类方法、装置、设备及存储介质 |
WO2022042123A1 (zh) * | 2020-08-25 | 2022-03-03 | 深圳思谋信息科技有限公司 | 图像识别模型生成方法、装置、计算机设备和存储介质 |
WO2022135450A1 (zh) * | 2020-12-24 | 2022-06-30 | 华为技术有限公司 | 信息生成方法及相关装置 |
US11983363B1 (en) * | 2023-02-09 | 2024-05-14 | Primax Electronics Ltd. | User gesture behavior simulation system and user gesture behavior simulation method applied thereto |
Families Citing this family (14)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110598840B (zh) * | 2018-06-13 | 2023-04-18 | 富士通株式会社 | 知识迁移方法、信息处理设备以及存储介质 |
CN112466436B (zh) * | 2020-11-25 | 2024-02-23 | 北京小白世纪网络科技有限公司 | 基于循环神经网络的智能中医开方模型训练方法及装置 |
CN112463972B (zh) * | 2021-01-28 | 2021-05-18 | 成都数联铭品科技有限公司 | 一种基于类别不均衡的文本样本分类方法 |
CN112905325B (zh) * | 2021-02-10 | 2023-01-10 | 山东英信计算机技术有限公司 | 一种分布式数据缓存加速训练的方法、系统及介质 |
CN113642621B (zh) * | 2021-08-03 | 2024-06-28 | 南京邮电大学 | 基于生成对抗网络的零样本图像分类方法 |
CN114186617B (zh) * | 2021-11-23 | 2022-08-30 | 浙江大学 | 一种基于分布式深度学习的机械故障诊断方法 |
CN113902131B (zh) * | 2021-12-06 | 2022-03-08 | 中国科学院自动化研究所 | 抵抗联邦学习中歧视传播的节点模型的更新方法 |
CN114360008B (zh) * | 2021-12-23 | 2023-06-20 | 上海清鹤科技股份有限公司 | 人脸认证模型的生成方法、认证方法、设备及存储介质 |
CN114358282B (zh) * | 2022-01-05 | 2024-10-29 | 深圳大学 | 深度网络对抗鲁棒性提升模型、构建方法、设备、介质 |
CN114545255B (zh) * | 2022-01-18 | 2022-08-26 | 广东工业大学 | 基于竞争型生成式对抗神经网络的锂电池soc估计方法 |
CN114493808A (zh) * | 2022-01-28 | 2022-05-13 | 中山大学 | 联邦学习中基于反向拍卖的隐私保护激励机制训练方法 |
CN116934385B (zh) * | 2023-09-15 | 2024-01-19 | 山东理工昊明新能源有限公司 | 用户流失预测模型的构建方法、用户流失预测方法及装置 |
CN118262181B (zh) * | 2024-05-29 | 2024-08-13 | 山东鲁能控制工程有限公司 | 一种基于大数据的自动化数据处理系统 |
CN118447342B (zh) * | 2024-07-08 | 2024-10-22 | 杭州心智医联科技有限公司 | 一种基于层级信息传播的小样本图像分类方法及系统 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN106650721A (zh) * | 2016-12-28 | 2017-05-10 | 吴晓军 | 一种基于卷积神经网络的工业字符识别方法 |
CN108805188A (zh) * | 2018-05-29 | 2018-11-13 | 徐州工程学院 | 一种基于特征重标定生成对抗网络的图像分类方法 |
JP2019028839A (ja) * | 2017-08-01 | 2019-02-21 | 国立研究開発法人情報通信研究機構 | 分類器、分類器の学習方法、分類器における分類方法 |
Family Cites Families (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN106097355A (zh) * | 2016-06-14 | 2016-11-09 | 山东大学 | 基于卷积神经网络的胃肠道肿瘤显微高光谱图像处理方法 |
US11120337B2 (en) * | 2017-10-20 | 2021-09-14 | Huawei Technologies Co., Ltd. | Self-training method and system for semi-supervised learning with generative adversarial networks |
CN108537743B (zh) * | 2018-03-13 | 2022-05-20 | 杭州电子科技大学 | 一种基于生成对抗网络的面部图像增强方法 |
CN108763874A (zh) * | 2018-05-25 | 2018-11-06 | 南京大学 | 一种基于生成对抗网络的染色体分类方法及装置 |
CN109165666A (zh) * | 2018-07-05 | 2019-01-08 | 南京旷云科技有限公司 | 多标签图像分类方法、装置、设备及存储介质 |
CN109522973A (zh) * | 2019-01-17 | 2019-03-26 | 云南大学 | 基于生成式对抗网络与半监督学习的医疗大数据分类方法及系统 |
CN110097130B (zh) * | 2019-05-07 | 2022-12-13 | 深圳市腾讯计算机系统有限公司 | 分类任务模型的训练方法、装置、设备及存储介质 |
-
2019
- 2019-05-07 CN CN201910377510.3A patent/CN110097130B/zh active Active
-
2020
- 2020-04-16 WO PCT/CN2020/085006 patent/WO2020224403A1/zh unknown
- 2020-04-16 EP EP20802264.0A patent/EP3968222B1/en active Active
-
2021
- 2021-06-23 US US17/355,310 patent/US20210319258A1/en active Pending
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN106650721A (zh) * | 2016-12-28 | 2017-05-10 | 吴晓军 | 一种基于卷积神经网络的工业字符识别方法 |
JP2019028839A (ja) * | 2017-08-01 | 2019-02-21 | 国立研究開発法人情報通信研究機構 | 分類器、分類器の学習方法、分類器における分類方法 |
CN108805188A (zh) * | 2018-05-29 | 2018-11-13 | 徐州工程学院 | 一种基于特征重标定生成对抗网络的图像分类方法 |
Cited By (27)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2020224403A1 (zh) * | 2019-05-07 | 2020-11-12 | 腾讯科技(深圳)有限公司 | 分类任务模型的训练方法、装置、设备及存储介质 |
CN110570492B (zh) * | 2019-09-11 | 2021-09-03 | 清华大学 | 一种基于神经网络的ct伪影抑制方法、设备以及介质 |
CN110570492A (zh) * | 2019-09-11 | 2019-12-13 | 清华大学 | 神经网络训练方法和设备、图像处理方法和设备以及介质 |
CN110888911A (zh) * | 2019-10-11 | 2020-03-17 | 平安科技(深圳)有限公司 | 样本数据处理方法、装置、计算机设备及存储介质 |
WO2021068563A1 (zh) * | 2019-10-11 | 2021-04-15 | 平安科技(深圳)有限公司 | 样本数据处理方法、装置、计算机设备及存储介质 |
CN110732139B (zh) * | 2019-10-25 | 2024-03-05 | 腾讯科技(深圳)有限公司 | 检测模型的训练方法和用户数据的检测方法、装置 |
CN110732139A (zh) * | 2019-10-25 | 2020-01-31 | 腾讯科技(深圳)有限公司 | 检测模型的训练方法和用户数据的检测方法、装置 |
US11967312B2 (en) | 2019-10-30 | 2024-04-23 | Tencent Technology (Shenzhen) Company Limited | Method and apparatus for training semantic understanding model, electronic device, and storage medium |
WO2021082786A1 (zh) * | 2019-10-30 | 2021-05-06 | 腾讯科技(深圳)有限公司 | 语义理解模型的训练方法、装置、电子设备及存储介质 |
CN110706738A (zh) * | 2019-10-30 | 2020-01-17 | 腾讯科技(深圳)有限公司 | 蛋白质的结构信息预测方法、装置、设备及存储介质 |
CN111126503B (zh) * | 2019-12-27 | 2023-09-26 | 北京同邦卓益科技有限公司 | 一种训练样本的生成方法和装置 |
CN111126503A (zh) * | 2019-12-27 | 2020-05-08 | 北京同邦卓益科技有限公司 | 一种训练样本的生成方法和装置 |
CN111241969A (zh) * | 2020-01-06 | 2020-06-05 | 北京三快在线科技有限公司 | 目标检测方法、装置及相应模型训练方法、装置 |
CN111444967A (zh) * | 2020-03-30 | 2020-07-24 | 腾讯科技(深圳)有限公司 | 生成对抗网络的训练方法、生成方法、装置、设备及介质 |
CN111444967B (zh) * | 2020-03-30 | 2023-10-31 | 腾讯科技(深圳)有限公司 | 生成对抗网络的训练方法、生成方法、装置、设备及介质 |
CN111582647A (zh) * | 2020-04-09 | 2020-08-25 | 上海淇毓信息科技有限公司 | 用户数据处理方法、装置及电子设备 |
CN111291841A (zh) * | 2020-05-13 | 2020-06-16 | 腾讯科技(深圳)有限公司 | 图像识别模型训练方法、装置、计算机设备和存储介质 |
CN111291841B (zh) * | 2020-05-13 | 2020-08-21 | 腾讯科技(深圳)有限公司 | 图像识别模型训练方法、装置、计算机设备和存储介质 |
CN111832404A (zh) * | 2020-06-04 | 2020-10-27 | 中国科学院空天信息创新研究院 | 一种基于特征生成网络的小样本遥感地物分类方法及系统 |
WO2022042123A1 (zh) * | 2020-08-25 | 2022-03-03 | 深圳思谋信息科技有限公司 | 图像识别模型生成方法、装置、计算机设备和存储介质 |
WO2022135450A1 (zh) * | 2020-12-24 | 2022-06-30 | 华为技术有限公司 | 信息生成方法及相关装置 |
CN113723519B (zh) * | 2021-08-31 | 2023-07-25 | 平安科技(深圳)有限公司 | 基于对比学习的心电数据处理方法、装置及存储介质 |
CN113723519A (zh) * | 2021-08-31 | 2021-11-30 | 平安科技(深圳)有限公司 | 基于对比学习的心电数据处理方法、装置及存储介质 |
CN113610191B (zh) * | 2021-09-07 | 2023-08-29 | 中原动力智能机器人有限公司 | 垃圾分类模型建模方法、垃圾分类方法 |
CN113610191A (zh) * | 2021-09-07 | 2021-11-05 | 中原动力智能机器人有限公司 | 垃圾分类模型建模方法、垃圾分类方法及装置 |
CN113869398A (zh) * | 2021-09-26 | 2021-12-31 | 平安科技(深圳)有限公司 | 一种不平衡文本分类方法、装置、设备及存储介质 |
US11983363B1 (en) * | 2023-02-09 | 2024-05-14 | Primax Electronics Ltd. | User gesture behavior simulation system and user gesture behavior simulation method applied thereto |
Also Published As
Publication number | Publication date |
---|---|
CN110097130B (zh) | 2022-12-13 |
EP3968222A1 (en) | 2022-03-16 |
EP3968222A4 (en) | 2022-06-29 |
WO2020224403A1 (zh) | 2020-11-12 |
EP3968222B1 (en) | 2024-01-17 |
US20210319258A1 (en) | 2021-10-14 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110097130A (zh) | 分类任务模型的训练方法、装置、设备及存储介质 | |
Han et al. | Two-stage learning to predict human eye fixations via SDAEs | |
CN107492099B (zh) | 医学图像分析方法、医学图像分析系统以及存储介质 | |
CN109685819B (zh) | 一种基于特征增强的三维医学图像分割方法 | |
US10452899B2 (en) | Unsupervised deep representation learning for fine-grained body part recognition | |
Pan et al. | Shallow and deep convolutional networks for saliency prediction | |
Oyelade et al. | CovFrameNet: An enhanced deep learning framework for COVID-19 detection | |
US9111375B2 (en) | Evaluation of three-dimensional scenes using two-dimensional representations | |
CN113728335A (zh) | 用于3d图像的分类和可视化的方法和系统 | |
Rahman et al. | Hand gesture recognition using multiclass support vector machine | |
CN107220971A (zh) | 一种基于卷积神经网络和主成分分析法的肺结节特征提取方法 | |
CN109785399A (zh) | 合成病变图像的生成方法、装置、设备及可读存储介质 | |
Ogiela et al. | Natural user interfaces in medical image analysis | |
Mutepfe et al. | Generative adversarial network image synthesis method for skin lesion generation and classification | |
CN113011340B (zh) | 一种基于视网膜图像的心血管手术指标风险分类方法及系统 | |
Oliveira et al. | A comparison between end-to-end approaches and feature extraction based approaches for sign language recognition | |
Avraam | Static gesture recognition combining graph and appearance features | |
Kryvonos et al. | Information technology for the analysis of mimic expressions of human emotional states | |
Williams et al. | Fast blur detection and parametric deconvolution of retinal fundus images | |
CN110929731A (zh) | 一种基于探路者智能搜索算法的医疗影像处理方法及装置 | |
Ubale Kiru et al. | Comparative analysis of some selected generative adversarial network models for image augmentation: a case study of COVID-19 x-ray and CT images | |
Hossain et al. | Recognition of tuberculosis on medical X-ray images utilizing MobileNet transfer learning | |
Iqbal et al. | Implementation of the introduction of skin diseases based on augmented reality | |
Farouk | Principal component pyramids using image blurring for nonlinearity reduction in hand shape recognition | |
Santos et al. | Detection of Fundus Lesions through a Convolutional Neural Network in Patients with Diabetic Retinopathy |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |