CN111461329B - 一种模型的训练方法、装置、设备及可读存储介质 - Google Patents
一种模型的训练方法、装置、设备及可读存储介质 Download PDFInfo
- Publication number
- CN111461329B CN111461329B CN202010269451.0A CN202010269451A CN111461329B CN 111461329 B CN111461329 B CN 111461329B CN 202010269451 A CN202010269451 A CN 202010269451A CN 111461329 B CN111461329 B CN 111461329B
- Authority
- CN
- China
- Prior art keywords
- sample data
- model
- test
- prediction
- data
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 102
- 238000000034 method Methods 0.000 title claims abstract description 97
- 238000003860 storage Methods 0.000 title claims description 9
- 238000012360 testing method Methods 0.000 claims abstract description 155
- 230000008569 process Effects 0.000 claims abstract description 25
- 238000009826 distribution Methods 0.000 claims abstract description 18
- 230000006870 function Effects 0.000 claims description 79
- 238000004364 calculation method Methods 0.000 claims description 13
- 238000004590 computer program Methods 0.000 claims description 6
- 238000004140 cleaning Methods 0.000 abstract description 6
- 238000001514 detection method Methods 0.000 abstract description 4
- 230000000694 effects Effects 0.000 abstract description 4
- 238000004891 communication Methods 0.000 description 7
- 238000010801 machine learning Methods 0.000 description 7
- 238000010586 diagram Methods 0.000 description 6
- 230000009471 action Effects 0.000 description 3
- 230000007246 mechanism Effects 0.000 description 3
- 238000005457 optimization Methods 0.000 description 3
- 238000012545 processing Methods 0.000 description 2
- 238000005070 sampling Methods 0.000 description 2
- 230000003247 decreasing effect Effects 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 230000000750 progressive effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/20—Information retrieval; Database structures therefor; File system structures therefor of structured data, e.g. relational data
- G06F16/21—Design, administration or maintenance of databases
- G06F16/215—Improving data quality; Data cleansing, e.g. de-duplication, removing invalid entries or correcting typographical errors
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/10—Pre-processing; Data cleansing
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Life Sciences & Earth Sciences (AREA)
- Software Systems (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Databases & Information Systems (AREA)
- Computing Systems (AREA)
- Evolutionary Biology (AREA)
- Mathematical Physics (AREA)
- Biophysics (AREA)
- Quality & Reliability (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Medical Informatics (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Health & Medical Sciences (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请提供了一种模型的训练方法,将损失函数值小于或等于第一预设阈值的预设的目标模型作为待测试模型,获取待测试模型的测试结果,当测试结果满足预设的测试条件时,将待测试模型作为预测模型,或者,当预测结果不满足测试条件时,依据测试结果,更新样本数据的影响因子。本申请实施例提供的模型训练方法可以自动控制样本数据对模型训练的影响,一方面可以避免模型在训练过程中由于样本数据分布问题引起的模型训练效果不佳,最终导致得到的预测模型的预测准确度低。另一方面,相对于现有技术中,人工清洗数据的方法,避免了漏检的情况,节省了大量的人工成本和时间。
Description
技术领域
本申请涉及机器学习技术领域,更具体地说,涉及一种模型的训练方法、装置、设备及可读存储介质。
背景技术
模型训练所使用的样本数据往往参杂有脏数据或缺省数据,并且可能存在样本数据不均衡。现有的模型训练方法将样本数据直接用于机器学习任务,往往导致模型训练结果不准确,从而使训练得到的预测模型在实际预测过程中,预测结果与实际出现较大偏差。
发明内容
有鉴于此,本申请提供了一种模型的训练方法、装置、设备及可读存储介质,如下:
一种模型的训练方法,包括:
将损失函数值小于或等于第一预设阈值的预设的目标模型作为待测试模型;其中,所述损失函数值由样本数据的真实值以及所述样本数据的预设的影响因子,计算得到,所述影响因子用于表示所述样本数据的预测误差的权重值,任一所述样本数据的预测误差表示所述目标模型输出的所述样本数据的预测结果与所述样本数据的真实值之间的偏差;
获取所述待测试模型的测试结果;
当所述测试结果满足预设的测试条件时,将所述待测试模型作为预测模型;或者,当所述预测结果不满足所述测试条件时,依据所述测试结果,更新所述样本数据的所述影响因子。
可选地,在将小于或等于第一预设阈值的预设的目标模型作为待测试模型之前,还包括:
获取所述样本数据以及每一所述样本数据的所述影响因子;
将所述样本数据输入至所述目标模型;
获取所述目标模型输出的所述样本数据的预测结果;
依据所述样本数据的真实值、所述样本数据的预测结果以及所述样本数据的所述影响因子,计算所述损失函数值。
可选地,依据所述样本数据的真实值、所述样本数据的预测结果以及所述样本数据的所述影响因子,计算所述损失函数值,包括:
计算每一所述样本数据的真实值与所述样本数据的预测结果之间的偏差,作为所述样本数据的预测误差;
将每一所述样本数据的所述影响因子与所述样本数据的预测误差相乘,得到所述样本数据的预测损失值;
依据每一所述样本数据的预测损失值与预设的正则函数,计算得到所述目标模型的损失函数值。
可选地,获取所述待测试模型的测试结果,包括:
获取测试数据;
将所述测试数据输入至所述待测试模型,得到每一所述测试数据的预测结果;
依据每一所述测试数据的预测结果,计算所述待测试模型的测试结果,所述测试结果至少包括查准率和查全率。
可选地,预设的测试条件至少包括第一测试条件和第二测试条件,所述第一测试条件为,所述查全率大于第二预设阈值,所述第二测试条件为,所述查准率大于第三预设阈值。
可选地,依据所述测试结果,更新所述样本数据的所述影响因子,至少包括:
增大第一类样本数据的影响因子,其中,所述第一类样本数据的真实值与第一类测试数据的真实值相同,所述第一类测试数据的预测准确率小于第四预设阈值,所述预测准确率为所述第一类样本数据的预测结果与所述第一类样本数据的真实值相同的第一类样本数据的数量,与全部第一类测试数据的数量的比值。
一种模型的训练装置,包括:
模型获取单元,用于将损失函数值小于或等于第一预设阈值的预设的目标模型作为待测试模型;其中,所述损失函数值由样本数据的真实值以及所述样本数据的预设的影响因子,计算得到,所述影响因子用于表示所述样本数据的预测误差的权重值,任一所述样本数据的预测误差表示所述目标模型输出的所述样本数据的预测结果与所述样本数据的真实值之间的偏差;
测试结果获取单元,用于获取所述待测试模型的测试结果;
结果判定单元,用于当所述测试结果满足预设的测试条件时,将所述待测试模型作为预测模型;或者,当所述预测结果不满足所述测试条件时,依据所述测试结果,更新所述样本数据的所述影响因子。
可选地,模型的训练装置还包括:损失函数值计算单元,用于在将小于或等于第一预设阈值的预设的目标模型作为待测试模型之前,获取所述损失函数值;所述损失函数值计算单元具体用于:
获取所述样本数据以及每一所述样本数据的所述影响因子;
将所述样本数据输入至所述目标模型;
获取所述目标模型输出的所述样本数据的预测结果;
依据所述样本数据的真实值、所述样本数据的预测结果以及所述样本数据的所述影响因子,计算所述损失函数值。
一种模型的训练设备,包括:存储器和处理器;
所述存储器,用于存储程序;
所述处理器,用于执行所述程序,实现如上所述的模型的训练方法的各个步骤。
一种可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时,实现如上所述的模型的训练方法的各个步骤。
由上述技术方案可以看出,本申请实施例提供的模型训练方法,对经过样本数据训练的模型进行测试,并依据测试结果重置样本数据的影响因子。并且,本实施例提供的模型训练方法在原始损失函数机制基础上增加了影响因子,通过样本数据的影响因子来对损失函数中每一样本数据的预测误差进行加权相加,自动调整样本数据在训练过程中的分布,以此来控制每一个样本对训练过程的影响程度,决定最终模型的优化方向。因此,本方法得到的预测模型既满足训练条件(损失函数值小于或等于第一预设阈值),也满足预设的测试条件。显然,本申请实施例提供的模型训练方法可以自动控制样本数据对模型训练的影响,一方面可以避免模型在训练过程中由于样本数据分布问题(例如样本数据分布不均衡、数值缺省以及取值不合理)引起的模型训练效果不佳,最终导致得到的预测模型的预测准确度低。另一方面,相对于现有技术中,人工清洗数据的方法,避免了漏检的情况,节省了大量的人工成本和时间。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据提供的附图获得其他的附图。
图1为本申请实施例提供的模型的训练方法的一种具体实施方法的流程示意图;
图2为本申请实施例提供的一种模型的训练系统的结构示意图;
图3为本申请实施例提供的一种模型的训练方法的流程示意图;
图4为本申请实施例提供的一种模型的训练装置的结构示意图;
图5为本申请实施例提供的一种模型的训练设备的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
本申请实施例提供的模型的训练方法可以应用于对任意类型的机器学习模型进行训练的过程。例如,以二分类器为例,二分类器的输入数据为待分类数据,输出为输入数据所属的类别,一般地,可以为1或0。所以,训练任一二分类器的方法为将大量的已知类别的样本数据输入至待训练的模型,每一输入的样本数据对应的模型的目标输出即为样本数据的标签。但是,一般情况下,样本数据分布问题至少包括样本数据不均衡、样本数据存在特征值缺省、以及样本数据取值不合理。在实际的训练过程中,样本数据不均衡会导致训练好的二分类器的预测结果出现偏差,例如,二分类器更倾向于将输入数据预测成样本数据数量占多数的类别。并且,样本数据存在特征值缺省以及样本数据取值不合理,也会导致模型预测结果不合理。本申请实施例提出一种自动调节样本数据的分布的训练方法,图1为本申请实施例提供的模型的训练方法的一种具体实施方法的流程示意图,具体可以包括:
S101、获取样本数据。
本申请实施例中,可以获取样本数据集合,样本数据集合中包括多条带标签的样本数据,其中,样本数据的标签为该样本数据的真实值。需要说明的是,本实施例中获取的样本数据集合中不包括重复的样本数据,也即本实施例获取的样本数据集合为最小样本数据集合。
为描述方便,本实施例将包含有n条样本数据的样本数据集合(记为X)中任一样本数据记为xi,xi的标签为yi,其中1≤i≤n。
S102、将样本数据输入至目标模型进行预测。
需要说明的是,目标模型可以为任意一种类型的机器学习模型,例如,线性模型或神经网络模型。
S103、获取目标模型的预测结果。
可以理解的是,目标模型可以对输入的样本数据通过预测函数进行预测,得到每一样本数据的预测结果。需要说明的是,机器学习模型的结构不同,则预测函数不同,并且,预测函数中包括大量的模型参数。
本实施例中,目标模型的预测函数记为f,f中包括m个模型参数,任一模型参数记为ωj,1≤j≤m。
将样本数据xi输入至该目标模型,目标模型可以依据预测函数f对样本数据进行预测,并输出预测结果f(xi)。需要说明的是,目标模型的结构不同,则输出的预测结果f(xi)的类型不同。以二分类器为例,针对任一样本数据xi,二分类器可以输出该样本数据xi的预测值(1或0),也可以输出该样本数据xi为1的概率值。
本步骤可以获取所有样本数据的预测结果,即获取每一样本数据的预测值。
S104、针对每一样本数据,计算得到该样本数据的预测误差。
具体地,任一样本数据的预测误差可以表征该样本数据的预测值与真实值之间的偏差程度。本实施例中,任一样本数据的预测误差可以为样本数据的预测值与真实值的均方误差,或者预测误差也可以为样本数据的预测值与真实值的交叉熵。
本实施例中,均方误差的计算方法可以参照下述公式(1):
li=(yi-f(xi))2 (1)
其中,yi为样本数据xi的真实值,f(xi)为样本数据xi的预测值,li为样本数据xi的预测误差。通过公式(1)可以看出,本实施例可以通过每个样本数据的预测值与真实值的平方差来作为该样本数据的预测值与真实值之间的偏差程度的参照。
本实施例中,以二分类器为例,交叉熵的计算方法可以参照下述公式(2):
其中,yi为样本数据xi的真实值,f(xi)为样本数据xi的预测值为1的概率值,li为样本数据xi的预测误差。
需要说明的是,除上述介绍的样本数据的预测值与真实值的均方误差,或者样本数据的预测值与真实值的交叉熵,预测误差还可以为其他任意可以度量预测值与真实值之间的偏差程度的值,本实施例对此不做限定。
S105、依据预设的影响因子,计算每一样本数据的预测损失值。
其中,影响因子为预先设置的每个样本数据的权重值,需要说明的是,初始的影响因子需要数据分析师通过对样本数据进行观察来确定。本实施例中,对影响因子进行归一化处理,也即任一样本数据xi对应的影响因子αi取值为[0,1]。
依据预设的影响因子,计算每一样本数据的预测损失值的方法为,将样本数据的预测误差与样本数据的影响因子相乘。
S106、依据每一样本数据的预测损失值,计算目标模型的损失函数值。
本实施例中,目标模型的损失函数可以参照下述公式(3)。
其中,L(Y,f(X))为目标模型的损失函数,αili为任一样本数据的预测损失值,为目标模型的正则项。需要说明的是,γ为预设的正则参数,/>为预设的正则函数,ωj为模型参数。
需要说明的是,正则项可以对目标模型的模型参数进行约束,简化模型参数,进而防止目标模型的训练过程中出现过拟合现象。正则函数可以包括多种,其中,可选的一种正则函数可以为L1范数,即将目标模型中的所有模型参数的绝对值相加。可选的另一种正则函数可以为L1范数,即计算目标模型中的所有模型参数的平方和的平方根。
进一步需要说明的是,计算正则项的具体实施方式可以参照现有技术,本实施例不做赘述,并且对于具体计算方法不做限定。
S107、判断损失函数值是否大于第一预设阈值,若是,则执行S108,若否,则执行S109。
S108、当损失函数值大于第一预设阈值,更新模型参数,得到更新后的目标模型。进一步,返回S102,将样本数据输入至更新后的目标模型进行预测,重复上述S102~S107的模型训练过程。需要说明的是,执行一次S102~S107的训练过程,输入至目标模型中的训练数据的个数可以预设,例如,可以将样本数据集合中的所有样本数据输入至目标模型进行训练,也可以将样本数据集合中的样本数据的按照预设比例划分为训练数据和测试数据,将训练数据输入至目标模型进行训练。
本实施例中,更新模型参数的方法为:分别以每一模型参数为变量,对损失函数进行求导计算,求取更新模型参数的变化量,然后,根据每一模型参数的变化量,更新每一模型参数,得到更新后的目标模型。具体地,更新模型参数的方法可以参照现有技术,本申请实施例不做赘述。
S109、当损失函数值小于或等于第一预设阈值,将目标模型作为待测试模型,并将测试数据输入至待测试模型。
其中,测试数据可以由样本数据集合中的样本数据按比例划分得到,也可以获取新的带标签的样本数据作为测试数据,用于测试模型。
S110、获取待测试模型输出的每一测试数据的预测结果。
可以理解的是,待测试模型利用预测函数计算输入的每一测试数据的预测结果,并输出预测结果f(qr),其中qr为任一测试数据。
S111、依据所有测试数据的真实值与预测结果,获取测试结果。
本实施例中,测试结果可以包括多种,本实施例中,测试结果可以包括测试数据的查全率和查准率。
S112、判断测试结果是否满足预设的测试条件,若是则执行S113,若否则执行S114。
需要说明的是,预设的测试条件可以为查全率大于第二预设阈值,并且查准率大于第三预设阈值。
S113、测试结果满足预设的测试条件,将待测试模型确定为预测模型。
可以理解的是,预测模型经过训练和测试,可以用于数据的预测,并且预测结果具有较高的准确性。
S114、测试结果不满足预设的测试条件,依据测试数据的预测结果与真实值之间的预测误差,更新样本数据的影响因子。
具体地,将预测准确率大于第四预设阈值的测试数据记为第一类测试数据,增大与第一类测试数据同类别(真实值相同)的样本数据的影响因子,或减小与第一类测试数据不同类别(真实值不相同)的样本数据的影响因子。预测准确率指的是,在第一类测试数据中,预测错误的测试数据的数量与第一类测试数据的总数量的比值。
例如,待测试模型为二分类器,对于100条真实值为1的测试数据,二分类器输出的预测值的查准率为95%,对于100条真实值为0的测试数据,二分类器输出的预测值的查准率为80%,则本实施例可以增大样本数据中真实值为0的影响因子,并减小样本数据中真实值为1的影响因子。
进一步,返回S102,将样本数据输入至目标模型进行预测。
由上述技术方案可以看出,本实施例提供的模型训练方法中的损失函数(可参照上述公式3),对比现有技术的损失函数(参照下述公式4):
在原始损失函数机制基础上增加了影响因子,通过样本数据的影响因子来对损失函数中每一样本数据的预测误差进行加权相加,自动调整样本数据在训练过程中的分布,以此来控制每一个样本对训练过程的影响程度,决定最终模型的优化方向。并且,本方法中对训练好的模型进行测试,并依据测试结果重置样本数据的影响因子,直至目标模型满足训练条件,即损失函数值不大于第一预设阈值,并且满足预设的测试条件。显然,本申请实施例提供的模型训练方法可以自动控制样本数据对模型训练的影响,一方面可以避免模型在训练过程中由于样本数据分布问题(例如样本数据分布不均衡、数值缺省以及取值不合理)引起的模型训练效果不佳,最终导致得到的预测模型的预测准确度低。另一方面,相对于现有技术中,人工清洗数据的方法,避免了漏检的情况,节省了大量的人工成本和时间。
例如,现有的模型训练方法中,对于样本数据分布问题的处理方法之一为:由数据分析师通过观察样本数据,对缺省数据以及脏数据手动进行数据清洗。显然,大批量的数据清洗工作量会浪费大量的人工成本,导致数据处理效率低,并且受限于人工的水平,数据清洗的准确度往往较低。可见,本方法是通过设置每一样本数据的影响因子,调整样本数据对于训练的重要性,例如,当样本数据为脏数据,可以设置其影响因子为0,则自动排除该数据对训练的干扰。
再例如,现有技术中处理样本数据分布问题的另一种方法为:对样本数据进行over-sampling(上采样)或under-sampling(下采样),即增加数量过少的样本数据的数量,减少数量过多的样本数据的数量。但是,这种方法的工作效率很低,例如,需要提前统计各种样本数据的数量,并且,上采样后样本数据中带有重复的样本数据,增加了资源存储的负担。可见,本方法是通过设置每一样本数据的影响因子,可以将样本数据的数据量体现在样本数据的影响因子上,因此,本方法中使用的样本数据集合为最小样本数据,无需使用大量的重复数据,提高了训练的效率,并且对于大数据来说,节省了大量的存储资源。
需要说明的是,本申请提供的模型的训练方法可以应用于模型的训练系统,图2示例了一种模型的训练系统的结构示意图。如图所示,具体可以包括:
样本数据获取单元201,用于获取样本数据。
模型预测单元202,用于获取预测结果。
误差计算单元203,用于计算预测误差,其中包括多个误差计算器,每一误差计算器可以计算一条样本数据的预测误差。
影响度控制单元204,用于计算预测损失,其中包括多个影响度控制门,每一影响度控制门可以依据一条样本数据的预测误差以及预设的影响因子,计算该条样本数据的预测损失值。
损失函数计算单元205,用于计算损失函数值。
第一判断单元206,用于判断损失函数值与第一预设阈值的大小。
模型更新单元207,用于更新模型参数。
第一模型生成单元208,用于生成待测试模型。
模型测试单元209,用于获取测试结果。
第二判断单元210,用于判断测试结果是否满足预设的测试条件。
影响因子调整单元211,用于调整每一样本数据的影响因子。
第二模型生成单元212,用于生成最终的预测模型。
需要说明的是,上述各个单元可以单独设置于一个模块,也可以多个单元设置于同一模块,执行相应的功能。具体的执行过程,可以与上述模型的训练方法相互参照,本实施例不做赘述。
由上述模型的训练系统可以看出,本申请实施例相对于现有的模型训练系统,增加影响度控制单元、模型测试单元、第二判断单元、影响因子调整单元、以及第二模型生成单元。其中,影响度控制单元将每一个样本数据的预测误差,乘以该样本数据的影响因子对样本分布进行调整,可以理解的是,影响因子越大,样本数据在模型训练过程中起到的作用就越大。因此,影响度控制单元可以控制每一样本数据对训练过程中模型的影响程度。又因为,影响因子调整单元可以获取调整后的影响因子,所以,本系统可以依据测试结果随时获取调整后的影响因子,保证预测模型的准确度。
综上,本申请实施例提供的模型训练方法通过设置并调整每一样本数据的影响因子,改善由样本数据分布问题造成的预测模型不准确的缺点。具体地,图3为本申请实施例提供的模型的训练方法的流程示意图,如图3所示,本实施例将模型的训练方法总结如下述S301~S303。
S301、将损失函数值小于或等于第一预设阈值的预设的目标模型作为待测试模型。
其中,预设的目标模型为待训练的机器学习模型,机器学习模型可以为任一类型的模型。
损失函数值由样本数据的真实值以及样本数据的预设的影响因子,计算得到。具体地,计算损失函数值的方法包括多种,其中可选的一种方法可以包括A1~A4。
A1、将样本数据输入至目标模型,并获取目标模型输出的样本数据的预测结果。
A2、计算每一样本数据的真实值与样本数据的预测结果之间的偏差,作为样本数据的预测误差,可选地,样本数据的预测误差可以为均方误差或者交叉熵,可以理解的是,预测误差可以表征目标模型的预测准确度。
A3、将每一样本数据的影响因子与样本数据的预测误差相乘,得到样本数据的预测损失值。
A4、依据每一样本数据的预测损失值与预设的正则函数,计算得到目标模型的损失函数值。其中,正则函数用于防止目标模型的训练过程中出现过拟合现象。
需要说明的是,损失函数值的具体计算方法可以参照上述S101~S106。可以理解的是,预测损失值越大,表示该目标模型的预测准确度越低,预测损失值越小,表示该目标模型的预测准确度越高。因此,本实施例中,当损失函数值小于或等于第一预设阈值,将目标模型作为待测试模型进行进一步的测试。
进一步需要说明的是,本实施例中,当损失函数值大于第一预设阈值,更新目标模型中模型参数,重新计算损失函数值。
S302、获取待测试模型的测试结果。
具体地,首先,将测试数据输入至待测试模型,得到每一测试数据的预测结果,其中,测试数据为已知真实值的数据,可以从样本数据中选取。
进一步,依据每一测试数据的预测结果,计算待测试模型的测试结果。本实施例中,测试结果至少包括查准率和查全率。其中,查准率为预测结果等于真实值的测试数据的数量,与测试数据的总数量的比值。查全率为预测结果不为空的测试数据的数量,与测试数据的总数量的比值。需要说明的是,预测结果为空指的是待测试模型没有输出该测试数据的预测结果,也即测试失败。
S303、当测试结果满足预设的测试条件时,将待测试模型作为预测模型。或者,当预测结果不满足测试条件时,更新样本数据的影响因子。
具体地,待测试模型的查准率和查全率均能表征待测试模型的测试准确度,所以预设的测试条件至少包括第一测试条件和第二测试条件,第一测试条件为,查全率大于第二预设阈值。第二测试条件为,查准率大于第三预设阈值。
本实施例中,将查全率大于第二预设阈值并且查准率大于第三预设阈值的待测试模型作为预测模型,即该待测试模型为预测准确率较高的模型。
本实施例中,当待测试模型的查全率不大于第二预设阈值和/或查准率不大于第三预设阈值时,将待测试模型重新作为目标模型,并更新样本数据的影响因子重新对该目标模型进行训练。
需要说明的是,更新样本数据的影响因子的方法为依据测试结果,更新影响因子,具体的更新方法可以参照上述S114,本实施例不做赘述。
由上述技术方案可以看出,本申请实施例提供的模型训练方法,本方法中对经过样本数据训练的模型进行测试,并依据测试结果重置样本数据的影响因子。其中,本实施例提供的模型训练方法在原始损失函数机制基础上增加了影响因子,通过样本数据的影响因子来对损失函数中每一样本数据的预测误差进行加权相加,自动调整样本数据在训练过程中的分布,以此来控制每一个样本对训练过程的影响程度,决定最终模型的优化方向。因此,本方法得到的预测模型既满足训练条件(损失函数值小于或等于第一预设阈值),也满足预设的测试条件。显然,本申请实施例提供的模型训练方法可以自动控制样本数据对模型训练的影响,一方面可以避免模型在训练过程中由于样本数据分布问题(例如样本数据分布不均衡、数值缺省以及取值不合理)引起的模型训练效果不佳,最终导致得到的预测模型的预测准确度低。另一方面,相对于现有技术中,人工清洗数据的方法,避免了漏检的情况,节省了大量的人工成本和时间。
本申请实施例还提供了一种模型的训练装置,下面对本申请实施例提供的模型的训练装置进行描述,下文描述的模型的训练装置与上文描述的模型的训练方法可相互对应参照。
请参阅图4,示出了本申请实施例提供的一种模型的训练装置的结构示意图,如图4所示,该装置可以包括:
模型获取单元401,用于将损失函数值小于或等于第一预设阈值的预设的目标模型作为待测试模型;其中,所述损失函数值由样本数据的真实值以及所述样本数据的预设的影响因子,计算得到,所述影响因子用于表示所述样本数据的预测误差的权重值,任一所述样本数据的预测误差表示所述目标模型输出的所述样本数据的预测结果与所述样本数据的真实值之间的偏差;
测试结果获取单元402,用于获取所述待测试模型的测试结果;
结果判定单元403,用于当所述测试结果满足预设的测试条件时,将所述待测试模型作为预测模型;或者,当所述预测结果不满足所述测试条件时,依据所述测试结果,更新所述样本数据的所述影响因子。
可选地,该装置还包括:损失函数值计算单元,用于在将小于或等于第一预设阈值的预设的目标模型作为待测试模型之前,获取所述损失函数值;所述损失函数值计算单元具体用于:
获取所述样本数据以及每一所述样本数据的所述影响因子;
将所述样本数据输入至所述目标模型;
获取所述目标模型输出的所述样本数据的预测结果;
依据所述样本数据的真实值、所述样本数据的预测结果以及所述样本数据的所述影响因子,计算所述损失函数值。
可选地,损失函数值计算单元用于依据所述样本数据的真实值、所述样本数据的预测结果以及所述样本数据的所述影响因子,计算所述损失函数值,包括:损失函数值计算单元具体用于:
计算每一所述样本数据的真实值与所述样本数据的预测结果之间的偏差,作为所述样本数据的预测误差;
将每一所述样本数据的所述影响因子与所述样本数据的预测误差相乘,得到所述样本数据的预测损失值;
依据每一所述样本数据的预测损失值与预设的正则函数,计算得到所述目标模型的损失函数值。
可选地,测试结果获取单元用于获取所述待测试模型的测试结果,包括:测试结果获取单元具体用于:
获取测试数据;
将所述测试数据输入至所述待测试模型,得到每一所述测试数据的预测结果;
依据每一所述测试数据的预测结果,计算所述待测试模型的测试结果,所述测试结果至少包括查准率和查全率。
可选地,所述预设的测试条件至少包括第一测试条件和第二测试条件,所述第一测试条件为,所述查全率大于第二预设阈值,所述第二测试条件为,所述查准率大于第三预设阈值。
可选地,结果判定单元用于依据所述测试结果,更新所述样本数据的所述影响因子,至少包括:结果判定单元具体用于:
增大第一类样本数据的影响因子,其中,所述第一类样本数据的真实值与第一类测试数据的真实值相同,所述第一类测试数据的预测准确率小于第四预设阈值,所述预测准确率为所述第一类样本数据的预测结果与所述第一类样本数据的真实值相同的第一类样本数据的数量,与全部第一类测试数据的数量的比值。
本申请实施例还提供了一种模型的训练设备,请参阅图5,示出了该模型的训练设备的结构示意图,该设备可以包括:至少一个处理器501,至少一个通信接口502,至少一个存储器503和至少一个通信总线504;
在本申请实施例中,处理器501、通信接口502、存储器503、通信总线504的数量为至少一个,且处理器501、通信接口502、存储器503通过通信总线504完成相互间的通信;
处理器501可能是一个中央处理器CPU,或者是特定集成电路ASIC(ApplicationSpecific Integrated Circuit),或者是被配置成实施本发明实施例的一个或多个集成电路等;
存储器503可能包含高速RAM存储器,也可能还包括非易失性存储器(non-volatile memory)等,例如至少一个磁盘存储器;
其中,存储器存储有程序,处理器可执行存储器存储的程序,实现上述的模型的训练方法的各个步骤。
本申请实施例还提供一种可读存储介质,该可读存储介质可存储有适于处理器执行的计算机程序,计算机程序被处理器执行时,实现上述的模型的训练方法的各个步骤。
最后,还需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
本说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似部分互相参见即可。
对所公开的实施例的上述说明,使本领域专业技术人员能够实现或使用本申请。对这些实施例的多种修改对本领域的专业技术人员来说将是显而易见的,本文中所定义的一般原理可以在不脱离本申请的精神或范围的情况下,在其它实施例中实现。因此,本申请将不会被限制于本文所示的这些实施例,而是要符合与本文所公开的原理和新颖特点相一致的最宽的范围。
Claims (9)
1.一种模型的训练方法,其特征在于,包括:
将损失函数值小于或等于第一预设阈值的预设的目标模型作为待测试模型;其中,所述损失函数值由样本数据的真实值以及所述样本数据的预设的影响因子,计算得到,所述影响因子用于表示所述样本数据的预测误差的权重值,任一所述样本数据的预测误差表示所述目标模型输出的所述样本数据的预测结果与所述样本数据的真实值之间的偏差,所述影响因子对损失函数中每一样本数据的预测误差进行加权相加,自动调整样本数据在训练过程中的分布;
获取所述待测试模型的测试结果;
当所述测试结果满足预设的测试条件时,将所述待测试模型作为预测模型;或者,当所述预测结果不满足所述测试条件时,依据所述测试结果,更新所述样本数据的所述影响因子;
其中,所述依据所述测试结果,更新所述样本数据的所述影响因子,至少包括:
增大第一类样本数据的影响因子,其中,所述第一类样本数据的真实值与第一类测试数据的真实值相同,所述第一类测试数据的预测准确率小于第四预设阈值,所述预测准确率为所述第一类样本数据的预测结果与所述第一类样本数据的真实值相同的第一类样本数据的数量,与全部第一类测试数据的数量的比值;当样本数据为脏数据,设置其影响因子为0,以自动排除该数据对训练的干扰。
2.根据权利要求1所述的模型的训练方法,其特征在于,在将小于或等于第一预设阈值的预设的目标模型作为待测试模型之前,还包括:
获取所述样本数据以及每一所述样本数据的所述影响因子;
将所述样本数据输入至所述目标模型;
获取所述目标模型输出的所述样本数据的预测结果;
依据所述样本数据的真实值、所述样本数据的预测结果以及所述样本数据的所述影响因子,计算所述损失函数值。
3.根据权利要求2所述的模型的训练方法,其特征在于,所述依据所述样本数据的真实值、所述样本数据的预测结果以及所述样本数据的所述影响因子,计算所述损失函数值,包括:
计算每一所述样本数据的真实值与所述样本数据的预测结果之间的偏差,作为所述样本数据的预测误差;
将每一所述样本数据的所述影响因子与所述样本数据的预测误差相乘,得到所述样本数据的预测损失值;
依据每一所述样本数据的预测损失值与预设的正则函数,计算得到所述目标模型的损失函数值。
4.根据权利要求1所述的模型的训练方法,其特征在于,所述获取所述待测试模型的测试结果,包括:
获取测试数据;
将所述测试数据输入至所述待测试模型,得到每一所述测试数据的预测结果;
依据每一所述测试数据的预测结果,计算所述待测试模型的测试结果,所述测试结果至少包括查准率和查全率。
5.根据权利要求4所述的模型的训练方法,其特征在于,所述预设的测试条件至少包括第一测试条件和第二测试条件,所述第一测试条件为,所述查全率大于第二预设阈值,所述第二测试条件为,所述查准率大于第三预设阈值。
6.一种模型的训练装置,其特征在于,包括:
模型获取单元,用于将损失函数值小于或等于第一预设阈值的预设的目标模型作为待测试模型;其中,所述损失函数值由样本数据的真实值以及所述样本数据的预设的影响因子,计算得到,所述影响因子用于表示所述样本数据的预测误差的权重值,任一所述样本数据的预测误差表示所述目标模型输出的所述样本数据的预测结果与所述样本数据的真实值之间的偏差,所述影响因子对损失函数中每一样本数据的预测误差进行加权相加,自动调整样本数据在训练过程中的分布;
测试结果获取单元,用于获取所述待测试模型的测试结果;
结果判定单元,用于当所述测试结果满足预设的测试条件时,将所述待测试模型作为预测模型;或者,当所述预测结果不满足所述测试条件时,依据所述测试结果,更新所述样本数据的所述影响因子;其中,所述依据所述测试结果,更新所述样本数据的所述影响因子,至少包括:增大第一类样本数据的影响因子,其中,所述第一类样本数据的真实值与第一类测试数据的真实值相同,所述第一类测试数据的预测准确率小于第四预设阈值,所述预测准确率为所述第一类样本数据的预测结果与所述第一类样本数据的真实值相同的第一类样本数据的数量,与全部第一类测试数据的数量的比值;当样本数据为脏数据,设置其影响因子为0,以自动排除该数据对训练的干扰。
7.根据权利要求6所述的模型的训练装置,其特征在于,还包括:损失函数值计算单元,用于在将小于或等于第一预设阈值的预设的目标模型作为待测试模型之前,获取所述损失函数值;所述损失函数值计算单元具体用于:
获取所述样本数据以及每一所述样本数据的所述影响因子;
将所述样本数据输入至所述目标模型;
获取所述目标模型输出的所述样本数据的预测结果;
依据所述样本数据的真实值、所述样本数据的预测结果以及所述样本数据的所述影响因子,计算所述损失函数值。
8.一种模型的训练设备,其特征在于,包括:存储器和处理器;
所述存储器,用于存储程序;
所述处理器,用于执行所述程序,实现如权利要求1~5中任一项所述的模型的训练方法的各个步骤。
9.一种可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时,实现如权利要求1~5中任一项所述的模型的训练方法的各个步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010269451.0A CN111461329B (zh) | 2020-04-08 | 2020-04-08 | 一种模型的训练方法、装置、设备及可读存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010269451.0A CN111461329B (zh) | 2020-04-08 | 2020-04-08 | 一种模型的训练方法、装置、设备及可读存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111461329A CN111461329A (zh) | 2020-07-28 |
CN111461329B true CN111461329B (zh) | 2024-01-23 |
Family
ID=71681409
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010269451.0A Active CN111461329B (zh) | 2020-04-08 | 2020-04-08 | 一种模型的训练方法、装置、设备及可读存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111461329B (zh) |
Families Citing this family (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112801178B (zh) * | 2021-01-26 | 2024-04-09 | 上海明略人工智能(集团)有限公司 | 模型训练方法、装置、设备及计算机可读介质 |
CN114880995B (zh) * | 2022-06-30 | 2022-10-04 | 浙江大华技术股份有限公司 | 算法方案部署方法及相关装置、设备和存储介质 |
Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108520220A (zh) * | 2018-03-30 | 2018-09-11 | 百度在线网络技术(北京)有限公司 | 模型生成方法和装置 |
CN109214436A (zh) * | 2018-08-22 | 2019-01-15 | 阿里巴巴集团控股有限公司 | 一种针对目标场景的预测模型训练方法及装置 |
CN109409318A (zh) * | 2018-11-07 | 2019-03-01 | 四川大学 | 统计模型的训练方法、统计方法、装置及存储介质 |
CN109815332A (zh) * | 2019-01-07 | 2019-05-28 | 平安科技(深圳)有限公司 | 损失函数优化方法、装置、计算机设备及存储介质 |
CN109871702A (zh) * | 2019-02-18 | 2019-06-11 | 深圳前海微众银行股份有限公司 | 联邦模型训练方法、系统、设备及计算机可读存储介质 |
CN110070117A (zh) * | 2019-04-08 | 2019-07-30 | 腾讯科技(深圳)有限公司 | 一种数据处理方法及装置 |
WO2020022639A1 (ko) * | 2018-07-18 | 2020-01-30 | 한국과학기술정보연구원 | 딥러닝 기반의 가치 평가 방법 및 그 장치 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
JP6954082B2 (ja) * | 2017-12-15 | 2021-10-27 | 富士通株式会社 | 学習プログラム、予測プログラム、学習方法、予測方法、学習装置および予測装置 |
-
2020
- 2020-04-08 CN CN202010269451.0A patent/CN111461329B/zh active Active
Patent Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108520220A (zh) * | 2018-03-30 | 2018-09-11 | 百度在线网络技术(北京)有限公司 | 模型生成方法和装置 |
WO2020022639A1 (ko) * | 2018-07-18 | 2020-01-30 | 한국과학기술정보연구원 | 딥러닝 기반의 가치 평가 방법 및 그 장치 |
CN109214436A (zh) * | 2018-08-22 | 2019-01-15 | 阿里巴巴集团控股有限公司 | 一种针对目标场景的预测模型训练方法及装置 |
CN109409318A (zh) * | 2018-11-07 | 2019-03-01 | 四川大学 | 统计模型的训练方法、统计方法、装置及存储介质 |
CN109815332A (zh) * | 2019-01-07 | 2019-05-28 | 平安科技(深圳)有限公司 | 损失函数优化方法、装置、计算机设备及存储介质 |
CN109871702A (zh) * | 2019-02-18 | 2019-06-11 | 深圳前海微众银行股份有限公司 | 联邦模型训练方法、系统、设备及计算机可读存储介质 |
CN110070117A (zh) * | 2019-04-08 | 2019-07-30 | 腾讯科技(深圳)有限公司 | 一种数据处理方法及装置 |
Also Published As
Publication number | Publication date |
---|---|
CN111461329A (zh) | 2020-07-28 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN106874581B (zh) | 一种基于bp神经网络模型的建筑空调能耗预测方法 | |
CN107832581B (zh) | 状态预测方法和装置 | |
US20170300546A1 (en) | Method and Apparatus for Data Processing in Data Modeling | |
TWI539298B (zh) | 具取樣率決定機制的量測抽樣方法 與其電腦程式產品 | |
CN111461329B (zh) | 一种模型的训练方法、装置、设备及可读存储介质 | |
CN107730286A (zh) | 一种目标客户筛选方法及装置 | |
CN107908864A (zh) | 一种基于特征融合的复杂设备剩余使用寿命预测方法 | |
CN112418921A (zh) | 用电需量预测方法、装置、系统与计算机存储介质 | |
CA2344769A1 (en) | System and method for on-line adaptive prediction using dynamic management of multiple sub-models | |
CN107220500B (zh) | 基于逆高斯过程的性能退化试验贝叶斯可靠性评估方法 | |
CN114978956A (zh) | 智慧城市网络设备性能异常突变点检测方法及装置 | |
CN103489034A (zh) | 预测与诊断在线海流监测数据的方法和装置 | |
CN111325310A (zh) | 一种数据预测方法、装置及存储介质 | |
CN112182056A (zh) | 一种数据检测方法、装置、设备及存储介质 | |
CN109376929B (zh) | 配送参数的确定方法、确定装置、存储介质和电子设备 | |
CN114840375A (zh) | 一种半导体存储产品的老化性能测试方法及系统 | |
CN114330102A (zh) | 基于降雨相似与模型参数智能适配的快速洪水预报方法及装置 | |
CN110929849B (zh) | 一种基于神经网络模型压缩的视频检测方法和装置 | |
CN109993374B (zh) | 货物量预测方法及装置 | |
CN113592090B (zh) | 基于深度学习的建筑质量预测方法、装置及存储介质 | |
CN112800037B (zh) | 工程造价数据处理的优化方法及装置 | |
CN110134575B (zh) | 一种服务器集群的服务能力计算方法及装置 | |
CN115453447A (zh) | 基于嫌疑电表分步补偿剔除的超差电表在线检测方法 | |
CN113610148A (zh) | 一种基于偏置加权AdaBoost的故障诊断方法 | |
CN110236558A (zh) | 婴儿发育情况预测方法、装置、存储介质及电子设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |