Nothing Special   »   [go: up one dir, main page]

CN118245638B - Method, device, equipment and storage medium for predicting graph data based on generalization model - Google Patents

Method, device, equipment and storage medium for predicting graph data based on generalization model Download PDF

Info

Publication number
CN118245638B
CN118245638B CN202410649703.0A CN202410649703A CN118245638B CN 118245638 B CN118245638 B CN 118245638B CN 202410649703 A CN202410649703 A CN 202410649703A CN 118245638 B CN118245638 B CN 118245638B
Authority
CN
China
Prior art keywords
sample
node
target
subgraph
matrix
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202410649703.0A
Other languages
Chinese (zh)
Other versions
CN118245638A (en
Inventor
张亮
胡志威
朱光旭
史清江
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Chinese University of Hong Kong Shenzhen
Original Assignee
Chinese University of Hong Kong Shenzhen
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Chinese University of Hong Kong Shenzhen filed Critical Chinese University of Hong Kong Shenzhen
Priority to CN202410649703.0A priority Critical patent/CN118245638B/en
Publication of CN118245638A publication Critical patent/CN118245638A/en
Application granted granted Critical
Publication of CN118245638B publication Critical patent/CN118245638B/en
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/90Details of database functions independent of the retrieved data types
    • G06F16/901Indexing; Data structures therefor; Storage structures
    • G06F16/9024Graphs; Linked lists
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/90Details of database functions independent of the retrieved data types
    • G06F16/903Querying
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/90Details of database functions independent of the retrieved data types
    • G06F16/906Clustering; Classification
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/042Knowledge-based neural networks; Logical representations of neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Databases & Information Systems (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Software Systems (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本申请实施例提供了一种基于泛化模型的图数据预测方法、装置、设备及存储介质,属于图数据处理技术领域。方法包括:获取目标图数据,其中,目标图数据包括多个目标节点和对应的节点连接结构;基于目标图数据中的多个目标节点之间组成的节点连接结构,查找出每个目标节点在目标关联层级内的目标连接结构的目标子图;将目标子图输入至泛化模型,得到目标子图中的目标节点的节点类别标签;其中,泛化模型基于目标损失最小化的训练过程进行样本子图中样本节点连接结构与样本节点类别标签之间的因果预测学习得到,目标损失由第一损失、第二损失和第三损失构成。本申请能够提高对图数据预测的准确性。

The embodiments of the present application provide a method, device, equipment and storage medium for predicting graph data based on a generalization model, which belongs to the field of graph data processing technology. The method includes: obtaining target graph data, wherein the target graph data includes multiple target nodes and corresponding node connection structures; based on the node connection structure composed of multiple target nodes in the target graph data, finding the target subgraph of the target connection structure of each target node in the target association level; inputting the target subgraph into the generalization model to obtain the node category label of the target node in the target subgraph; wherein the generalization model is based on the training process of minimizing the target loss to perform causal prediction learning between the sample node connection structure and the sample node category label in the sample subgraph, and the target loss is composed of the first loss, the second loss and the third loss. The present application can improve the accuracy of graph data prediction.

Description

基于泛化模型的图数据预测方法、装置、设备及存储介质Graph data prediction method, device, equipment and storage medium based on generalization model

技术领域Technical Field

本申请涉及图数据处理技术领域,尤其涉及一种基于泛化模型的图数据预测方法、装置、设备及存储介质。The present application relates to the technical field of graph data processing, and in particular to a graph data prediction method, device, equipment and storage medium based on a generalized model.

背景技术Background Art

图神经网络(Graph Neural Networks,GNNs)是用于处理图结构数据的深度学习模型。在图神经网络中,通过考虑节点之间的关系(即边的存在)来更新和聚合节点的特征,以逐步构建出节点之间的关系,并学习到图的结构表示。Graph Neural Networks (GNNs) are deep learning models for processing graph-structured data. In a graph neural network, the features of nodes are updated and aggregated by considering the relationship between nodes (i.e., the existence of edges), so as to gradually build the relationship between nodes and learn the structural representation of the graph.

相关技术中,在对图神经网络进行训练的过程中,通常在最大化图数据和真实标签之间的互信息的前提下,从输入图数据中提取关键特征用于对图数据的预测。然而,这种训练方式使得图神经网络尽可能学习图数据中输入特征和标签之间的所有统计相关性,没有区分输入特征与标签之间的因果效应和非因果效应,使得图神经网络倾向于访问非因果特征作为图数据预测的捷径,降低了图神经网络训练完毕后,对图数据预测的准确性。In the related art, during the training of graph neural networks, key features are usually extracted from the input graph data for prediction of the graph data under the premise of maximizing the mutual information between the graph data and the true label. However, this training method makes the graph neural network learn all the statistical correlations between the input features and labels in the graph data as much as possible, without distinguishing between the causal effect and the non-causal effect between the input features and the labels, which makes the graph neural network tend to access non-causal features as a shortcut to predict graph data, reducing the accuracy of graph data prediction after the graph neural network is trained.

发明内容Summary of the invention

本申请实施例的主要目的在于提出一种基于泛化模型的图数据预测方法、装置、设备及存储介质,能够提高对图数据预测的准确性。The main purpose of the embodiments of the present application is to propose a graph data prediction method, device, equipment and storage medium based on a generalized model, which can improve the accuracy of graph data prediction.

为实现上述目的,本申请实施例的第一方面提出了一种基于泛化模型的图数据预测方法,所述方法包括:To achieve the above-mentioned purpose, a first aspect of an embodiment of the present application proposes a graph data prediction method based on a generalization model, the method comprising:

获取目标图数据,其中,所述目标图数据包括多个目标节点和所述多个目标节点组成的节点连接结构;Acquire target graph data, wherein the target graph data includes a plurality of target nodes and a node connection structure composed of the plurality of target nodes;

基于所述目标图数据中的多个所述目标节点之间组成的节点连接结构,查找出每个目标节点在目标关联层级内的目标连接结构的目标子图;Based on the node connection structure formed between the plurality of target nodes in the target graph data, finding a target subgraph of the target connection structure of each target node in the target association level;

将所述目标子图输入至泛化模型,得到所述目标子图中的目标节点的节点类别标签;Inputting the target subgraph into a generalization model to obtain a node category label of a target node in the target subgraph;

其中,所述泛化模型基于目标损失最小化的训练过程进行样本子图中样本节点连接结构与样本节点类别标签之间的因果预测学习得到,所述目标损失由第一损失、第二损失和第三损失构成,所述第一损失根据对样本子图的核心特征在对应的所述样本节点类别标签下的第一预测概率确定;所述第二损失根据对所述样本子图的冗余特征在对应的所述样本节点类别标签下的第二预测概率确定;所述第三损失根据所述核心特征和所述冗余特征之间的样本距离确定,所述核心特征为按照注意力机制得到的边掩码矩阵对样本子图的样本节点连接结构进行增强表示得到的特征;所述冗余特征为按照所述边掩码矩阵的补集矩阵对样本子图的样本节点连接结构进行弱化表示得到的特征。The generalization model is obtained by performing causal prediction learning between the sample node connection structure and the sample node category label in the sample subgraph based on a training process of minimizing the target loss, and the target loss is composed of a first loss, a second loss and a third loss. The first loss is determined according to the first prediction probability of the core feature of the sample subgraph under the corresponding sample node category label; the second loss is determined according to the second prediction probability of the redundant feature of the sample subgraph under the corresponding sample node category label; the third loss is determined according to the sample distance between the core feature and the redundant feature, the core feature is a feature obtained by enhancing the sample node connection structure of the sample subgraph according to the edge mask matrix obtained by the attention mechanism; the redundant feature is a feature obtained by weakening the sample node connection structure of the sample subgraph according to the complement matrix of the edge mask matrix.

相应的,本申请实施例的第二方面提出了一种基于泛化模型的图数据预测装置,所述装置包括:Accordingly, a second aspect of an embodiment of the present application proposes a graph data prediction device based on a generalization model, the device comprising:

获取模块,用于获取目标图数据,其中,所述目标图数据包括多个目标节点和所述多个目标节点组成的节点连接结构;An acquisition module, used to acquire target graph data, wherein the target graph data includes a plurality of target nodes and a node connection structure composed of the plurality of target nodes;

查找模块,用于基于所述目标图数据中的多个所述目标节点之间组成的节点连接结构,查找出每个目标节点在目标关联层级内的目标连接结构的目标子图;A search module, configured to search for a target subgraph of a target connection structure of each target node within a target association level based on a node connection structure formed between a plurality of the target nodes in the target graph data;

输入模块,用于将所述目标子图输入至泛化模型,得到所述目标子图中的目标节点的节点类别标签;其中,所述泛化模型基于目标损失最小化的训练过程进行样本子图中样本节点连接结构与样本节点类别标签之间的因果预测学习得到,所述目标损失由第一损失、第二损失和第三损失构成,所述第一损失根据对样本子图的核心特征在对应的所述样本节点类别标签下的第一预测概率确定;所述第二损失根据对所述样本子图的冗余特征在对应的所述样本节点类别标签下的第二预测概率确定;所述第三损失根据所述核心特征和所述冗余特征之间的样本距离确定,所述核心特征为按照注意力机制得到的边掩码矩阵对样本子图的样本节点连接结构进行增强表示得到的特征;所述冗余特征为按照所述边掩码矩阵的补集矩阵对样本子图的样本节点连接结构进行弱化表示得到的特征。An input module is used to input the target subgraph into a generalization model to obtain a node category label of a target node in the target subgraph; wherein the generalization model is obtained by performing causal prediction learning between a sample node connection structure and a sample node category label in a sample subgraph based on a training process of minimizing a target loss, wherein the target loss is composed of a first loss, a second loss and a third loss, wherein the first loss is determined according to a first prediction probability of a core feature of the sample subgraph under the corresponding sample node category label; the second loss is determined according to a second prediction probability of a redundant feature of the sample subgraph under the corresponding sample node category label; the third loss is determined according to a sample distance between the core feature and the redundant feature, wherein the core feature is a feature obtained by enhancing the representation of the sample node connection structure of the sample subgraph according to an edge mask matrix obtained by an attention mechanism; and the redundant feature is a feature obtained by weakening the representation of the sample node connection structure of the sample subgraph according to a complement matrix of the edge mask matrix.

在一些实施方式中,所述样本子图的样本节点包括中心样本节点和多个相邻样本节点;所述基于泛化模型的图数据预测装置还包括训练模块,用于:In some implementations, the sample nodes of the sample subgraph include a central sample node and a plurality of adjacent sample nodes; and the graph data prediction apparatus based on a generalization model further includes a training module for:

获取用于训练预设模型的样本图数据,并按照所述样本图数据的样本节点连接结构对所述样本图数据中的每个样本节点生成对应的样本子图;Acquire sample graph data for training a preset model, and generate a corresponding sample subgraph for each sample node in the sample graph data according to the sample node connection structure of the sample graph data;

获取各样本子图的初始特征矩阵和邻接矩阵,并依次将所述初始特征矩阵输入至预设模型中进行特征聚合,得到每个样本子图的融合特征矩阵;Obtaining an initial feature matrix and an adjacency matrix of each sample subgraph, and sequentially inputting the initial feature matrix into a preset model for feature aggregation to obtain a fused feature matrix of each sample subgraph;

根据所述邻接矩阵和所述融合特征矩阵,确定每个所述样本子图的核心特征和冗余特征;Determining the core features and redundant features of each of the sample subgraphs according to the adjacency matrix and the fusion feature matrix;

基于所述核心特征和所述冗余特征,确定所述预设模型的目标损失;Determining a target loss of the preset model based on the core features and the redundant features;

基于所述目标损失更新所述预设模型的参数,当所述预设模型收敛时,得到训练好的泛化模型。The parameters of the preset model are updated based on the target loss, and when the preset model converges, a trained generalization model is obtained.

在一些实施方式中,所述训练模块,还用于:In some embodiments, the training module is further used to:

基于所述样本子图的融合特征矩阵,将所述样本子图中任意两个样本节点输入至多层感知机,得到任意两个样本节点之间的边掩码值;Based on the fusion feature matrix of the sample subgraph, any two sample nodes in the sample subgraph are input into a multi-layer perceptron to obtain an edge mask value between any two sample nodes;

根据所述样本子图中的多个边掩码值,生成所述样本子图的边掩码矩阵;Generating an edge mask matrix of the sample sub-graph according to a plurality of edge mask values in the sample sub-graph;

基于所述邻接矩阵、所述融合特征矩阵和所述边掩码矩阵,生成所述样本子图对应的核心特征;Generate core features corresponding to the sample subgraph based on the adjacency matrix, the fused feature matrix and the edge mask matrix;

根据所述样本子图的边掩码矩阵,生成对应的补集矩阵;Generate a corresponding complement matrix according to the edge mask matrix of the sample subgraph;

基于所述邻接矩阵、所述融合特征矩阵和所述补集矩阵,得到所述样本子图的冗余特征。Based on the adjacency matrix, the fusion feature matrix and the complement matrix, redundant features of the sample subgraph are obtained.

在一些实施方式中,所述训练模块,还用于:In some embodiments, the training module is further used to:

获取基础图数据和不同类别的多个子图案,并将多个所述子图案分配至所述基础图数据的任一样本节点上,得到样本图数据;Acquire basic graph data and a plurality of sub-patterns of different categories, and assign the plurality of sub-patterns to any sample node of the basic graph data to obtain sample graph data;

针对所述样本图数据的节点数量确定添加所述样本图数据的边的预设数量,并在所述样本图数据中对任意两个样本节点添加边,其中,对所述样本图数据添加的边的数量等于所述预设数量。A preset number of edges to be added to the sample graph data is determined based on the number of nodes in the sample graph data, and edges are added to any two sample nodes in the sample graph data, wherein the number of edges added to the sample graph data is equal to the preset number.

在一些实施方式中,所述训练模块,还用于:In some embodiments, the training module is further used to:

获取所述样本图数据的节点密度,并基于所述节点密度,确定每个所述样本节点的样本关联层级;Acquire the node density of the sample graph data, and determine the sample association level of each of the sample nodes based on the node density;

针对每个所述样本节点作为中心样本节点,基于所述中心样本节点和相邻样本节点之间组成的样本节点连接结构,查找出每个所述中心样本节点在样本关联层级内的目标样本节点连接结构的样本子图。For each of the sample nodes as a central sample node, based on the sample node connection structure formed between the central sample node and adjacent sample nodes, a sample subgraph of the target sample node connection structure of each of the central sample nodes in the sample association level is found.

在一些实施方式中,所述基于泛化模型的图数据预测装置还包括生成模块,用于:In some implementations, the graph data prediction apparatus based on the generalization model further includes a generation module, which is used to:

获取预设的图论库;Get the preset graph theory library;

获取所述样本节点的预设的特征指标,并通过所述图论库生成每个样本节点在对应的所述特征指标下的节点特征向量;其中,所述特征指标包括节点标识、节点度、聚类系数、介数中心性和接近中心性中的至少一个;Obtaining preset characteristic indicators of the sample nodes, and generating node characteristic vectors of each sample node under the corresponding characteristic indicators through the graph theory library; wherein the characteristic indicators include at least one of node identification, node degree, clustering coefficient, betweenness centrality and closeness centrality;

根据多个样本节点对应的多个节点特征向量,生成第一特征矩阵。A first feature matrix is generated according to a plurality of node feature vectors corresponding to a plurality of sample nodes.

在一些实施方式中,所述训练模块,还用于:In some embodiments, the training module is further used to:

在每个所述样本子图中,确定每个样本节点的节点标识;In each of the sample subgraphs, determining a node identifier of each sample node;

基于所述节点标识,从所述第一特征矩阵中,确定每个所述样本节点对应的节点特征向量;Based on the node identifier, determining a node feature vector corresponding to each of the sample nodes from the first feature matrix;

根据所述样本子图的多个样本节点对应的多个节点特征向量,生成所述样本子图的初始特征矩阵;Generating an initial feature matrix of the sample subgraph according to a plurality of node feature vectors corresponding to a plurality of sample nodes of the sample subgraph;

基于所述样本节点之间的连接关系,生成所述特征矩阵对应的邻接矩阵。Based on the connection relationship between the sample nodes, an adjacency matrix corresponding to the feature matrix is generated.

在一些实施方式中,所述训练模块,还用于:In some embodiments, the training module is further used to:

根据每个所述样本子图的节点标识,从所述第一特征矩阵中确定所述样本子图的节点度矩阵;Determining a node degree matrix of the sample subgraph from the first feature matrix according to the node identifier of each sample subgraph;

获取所述样本子图的邻接矩阵,并通过所述节点度矩阵对所述邻接矩阵进行归一化,得到归一化后的邻接矩阵;Obtaining an adjacency matrix of the sample subgraph, and normalizing the adjacency matrix using the node degree matrix to obtain a normalized adjacency matrix;

将所述中心样本节点作为融合中心,基于归一化后的所述邻接矩阵,进行样本节点特征的逐层融合,得到所述中心样本节点对应更新后的中心节点特征向量;Taking the central sample node as the fusion center, based on the normalized adjacency matrix, performing layer-by-layer fusion of sample node features to obtain an updated central node feature vector corresponding to the central sample node;

根据所述中心节点特征向量对所述初始特征矩阵的中心样本节点的节点特征向量进行更新,得到每个所述样本子图的融合特征矩阵。The node feature vector of the central sample node of the initial feature matrix is updated according to the central node feature vector to obtain a fusion feature matrix of each of the sample subgraphs.

在一些实施方式中,所述训练模块,还用于:In some embodiments, the training module is further used to:

基于针对所述样本子图的核心特征输出的预测节点类别标签相对于所述核心特征的样本节点类别标签之间的差异确定第一损失;Determine a first loss based on a difference between a predicted node category label output for a core feature of the sample subgraph and a sample node category label of the core feature;

基于预设模型针对所述样本子图的冗余特征预测各样本节点类别标签时的概率分布均匀度确定第二损失;Determine the second loss based on the uniformity of probability distribution when predicting the category label of each sample node according to the redundant features of the sample subgraph based on the preset model;

根据所述核心特征和所述冗余特征之间的样本距离确定第三损失;Determining a third loss according to a sample distance between the core feature and the redundant feature;

基于所述第一损失、所述第二损失和所述第三损失之和,构成所述预设模型的目标损失。The target loss of the preset model is constructed based on the sum of the first loss, the second loss and the third loss.

相应的,本申请实施例的第三方面提出了一种计算机设备,所述计算机设备包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现本申请第一方面实施例任一项所述的基于泛化模型的图数据预测方法。Correspondingly, the third aspect of the embodiments of the present application proposes a computer device, which includes a memory and a processor, the memory stores a computer program, and the processor implements the graph data prediction method based on the generalization model as described in any one of the embodiments of the first aspect of the present application when executing the computer program.

相应的,本申请实施例的第四方面提出了一种计算机可读存储介质,所述存储介质存储有计算机程序,所述计算机程序被处理器执行时实现本申请第一方面实施例任一项所述的基于泛化模型的图数据预测方法。Correspondingly, the fourth aspect of the embodiments of the present application proposes a computer-readable storage medium, which stores a computer program. When the computer program is executed by a processor, it implements the graph data prediction method based on the generalization model as described in any one of the embodiments of the first aspect of the present application.

本申请实施例通过获取目标图数据,其中,目标图数据包括多个目标节点和多个目标节点组成的节点连接结构;基于目标图数据中的多个目标节点之间组成的节点连接结构,查找出每个目标节点在目标关联层级内的目标连接结构的目标子图;将目标子图输入至泛化模型,得到目标子图中的目标节点的节点类别标签;其中,泛化模型基于目标损失最小化的训练过程进行样本子图中样本节点连接结构与样本节点类别标签之间的因果预测学习得到,目标损失由第一损失、第二损失和第三损失构成,第一损失基于针对样本子图的核心特征输出的预测节点类别标签相对于核心特征的样本节点类别标签之间的差异确定;第二损失基于预设模型针对样本子图的冗余特征预测各样本节点类别标签时的概率分布均匀度确定;第三损失根据核心特征和冗余特征之间的样本距离确定,核心特征为按照注意力机制得到的边掩码矩阵对样本子图的样本节点连接结构进行增强表示得到的特征;冗余特征为按照边掩码矩阵的补集矩阵对样本子图的样本节点连接结构进行弱化表示得到的特征。以此,可以通过构建目标损失对预设模型进行训练,并在训练的过程中使得预设模型趋向于通过核心特征来进行预测、减少对冗余特征的偏好,减少核心特征和冗余特征之间的样本距离,由此,可以使得训练好的泛化模型具备更强的因果解释能力和泛化能力。并且,通过注意力机制得到的边掩码矩阵增强表示样本子图的样本连接结构,通过边掩码矩阵的补集矩阵弱化表示样本子图的样本节点连接结构,有助于训练好的泛化模型能够更好地捕捉目标节点之间的重要关系,减少冗余信息的干扰,使得预设模型更加专注于核心特征的学习。在泛化模型应用的过程中,能够基于目标图数据获取需要预测的目标节点的目标子图,使得泛化模型无需对整个目标图数据进行预测,只需要获取需要预测的目标节点的相关目标子图即可,有助于泛化模型更加专注地分析目标节点之间的因果关系,提高了预测的效率和准确性。综上,本申请能够提高训练好的泛化模型对图数据预测的准确性。In an embodiment of the present application, target graph data is obtained, wherein the target graph data includes a plurality of target nodes and a node connection structure composed of a plurality of target nodes; based on the node connection structure composed of a plurality of target nodes in the target graph data, a target subgraph of the target connection structure of each target node in the target association level is found; the target subgraph is input into a generalization model to obtain a node category label of the target node in the target subgraph; wherein the generalization model is obtained by performing causal prediction learning between the sample node connection structure and the sample node category label in the sample subgraph based on a training process of minimizing the target loss, and the target loss is composed of a first loss, a second loss and a third loss, and the first loss is determined based on the difference between the predicted node category label output for the core feature of the sample subgraph and the sample node category label of the core feature; the second loss is determined based on the uniformity of the probability distribution when predicting each sample node category label for the redundant feature of the sample subgraph according to the preset model; the third loss is determined according to the sample distance between the core feature and the redundant feature, and the core feature is a feature obtained by enhancing the sample node connection structure of the sample subgraph according to the edge mask matrix obtained by the attention mechanism; the redundant feature is a feature obtained by weakening the sample node connection structure of the sample subgraph according to the complement matrix of the edge mask matrix. In this way, the preset model can be trained by constructing a target loss, and during the training process, the preset model tends to predict through core features, reduce the preference for redundant features, and reduce the sample distance between core features and redundant features, thereby enabling the trained generalized model to have stronger causal explanation ability and generalization ability. In addition, the edge mask matrix obtained by the attention mechanism enhances the sample connection structure representing the sample subgraph, and the complement matrix of the edge mask matrix weakens the sample node connection structure representing the sample subgraph, which helps the trained generalized model to better capture the important relationship between target nodes, reduce the interference of redundant information, and make the preset model more focused on the learning of core features. In the process of applying the generalized model, the target subgraph of the target node to be predicted can be obtained based on the target graph data, so that the generalized model does not need to predict the entire target graph data, but only needs to obtain the relevant target subgraph of the target node to be predicted, which helps the generalized model to analyze the causal relationship between target nodes more attentively, and improves the efficiency and accuracy of prediction. In summary, the present application can improve the accuracy of the trained generalized model in predicting graph data.

附图说明BRIEF DESCRIPTION OF THE DRAWINGS

图1是本申请实施例提供的基于泛化模型的图数据预测系统的架构示意图;FIG1 is a schematic diagram of the architecture of a graph data prediction system based on a generalized model provided in an embodiment of the present application;

图2是本申请实施例提供的基于泛化模型的图数据预测方法的流程图;FIG2 is a flow chart of a graph data prediction method based on a generalization model provided in an embodiment of the present application;

图3是本申请实施例提供的训练预设模型的步骤流程图;FIG3 is a flowchart of the steps of training a preset model provided in an embodiment of the present application;

图4是本申请实施例提供的基于泛化模型的图数据预测方法的总体流程图;FIG4 is an overall flow chart of a graph data prediction method based on a generalized model provided in an embodiment of the present application;

图5是本申请实施例提供的基于泛化模型的图数据预测装置的功能模块示意图;FIG5 is a schematic diagram of functional modules of a graph data prediction device based on a generalized model provided in an embodiment of the present application;

图6是本申请实施例提供的计算机设备的硬件结构示意图。FIG. 6 is a schematic diagram of the hardware structure of a computer device provided in an embodiment of the present application.

具体实施方式DETAILED DESCRIPTION

为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处所描述的具体实施例仅用以解释本申请,并不用于限定本申请。In order to make the purpose, technical solution and advantages of the present application more clearly understood, the present application is further described in detail below in conjunction with the accompanying drawings and embodiments. It should be understood that the specific embodiments described herein are only used to explain the present application and are not used to limit the present application.

需要说明的是,虽然在装置示意图中进行了功能模块划分,在流程图中示出了逻辑顺序,但是在某些情况下,可以以不同于装置中的模块划分,或流程图中的顺序执行所示出或描述的步骤。说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。It should be noted that, although the functional modules are divided in the device schematic diagram and the logical order is shown in the flowchart, in some cases, the steps shown or described may be performed in a different order than the module division in the device or the order in the flowchart. The terms "first", "second", etc. in the specification, claims and the above drawings are used to distinguish similar objects, and are not necessarily used to describe a specific order or sequence.

除非另有定义,本文所使用的所有的技术和科学术语与属于本申请的技术领域的技术人员通常理解的含义相同。本文中所使用的术语只是为了描述本申请实施例的目的,不是旨在限制本申请。Unless otherwise defined, all technical and scientific terms used herein have the same meaning as those commonly understood by those skilled in the art to which this application belongs. The terms used herein are only for the purpose of describing the embodiments of this application and are not intended to limit this application.

图神经网络(Graph Neural Networks,GNNs)是用于处理图结构数据的深度学习模型。在图神经网络中,通过考虑节点之间的关系(即边的存在)来更新和聚合节点的特征,以逐步构建出节点之间的关系,学习到图的结构表示,并实现对目标节点的预测。Graph Neural Networks (GNNs) are deep learning models for processing graph-structured data. In a graph neural network, the features of nodes are updated and aggregated by considering the relationship between nodes (i.e., the existence of edges), so as to gradually build the relationship between nodes, learn the structural representation of the graph, and realize the prediction of the target node.

相关技术中,在对图神经网络进行训练的过程中,通常在最大化图数据和真实标签之间的互信息的前提下,从输入图数据中提取关键特征用于对图数据的预测。然而,这种训练方式使得图神经网络尽可能学习图数据中输入特征和标签之间的所有统计相关性,没有区分输入特征与标签之间的因果效应和非因果效应,使得图神经网络倾向于访问非因果特征作为图数据预测的捷径,降低了图神经网络训练完毕后,对图数据预测的准确性。In the related art, during the training of graph neural networks, key features are usually extracted from the input graph data for prediction of the graph data under the premise of maximizing the mutual information between the graph data and the true label. However, this training method makes the graph neural network learn all the statistical correlations between the input features and labels in the graph data as much as possible, without distinguishing between the causal effect and the non-causal effect between the input features and the labels, which makes the graph neural network tend to access non-causal features as a shortcut to predict graph data, reducing the accuracy of graph data prediction after the graph neural network is trained.

基于此,本申请实施例提供了一种基于泛化模型的图数据预测方法、装置、设备及存储介质,能够提高对图数据预测的准确性。Based on this, the embodiments of the present application provide a graph data prediction method, apparatus, device and storage medium based on a generalized model, which can improve the accuracy of graph data prediction.

本申请实施例提供的基于泛化模型的图数据预测方法、装置、设备及存储介质,具体通过如下实施例进行说明,首先描述本申请实施例中的基于泛化模型的图数据预测系统。The graph data prediction method, device, equipment and storage medium based on the generalized model provided in the embodiments of the present application are specifically illustrated through the following embodiments. First, the graph data prediction system based on the generalized model in the embodiments of the present application is described.

在一些实施方式中,本申请实施例提供一种基于泛化模型的图数据预测系统,请参照图1,基于泛化模型的图数据预测系统包括终端11和服务器12。In some implementations, an embodiment of the present application provides a graph data prediction system based on a generalized model. Please refer to FIG. 1 . The graph data prediction system based on a generalized model includes a terminal 11 and a server 12 .

具体的,终端11可以为移动终端设备,也可以为非移动终端设备。移动终端设备可以为手机、平板电脑、笔记本电脑、掌上电脑、可穿戴设备、超级移动个人计算机、上网本、个人数字助理、无线热点设备等;非移动终端设备可以为个人计算机等,本申请实施方案不作具体限定。技术人员可以通过终端11与服务器12直接进行交互,向服务器12发送训练预设模型的指令、设置预设模型训练的次数,等等,并对在训练得到泛化模型后进行结果展示。Specifically, the terminal 11 can be a mobile terminal device or a non-mobile terminal device. The mobile terminal device can be a mobile phone, a tablet computer, a laptop computer, a PDA, a wearable device, a super mobile personal computer, a netbook, a personal digital assistant, a wireless hotspot device, etc.; the non-mobile terminal device can be a personal computer, etc., which is not specifically limited in the implementation scheme of this application. The technician can directly interact with the server 12 through the terminal 11, send instructions for training a preset model to the server 12, set the number of preset model trainings, etc., and display the results after the generalized model is obtained through training.

进一步的,服务器12可以包括云计算服务器、数据中心服务器等,负责存储泛化模型的相关数据,并在接收终端11的模型训练指令后,对预设模型进行训练,得到泛化模型。服务器12具有较高的计算、存储和网络性能,能够支持多个终端11的请求。Furthermore, the server 12 may include a cloud computing server, a data center server, etc., which is responsible for storing relevant data of the generalized model, and training the preset model to obtain the generalized model after receiving the model training instruction from the terminal 11. The server 12 has high computing, storage and network performance, and can support requests from multiple terminals 11.

本申请实施例中的基于泛化模型的图数据预测方法可以通过如下实施例进行说明。The graph data prediction method based on the generalization model in the embodiments of the present application can be illustrated by the following embodiments.

需要说明的是,在本申请的各个具体实施方式中,当涉及到需要根据用户信息、用户行为数据,用户历史数据以及用户位置信息等与用户身份或特性相关的数据进行相关处理时,都会先获得用户的许可或者同意。而且,对这些数据的收集、使用和处理等,都会遵守相关法律法规和标准。此外,当本申请实施例需要获取用户的敏感个人信息时,会通过弹窗或者跳转到确认页面等方式获得用户的单独许可或者单独同意,在明确获得用户的单独许可或者单独同意之后,再获取用于使本申请实施例能够正常运行的必要的用户相关数据。It should be noted that in each specific implementation of the present application, when it comes to the need to perform relevant processing based on data related to user identity or characteristics such as user information, user behavior data, user historical data, and user location information, the user's permission or consent will be obtained first. Moreover, the collection, use, and processing of these data will comply with relevant laws, regulations, and standards. In addition, when the embodiment of the present application needs to obtain the user's sensitive personal information, the user's separate permission or separate consent will be obtained through a pop-up window or by jumping to a confirmation page. After clearly obtaining the user's separate permission or separate consent, the necessary user-related data for enabling the normal operation of the embodiment of the present application will be obtained.

在本申请实施例中,将从基于泛化模型的图数据预测装置的维度进行描述,该基于泛化模型的图数据预测装置具体可以集成在计算机设备中。参见图2,图2为本申请实施例提供的基于泛化模型的图数据预测方法的步骤流程图,本申请实施例以基于泛化模型的图数据预测装置具体集成在如终端或服务器上为例,终端或服务器上的处理器执行基于泛化模型的图数据预测方法对应的程序指令时,具体流程如下:In the embodiment of the present application, the description will be made from the dimension of a graph data prediction device based on a generalized model, which can be specifically integrated in a computer device. Referring to FIG. 2, FIG. 2 is a flowchart of the steps of a graph data prediction method based on a generalized model provided in the embodiment of the present application. The embodiment of the present application takes the graph data prediction device based on a generalized model as an example, in which the graph data prediction device based on a generalized model is specifically integrated in a terminal or a server. When the processor on the terminal or server executes the program instructions corresponding to the graph data prediction method based on a generalized model, the specific process is as follows:

步骤101,获取目标图数据,其中,目标图数据包括多个目标节点和多个目标节点组成的节点连接结构。Step 101, obtaining target graph data, wherein the target graph data includes a plurality of target nodes and a node connection structure composed of the plurality of target nodes.

需要说明的是,为了有效对目标图数据进行分类预测,可以获取目标图数据,以便于后续泛化模型能够理解目标图数据的目标节点组成的节点连接结构,从而更准确地进行分类。It should be noted that in order to effectively perform classification prediction on the target graph data, the target graph data can be obtained so that the subsequent generalization model can understand the node connection structure composed of the target nodes of the target graph data, thereby performing classification more accurately.

其中,目标图数据可以是由多个目标节点和对应的边组成的图结构数据,比如社交网络、推荐系统中的用户-物品关系图、生物信息学中的蛋白质相互作用网络等等。Among them, the target graph data can be graph structure data composed of multiple target nodes and corresponding edges, such as social networks, user-item relationship graphs in recommendation systems, protein interaction networks in bioinformatics, etc.

其中,目标节点可以是在目标图数据中的实体,比如社交网络中的用户或推荐系统中的物品等等,每个目标节点都有其特定的属性或类别标签。Among them, the target node can be an entity in the target graph data, such as a user in a social network or an item in a recommendation system, etc. Each target node has its specific attributes or category labels.

其中,节点连接结构可以是目标节点之间的关系或者连接,例如,在社交网络中,当两个用户之间存在好友关系时,则两个用户之间存在边,两个用户和边构成节点连接结构。当一个用户曾购买过某个物品,那么可以用一条边连接用户节点和物品节点,表示用户对物品的购买行为,用户、边和物品可以组成节点连接结构,等等。The node connection structure can be the relationship or connection between target nodes. For example, in a social network, when two users are friends, there is an edge between the two users. The two users and the edge constitute a node connection structure. When a user has purchased an item, an edge can be used to connect the user node and the item node to indicate the user's purchase behavior of the item. The user, edge, and item can constitute a node connection structure, and so on.

示例性的,可以首先获取待识别的目标图数据,目标图数据可以预先设定,也可以实时获取。例如,可以确定需要获取的目标图数据的数据源,例如,需要从社交网络中获取目标图数据时,可以使用网络爬虫技术,收集社交媒体平台上的用户数据和用户之间的关系;需要获取生物信息学相关的目标图数据时,可以从公共的生物信息数据平台获取蛋白质相互作用的网络数据等等。进一步的,在采集到原始数据之后,可以对原始数据进行预处理,包括数据清洗、数据格式化等等。Exemplarily, the target graph data to be identified can be obtained first. The target graph data can be preset or obtained in real time. For example, the data source of the target graph data to be obtained can be determined. For example, when the target graph data needs to be obtained from a social network, the web crawler technology can be used to collect user data and the relationship between users on the social media platform; when the target graph data related to bioinformatics needs to be obtained, the network data of protein interactions can be obtained from a public bioinformatics data platform, etc. Furthermore, after the raw data is collected, the raw data can be preprocessed, including data cleaning, data formatting, etc.

通过以上方式获取目标图数据,可以便于后续通过泛化模型对目标图数据进行分类。Obtaining the target graph data in the above manner can facilitate subsequent classification of the target graph data through a generalized model.

步骤102,基于目标图数据中的多个目标节点之间组成的节点连接结构,查找出每个目标节点在目标关联层级内的目标连接结构的目标子图。Step 102, based on the node connection structure formed by multiple target nodes in the target graph data, find out the target subgraph of the target connection structure of each target node in the target association level.

可以理解的是,在对待预测的目标节点进行分类时,若将整个目标图数据输入至泛化模型中,会耗费大量的计算资源,通过查找每个目标节点在目标关联层级内的目标连接结构的目标子图,可以更准确地捕捉与待预测的目标节点的分类任务相关的信息,在提高分类的准确性的同时,节约计算资源。It is understandable that when classifying the target node to be predicted, if the entire target graph data is input into the generalization model, it will consume a lot of computing resources. By finding the target subgraph of the target connection structure of each target node within the target association hierarchy, the information related to the classification task of the target node to be predicted can be captured more accurately, thereby improving the accuracy of classification and saving computing resources.

其中,目标关联层级可以是在目标图数据中通过宽度优先搜索进行K-hop采样所确定的层级,K-hop采样可以帮助泛化模型确定对应的目标节点在目标图数据中的局部区域信息。例如,需要对待预测的目标节点进行预测时,可以通过目标图数据的密度确定距离中心的目标节点的步数,并将确定的步数作为目标关联层级。进一步的,当目标图数据中的目标节点分布较为稀疏时,可以确定较大的目标关联层级,当目标图数据中的目标节点分布较为密集时,可以确定较小的目标关联层级。示例性的,目标关联层级可以为2级、3级等等,本申请实施例对此不作具体限制。Among them, the target association level can be a level determined by K-hop sampling through breadth-first search in the target graph data, and K-hop sampling can help the generalization model determine the local area information of the corresponding target node in the target graph data. For example, when it is necessary to predict the target node to be predicted, the number of steps from the target node to the center can be determined by the density of the target graph data, and the determined number of steps can be used as the target association level. Furthermore, when the target nodes in the target graph data are sparsely distributed, a larger target association level can be determined, and when the target nodes in the target graph data are densely distributed, a smaller target association level can be determined. Exemplarily, the target association level can be level 2, level 3, etc., and the embodiments of the present application do not impose specific restrictions on this.

其中,目标连接结构为在目标关联层级内,待预测的目标节点与相连接的其他目标节点组成的节点连接结构。The target connection structure is a node connection structure consisting of a target node to be predicted and other connected target nodes within a target association level.

其中,目标子图可以是在目标图数据中,从待预测的目标节点出发,通过宽度优先搜索算法获取的局部图结构。每个目标子图包括待预测的目标节点作为中心节点,以及在目标关联层级内与中心节点连接的所有其他的目标节点。The target subgraph may be a local graph structure obtained by a breadth-first search algorithm starting from the target node to be predicted in the target graph data. Each target subgraph includes the target node to be predicted as a central node and all other target nodes connected to the central node in the target association hierarchy.

具体的,可以将待预测的目标节点作为中心起点,通过宽度优先搜索,在目标关联层级内查找与待预测的目标节点相关联的其他目标节点,并通过待预测的目标节点与其他目标节点的节点连接结构,生成待预测的目标节点对应的目标子图。Specifically, the target node to be predicted can be taken as the central starting point, and other target nodes associated with the target node to be predicted can be found in the target association hierarchy through breadth-first search, and the target subgraph corresponding to the target node to be predicted can be generated through the node connection structure between the target node to be predicted and other target nodes.

示例性的,当目标图数据为社交网络图时,需要对一个新的用户节点进行分类,假设待预测的目标节点是用户节点A,从节点A开始,使用宽度优先搜索算法在目标图数据中确定与节点A直接相连的其他用户节点。若根据目标图数据的节点密度确定目标关联层级为3,那么,可以获取社交网络图的节点连接结构,并在以用户节点A为起点的3个目标关联层级内,查找与节点A直接相连的其他用户节点,若同用户节点A直接相连的有用户节点B、C、D,那么,可以将用户节点A、用户节点B、用户节点C、用户节点D组成的节点连接结构作为目标连接结构,并将该目标连接结构作为目标子图。Exemplarily, when the target graph data is a social network graph, a new user node needs to be classified. Assuming that the target node to be predicted is user node A, starting from node A, the breadth-first search algorithm is used to determine other user nodes directly connected to node A in the target graph data. If the target association level is determined to be 3 according to the node density of the target graph data, then the node connection structure of the social network graph can be obtained, and other user nodes directly connected to node A can be found within the three target association levels starting from user node A. If user node A is directly connected to user nodes B, C, and D, then the node connection structure composed of user node A, user node B, user node C, and user node D can be used as the target connection structure, and the target connection structure can be used as the target subgraph.

通过以上方式,可以通过查找待预测的目标节点对应的目标子图,更加准确地捕捉与该目标节点的分类任务相关的信息,同时也有利于减少后续泛化模型对目标子图预测时的计算量。In the above way, by finding the target subgraph corresponding to the target node to be predicted, the information related to the classification task of the target node can be captured more accurately, which is also beneficial to reduce the amount of calculation when the subsequent generalization model predicts the target subgraph.

步骤103,将目标子图输入至泛化模型,得到目标子图中的目标节点的节点类别标签;Step 103, input the target subgraph into the generalization model to obtain the node category label of the target node in the target subgraph;

其中,泛化模型基于目标损失最小化的训练过程进行样本子图中样本节点连接结构与样本节点类别标签之间的因果预测学习得到,目标损失由第一损失、第二损失和第三损失构成,第一损失基于针对样本子图的核心特征输出的预测节点类别标签相对于核心特征的样本节点类别标签之间的差异确定;第二损失基于预设模型针对样本子图的冗余特征预测各样本节点类别标签时的概率分布均匀度确定;第三损失根据核心特征和冗余特征之间的样本距离确定,核心特征为按照注意力机制得到的边掩码矩阵对样本子图的样本节点连接结构进行增强表示得到的特征;冗余特征为按照边掩码矩阵的补集矩阵对样本子图的样本节点连接结构进行弱化表示得到的特征。Among them, the generalization model is obtained by performing causal prediction learning between the sample node connection structure and the sample node category label in the sample subgraph based on the training process of minimizing the target loss. The target loss is composed of a first loss, a second loss and a third loss. The first loss is determined based on the difference between the predicted node category label output for the core feature of the sample subgraph and the sample node category label of the core feature; the second loss is determined based on the uniformity of the probability distribution when the preset model predicts the category label of each sample node for the redundant features of the sample subgraph; the third loss is determined according to the sample distance between the core feature and the redundant feature. The core feature is the feature obtained by enhancing the sample node connection structure of the sample subgraph according to the edge mask matrix obtained by the attention mechanism; the redundant feature is the feature obtained by weakening the sample node connection structure of the sample subgraph according to the complement matrix of the edge mask matrix.

在一些实施方式中,为了提高对目标图数据中的目标节点预测的准确性,可以通过预先对预设模型进行训练,使得预设模型不断通过专注于因果特征学习样本节点和真实标签之间的真正因果关系,而减少对捷径特征(即冗余特征)的关注,以使得训练好的泛化模型具有更高的泛化性能。In some embodiments, in order to improve the accuracy of prediction of target nodes in target graph data, a preset model can be pre-trained so that the preset model continuously learns the true causal relationship between sample nodes and true labels by focusing on causal features, while reducing attention to shortcut features (i.e., redundant features), so that the trained generalization model has higher generalization performance.

其中,泛化模型可以是用于预测目标子图中目标节点的节点类别标签的、具备良好的泛化能力的模型,也可以是用于预测整个目标图数据的类别的模型。泛化模型由预设模型预先训练得到,当面对新任务和数据时,泛化模型能够保持良好的泛化性能。示例性的,泛化模型可以是图神经网络(Graph Neural Networks,GNNs)、图注意力网络(GraphAttention Networks,GATs)、图卷积网络(Graph Convolutional Networks,GCNs)等适用于处理目标图数据并进行目标节点的预测的神经网络。Among them, the generalization model can be a model with good generalization ability for predicting the node category label of the target node in the target subgraph, or it can be a model for predicting the category of the entire target graph data. The generalization model is pre-trained by a preset model, and when faced with new tasks and data, the generalization model can maintain good generalization performance. Exemplarily, the generalization model can be a graph neural network (Graph Neural Networks, GNNs), a graph attention network (Graph Attention Networks, GATs), a graph convolutional network (Graph Convolutional Networks, GCNs) and other neural networks suitable for processing target graph data and predicting target nodes.

其中,节点类别标签可以是每个目标节点所属的具体类别。例如,在社交网络中,目标节点可能代表用户,而节点类别标签可以表示用户的兴趣或群组。在一些实施方式中,当要对整个目标图数据进行预测时,节点类别标签表示整个目标图数据的类别,例如,在生物信息学中,目标图数据可以表示蛋白质的相互作用网络,而节点类别标签可以表示这个网络的功能类型,等等。The node category label may be a specific category to which each target node belongs. For example, in a social network, a target node may represent a user, and a node category label may represent the user's interest or group. In some embodiments, when the entire target graph data is to be predicted, the node category label represents the category of the entire target graph data. For example, in bioinformatics, the target graph data may represent a protein interaction network, and the node category label may represent the functional type of this network, and so on.

其中,样本子图可以是样本图数据中,从待预测的样本节点出发,通过宽度优先搜索算法获取的局部样本图结构。每个样本子图包括待预测的样本节点作为样本中心节点,以及在样本关联层级内与样本中心节点连接的所有其他的样本节点。The sample subgraph can be a local sample graph structure obtained by a breadth-first search algorithm starting from the sample node to be predicted in the sample graph data. Each sample subgraph includes the sample node to be predicted as the sample center node and all other sample nodes connected to the sample center node in the sample association hierarchy.

其中,样本节点连接结构可以是样本节点之间的关系或者连接,例如,在社交网络中,当两个用户之间存在好友关系时,则两个用户之间存在边,两个用户和好友关系构成节点连接结构。当一个用户曾购买过某个物品,那么可以用一条边连接用户节点和物品节点,表示用户对物品的购买行为,用户、边和物品可以组成节点连接结构,等等。The sample node connection structure may be the relationship or connection between sample nodes. For example, in a social network, when two users have a friend relationship, there is an edge between the two users. The two users and the friend relationship constitute a node connection structure. When a user has purchased an item, an edge may be used to connect the user node and the item node to indicate the user's purchase behavior of the item. The user, edge, and item may constitute a node connection structure, and so on.

其中,目标损失可以是由第一损失、第二损失和第三损失构成的损失,用于指示对预设模型的参数的调节,以对预设模型进行训练,得到泛化模型。具体的,第一损失可以为交叉熵损失,第二损失可以为熵损失,第三损失可以为对比损失,通过将第一损失、第二损失和第三损失构成目标损失,并利用目标损失来更新模型参数,可以使泛化模型更加专注地学习核心特征,避免泛化模型过度关注冗余特征,从而提高整体特征表征的质量。Among them, the target loss can be a loss composed of a first loss, a second loss, and a third loss, which is used to indicate the adjustment of the parameters of the preset model to train the preset model to obtain a generalized model. Specifically, the first loss can be a cross entropy loss, the second loss can be an entropy loss, and the third loss can be a contrast loss. By forming the target loss with the first loss, the second loss, and the third loss, and using the target loss to update the model parameters, the generalized model can be made to learn the core features more attentively, avoid the generalized model from over-focusing on redundant features, and thus improve the quality of the overall feature representation.

其中,第一损失可以是交叉熵损失,用于评估预设模型对于核心特征预测的各预测节点类别标签的概率分布与核心特征的样本节点类别标签之间的差异。具体的,可以针对每个样本节点类别标签设置编码,将每个样本独热节点类别标签都用一个二进制向量表示,并只设置核心特征对应的真实的样本节点类别标签为1,其余样本节点类别标签为0。具体的,第一损失可以为:The first loss may be a cross entropy loss, which is used to evaluate the difference between the probability distribution of each predicted node category label predicted by the preset model for the core feature and the sample node category label of the core feature. Specifically, an encoding may be set for each sample node category label, and each sample unique node category label may be represented by a binary vector, and only the real sample node category label corresponding to the core feature is set to 1, and the remaining sample node category labels are set to 0. Specifically, the first loss Can be:

其中,是样本节点类别标签的独热编码,是模型对第v个样本节点类别标签的第一预测概率。当预测节点类别标签与核心特征对应的真实的样本节点类别标签一致时,第一损失最小,由此,可以通过最小化第一损失,使得预设模型能够更好地学习数据的因果关系,从而不断作出更加准确的预测。in, is the one-hot encoding of the sample node category label, is the model's first predicted probability for the vth sample node category label. When the predicted node category label is consistent with the actual sample node category label corresponding to the core feature, the first loss is minimal. Therefore, by minimizing the first loss, the preset model can better learn the causal relationship of the data, thereby continuously making more accurate predictions.

其中,预测节点类别标签为预设模型对于核心特征预测得到的多个节点类别标签。例如,对于一个动物图像,可以得到预测节点类别标签为猫、狗、兔子,每个预测节点类别标签都对应一个预测概率。The predicted node category labels are multiple node category labels predicted by the preset model for the core features. For example, for an animal image, the predicted node category labels may be cat, dog, and rabbit, and each predicted node category label corresponds to a predicted probability.

其中,第二损失可以是熵损失,用于衡量预设模型针对样本子图的冗余特征预测各样本节点类别标签时的概率分布均匀度,概率分布越均匀,表明预设模型对于冗余特征的分类偏好越小,由此,能够有效地防止预设模型学习捷径特征。具体的,第二损失可以为:The second loss may be an entropy loss, which is used to measure the uniformity of the probability distribution when the preset model predicts the category labels of each sample node based on the redundant features of the sample subgraph. The more uniform the probability distribution, the smaller the classification preference of the preset model for redundant features. Thus, it can effectively prevent the preset model from learning shortcut features. Specifically, the second loss Can be:

其中,代表均匀分布时对每个样本节点类别标签的均匀概率,通常=1/类别数,是预设模型对第v个样本节点类别标签的第二预测概率。in, Represents the uniform probability of each sample node category label when uniformly distributed, usually =1/number of categories, It is the second predicted probability of the preset model for the category label of the vth sample node.

具体来说,当冗余特征输入至预设模型之后,输出的第二预测概率都接近于均匀分布,则表明冗余特征对于预设模型的分类并无帮助。示例性的,若样本类别标签为4,分别为猫、兔、狗、猪,那么=1/4,若将冗余特征输入至预设模型,得到对猫、兔、狗、猪的预测概率均为1/4时,预设模型的第二损失也最小,此时,表明预设模型对于冗余特征并无分类偏好,输出的概率分布均匀度高。因此,需要通过不断对预设模型进行训练,使得预设模型对于冗余特征在各个样本节点类别标签的预测概率接近于均匀分布。Specifically, when redundant features are input into the preset model, the output second prediction probabilities are close to uniform distribution, indicating that redundant features are not helpful for the classification of the preset model. For example, if the sample category labels are 4, namely cat, rabbit, dog, and pig, then =1/4, if the redundant features are input into the preset model, and the predicted probabilities of cats, rabbits, dogs, and pigs are all 1/4, the second loss of the preset model is also the smallest. At this time, it shows that the preset model has no classification preference for redundant features, and the output probability distribution is highly uniform. Therefore, it is necessary to continuously train the preset model so that the predicted probability of the preset model for the redundant features in the category labels of each sample node is close to uniform distribution.

其中,概率分布均匀度可以是在冗余特征输入至预设模型之后,预设模型对冗余特征预测输出的各个样本节点类别标签下的第二预设概率。各第二预设概率分布越均匀,表明概率分布均匀度越高,例如,第二预设概率为1/3、1/3、1/3的概率分布均匀度大于2/3、1/6、1/6。The uniformity of probability distribution may be the second preset probability of each sample node category label predicted and output by the preset model for the redundant features after the redundant features are input into the preset model. The more uniform the second preset probability distribution is, the higher the uniformity of probability distribution is. For example, the uniformity of probability distribution with the second preset probabilities of 1/3, 1/3, and 1/3 is greater than 2/3, 1/6, and 1/6.

其中,第三损失可以是对比损失,具体的,可以将核心特征作为正样本,冗余特征作为负样本,对比损失可以激励预设模型能够正确区分正样本和负样本,使正样本和正样本之间的距离更近,正样本和负样本之间的距离更远。进一步的,第三损失可以为:Among them, the third loss can be a contrast loss. Specifically, the core features can be used as positive samples and the redundant features can be used as negative samples. The contrast loss can encourage the preset model to correctly distinguish between positive samples and negative samples, making the distance between positive samples closer and the distance between positive samples and negative samples farther. Furthermore, the third loss Can be:

其中,是sigmoid函数,通常用于将输入映射到(0,1)之间,Q为超参数,C和是核心特征的不同实例,表示正样本对,C和均表示正样本,即核心特征;C和表示正负样本对,表示负样本,即冗余特征;表示期望值,可以通过预设模型对所有正样本对的核心特征的第一预测概率,和对所有负样本对的冗余特征的多个第二预测概率数量取平均得到。通过第三损失训练预设模型,可以激励预设模型能够正确区分正样本和负样本,使得正样本和正样本之间的距离更近,正样本和负样本之间的距离更远。in, is the sigmoid function, which is usually used to map the input to (0, 1), Q is a hyperparameter, C and are different instances of the core features, representing positive sample pairs, C and Both represent positive samples, i.e. core features; C and represents a positive and negative sample pair, Represents negative samples, i.e., redundant features; It represents the expected value, which can be obtained by averaging the first predicted probability of the core features of all positive sample pairs by the preset model and the number of second predicted probabilities of the redundant features of all negative sample pairs. By training the preset model with the third loss, the preset model can be encouraged to correctly distinguish positive samples from negative samples, so that the distance between positive samples is closer and the distance between positive samples is farther.

其中,样本距离可以是两个核心特征之间的距离或者核心特征和冗余特征之间的距离,样本距离可以通过欧氏距离或者余弦相似度等方式计算得到。The sample distance may be the distance between two core features or the distance between a core feature and a redundant feature. The sample distance may be calculated by Euclidean distance or cosine similarity.

其中,核心特征可以是样本子图中具有因果关系或者因果解释能力的特征,核心特征具有待预测的样本节点与样本节点类别标签的因果信息。Among them, the core feature can be a feature with causal relationship or causal explanation ability in the sample subgraph, and the core feature has causal information of the sample node to be predicted and the sample node category label.

其中,边掩码矩阵可以是样本子图中的任意两个样本节点的节点特征向量经过多层感知机的边掩码值,由样本子图中所有样本节点的边掩码值组成的矩阵。The edge mask matrix may be the edge mask values of node feature vectors of any two sample nodes in the sample subgraph after passing through a multi-layer perceptron, and may be a matrix composed of the edge mask values of all sample nodes in the sample subgraph.

其中,冗余特征可以是核心特征和预测结果之间的捷径特征,冗余特征对于预测样本节点与样本节点类别标签并无额外的贡献,不具备因果解释特性。并且,冗余特征会使得预设模型倾向于学习捷径特征作出决策,从而使得预设模型对于样本图数据预测的准确性降低,导致训练得到的泛化模型在分布外测试数据中性能下降,也即泛化性能下降。示例性的,冗余特征可以是噪声特征。Among them, redundant features can be shortcut features between core features and prediction results. Redundant features have no additional contribution to the prediction of sample nodes and sample node category labels, and do not have causal explanation characteristics. Moreover, redundant features will make the preset model tend to learn shortcut features to make decisions, thereby reducing the accuracy of the preset model's prediction of sample graph data, resulting in the performance of the trained generalized model in out-of-distribution test data, that is, the generalization performance is reduced. Exemplarily, redundant features can be noise features.

其中,补集矩阵=1-边掩码矩阵,补集矩阵可以用于表示样本子图中冗余的边。Among them, the complement matrix = 1-edge mask matrix, and the complement matrix can be used to represent redundant edges in the sample subgraph.

可以理解的是,将目标子图输入至训练好的泛化模型中,泛化模型可以快速提取目标子图中有利于对待预测的目标节点进行分类的核心特征,并结合核心特征对待预测的目标节点进行预测,从而提高对目标节点预测的效率和准确性。It can be understood that by inputting the target subgraph into the trained generalization model, the generalization model can quickly extract the core features in the target subgraph that are helpful for classifying the target node to be predicted, and predict the target node to be predicted based on the core features, thereby improving the efficiency and accuracy of the target node prediction.

示例性的,当需要对社交网络图中用户A的兴趣类别进行预测时,可以通过目标图数据的节点连接结构,确定用户A对应的目标连接结构作为用户A的目标子图,例如,目标子图可以是包含用户A及其朋友的目标子图,再将目标子图输入至泛化模型中,得到用户A的具体的兴趣类别。在这个过程中,泛化模型可以从用户A的社交网络结构中提取有用的信息,并据此做出准确的分类决策。For example, when the interest category of user A in the social network graph needs to be predicted, the target connection structure corresponding to user A can be determined as the target subgraph of user A through the node connection structure of the target graph data. For example, the target subgraph can be a target subgraph containing user A and his friends, and then the target subgraph is input into the generalization model to obtain the specific interest category of user A. In this process, the generalization model can extract useful information from the social network structure of user A and make accurate classification decisions based on it.

本申请实施例通过获取目标图数据,其中,目标图数据包括多个目标节点和多个目标节点组成的节点连接结构;基于目标图数据中的多个目标节点之间组成的节点连接结构,查找出每个目标节点在目标关联层级内的目标连接结构的目标子图;将目标子图输入至泛化模型,得到目标子图中的目标节点的节点类别标签;其中,泛化模型基于目标损失最小化的训练过程进行样本子图中样本节点连接结构与样本节点类别标签之间的因果预测学习得到,目标损失由第一损失、第二损失和第三损失构成,第一损失基于针对样本子图的核心特征输出的预测节点类别标签相对于核心特征的样本节点类别标签之间的差异确定;第二损失基于预设模型针对样本子图的冗余特征预测各样本节点类别标签时的概率分布均匀度确定;第三损失根据核心特征和冗余特征之间的样本距离确定,核心特征为按照注意力机制得到的边掩码矩阵对样本子图的样本节点连接结构进行增强表示得到的特征;冗余特征为按照边掩码矩阵的补集矩阵对样本子图的样本节点连接结构进行弱化表示得到的特征。以此,可以通过构建目标损失对预设模型进行训练,并在训练的过程中使得预设模型趋向于通过核心特征来进行预测、减少对冗余特征的偏好,减少核心特征和冗余特征之间的样本距离,由此,可以使得训练好的泛化模型具备更强的因果解释能力和泛化能力。并且,通过注意力机制得到的边掩码矩阵增强表示样本子图的样本连接结构,通过边掩码矩阵的补集矩阵弱化表示样本子图的样本节点连接结构,有助于训练好的泛化模型能够更好地捕捉目标节点之间的重要关系,减少冗余信息的干扰,使得预设模型更加专注于核心特征的学习。在泛化模型应用的过程中,能够基于目标图数据获取需要预测的目标节点的目标子图,使得泛化模型无需对整个目标图数据进行预测,只需要获取需要预测的目标节点的相关目标子图即可,有助于泛化模型更加专注地分析目标节点之间的因果关系,提高了预测的效率和准确性。综上,本申请能够提高训练好的泛化模型对图数据预测的准确性。In an embodiment of the present application, target graph data is obtained, wherein the target graph data includes a plurality of target nodes and a node connection structure composed of a plurality of target nodes; based on the node connection structure composed of a plurality of target nodes in the target graph data, a target subgraph of the target connection structure of each target node in the target association level is found; the target subgraph is input into a generalization model to obtain a node category label of the target node in the target subgraph; wherein the generalization model is obtained by performing causal prediction learning between the sample node connection structure and the sample node category label in the sample subgraph based on a training process of minimizing the target loss, and the target loss is composed of a first loss, a second loss and a third loss, and the first loss is determined based on the difference between the predicted node category label output for the core feature of the sample subgraph and the sample node category label of the core feature; the second loss is determined based on the uniformity of the probability distribution when predicting each sample node category label for the redundant feature of the sample subgraph according to the preset model; the third loss is determined according to the sample distance between the core feature and the redundant feature, and the core feature is a feature obtained by enhancing the sample node connection structure of the sample subgraph according to the edge mask matrix obtained by the attention mechanism; the redundant feature is a feature obtained by weakening the sample node connection structure of the sample subgraph according to the complement matrix of the edge mask matrix. In this way, the preset model can be trained by constructing a target loss, and during the training process, the preset model tends to predict through core features, reduce the preference for redundant features, and reduce the sample distance between core features and redundant features, thereby enabling the trained generalized model to have stronger causal explanation ability and generalization ability. In addition, the edge mask matrix obtained by the attention mechanism enhances the sample connection structure representing the sample subgraph, and the complement matrix of the edge mask matrix weakens the sample node connection structure representing the sample subgraph, which helps the trained generalized model to better capture the important relationship between target nodes, reduce the interference of redundant information, and make the preset model more focused on the learning of core features. In the process of applying the generalized model, the target subgraph of the target node to be predicted can be obtained based on the target graph data, so that the generalized model does not need to predict the entire target graph data, but only needs to obtain the relevant target subgraph of the target node to be predicted, which helps the generalized model to analyze the causal relationship between target nodes more attentively, and improves the efficiency and accuracy of prediction. In summary, the present application can improve the accuracy of the trained generalized model in predicting graph data.

请参照图3,在一些实施方式中,样本子图的样本节点可以包括中心样本节点和与中心样本节点相邻的相邻样本节点。为了提高泛化模型对于目标图数据预测的准确性,可以通过对预设模型进行训练,使得泛化模型能够学习因果特征(即核心特征),而避免学习捷径特征(冗余特征),具体的,泛化模型通过步骤201至步骤205训练得到:Please refer to FIG. 3 , in some embodiments, the sample nodes of the sample subgraph may include a central sample node and adjacent sample nodes adjacent to the central sample node. In order to improve the accuracy of the generalization model for predicting the target graph data, the preset model may be trained so that the generalization model can learn causal features (i.e., core features) and avoid learning shortcut features (redundant features). Specifically, the generalization model is trained through steps 201 to 205 to obtain:

步骤201,获取用于训练预设模型的样本图数据,并按照样本图数据的样本节点连接结构对样本图数据中的每个样本节点生成对应的样本子图。Step 201: obtain sample graph data for training a preset model, and generate a corresponding sample subgraph for each sample node in the sample graph data according to the sample node connection structure of the sample graph data.

其中,样本图数据可以是一个通过随机方式生成的合成图,通过在基础图的各节点中随机添加不同形状的图,可以得到不同类别的样本节点,并对任意两个样本节点添加边,可以将样本节点和边组成的样本节点连接结构作为样本图数据。在一些实施方式中,也可以通过获取现有的图数据作为样本图数据。The sample graph data may be a synthetic graph generated in a random manner, and sample nodes of different categories may be obtained by randomly adding graphs of different shapes to each node of the base graph, and edges may be added to any two sample nodes, and a sample node connection structure composed of sample nodes and edges may be used as sample graph data. In some implementations, existing graph data may also be obtained as sample graph data.

示例性的,样本子图可以是在样本图数据中,从待预测的样本节点出发,通过宽度优先搜索算法获取的局部样本图结构,以此,能够使得预设模型具备更良好的处理局部特征的能力,并且,尽管每个计算节点独立地处理样本子图,但是都能够对预设模型进行训练,并通过同步更新和共享预设模型的参数来提高对预设模型的训练效率,优化预设模型的性能。每个样本子图包括待预测的样本节点作为样本中心节点,以及在样本关联层级内与样本中心节点连接的所有其他的样本节点。Exemplarily, the sample subgraph can be a local sample graph structure obtained by a breadth-first search algorithm from the sample node to be predicted in the sample graph data, so that the preset model can have a better ability to process local features, and although each computing node processes the sample subgraph independently, it can train the preset model, and improve the training efficiency of the preset model by synchronously updating and sharing the parameters of the preset model, and optimize the performance of the preset model. Each sample subgraph includes the sample node to be predicted as the sample center node, and all other sample nodes connected to the sample center node in the sample association hierarchy.

例如,假设有一个社交网络图,其中每个样本节点代表一个用户,每条边代表两个用户之间的关系(例如朋友关系)。当需要根据社交网络图生成用户A的样本子图时,可以通过宽度优先搜索,确定与用户A连接的其他用户,假设用户A与用户B和C直接相连,那么生成的样本子图将包括用户A、B和C,用户A作为样本子图的中心节点。进一步的,也可以设定宽度优先搜索的范围,以避免样本子图过大,预设模型无法专注于核心特征。For example, suppose there is a social network graph, in which each sample node represents a user, and each edge represents a relationship between two users (such as a friend relationship). When it is necessary to generate a sample subgraph of user A based on the social network graph, other users connected to user A can be determined through breadth-first search. Assuming that user A is directly connected to users B and C, the generated sample subgraph will include users A, B, and C, with user A as the central node of the sample subgraph. Furthermore, the scope of the breadth-first search can also be set to avoid the sample subgraph being too large, so that the preset model cannot focus on the core features.

通过以上方式,可以生成每个样本节点的样本子图,以此来训练预设模型,提高预设模型对待预测的样本节点的分类能力,提高预测的准确性。Through the above method, a sample subgraph of each sample node can be generated to train the preset model, improve the classification ability of the preset model for the sample nodes to be predicted, and improve the accuracy of the prediction.

可以理解的是,为了提高泛化模型的泛化能力,可以随机生成样本图数据,并根据随机生成的样本图数据的方式生成具有多样化结构的合成图,以助于预设模型能够更全面地理解和预测不同类型的图结构,增加训练好的泛化模型的鲁棒性。例如,步骤201中的“获取用于训练预设模型的样本图数据”,包括:It is understandable that in order to improve the generalization ability of the generalization model, sample graph data can be randomly generated, and synthetic graphs with diversified structures can be generated according to the randomly generated sample graph data, so as to help the preset model to more comprehensively understand and predict different types of graph structures and increase the robustness of the trained generalization model. For example, "obtaining sample graph data for training the preset model" in step 201 includes:

(201.a1)获取基础图数据和不同类别的多个子图案,并将多个子图案分配至基础图数据的任一样本节点上,得到样本图数据;(201.a1) Obtaining basic graph data and multiple sub-patterns of different categories, and assigning the multiple sub-patterns to any sample node of the basic graph data to obtain sample graph data;

(201.a2)针对样本图数据的节点数量确定添加样本图数据的边的预设数量,并在样本图数据中对任意两个样本节点添加边,其中,对样本图数据添加的边的数量等于预设数量。(201.a2) Determine a preset number of edges to be added to the sample graph data according to the number of nodes in the sample graph data, and add edges to any two sample nodes in the sample graph data, wherein the number of edges added to the sample graph data is equal to the preset number.

其中,基础图数据可以是一个无标度网络模型图,无标度网络模型图是用于模拟真实世界网络结构的数学模型,无标度网络模型图可以从一个单个节点开始构建,随着时间的推移,可以以恒定的速率将节点添加到网络中。Among them, the basic graph data can be a scale-free network model graph, which is a mathematical model used to simulate the real-world network structure. The scale-free network model graph can be built from a single node, and nodes can be added to the network at a constant rate over time.

其中,子图案可以是任意形状的图案,例如房屋、圆圈、网格以及其他图案,通过对基础图数据的节点添加不同的子图案,可以使对应的节点具有特定的类别。The sub-pattern may be a pattern of any shape, such as a house, a circle, a grid, or other patterns. By adding different sub-patterns to the nodes of the basic graph data, the corresponding nodes may have specific categories.

其中,边是在基础图数据中连接各个节点的元素。Among them, edges are elements that connect various nodes in the basic graph data.

其中,预设数量可以根据样本图数据的节点数量确定,例如预设数量可以为节点数量的10%、15%等等,也可以是具体的数量,例如20、25等等。Among them, the preset number can be determined according to the number of nodes in the sample graph data. For example, the preset number can be 10%, 15%, etc. of the number of nodes, or it can be a specific number, such as 20, 25, etc.

示例性的,可以生成一个由基础图和四种形状组成的合成图,合成图的网络结构为,其中是节点集合,是边集合。具体的,合成图可以由一个基础图无标度网络模型图和多个四种形状的子图案组合而成,四种形状的子图案可以为房屋、圆圈、钻石、网格或者其他形状。之后,将四种形状的子图案随机附着在基础图的一个节点上,之后添加10%的随机边对合成图进一步扰动,生成样本图数据。For example, a composite graph consisting of a basic graph and four shapes can be generated, and the network structure of the composite graph is ,in is a collection of nodes, is a set of edges. Specifically, the synthetic graph can be composed of a base graph scale-free network model graph and multiple sub-patterns of four shapes, where the sub-patterns of four shapes can be houses, circles, diamonds, grids or other shapes. Then, the sub-patterns of four shapes are randomly attached to a node of the base graph, and then 10% of random edges are added to further perturb the synthetic graph to generate sample graph data.

进一步的,对于样本图数据中的所有样本节点,也可以通过随机选取的方式,选取目标数量的样本节点,进行标签掩盖,例如选取所有样本节点数量的15%的样本节点进行标签掩盖。Furthermore, for all sample nodes in the sample graph data, a target number of sample nodes may be selected randomly for label masking, for example, 15% of all sample nodes may be selected for label masking.

通过以上方式,可以生成多样化结构的样本图数据,从而有助于预设模型能够更全面地理解和预测不同类型的图结构,提高泛化模型在真实场景下的泛化性能,并且,通过对合成图添加随机边进行搅动、随机选取样本节点进行标签掩盖得到的样本图数据,可以使得泛化模型可以在面对一定程度的噪声和不完整信息时仍然保持良好的性能。Through the above methods, sample graph data with diverse structures can be generated, which will help the preset model to more comprehensively understand and predict different types of graph structures and improve the generalization performance of the generalization model in real scenarios. In addition, by adding random edges to the synthetic graph to stir it and randomly selecting sample nodes to mask the labels to obtain sample graph data, the generalization model can still maintain good performance in the face of a certain degree of noise and incomplete information.

在一些实施方式中,为了有效地控制计算复杂度,使得预设模型的训练过程更加高效,对于每个样本节点,可以生成该样本节点对应的样本子图,以便于后续预设模型对样本子图进行训练,提高预设模型的泛化能力。例如,101.1中的“按照样本图数据的样本节点连接结构对样本图数据中的每个样本节点生成对应的样本子图”,包括:In some implementations, in order to effectively control the computational complexity and make the training process of the preset model more efficient, for each sample node, a sample subgraph corresponding to the sample node can be generated, so that the preset model can be used to train the sample subgraph and improve the generalization ability of the preset model. For example, "generating a corresponding sample subgraph for each sample node in the sample graph data according to the sample node connection structure of the sample graph data" in 101.1 includes:

(201.b1)获取样本图数据的节点密度,并基于节点密度,确定每个样本节点的样本关联层级;(201.b1) Obtain the node density of the sample graph data, and determine the sample association level of each sample node based on the node density;

(201.b2)针对每个样本节点作为中心样本节点,基于中心样本节点和相邻样本节点之间组成的样本节点连接结构,查找出每个中心样本节点在样本关联层级内的目标样本节点连接结构的样本子图。(201.b2) For each sample node as a central sample node, based on the sample node connection structure formed between the central sample node and the adjacent sample nodes, find out the sample subgraph of the target sample node connection structure of each central sample node in the sample association level.

其中,节点密度可以是在样本图数据中,各样本节点之间的紧密程度。示例性的,节点密度可以通过样本节点的数量与样本图数据的面积之比计算得到,也可以通过其他方式计算得到,在此不予限定。The node density may be the degree of closeness between sample nodes in the sample graph data. For example, the node density may be calculated by the ratio of the number of sample nodes to the area of the sample graph data, or by other methods, which are not limited here.

其中,样本关联层级密度较高时,意味着样本节点之间的连接较为紧密,样本关联层级可以是从待预测的样本节点出发,通过广度优先搜索算法搜索的层级数量。当样本图数据的节点密度较大时,可以选择较小的样本关联层级,以限制搜索的范围,减少计算量,同时能够较好地捕捉到样本节点间的相互关系。相反,当样本图数据的节点密度较小时,意味着样本节点之间的连接相对稀疏,为了充分考虑图中的样本节点之间的关联性,需要选择较大的样本关联层级,通过扩大搜索的范围,可以更好地捕捉到样本节点之间的潜在关系。因此,样本关联层级也可以是样本搜索步长。Among them, when the density of the sample association level is high, it means that the connection between the sample nodes is relatively close. The sample association level can be the number of levels searched by the breadth-first search algorithm starting from the sample node to be predicted. When the node density of the sample graph data is large, a smaller sample association level can be selected to limit the scope of the search, reduce the amount of calculation, and better capture the relationship between the sample nodes. On the contrary, when the node density of the sample graph data is small, it means that the connection between the sample nodes is relatively sparse. In order to fully consider the correlation between the sample nodes in the graph, it is necessary to select a larger sample association level. By expanding the scope of the search, the potential relationship between the sample nodes can be better captured. Therefore, the sample association level can also be the sample search step.

其中,相邻样本节点可以是在样本图数据中,通过连通性判定为与待预测的样本节点直接连接或者通过中间节点相互连接的样本节点。The adjacent sample nodes may be sample nodes that are directly connected to the sample node to be predicted or are connected to each other through intermediate nodes in the sample graph data through connectivity.

通过以上方式,可以高效地从样本图数据中选取合适的样本子图,在确保选取的样本子图能够具有代表性的前提下,尽可能地节约计算资源。Through the above method, it is possible to efficiently select a suitable sample sub-graph from the sample graph data, and save computing resources as much as possible while ensuring that the selected sample sub-graph is representative.

需要说明的是,可以通过生成样本图数据的第一特征矩阵,以便于后续需要对样本子图进行训练时,可以直接通过第一特征矩阵进行查找,得到样本子图的初始特征矩阵,因此,在获取各样本子图的初始特征矩阵和邻接矩阵之前,也即是步骤202之前,还包括:It should be noted that the first feature matrix of the sample graph data can be generated so that when the sample subgraph needs to be trained later, the initial feature matrix of the sample subgraph can be directly searched through the first feature matrix to obtain the initial feature matrix of the sample subgraph. Therefore, before obtaining the initial feature matrix and adjacency matrix of each sample subgraph, that is, before step 202, the following is also included:

(A1)获取预设的图论库;(A1) Obtain the preset graph theory library;

(A2)获取样本节点的预设的特征指标,并通过图论库生成每个样本节点在对应的特征指标下的节点特征向量;其中,特征指标包括节点标识、节点度、聚类系数、介数中心性和接近中心性中的至少一个;(A2) obtaining preset characteristic indicators of the sample nodes, and generating a node characteristic vector of each sample node under the corresponding characteristic indicator through a graph theory library; wherein the characteristic indicator includes at least one of a node identifier, a node degree, a clustering coefficient, a betweenness centrality, and a closeness centrality;

(A3)根据多个样本节点对应的多个节点特征向量,生成第一特征矩阵。(A3) Generate a first feature matrix according to a plurality of node feature vectors corresponding to the plurality of sample nodes.

其中,图论库可以是多个开源图论库,如NetworkX库、igraph库和SNAP库等。以NetworkX库为例,NetworkX库是一个开源的Python包,用于创建、操作复杂网络的结构、动态和功能。NetworkX库提供了一个全面的图论和网络分析的工具集,适合于研究各种类型的网络,包括社交网络、生物信息学网络、交通网络等等。NetworkX库提供了丰富的函数和方法来研究这些图的结构特性,包括节点标识、搜索和遍历图、计算节点度、聚类系数、介数中心性、接近中心性等图论特征,等等。Among them, the graph theory library can be multiple open source graph theory libraries, such as NetworkX library, igraph library and SNAP library. Taking NetworkX library as an example, NetworkX library is an open source Python package for creating and manipulating the structure, dynamics and functions of complex networks. NetworkX library provides a comprehensive set of graph theory and network analysis tools, which is suitable for studying various types of networks, including social networks, bioinformatics networks, transportation networks, etc. NetworkX library provides a wealth of functions and methods to study the structural characteristics of these graphs, including node identification, searching and traversing graphs, calculating node degrees, clustering coefficients, betweenness centrality, closeness centrality and other graph theory features, etc.

其中,预设的特征指标可以是节点标识、节点度、聚类系数、介数中心性和接近中心性中的至少一个。可以理解的是,节点标识也可以在生成样本图数据时生成。具体的,节点度可以是样本节点与相邻样本节点的连接的边的数量,聚类系数可以是衡量当前样本节点与邻居节点之间连接紧密程度的指标,介数中心性可以用于衡量当前样本节点在样本数据图中连接其他样本节点的重要性程度,接近中心性衡量了当前的样本节点与其他样本节点之间的距离程度。Among them, the preset characteristic index can be at least one of node identification, node degree, clustering coefficient, betweenness centrality and closeness centrality. It is understandable that the node identification can also be generated when the sample graph data is generated. Specifically, the node degree can be the number of edges connecting the sample node with the adjacent sample nodes, the clustering coefficient can be an indicator to measure the closeness of the connection between the current sample node and the neighboring nodes, the betweenness centrality can be used to measure the importance of the current sample node in connecting other sample nodes in the sample data graph, and the closeness centrality measures the distance between the current sample node and other sample nodes.

进一步的,也可以通过聚类系数确定每个样本节点的样本关联层级,聚类系数高的样本节点可以选取较小的样本关联层级,聚类系数低的样本节点可以选取较大的样本关联层级。Furthermore, the sample association level of each sample node may also be determined by the clustering coefficient. Sample nodes with a high clustering coefficient may select a smaller sample association level, and sample nodes with a low clustering coefficient may select a larger sample association level.

其中,节点特征向量可以是针对每个预设的特征指标,在图论库生成的节点特征向量。节点特征向量可以包括节点标识、节点度、聚类系数、介数中心性和接近中心性等特征指标的特征值。The node feature vector may be a node feature vector generated in the graph theory library for each preset feature index. The node feature vector may include feature values of feature indexes such as node identification, node degree, clustering coefficient, betweenness centrality, and closeness centrality.

其中,第一特征矩阵可以由预设的特征指标的节点特征向量组成。The first feature matrix may be composed of node feature vectors of preset feature indicators.

示例性的,若将聚类系数、介数中心性、接近中心性以及一个随机特征每个样本节点的节点特征向量,那么样本节点1的节点特征向量可以是[0.75,0.33,0.5,0.42],每个数字代表一个特定的特征指标对应的值。进一步的,当特征指标还包括节点标识和节点度时,样本节点1的节点特征向量就可以是[1,2,0.75,0.33,0.5,0.42],样本节点2的节点特征向量可以是[2,4,0.85,0.54,0.6,0.59]。进一步的,当样本图数据包括样本节点1和样本节点2时,第一特征矩阵就由样本节点1和样本节点2的节点特征向量组成,第一特征矩阵中的每一行代表一个样本节点,每一列代表一个特征指标的特征值。Exemplarily, if the clustering coefficient, betweenness centrality, closeness centrality and a random feature are taken as the node feature vector of each sample node, then the node feature vector of sample node 1 can be [0.75, 0.33, 0.5, 0.42], and each number represents the value corresponding to a specific feature indicator. Further, when the feature indicator also includes node identification and node degree, the node feature vector of sample node 1 can be [1, 2, 0.75, 0.33, 0.5, 0.42], and the node feature vector of sample node 2 can be [2, 4, 0.85, 0.54, 0.6, 0.59]. Further, when the sample graph data includes sample node 1 and sample node 2, the first feature matrix is composed of the node feature vectors of sample node 1 and sample node 2, and each row in the first feature matrix represents a sample node, and each column represents the eigenvalue of a feature indicator.

通过以上方式,可以生成样本图数据的第一特征矩阵,以便于后续根据第一特征矩阵生成每个样本子图的初始特征矩阵,不必针对每个样本子图单独生成特征矩阵,提高了对数据处理的效率。In the above manner, a first characteristic matrix of the sample graph data can be generated, so that an initial characteristic matrix of each sample sub-graph can be subsequently generated according to the first characteristic matrix. It is not necessary to generate a characteristic matrix separately for each sample sub-graph, thereby improving the efficiency of data processing.

步骤202,获取各样本子图的初始特征矩阵和邻接矩阵,并依次将初始特征矩阵输入至预设模型中进行特征聚合,得到每个样本子图的融合特征矩阵。Step 202, obtaining the initial feature matrix and adjacency matrix of each sample subgraph, and sequentially inputting the initial feature matrix into a preset model for feature aggregation to obtain a fused feature matrix of each sample subgraph.

在一些实施方式中,可以通过融合样本子图中的中心样本节点和与中心样本节点向量的相邻样本节点的信息,来增强中心节点的特征表达能力,使其能够捕捉到更丰富的局部结构和上下文信息,以便于后续提取核心特征。In some embodiments, the feature expression capability of the central node can be enhanced by fusing the information of the central sample node and the adjacent sample nodes of the central sample node vector in the sample subgraph, so that it can capture richer local structure and context information to facilitate the subsequent extraction of core features.

其中,初始特征矩阵可以是每个样本子图的特征矩阵,初始特征矩阵包含用于描述样本子图中每个样本节点的节点特征的向量。具体的,初始特征矩阵可以根据每个样本节点的节点标识,从样本图数据对应的第一特征矩阵中查找对应的节点特征向量,并根据多个样本节点的节点特征向量,生成初始特征矩阵。The initial feature matrix may be a feature matrix of each sample subgraph, and the initial feature matrix includes a vector for describing the node features of each sample node in the sample subgraph. Specifically, the initial feature matrix may search for the corresponding node feature vector from the first feature matrix corresponding to the sample graph data according to the node identifier of each sample node, and generate the initial feature matrix according to the node feature vectors of multiple sample nodes.

其中,邻接矩阵可以是用于表示样本子图中各个样本节点的连接关系的矩阵,邻接矩阵是一个二维矩阵,其中的元素表示样本节点之间是否存在边。示例性的,对于一个具有n个节点的样本子图,邻接矩阵是一个n×n的矩阵,其中的元素用于记录样本节点之间的连接情况:邻接矩阵中的元素可以用0和1表示,其中1表示存在边,0表示不存在边。The adjacency matrix may be a matrix used to represent the connection relationship between each sample node in the sample subgraph, and the adjacency matrix is a two-dimensional matrix, in which the elements represent whether there is an edge between the sample nodes. Exemplarily, for a sample subgraph with n nodes, the adjacency matrix is an n×n matrix, in which the elements are used to record the connection status between the sample nodes: the elements in the adjacency matrix can be represented by 0 and 1, in which 1 represents the existence of an edge and 0 represents the absence of an edge.

其中,预设模型可以是为训练好的图神经网络模型,通过对预设模型进行训练并更新模型参数,在对预设模型训练完毕之后,可以得到泛化模型。Among them, the preset model can be a trained graph neural network model. By training the preset model and updating the model parameters, a generalized model can be obtained after the preset model is trained.

其中,融合特征矩阵可以是对初始特征矩阵中,中心样本节点的特征向量更新后得到的特征矩阵。融合特征矩阵可以由预设模型通过对每个样本节点及其相邻样本节点的信息进行特征转换和聚合得到。The fused feature matrix may be a feature matrix obtained by updating the feature vector of the central sample node in the initial feature matrix. The fused feature matrix may be obtained by a preset model by performing feature conversion and aggregation on the information of each sample node and its adjacent sample nodes.

通过以上方式,可以针对每个样本子图,融合各样本节点的重要特征得到融合特征矩阵,使得预设模型对于样本子图的预测准确性更高,有利于后续的分类任务。Through the above method, for each sample subgraph, the important features of each sample node can be fused to obtain a fused feature matrix, so that the preset model has higher prediction accuracy for the sample subgraph, which is beneficial to subsequent classification tasks.

在一些实施方式中,可以通过样本子图中每个样本节点的节点标识,在样本图数据的第一特征矩阵中查找对应样本子图的节点特征向量,并生成每个样本子图的初始特征矩阵,以此提高初始特征矩阵的生成效率,示例性的,(101.2)中的“获取各样本子图的初始特征矩阵和邻接矩阵”,包括:In some implementations, the node feature vector of the corresponding sample subgraph can be searched in the first feature matrix of the sample graph data according to the node identifier of each sample node in the sample subgraph, and an initial feature matrix of each sample subgraph can be generated, so as to improve the generation efficiency of the initial feature matrix. Exemplarily, "obtaining the initial feature matrix and adjacency matrix of each sample subgraph" in (101.2) includes:

(202.a1)在每个样本子图中,确定每个样本节点的节点标识;(202.a1) In each sample subgraph, determine the node identity of each sample node;

(202.a2)基于节点标识,从第一特征矩阵中,确定每个样本节点对应的节点特征向量;(202.a2) Based on the node identifier, determine the node feature vector corresponding to each sample node from the first feature matrix;

(202.a3)根据样本子图的多个样本节点对应的多个节点特征向量,生成样本子图的初始特征矩阵;(202.a3) Generate an initial feature matrix of the sample subgraph according to multiple node feature vectors corresponding to multiple sample nodes of the sample subgraph;

(202.a4)基于样本节点之间的连接关系,生成特征矩阵对应的邻接矩阵。(202.a4) Generate the adjacency matrix corresponding to the feature matrix based on the connection relationship between sample nodes.

其中,节点标识可以是每个样本图数据和样本子图中的每个样本节点的唯一标识或索引值。通过节点标识,可以在样本图数据的第一特征矩阵中快速定位对应样本子图的节点特征向量。The node identifier may be a unique identifier or index value of each sample node in each sample graph data and sample subgraph. Through the node identifier, the node feature vector of the corresponding sample subgraph can be quickly located in the first feature matrix of the sample graph data.

示例性的,假设样本子图中包括多个样本节点,每个样本节点具有一个唯一的用户标识,例如user123、user456等等。之后,根据user123和user456从样本图数据中查找对应的样本节点的节点特征向量,例如[user123,2,0.75,0.33,0.5,0.42]、[user456,4,0.85,0.54,0.6,0.59]、[user789,3,0.66,0.24,0.6,0.5]等等,即可以查找到user123的节点特征向量为[user123,2,0.75,0.33,0.5,0.42],user456的节点特征向量为[user456,4,0.85,0.54,0.6,0.59],并根据[user123,2,0.75,0.33,0.5,0.42]和[user456,4,0.85,0.54,0.6,0.59]组成样本子图的初始特征矩阵。For example, assume that the sample subgraph includes multiple sample nodes, each of which has a unique user ID, such as user123, user456, etc. Then, according to user123 and user456, the node feature vectors of the corresponding sample nodes are found from the sample graph data, such as [user123, 2, 0.75, 0.33, 0.5, 0.42], [user456, 4, 0.85, 0.54, 0.6, 0.59], [user789, 3, 0.66, 0.24, 0.6, 0.5], etc., that is, the node feature vectors of user123 can be found. The node feature vector is [user123, 2, 0.75, 0.33, 0.5, 0.42], the node feature vector of user456 is [user456, 4, 0.85, 0.54, 0.6, 0.59], and the initial feature matrix of the sample subgraph is composed according to [user123, 2, 0.75, 0.33, 0.5, 0.42] and [user456, 4, 0.85, 0.54, 0.6, 0.59].

示例性的,当样本节点之间存在连接关系时,也即存在边时,可以生成对应的邻接矩阵。例如,当样本子图为社交网络图时,两个用户是朋友,那么两个用户之间存在连接关系,这两个用户在邻接矩阵中对应的元素就是1,否则就是0。Exemplarily, when there is a connection relationship between sample nodes, that is, when there is an edge, a corresponding adjacency matrix can be generated. For example, when the sample subgraph is a social network graph, if two users are friends, then there is a connection relationship between the two users, and the corresponding elements of the two users in the adjacency matrix are 1, otherwise they are 0.

可以理解的是,直接利用已有的第一特征矩阵获取初始特征矩阵,可以避免重新计算样本子图中的初始特征矩阵,显著提高了生成初始特征矩阵的效率和准确性,有利于后续获取核心特征和冗余特征。It can be understood that directly using the existing first feature matrix to obtain the initial feature matrix can avoid recalculating the initial feature matrix in the sample subgraph, significantly improving the efficiency and accuracy of generating the initial feature matrix, and facilitating the subsequent acquisition of core features and redundant features.

需要说明的是,中心样本节点作为样本子图的核心,也即是样本子图的分类对象,其特征的好坏直接影响到了预设模型对样本子图的分析,因此,可以对样本子图进行特征融合,使得中心样本节点能够吸收和整合相连接的相邻样本节点的信息,从而得到更全面和更准确的中心样本节点的特征表示,例如,步骤202中的“依次将初始特征矩阵输入至预设模型中进行特征聚合,得到每个样本子图的融合特征矩阵”,包括:It should be noted that the central sample node is the core of the sample subgraph, that is, the classification object of the sample subgraph. The quality of its features directly affects the analysis of the sample subgraph by the preset model. Therefore, the sample subgraph can be feature fused so that the central sample node can absorb and integrate the information of the connected adjacent sample nodes, thereby obtaining a more comprehensive and accurate feature representation of the central sample node. For example, the step 202 of "inputting the initial feature matrix into the preset model for feature aggregation to obtain the fused feature matrix of each sample subgraph" includes:

(202.b1)根据每个样本子图的节点标识,从第一特征矩阵中确定样本子图的节点度矩阵;(202.b1) According to the node identifier of each sample subgraph, determine the node degree matrix of the sample subgraph from the first feature matrix;

(202.b2)获取样本子图的邻接矩阵,并通过节点度矩阵对邻接矩阵进行归一化,得到归一化后的邻接矩阵;(202.b2) Obtain the adjacency matrix of the sample subgraph and normalize the adjacency matrix using the node degree matrix to obtain the normalized adjacency matrix;

(202.b3)将中心样本节点作为融合中心,基于归一化后的邻接矩阵,进行样本节点特征的逐层融合,得到中心样本节点对应更新后的中心节点特征向量;(202.b3) Take the central sample node as the fusion center, and perform layer-by-layer fusion of sample node features based on the normalized adjacency matrix to obtain the updated central node feature vector corresponding to the central sample node;

(202.b4)根据中心节点特征向量对初始特征矩阵的中心样本节点的节点特征向量进行更新,得到每个样本子图的融合特征矩阵。(202.b4) Update the node feature vector of the central sample node of the initial feature matrix according to the central node feature vector to obtain the fused feature matrix of each sample subgraph.

其中,节点度矩阵可以是样本子图的对角矩阵,节点度矩阵的对角线上的每个元素代表图中对应节点的节点度。每个样本节点的节点度可以通过样本子图中每个样本节点的节点标识,从第一特征矩阵中获取。示例性的,当每个样本节点的节点特征向量由节点标识、节点度、聚类系数、介数中心性、接近中心性以及一个随机特征构成时,如 [user123,2,0.75,0.33,0.5,0.42],则当样本节点的节点标识为user123时,可以通过从样本图数据中查找到对应的节点特征向量为[user123,2,0.75,0.33,0.5,0.42],则可以查找到该样本节点的节点度为2。通过样本子图中的多个样本节点对应的多个节点度,可以生成样本子图的节点度矩阵。Among them, the node degree matrix can be a diagonal matrix of the sample subgraph, and each element on the diagonal of the node degree matrix represents the node degree of the corresponding node in the graph. The node degree of each sample node can be obtained from the first feature matrix through the node identification of each sample node in the sample subgraph. Exemplarily, when the node feature vector of each sample node is composed of node identification, node degree, clustering coefficient, betweenness centrality, closeness centrality and a random feature, such as [user123, 2, 0.75, 0.33, 0.5, 0.42], then when the node identification of the sample node is user123, the corresponding node feature vector can be found from the sample graph data as [user123, 2, 0.75, 0.33, 0.5, 0.42], then the node degree of the sample node can be found to be 2. The node degree matrix of the sample subgraph can be generated by multiple node degrees corresponding to multiple sample nodes in the sample subgraph.

其中,中心节点特征向量可以是每个样本子图中的中心样本节点的节点特征向量,中心样本节点是样本子图的分类对象,预设模型可以基于中心节点特征向量对中心样本节点进行分类。The central node feature vector may be a node feature vector of a central sample node in each sample subgraph, the central sample node is a classification object of the sample subgraph, and the preset model may classify the central sample node based on the central node feature vector.

具体的,中心节点特征向量的融合过程可以包括邻接矩阵归一化、特征聚合、线性变换和非线性激活、多层迭代和读出。示例性的,首先可以计算归一化的拉普拉斯矩阵,以减少度数大的样本节点在特征传播中的影响,实现特征的归一化,保持传播过程中特征尺度的稳定,之后,在每一层中,将样本节点特征向量和归一化的邻接矩阵相乘,使得每个样本节点均能够接收其相邻样本节点的信息,之后,将聚合后的特征通过权重矩阵进行线性变换,并应用激活函数更新样本节点的表示。进一步的,通过多个层的迭代,中心样本节点逐渐融合了更远的相邻样本节点的信息。在每一层中,中心节点的表示都会根据其相邻样本节点的节点特征向量进行更新。在完成所有层的迭代后,可以直接使用中心样本节点最后一层的节点特征向量特征作为中心样本节点的中心节点特征向量,或者可以通过池化所有相邻样本节点的节点特征向量来得到整个样本子图的表示,并将其作为中心样本节点的中心节点特征向量。Specifically, the fusion process of the central node feature vector may include adjacency matrix normalization, feature aggregation, linear transformation and nonlinear activation, multi-layer iteration and readout. Exemplarily, the normalized Laplacian matrix can be calculated first to reduce the influence of sample nodes with large degrees in feature propagation, realize feature normalization, and maintain the stability of feature scale during propagation. After that, in each layer, the sample node feature vector and the normalized adjacency matrix are multiplied so that each sample node can receive the information of its adjacent sample nodes. After that, the aggregated features are linearly transformed through the weight matrix, and the activation function is applied to update the representation of the sample node. Further, through the iteration of multiple layers, the central sample node gradually integrates the information of the more distant adjacent sample nodes. In each layer, the representation of the central node is updated according to the node feature vector of its adjacent sample nodes. After completing the iteration of all layers, the node feature vector feature of the last layer of the central sample node can be directly used as the central node feature vector of the central sample node, or the node feature vector of all adjacent sample nodes can be pooled to obtain the representation of the entire sample subgraph and used as the central node feature vector of the central sample node.

示例性的,可以通过图卷积神经网络(Graph Convolutional Networks,GCN)的方式融合子图信息,从而得到中心样本节点的中心节点特征向量。具体公式如下:For example, the subgraph information can be fused by means of graph convolutional neural networks (GCN) to obtain the central node feature vector of the central sample node. The specific formula is as follows:

其中,表示第l+1层的样本节点v的节点特征向量;g(·)表示激活函数,比如ReLU或tanh,用于引入非线性;表示归一化的节点度矩阵的逆平方根,是节点度矩阵,的每个元素表示样本节点i的节点度;A表示样本子图的邻接矩阵,用于表示样本子图中各个样本节点间的连接关系;表示第l层的样本节点v的节点特征向量;W表示第l层的权重矩阵,W为预设模型训练学习后得到的模型参数。in, represents the node feature vector of the sample node v in the l+1th layer; g(·) represents the activation function, such as ReLU or tanh, which is used to introduce nonlinearity; Represents the normalized node degree matrix The inverse square root of is the node degree matrix, Each element of represents the node degree of sample node i; A represents the adjacency matrix of the sample subgraph, which is used to represent the connection relationship between each sample node in the sample subgraph; represents the node feature vector of the sample node v of the lth layer; W represents the weight matrix of the lth layer, and W is the model parameter obtained after the preset model training and learning.

对于每个样本子图,融合子图信息过程包括通过节点度矩阵对邻接矩阵进行归一化,即,以减少节点度大的样本节点在特征传播中的影响,实现特征的归一化,保持传播过程中特征尺度的稳定。For each sample subgraph, the process of fusing subgraph information includes normalizing the adjacency matrix by the node degree matrix, that is, , in order to reduce the influence of sample nodes with large node degrees in feature propagation, realize feature normalization, and maintain the stability of feature scale during propagation.

进一步的,可以将样本节点v在第l层的节点特征向量和归一化的邻接矩阵相乘,使得在l层中,每个样本节点均能够接收其相邻样本节点的信息。进一步的,可以将样本节点v在第l层的节点特征向量和归一化的邻接矩阵相乘后,再乘以权重矩阵W并应用激活函数g(·),以将当前各个样本节点的节点特征相邻从原始空间转换到新的特征空间中,并引入非线性激活。以此,能够得到样本子图更新后的特征向量Furthermore, the node feature vector of the sample node v at layer l can be Multiply it with the normalized adjacency matrix so that each sample node in layer l can receive information from its adjacent sample nodes. Furthermore, the node feature vector of the sample node v in layer l can be After multiplying it with the normalized adjacency matrix, it is multiplied by the weight matrix W and the activation function g(·) is applied to transform the node feature adjacency of each current sample node from the original space to the new feature space and introduce nonlinear activation. In this way, the updated feature vector of the sample subgraph can be obtained .

进一步的,在完成所有层的迭代后,可以直接使用中心样本节点最后一层的节点特征向量作为中心样本节点的表示,或者可以通过池化所有样本节点的节点特征向量来得到整个样本子图的表示,并以此作为中心样本节点的全局表示。Furthermore, after completing the iteration of all layers, the node feature vector of the last layer of the central sample node can be directly used as the representation of the central sample node, or the node feature vectors of all sample nodes can be pooled to obtain the representation of the entire sample subgraph and use it as the global representation of the central sample node.

进一步的,可以更新得到中心样本节点的中心节点特征向量,并用中心节点特征向量替换中心样本节点在初始特征矩阵中的节点特征向量,由此,得到每个样本子图的融合特征矩阵。Furthermore, the central node feature vector of the central sample node can be updated, and the node feature vector of the central sample node in the initial feature matrix can be replaced by the central node feature vector, thereby obtaining a fusion feature matrix of each sample subgraph.

通过以上方式,可以利用节点度矩阵和邻接矩阵对样本子图的节点特征向量进行归一化处理,能够减少样本子图中节点度较大的样本节点对特征传播的影响,保持特征传播过程中的尺度稳定性,使得每个样本节点能够更加公平地接收来自相邻的样本节点的信息,并且,通过逐层融合中心样本节点的相邻样本节点的节点特征向量,可以有效地聚合样本子图内的信息,增强中心样本节点的表征能力。Through the above method, the node degree matrix and the adjacency matrix can be used to normalize the node feature vectors of the sample subgraph, which can reduce the influence of sample nodes with larger node degrees in the sample subgraph on feature propagation and maintain the scale stability in the feature propagation process, so that each sample node can receive information from adjacent sample nodes more fairly. In addition, by fusing the node feature vectors of the adjacent sample nodes of the central sample node layer by layer, the information in the sample subgraph can be effectively aggregated to enhance the representation ability of the central sample node.

步骤203,根据邻接矩阵和融合特征矩阵,确定每个样本子图的核心特征和冗余特征。Step 203: Determine the core features and redundant features of each sample subgraph according to the adjacency matrix and the fusion feature matrix.

在一些实施方式中,可以通过邻接矩阵和融合特征矩阵对样本子图的核心特征和冗余特征进行计算,以便于后续通过核心特征和冗余特征,对预设模型进行训练,使得预设模型更加专注于因果特征,并减少对冗余特征的关注程度。In some embodiments, the core features and redundant features of the sample subgraph can be calculated through the adjacency matrix and the fusion feature matrix, so that the preset model can be trained through the core features and redundant features later, so that the preset model focuses more on causal features and reduces the attention to redundant features.

在一些实施方式中,可以在更新得到每个样本子图的中心节点特征向量之后,根据中心样本节点的节点标识,在样本图数据的第一特征矩阵中,更新该中心样本节点的中心节点特征向量。并在第一特征矩阵中的每个中心节点特征向量更新后,同步在所有样本子图中更新对应的节点特征向量。以此,能够尽可能地丰富每个样本子图中的样本节点的节点特征向量的表示。或者,也可以只对每个样本子图中的中心样本节点的中心节点特征向量进行更新,以使得预设模型在训练的过程中,能够更加关注核心特征。In some embodiments, after the central node feature vector of each sample subgraph is updated, the central node feature vector of the central sample node can be updated in the first feature matrix of the sample graph data according to the node identifier of the central sample node. And after each central node feature vector in the first feature matrix is updated, the corresponding node feature vector is updated synchronously in all sample subgraphs. In this way, the representation of the node feature vector of the sample node in each sample subgraph can be enriched as much as possible. Alternatively, only the central node feature vector of the central sample node in each sample subgraph can be updated, so that the preset model can pay more attention to the core features during the training process.

示例性的,可以针对样本子图中的任意两个样本节点的节点特征向量,将节点特征向量输入至多层感知机中,得到对应的边掩码值,并基于该样本子图对应的所有边掩码值生成边掩码矩阵。进一步的,可以将边掩码矩阵和对应样本子图的邻接矩阵相乘,以对各个边的权重进行调整,之后,再与融合特征矩阵结合,得到核心结构图,将核心网络图经过预设模型的处理,得到核心特征。可以理解的是,通过边掩码矩阵,可以调整邻接矩阵中各个存在的边的权重,使得样本节点间重要的边更加突出。Exemplarily, the node feature vectors of any two sample nodes in the sample subgraph can be input into a multilayer perceptron to obtain the corresponding edge mask values, and an edge mask matrix can be generated based on all edge mask values corresponding to the sample subgraph. Furthermore, the edge mask matrix can be multiplied with the adjacency matrix of the corresponding sample subgraph to adjust the weights of each edge, and then combined with the fusion feature matrix to obtain the core structure diagram, and the core network diagram is processed by a preset model to obtain the core features. It can be understood that through the edge mask matrix, the weights of each existing edge in the adjacency matrix can be adjusted to make the important edges between sample nodes more prominent.

示例性的,可以将边掩码矩阵的补集矩阵与对应样本子图的邻接矩阵相乘,以对各个边的权重进行调整,之后,再与融合特征矩阵结合,得到冗余网络图,将冗余结构图经过预设模型的处理,得到核心特征。可以理解的是,由于边掩码矩阵包含了各个样本节点之间的边的权重值,那么边掩码矩阵的冗余矩阵则包含了各个样本节点之间的边的冗余部分。Exemplarily, the complement matrix of the edge mask matrix can be multiplied with the adjacency matrix of the corresponding sample subgraph to adjust the weight of each edge, and then combined with the fusion feature matrix to obtain a redundant network graph, and the redundant structure graph is processed by a preset model to obtain the core features. It can be understood that since the edge mask matrix contains the weight values of the edges between each sample node, the redundant matrix of the edge mask matrix contains the redundant parts of the edges between each sample node.

通过以上方式可以计算得到的核心特征和冗余特征,可以便于后续对预设模型进行训练时,使得预设模型更专注于因果特征,而减少对捷径特征的关注。The core features and redundant features calculated in the above manner can facilitate the subsequent training of the preset model, so that the preset model can focus more on causal features and reduce the focus on shortcut features.

在一些实施方式中,为了准确获取核心特征和冗余特征,可以通过边掩码矩阵对邻接矩阵的权重进行调整,以量化两个样本节点的边的重要性,示例性的,步骤203可以包括:In some implementations, in order to accurately obtain core features and redundant features, the weight of the adjacency matrix may be adjusted by an edge mask matrix to quantify the importance of the edge between two sample nodes. Exemplarily, step 203 may include:

(203.1)基于样本子图的融合特征矩阵,将样本子图中任意两个样本节点输入至多层感知机,得到任意两个样本节点之间的边掩码值;(203.1) Based on the fusion feature matrix of the sample subgraph, any two sample nodes in the sample subgraph are input into the multi-layer perceptron to obtain the edge mask value between any two sample nodes;

(203.2)根据样本子图中的多个边掩码值,生成样本子图的边掩码矩阵;(203.2) generating an edge mask matrix of the sample subgraph according to a plurality of edge mask values in the sample subgraph;

(203.3)基于邻接矩阵、融合特征矩阵和边掩码矩阵,生成样本子图对应的核心特征;(203.3) Generate the core features corresponding to the sample subgraph based on the adjacency matrix, fusion feature matrix and edge mask matrix;

(203.4)根据样本子图的边掩码矩阵,生成对应的补集矩阵;(203.4) Generate the corresponding complement matrix based on the edge mask matrix of the sample subgraph;

(203.5)基于邻接矩阵、融合特征矩阵和补集矩阵,得到样本子图的冗余特征。(203.5) Based on the adjacency matrix, fusion feature matrix and complement matrix, the redundant features of the sample subgraph are obtained.

其中,边掩码值可以是针对样本子图中的每一条边,通过多层感知机的计算得到的标量值。边掩码值表示了两个样本节点之间的边的重要性。具体的,边掩码值可以通过将样本子图中的任意两个样本节点输入多层感知机计算得到,例如,边的边掩码值可以为0、0.1、0.6等等。The edge mask value can be a scalar value obtained by calculating the multi-layer perceptron for each edge in the sample subgraph. The edge mask value indicates the importance of the edge between two sample nodes. Specifically, the edge mask value can be calculated by inputting any two sample nodes in the sample subgraph into the multi-layer perceptron. For example, the edge mask value of the edge can be 0, 0.1, 0.6, etc.

其中,边掩码矩阵可以是将通过将样本子图的任意两个样本节点的节点特征向量依次输入多层感知机,直至所有的样本节点均输入完毕,根据所有边掩码值生成的矩阵。The edge mask matrix may be a matrix generated according to all edge mask values by sequentially inputting node feature vectors of any two sample nodes of the sample subgraph into the multilayer perceptron until all sample nodes are input.

具体的,可以在样本子图中,将任意两个样本节点两两配对后,将两个样本节点的节点特征向量输入至多层感知机(Multilayer Perceptron,MLP),得到一个在0到1之间的边掩码值,得到该样本中所有样本节点之间的边掩码值之后,即可构成样本子图的边掩码矩阵Specifically, you can find it in the sample subgraph In the example, after pairing any two sample nodes, the node feature vectors of the two sample nodes are input into the multilayer perceptron (MLP) to obtain an edge mask value between 0 and 1. After obtaining the edge mask values between all sample nodes in the sample, the sample subgraph can be constructed. The edge mask matrix .

进一步的,可以将边掩码矩阵和样本子图的邻接矩阵A按照逐元素相乘之后,与融合特征矩阵X结合,得到核心结构图Furthermore, the edge mask matrix can be and sample subgraph The adjacency matrix A is multiplied element by element and combined with the fusion feature matrix X to obtain the core structure diagram .

进一步的,补集矩阵为,将补集矩阵和样本子图的邻接矩阵A按照逐元素相乘之后,与融合特征矩阵X结合,得到冗余结构图Furthermore, the complement matrix is , the complement matrix and sample subgraph The adjacency matrix A is multiplied element by element and combined with the fusion feature matrix X to obtain the redundant structure diagram .

具体的,可以将核心结构图和冗余结构图通过预设模型的处理,分别得到核心特征和冗余特征Specifically, the core structure graph and the redundant structure graph can be processed by the preset model to obtain the core features and redundant features .

通过以上方式,可以更好地捕捉样本子图的结构信息,尽可能地提高预设模型对于样本子图的预测准确性和解释性,并通过生成冗余特征,可以减少冗余数据,如噪声数据对预设模型性能的负面影响,以便于后续提高预设模型对核心特征的关注程度,减少对冗余特征的依赖。Through the above methods, the structural information of the sample subgraph can be better captured, and the prediction accuracy and interpretability of the preset model for the sample subgraph can be improved as much as possible. By generating redundant features, redundant data, such as the negative impact of noise data on the performance of the preset model, can be reduced, so as to subsequently improve the preset model's attention to core features and reduce dependence on redundant features.

步骤204,基于核心特征和冗余特征,确定预设模型的目标损失。Step 204, determining the target loss of the preset model based on the core features and the redundant features.

其中,泛化模型为通过对预设模型进行训练之后,预设模型收敛或者达到了预设训练次数后,停止训练得到的模型。The generalized model is a model obtained by training a preset model and stopping training after the preset model converges or reaches a preset number of training times.

其中,目标损失可以是由第一损失、第二损失和第三损失构成的损失,用于指示对预设模型的参数的调节,以对预设模型进行训练,得到泛化模型。具体的,第一损失可以为交叉熵损失,第二损失可以为熵损失,第三损失可以为对比损失,通过将第一损失、第二损失和第三损失构成目标损失,并利用目标损失来更新模型参数,可以使泛化模型更加专注地学习核心特征,避免泛化模型过度关注冗余特征,从而提高整体特征表征的质量。Among them, the target loss can be a loss composed of a first loss, a second loss, and a third loss, which is used to indicate the adjustment of the parameters of the preset model to train the preset model to obtain a generalized model. Specifically, the first loss can be a cross entropy loss, the second loss can be an entropy loss, and the third loss can be a contrast loss. By forming the target loss with the first loss, the second loss, and the third loss, and using the target loss to update the model parameters, the generalized model can be made to learn the core features more attentively, avoid the generalized model from over-focusing on redundant features, and thus improve the quality of the overall feature representation.

可以理解的是,可以通过将核心特征和冗余特征输入至预设模型中,针对样本子图的核心特征输出的预测节点类别标签相对于核心特征的样本节点类别标签之间的差异确定第一损失、基于预设模型针对样本子图的冗余特征预测各样本节点类别标签时的概率分布均匀度确定第二损失、根据核心特征和冗余特征之间的样本距离确定第三损失,以此,能够使得泛化模型将注意力集中在因果特征,即核心特征上,减少对不稳定的冗余特征的依赖,减少过拟合的风险,同时做出更可靠的决策。It can be understood that by inputting the core features and redundant features into the preset model, the first loss can be determined by the difference between the predicted node category label output by the core features of the sample subgraph and the sample node category label of the core features, the second loss can be determined based on the uniformity of the probability distribution when predicting each sample node category label for the redundant features of the sample subgraph based on the preset model, and the third loss can be determined according to the sample distance between the core features and the redundant features. In this way, the generalization model can focus on the causal features, i.e., the core features, reduce the reliance on unstable redundant features, reduce the risk of overfitting, and make more reliable decisions.

在一些实施方式中,为了使得泛化模型更好地捕捉数据的本质结构,从而在面对不同分布的测试数据时具有更好的泛化能力,可以基于核心特征和冗余特征对预设模型进行训练,使得预设模型具备更好的分类性能,例如,步骤204可以包括:In some implementations, in order to enable the generalization model to better capture the essential structure of the data, thereby having better generalization ability when facing test data with different distributions, the preset model can be trained based on the core features and redundant features so that the preset model has better classification performance. For example, step 204 may include:

(204.1)基于针对样本子图的核心特征输出的预测节点类别标签相对于核心特征的样本节点类别标签之间的差异确定第一损失;(204.1) determining a first loss based on a difference between a predicted node category label output for a core feature of a sample subgraph and a sample node category label of the core feature;

(204.2)基于预设模型针对样本子图的冗余特征预测各样本节点类别标签时的概率分布均匀度确定第二损失;(204.2) Determine the second loss based on the uniformity of probability distribution when predicting the category label of each sample node according to the redundant features of the sample subgraph based on the preset model;

(204.3)根据核心特征和冗余特征之间的样本距离确定第三损失;(204.3) Determine the third loss based on the sample distance between the core features and the redundant features;

(204.4)基于第一损失、第二损失和第三损失之和,构成预设模型的目标损失。(204.4) Based on the sum of the first loss, the second loss and the third loss, the target loss of the preset model is constructed.

其中,第一损失可以是交叉熵损失,用于评估预设模型对于核心特征预测的各预测节点类别标签的概率分布与核心特征的样本节点类别标签之间的差异。具体的,可以针对每个样本节点类别标签设置独热编码,将每个样本节点类别标签都用一个二进制向量表示,并只设置核心特征对应的真实的样本节点类别标签为1,其余样本节点类别标签为0。具体的,第一损失可以为:The first loss may be a cross entropy loss, which is used to evaluate the difference between the probability distribution of each predicted node category label predicted by the preset model for the core feature and the sample node category label of the core feature. Specifically, a unique hot encoding may be set for each sample node category label, each sample node category label is represented by a binary vector, and only the real sample node category label corresponding to the core feature is set to 1, and the remaining sample node category labels are set to 0. Specifically, the first loss may be:

其中,是样本节点类别标签的独热编码,是模型对第v个样本节点类别标签的第一预测概率。当预测节点类别标签与核心特征对应的真实的样本节点类别标签一致时,第一损失最小,由此,可以通过最小化第一损失,使得预设模型能够更好地学习数据的因果关系,从而不断作出更加准确的预测。in, is the one-hot encoding of the sample node category label, is the model's first predicted probability for the vth sample node category label. When the predicted node category label is consistent with the actual sample node category label corresponding to the core feature, the first loss is minimal. Therefore, by minimizing the first loss, the preset model can better learn the causal relationship of the data, thereby continuously making more accurate predictions.

其中,预测节点类别标签为预设模型对于核心特征预测得到的多个节点类别标签。例如,对于一个动物图像,可以得到预测节点类别标签为猫、狗、兔子,每个预测节点类别标签都对应一个预测概率。The predicted node category labels are multiple node category labels predicted by the preset model for the core features. For example, for an animal image, the predicted node category labels may be cat, dog, and rabbit, and each predicted node category label corresponds to a predicted probability.

其中,第二损失可以是熵损失,用于衡量预设模型针对样本子图的冗余特征预测各样本节点类别标签时的概率分布均匀度,概率分布越均匀,表明预设模型对于冗余特征的分类偏好越小,由此,能够有效地防止预设模型学习捷径特征。具体的,第二损失可以为:The second loss may be an entropy loss, which is used to measure the uniformity of the probability distribution of the preset model when predicting the category labels of each sample node based on the redundant features of the sample subgraph. The more uniform the probability distribution is, the smaller the classification preference of the preset model for the redundant features is, thereby effectively preventing the preset model from learning shortcut features. Specifically, the second loss may be:

其中,代表均匀分布时对每个样本节点类别标签的均匀概率,通常=1/类别数,是预设模型对第v个样本节点类别标签的第二预测概率。in, Represents the uniform probability of each sample node category label when uniformly distributed, usually =1/number of categories, It is the second predicted probability of the preset model for the category label of the vth sample node.

具体来说,当冗余特征输入至预设模型之后,输出的第二预测概率都接近于均匀分布,则表明冗余特征对于预设模型的分类并无帮助。示例性的,若样本类别标签为4,分别为猫、兔、狗、猪,那么=1/4,若将冗余特征1输入至预设模型,得到对猫、兔、狗、猪的预测概率均为1/4时,预设模型的第二损失也最小,此时,表明预设模型对于冗余特征并无分类偏好,输出的概率分布均匀度高。因此,需要通过不断对预设模型进行训练,使得预设模型对于冗余特征在各个样本节点类别标签的预测概率接近于均匀分布。Specifically, when redundant features are input into the preset model, the output second prediction probabilities are close to uniform distribution, indicating that redundant features are not helpful for the classification of the preset model. For example, if the sample category labels are 4, namely cat, rabbit, dog, and pig, then =1/4, if redundant feature 1 is input into the preset model, and the predicted probabilities of cat, rabbit, dog, and pig are all 1/4, the second loss of the preset model is also the smallest. At this time, it shows that the preset model has no classification preference for redundant features, and the output probability distribution is highly uniform. Therefore, it is necessary to continuously train the preset model so that the predicted probability of the preset model for the redundant features in the category labels of each sample node is close to uniform distribution.

其中,概率分布均匀度可以是在冗余特征输入至预设模型之后,预设模型对冗余特征预测输出的各个样本节点类别标签下的第二预设概率。各第二预设概率分布越均匀,表明概率分布均匀度越高,例如,第二预设概率为1/3、1/3、1/3的概率分布均匀度大于2/3、1/6、1/6。The uniformity of probability distribution may be the second preset probability of each sample node category label predicted and output by the preset model for the redundant features after the redundant features are input into the preset model. The more uniform the second preset probability distribution is, the higher the uniformity of probability distribution is. For example, the uniformity of probability distribution with the second preset probabilities of 1/3, 1/3, and 1/3 is greater than 2/3, 1/6, and 1/6.

其中,第三损失可以是对比损失,具体的,可以将核心特征作为正样本,冗余特征作为负样本,对比损失可以激励预设模型能够正确区分正样本和负样本,使正样本和正样本之间的距离更近,正样本和负样本之间的距离更远。进一步的,第三损失可以为:The third loss may be a contrast loss. Specifically, the core features may be used as positive samples and the redundant features may be used as negative samples. The contrast loss may encourage the preset model to correctly distinguish between positive samples and negative samples, so that the distance between positive samples is closer and the distance between positive samples and negative samples is farther. Furthermore, the third loss may be:

其中,是sigmoid函数,通常用来将输入映射到(0,1)之间,Q为超参数,C和是核心特征的不同实例,表示正样本对,C和均表示正样本,即核心特征;C和表示正负样本对,表示负样本,即冗余特征;表示期望值,可以通过预设模型对所有正样本对的核心特征的第一预测概率,和对所有负样本对的冗余特征的多个第二预测概率数量取平均得到。通过第三损失训练预设模型,可以激励预设模型能够正确区分正样本和负样本,使得正样本和正样本之间的距离更近,正样本和负样本之间的距离更远。in, is the sigmoid function, which is usually used to map the input to (0, 1), Q is a hyperparameter, C and are different instances of the core features, representing positive sample pairs, C and Both represent positive samples, i.e. core features; C and represents a positive and negative sample pair, Represents negative samples, i.e., redundant features; It represents the expected value, which can be obtained by averaging the first predicted probability of the core features of all positive sample pairs by the preset model and the number of second predicted probabilities of the redundant features of all negative sample pairs. By training the preset model with the third loss, the preset model can be encouraged to correctly distinguish positive samples from negative samples, so that the distance between positive samples is closer and the distance between positive samples is farther.

其中,样本距离可以是两个核心特征之间的距离或者核心特征和冗余特征之间的距离,样本距离可以通过欧氏距离和余弦相似度等方式计算得到。The sample distance may be the distance between two core features or the distance between a core feature and a redundant feature. The sample distance may be calculated by Euclidean distance, cosine similarity, and the like.

其中,核心特征可以是样本子图中具有因果关系或者因果解释能力的特征,核心特征具有待预测的样本节点与样本节点类别标签的因果信息。Among them, the core feature can be a feature with causal relationship or causal explanation ability in the sample subgraph, and the core feature has causal information of the sample node to be predicted and the sample node category label.

其中,冗余特征可以是核心特征和预测结果之间的捷径特征,对于预测样本节点与样本节点类别标签并无额外的贡献,不具备因果解释特性。并且,冗余特征会使得预设模型倾向于学习捷径特征作出决策,降低预设模型对于样本图数据预测的准确性,从而导致训练得到的泛化模型在分布外测试数据中性能下降。示例性的,冗余特征可以是噪声特征。Among them, redundant features can be shortcut features between core features and prediction results, which have no additional contribution to the prediction of sample nodes and sample node category labels, and do not have causal explanation characteristics. Moreover, redundant features will make the preset model tend to learn shortcut features to make decisions, reduce the accuracy of the preset model's prediction of sample graph data, and thus cause the performance of the generalized model obtained by training to degrade in out-of-distribution test data. Exemplarily, redundant features can be noise features.

示例性的,若样本图数据为社交网络图,预设模型需要对社交网络中的用户进行分类,标签包括“学生”、“工作者”、“自由职业者”等。样本子图是该用户(即待预测的样本节点)在社交网络上的活动记录,如帖子、照片、互动等。核心特征可以包括用户的职业信息、教育背景等。将核心特征输入至预设模型之后,预设模型对应输出该用户是学生、工作者和自由职业者的第一预测概率。进一步的,可以使用交叉熵损失作为第一损失,计算预设模型输出的在各预测节点类别标签下的概率分布与实际类别标签(用独热编码表示)之间的差异。例如,如果实际的样本类别标签为学生,则独热编码为[1,0,0];如果预测概率是[0.7,0.2,0.1],则第一损失将计算这两者之间的交叉熵损失,并调整预设模型的参数,以鼓励预设模型对该样本子图中的待预测的样本节点进行与测试,提高对学生类别的预测准确率。Exemplarily, if the sample graph data is a social network graph, the preset model needs to classify users in the social network, and the labels include "student", "worker", "freelancer", etc. The sample subgraph is the activity record of the user (i.e., the sample node to be predicted) on the social network, such as posts, photos, interactions, etc. The core features may include the user's occupational information, educational background, etc. After the core features are input into the preset model, the preset model outputs the first predicted probability that the user is a student, a worker, and a freelancer. Further, the cross entropy loss can be used as the first loss to calculate the difference between the probability distribution under each prediction node category label output by the preset model and the actual category label (represented by one-hot encoding). For example, if the actual sample category label is a student, the one-hot encoding is [1, 0, 0]; if the predicted probability is [0.7, 0.2, 0.1], the first loss will calculate the cross entropy loss between the two, and adjust the parameters of the preset model to encourage the preset model to test the sample nodes to be predicted in the sample subgraph, thereby improving the prediction accuracy of the student category.

示例性的,冗余特征可以是该用户所在地理位置的天气信息,虽然这看起来与分类任务相关,但实际上冗余特征可能带来偏差且对用户的职业信息的分类无实际帮助,因此,可以通过计算熵损失作为第二损失,当预设模型对冗余特征的输出趋于均匀分布时,表示预设模型没有从这些冗余特征中学习到任何有偏的分类决策,以此鼓励预设模型主要关注有实质性帮助于分类的核心特征。Exemplarily, redundant features may be weather information at the user's geographic location. Although this appears to be relevant to the classification task, redundant features may actually introduce bias and are of no practical help in classifying the user's occupational information. Therefore, entropy loss may be calculated as the second loss. When the output of the preset model for redundant features tends to be evenly distributed, it indicates that the preset model has not learned any biased classification decisions from these redundant features, thereby encouraging the preset model to focus mainly on core features that are of substantial help in classification.

进一步的,在对预设模型进行训练的过程中,可以通过对比损失来使预设模型能够区分相关特征(即核心特征)和非相关特征(即冗余特征),并尽可能学习将核心特征(正样本)之间彼此靠近,同时将冗余特征(负样本)与核心特征远离。进一步的,通过对比损失作为第三损失,计算正样本对之间的相似度,并惩罚正负样本对之间的相似度,以使学习到的嵌入空间能够更好地区分相关与不相关的信息。Furthermore, in the process of training the preset model, the contrast loss can be used to enable the preset model to distinguish relevant features (i.e., core features) from irrelevant features (i.e., redundant features), and learn to bring core features (positive samples) closer to each other as much as possible, while moving redundant features (negative samples) away from core features. Furthermore, the contrast loss is used as the third loss to calculate the similarity between positive sample pairs, and to penalize the similarity between positive and negative sample pairs, so that the learned embedding space can better distinguish relevant and irrelevant information.

可以理解的是,通过结合第一、第二和第三损失构成目标损失,可以确保预设模型在不同方面有均衡的性能提升,目标损失的最优化会帮助预设模型在学习样本的因果特性的同时,忽视对分类无帮助的冗余特征,并且更好地区分与预测任务相关的核心特征。It can be understood that by combining the first, second and third losses to form the target loss, it can ensure that the preset model has a balanced performance improvement in different aspects. The optimization of the target loss will help the preset model to learn the causal characteristics of the samples while ignoring redundant features that are not helpful for classification, and better distinguish the core features related to the prediction task.

步骤205,基于目标损失更新预设模型的参数,当预设模型收敛时,得到训练好的泛化模型。Step 205, updating the parameters of the preset model based on the target loss, and when the preset model converges, a trained generalized model is obtained.

进一步的,在对预设模型训练的过程中,可以根据目标损失来不断更新预设模型的参数,例如通过梯度下降算法。当预设模型的损失不再显著降低,或达到某个终止条件(例如达到预设的训练次数或预设模型收敛)时,即可停止对预设模型的训练,得到泛化模型。Furthermore, during the training of the preset model, the parameters of the preset model can be continuously updated according to the target loss, for example, by a gradient descent algorithm. When the loss of the preset model is no longer significantly reduced, or a certain termination condition is reached (for example, a preset number of training times is reached or the preset model converges), the training of the preset model can be stopped to obtain a generalized model.

综上所述,通过多损失函数结合对预设模型进行训练,预设模型更有可能学习到对分类有决定性作用的因果特征,而不是简单地依靠易学习但可能导致过拟合的捷径特征。以此,可以得到一个泛化性能更好,且预测的准确性更高的泛化模型。In summary, by combining multiple loss functions to train the preset model, the preset model is more likely to learn the causal features that are decisive for classification, rather than simply relying on shortcut features that are easy to learn but may lead to overfitting. In this way, a generalized model with better generalization performance and higher prediction accuracy can be obtained.

请参阅图4,在一些实施方式中,结合图4对本申请的总体流程进行介绍。示例性的,可以通过随机生成的方式,生成样本图数据,并针对样本图数据中的每个样本节点,在样本关联层级内生成样本子图。进一步的,针对每个样本子图的中心样本节点作为融合中心进行特征融合,生成中心样本节点的中心节点特征向量,并根据中心节点特征向量对对应样本子图的初始特征矩阵进行更新,得到融合特征矩阵,以此来提高对中心样本节点,也即是待预测样本节点的关注度。Please refer to Figure 4. In some embodiments, the overall process of the present application is introduced in conjunction with Figure 4. Exemplarily, the sample graph data can be generated by random generation, and a sample subgraph is generated in the sample association level for each sample node in the sample graph data. Furthermore, feature fusion is performed for the central sample node of each sample subgraph as a fusion center to generate a central node feature vector of the central sample node, and the initial feature matrix of the corresponding sample subgraph is updated according to the central node feature vector to obtain a fused feature matrix, so as to increase the attention to the central sample node, that is, the sample node to be predicted.

进一步的,将样本子图中任意两个样本节点经过多层感知机,生成边掩码值,再根据所有样本节点直接的边掩码值,生成边掩码矩阵,根据边掩码矩阵与邻接矩阵逐元素相乘,再结合融合特征矩阵,得到核心特征。进一步的,根据边掩码矩阵的补集矩阵与邻接矩阵逐元素相乘,再结合融合特征矩阵,得到冗余特征,由此,可以对边的重要性的进行量化。Furthermore, any two sample nodes in the sample subgraph are passed through a multi-layer perceptron to generate edge mask values, and then an edge mask matrix is generated based on the direct edge mask values of all sample nodes. The edge mask matrix is multiplied element by element with the adjacency matrix, and then combined with the fusion feature matrix to obtain the core features. Furthermore, the complementary matrix of the edge mask matrix is multiplied element by element with the adjacency matrix, and then combined with the fusion feature matrix to obtain redundant features, thereby quantifying the importance of the edge.

进一步的,可以基于针对样本子图的核心特征输出的预测节点类别标签相对于核心特征的样本节点类别标签之间的差异确定第一损失、基于预设模型针对样本子图的冗余特征预测各样本节点类别标签时的概率分布均匀度确定第二损失、根据核心特征和冗余特征之间的样本距离确定第三损失,并基于第一损失、第二损失和第三损失之和,构成预设模型的目标损失。以此,能够基于目标损失更新预设模型的参数,当预设模型收敛时,得到训练好的、具备良好的泛化能力的泛化模型,之后,即可通过泛化模型实现分布范围外对数据进行准确预测和评估。Furthermore, the first loss can be determined based on the difference between the predicted node category label output for the core feature of the sample subgraph and the sample node category label of the core feature, the second loss can be determined based on the uniformity of the probability distribution when predicting each sample node category label for the redundant features of the sample subgraph based on the preset model, the third loss can be determined based on the sample distance between the core feature and the redundant feature, and the target loss of the preset model can be constructed based on the sum of the first loss, the second loss and the third loss. In this way, the parameters of the preset model can be updated based on the target loss, and when the preset model converges, a well-trained generalization model with good generalization ability is obtained, and then the generalization model can be used to accurately predict and evaluate data outside the distribution range.

本申请实施例通过获取目标图数据,其中,目标图数据包括多个目标节点和多个目标节点组成的节点连接结构;基于目标图数据中的多个目标节点之间组成的节点连接结构,查找出每个目标节点在目标关联层级内的目标连接结构的目标子图;将目标子图输入至泛化模型,得到目标子图中的目标节点的节点类别标签;其中,泛化模型基于目标损失最小化的训练过程进行样本子图中样本节点连接结构与样本节点类别标签之间的因果预测学习得到,目标损失由第一损失、第二损失和第三损失构成,第一损失基于针对样本子图的核心特征输出的预测节点类别标签相对于核心特征的样本节点类别标签之间的差异确定;第二损失基于预设模型针对样本子图的冗余特征预测各样本节点类别标签时的概率分布均匀度确定;第三损失根据核心特征和冗余特征之间的样本距离确定,核心特征为按照注意力机制得到的边掩码矩阵对样本子图的样本节点连接结构进行增强表示得到的特征;冗余特征为按照边掩码矩阵的补集矩阵对样本子图的样本节点连接结构进行弱化表示得到的特征。以此,可以通过构建目标损失对预设模型进行训练,并在训练的过程中使得预设模型趋向于通过核心特征来进行预测、减少对冗余特征的偏好,减少核心特征和冗余特征之间的样本距离,由此,可以使得训练好的泛化模型具备更强的因果解释能力和泛化能力。并且,通过注意力机制得到的边掩码矩阵增强表示样本子图的样本连接结构,通过边掩码矩阵的补集矩阵弱化表示样本子图的样本节点连接结构,有助于训练好的泛化模型能够更好地捕捉目标节点之间的重要关系,减少冗余信息的干扰,使得预设模型更加专注于核心特征的学习。在泛化模型应用的过程中,能够基于目标图数据获取需要预测的目标节点的目标子图,使得泛化模型无需对整个目标图数据进行预测,只需要获取需要预测的目标节点的相关目标子图即可,有助于泛化模型更加专注地分析目标节点之间的因果关系,提高了预测的效率和准确性。综上,本申请能够提高训练好的泛化模型对图数据预测的准确性。In an embodiment of the present application, target graph data is obtained, wherein the target graph data includes a plurality of target nodes and a node connection structure composed of a plurality of target nodes; based on the node connection structure composed of a plurality of target nodes in the target graph data, a target subgraph of the target connection structure of each target node in the target association level is found; the target subgraph is input into a generalization model to obtain a node category label of the target node in the target subgraph; wherein the generalization model is obtained by performing causal prediction learning between the sample node connection structure and the sample node category label in the sample subgraph based on a training process of minimizing the target loss, and the target loss is composed of a first loss, a second loss and a third loss, and the first loss is determined based on the difference between the predicted node category label output for the core feature of the sample subgraph and the sample node category label of the core feature; the second loss is determined based on the uniformity of the probability distribution when predicting each sample node category label for the redundant feature of the sample subgraph according to the preset model; the third loss is determined according to the sample distance between the core feature and the redundant feature, and the core feature is a feature obtained by enhancing the sample node connection structure of the sample subgraph according to the edge mask matrix obtained by the attention mechanism; the redundant feature is a feature obtained by weakening the sample node connection structure of the sample subgraph according to the complement matrix of the edge mask matrix. In this way, the preset model can be trained by constructing a target loss, and during the training process, the preset model tends to predict through core features, reduce the preference for redundant features, and reduce the sample distance between core features and redundant features, thereby enabling the trained generalized model to have stronger causal explanation ability and generalization ability. In addition, the edge mask matrix obtained by the attention mechanism enhances the sample connection structure representing the sample subgraph, and the complement matrix of the edge mask matrix weakens the sample node connection structure representing the sample subgraph, which helps the trained generalized model to better capture the important relationship between target nodes, reduce the interference of redundant information, and make the preset model more focused on the learning of core features. In the process of applying the generalized model, the target subgraph of the target node to be predicted can be obtained based on the target graph data, so that the generalized model does not need to predict the entire target graph data, but only needs to obtain the relevant target subgraph of the target node to be predicted, which helps the generalized model to analyze the causal relationship between target nodes more attentively, and improves the efficiency and accuracy of prediction. In summary, the present application can improve the accuracy of the trained generalized model in predicting graph data.

请参阅图5,本申请实施例还提供一种基于泛化模型的图数据预测装置,可以实现上述基于泛化模型的图数据预测方法,基于泛化模型的图数据预测装置包括:Referring to FIG. 5 , an embodiment of the present application further provides a graph data prediction device based on a generalization model, which can implement the above-mentioned graph data prediction method based on a generalization model. The graph data prediction device based on a generalization model includes:

获取模块51,用于获取目标图数据,其中,目标图数据包括多个目标节点和多个目标节点组成的节点连接结构;An acquisition module 51 is used to acquire target graph data, wherein the target graph data includes a plurality of target nodes and a node connection structure composed of the plurality of target nodes;

查找模块52,用于基于目标图数据中的多个目标节点之间组成的节点连接结构,查找出每个目标节点在目标关联层级内的目标连接结构的目标子图;A search module 52, configured to search for a target subgraph of a target connection structure of each target node in a target association level based on a node connection structure formed between multiple target nodes in the target graph data;

输入模块53,用于将目标子图输入至泛化模型,得到目标子图中的目标节点的节点类别标签;其中,泛化模型基于目标损失最小化的训练过程进行样本子图中样本节点连接结构与样本节点类别标签之间的因果预测学习得到,目标损失由第一损失、第二损失和第三损失构成,第一损失基于针对样本子图的核心特征输出的预测节点类别标签相对于核心特征的样本节点类别标签之间的差异确定;第二损失基于预设模型针对样本子图的冗余特征预测各样本节点类别标签时的概率分布均匀度确定;第三损失根据核心特征和冗余特征之间的样本距离确定,核心特征为按照注意力机制得到的边掩码矩阵对样本子图的样本节点连接结构进行增强表示得到的特征;冗余特征为按照边掩码矩阵的补集矩阵对样本子图的样本节点连接结构进行弱化表示得到的特征。The input module 53 is used to input the target subgraph into the generalization model to obtain the node category label of the target node in the target subgraph; wherein the generalization model is obtained by performing causal prediction learning between the sample node connection structure and the sample node category label in the sample subgraph based on a training process of minimizing the target loss, and the target loss is composed of a first loss, a second loss and a third loss. The first loss is determined based on the difference between the predicted node category label output for the core feature of the sample subgraph and the sample node category label of the core feature; the second loss is determined based on the uniformity of the probability distribution when predicting each sample node category label for the redundant features of the sample subgraph based on the preset model; the third loss is determined according to the sample distance between the core feature and the redundant feature, the core feature is the feature obtained by enhancing the sample node connection structure of the sample subgraph according to the edge mask matrix obtained by the attention mechanism; the redundant feature is the feature obtained by weakening the sample node connection structure of the sample subgraph according to the complement matrix of the edge mask matrix.

该基于泛化模型的图数据预测装置的具体实施方式与上述基于泛化模型的图数据预测方法的具体实施例基本相同,在此不再赘述。在满足本申请实施例要求的前提下,基于泛化模型的图数据预测装置还可以设置其他功能模块,以实现上述实施例中的基于泛化模型的图数据预测方法。The specific implementation of the graph data prediction device based on the generalization model is basically the same as the specific implementation of the graph data prediction method based on the generalization model mentioned above, and will not be repeated here. On the premise of meeting the requirements of the embodiments of this application, the graph data prediction device based on the generalization model can also be provided with other functional modules to implement the graph data prediction method based on the generalization model in the above embodiments.

本申请实施例还提供了一种计算机设备,计算机设备包括存储器和处理器,存储器存储有计算机程序,处理器执行计算机程序时实现上述基于泛化模型的图数据预测方法。该计算机设备可以为包括平板电脑、车载电脑等任意智能终端。The embodiment of the present application also provides a computer device, the computer device includes a memory and a processor, the memory stores a computer program, and the processor implements the above-mentioned graph data prediction method based on the generalization model when executing the computer program. The computer device can be any intelligent terminal including a tablet computer, a car computer, etc.

请参阅图6,图6示意了另一实施例的计算机设备的硬件结构,计算机设备包括:Please refer to FIG. 6 , which schematically shows the hardware structure of a computer device according to another embodiment. The computer device includes:

处理器61,可以采用通用的CPU(CentralProcessingUnit,中央处理器)、微处理器、应用专用集成电路(ApplicationSpecificIntegratedCircuit,ASIC)、或者一个或多个集成电路等方式实现,用于执行相关程序,以实现本申请实施例所提供的技术方案;The processor 61 may be implemented by a general-purpose CPU (Central Processing Unit), a microprocessor, an application-specific integrated circuit (Application Specific Integrated Circuit, ASIC), or one or more integrated circuits, and is used to execute relevant programs to implement the technical solutions provided in the embodiments of the present application;

存储器62,可以采用只读存储器(ReadOnlyMemory,ROM)、静态存储设备、动态存储设备或者随机存取存储器(RandomAccessMemory,RAM)等形式实现。存储器62可以存储操作系统和其他应用程序,在通过软件或者固件来实现本说明书实施例所提供的技术方案时,相关的程序代码保存在存储器62中,并由处理器61来调用执行本申请实施例的基于泛化模型的图数据预测方法;The memory 62 can be implemented in the form of a read-only memory (ROM), a static storage device, a dynamic storage device, or a random access memory (RAM). The memory 62 can store an operating system and other application programs. When the technical solution provided in the embodiment of this specification is implemented by software or firmware, the relevant program code is stored in the memory 62, and the processor 61 calls and executes the graph data prediction method based on the generalized model in the embodiment of the present application;

输入/输出接口63,用于实现信息输入及输出;Input/output interface 63, used to implement information input and output;

通信接口64,用于实现本设备与其他设备的通信交互,可以通过有线方式(例如USB、网线等)实现通信,也可以通过无线方式(例如移动网络、WIFI、蓝牙等)实现通信;Communication interface 64, used to realize communication interaction between the device and other devices, which can be realized by wired mode (such as USB, network cable, etc.) or wireless mode (such as mobile network, WIFI, Bluetooth, etc.);

总线66,在设备的各个组件(例如处理器61、存储器62、输入/输出接口63和通信接口64)之间传输信息;A bus 66 that transmits information between the various components of the device (e.g., the processor 61, the memory 62, the input/output interface 63, and the communication interface 64);

其中处理器61、存储器62、输入/输出接口63和通信接口64通过总线66实现彼此之间在设备内部的通信连接。The processor 61 , the memory 62 , the input/output interface 63 and the communication interface 64 are connected to each other in communication within the device via a bus 66 .

本申请实施例还提供了一种计算机可读存储介质,该计算机可读存储介质存储有计算机程序,该计算机程序被处理器执行时实现上述基于泛化模型的图数据预测方法。An embodiment of the present application also provides a computer-readable storage medium, which stores a computer program. When the computer program is executed by a processor, it implements the above-mentioned graph data prediction method based on the generalization model.

存储器作为一种非暂态计算机可读存储介质,可用于存储非暂态软件程序以及非暂态性计算机可执行程序。此外,存储器可以包括高速随机存取存储器,还可以包括非暂态存储器,例如至少一个磁盘存储器件、闪存器件、或其他非暂态固态存储器件。在一些实施方式中,存储器可选包括相对于处理器远程设置的存储器,这些远程存储器可以通过网络连接至该处理器。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。The memory, as a non-transient computer-readable storage medium, can be used to store non-transient software programs and non-transient computer executable programs. In addition, the memory may include a high-speed random access memory, and may also include a non-transient memory, such as at least one disk storage device, a flash memory device, or other non-transient solid-state storage device. In some embodiments, the memory may optionally include a memory remotely disposed relative to the processor, and these remote memories may be connected to the processor via a network. Examples of the above-mentioned network include, but are not limited to, the Internet, an intranet, a local area network, a mobile communication network, and combinations thereof.

本申请实施例描述的实施例是为了更加清楚的说明本申请实施例的技术方案,并不构成对于本申请实施例提供的技术方案的限定,本领域技术人员可知,随着技术的演变和新应用场景的出现,本申请实施例提供的技术方案对于类似的技术问题,同样适用。The embodiments described in the embodiments of the present application are intended to more clearly illustrate the technical solutions of the embodiments of the present application and do not constitute a limitation on the technical solutions provided in the embodiments of the present application. Those skilled in the art will appreciate that with the evolution of technology and the emergence of new application scenarios, the technical solutions provided in the embodiments of the present application are also applicable to similar technical problems.

本领域技术人员可以理解的是,图中示出的技术方案并不构成对本申请实施例的限定,可以包括比图示更多或更少的步骤,或者组合某些步骤,或者不同的步骤。Those skilled in the art will appreciate that the technical solutions shown in the figures do not constitute a limitation on the embodiments of the present application, and may include more or fewer steps than shown in the figures, or a combination of certain steps, or different steps.

以上所描述的装置实施例仅仅是示意性的,其中作为分离部件说明的单元可以是或者也可以不是物理上分开的,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。The device embodiments described above are merely illustrative, and the units described as separate components may or may not be physically separated, that is, they may be located in one place, or they may be distributed on multiple network units. Some or all of the modules may be selected according to actual needs to achieve the purpose of the solution of this embodiment.

本领域普通技术人员可以理解,上文中所公开方法中的全部或某些步骤、系统、设备中的功能模块/单元可以被实施为软件、固件、硬件及其适当的组合。Those skilled in the art will appreciate that all or some of the steps in the methods disclosed above, and the functional modules/units in the systems and devices may be implemented as software, firmware, hardware, or a suitable combination thereof.

本申请的说明书及上述附图中的术语“第一”、“第二”、“第三”、“第四”等(如果存在)是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本申请的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。The terms "first", "second", "third", "fourth", etc. (if any) in the specification of the present application and the above-mentioned drawings are used to distinguish similar objects, and are not necessarily used to describe a specific order or sequence. It should be understood that the data used in this way can be interchangeable where appropriate, so that the embodiments of the present application described herein can be implemented in an order other than those illustrated or described herein. In addition, the terms "including" and "having" and any of their variations are intended to cover non-exclusive inclusions, for example, a process, method, system, product or device comprising a series of steps or units is not necessarily limited to those steps or units clearly listed, but may include other steps or units that are not clearly listed or inherent to these processes, methods, products or devices.

应当理解,在本申请中,“至少一个(项)”和“若干”是指一个或者多个,“多个”是指两个或两个以上。“和/或”,用于描述关联对象的关联关系,表示可以存在三种关系,例如,“A和/或B”可以表示:只存在A,只存在B以及同时存在A和B三种情况,其中A,B可以是单数或者复数。字符“/”一般表示前后关联对象是一种“或”的关系。“以下至少一项(个)”或其类似表达,是指这些项中的任意组合,包括单项(个)或复数项(个)的任意组合。例如,a,b或c中的至少一项(个),可以表示:a,b,c,“a和b”,“a和c”,“b和c”,或“a和b和c”,其中a,b,c可以是单个,也可以是多个。It should be understood that in the present application, "at least one (item)" and "several" refer to one or more, and "plurality" refers to two or more. "And/or" is used to describe the association relationship of associated objects, indicating that three relationships may exist. For example, "A and/or B" can mean: only A exists, only B exists, and A and B exist at the same time, where A and B can be singular or plural. The character "/" generally indicates that the previous and next associated objects are in an "or" relationship. "At least one of the following" or similar expressions refers to any combination of these items, including any combination of single or plural items. For example, at least one of a, b or c can mean: a, b, c, "a and b", "a and c", "b and c", or "a and b and c", where a, b, c can be single or multiple.

在本申请所提供的几个实施例中,应该理解到,所揭露的系统和方法,可以通过其它的方式实现。例如,以上所描述的系统实施例仅仅是示意性的,例如,上述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。In the several embodiments provided in the present application, it should be understood that the disclosed systems and methods can be implemented in other ways. For example, the system embodiments described above are only schematic. For example, the division of the above units is only a logical function division. There may be other division methods in actual implementation, such as multiple units or components can be combined or integrated into another system, or some features can be ignored or not executed. Another point is that the mutual coupling or direct coupling or communication connection shown or discussed can be an indirect coupling or communication connection through some interfaces, devices or units, which can be electrical, mechanical or other forms.

上述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。The units described above as separate components may or may not be physically separated, and the components shown as units may or may not be physical units, that is, they may be located in one place or distributed on multiple network units. Some or all of the units may be selected according to actual needs to achieve the purpose of the solution of this embodiment.

另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。In addition, each functional unit in each embodiment of the present application may be integrated into one processing unit, or each unit may exist physically separately, or two or more units may be integrated into one unit. The above-mentioned integrated unit may be implemented in the form of hardware or in the form of software functional units.

集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括多指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例的方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(Read-Only Memory,简称ROM)、随机存取存储器(Random Access Memory,简称RAM)、磁碟或者光盘等各种可以存储程序的介质。If the integrated unit is implemented in the form of a software functional unit and sold or used as an independent product, it can be stored in a computer-readable storage medium. Based on this understanding, the technical solution of the present application is essentially or the part that contributes to the prior art or all or part of the technical solution can be embodied in the form of a software product, and the computer software product is stored in a storage medium, including multiple instructions to enable a computer device (which can be a personal computer, server, or network device, etc.) to execute all or part of the steps of the methods of various embodiments of the present application. The aforementioned storage medium includes: U disk, mobile hard disk, read-only memory (Read-Only Memory, referred to as ROM), random access memory (Random Access Memory, referred to as RAM), disk or optical disk and other media that can store programs.

以上参照附图说明了本申请实施例的优选实施例,并非因此局限本申请实施例的权利范围。本领域技术人员不脱离本申请实施例的范围和实质内所作的任何修改、等同替换和改进,均应在本申请实施例的权利范围之内。The preferred embodiments of the present application are described above with reference to the accompanying drawings, but the scope of the rights of the present application is not limited thereto. Any modification, equivalent substitution and improvement made by a person skilled in the art without departing from the scope and essence of the present application should be within the scope of the rights of the present application.

Claims (11)

1. A graph data prediction method based on a generalization model, the method comprising:
obtaining target graph data, wherein the target graph data comprises a plurality of target nodes and a node connection structure formed by the plurality of target nodes;
Searching a target subgraph of a target connection structure of each target node in a target association level based on a node connection structure formed by a plurality of target nodes in the target graph data;
Inputting the target subgraph into a generalization model to obtain a node class label of a target node in the target subgraph;
The generalization model carries out causal prediction learning between a sample node connection structure and a sample node class label in a sample sub-graph based on a training process of minimizing target loss, wherein the target loss is formed by a first loss, a second loss and a third loss, and the first loss is determined based on the difference between a predicted node class label output for core characteristics of the sample sub-graph and a sample node class label of the core characteristics; the second loss is determined based on probability distribution uniformity when the redundancy features of the sample subgraph are used for predicting the category labels of each sample node according to a preset model; the third loss is determined according to the sample distance between the core feature and the redundant feature, wherein the core feature is a feature obtained by enhancing and representing a sample node connection structure of a sample sub-graph according to an edge mask matrix obtained according to an attention mechanism; the redundant features are features obtained by weakening and representing a sample node connection structure of a sample sub-graph according to a complement matrix of the edge mask matrix;
The sample nodes of the sample subgraph comprise a central sample node and a plurality of adjacent sample nodes; the generalization model is obtained by training in the following way:
Acquiring sample graph data for training a preset model, and generating a corresponding sample subgraph for each sample node in the sample graph data according to a sample node connection structure of the sample graph data;
acquiring an initial feature matrix and an adjacent matrix of each sample sub-graph, and sequentially inputting the initial feature matrix into a preset model to perform feature aggregation to obtain a fusion feature matrix of each sample sub-graph;
according to the adjacency matrix and the fusion feature matrix, determining core features and redundant features of each sample subgraph;
determining a target loss of the preset model based on the core feature and the redundancy feature;
And updating parameters of the preset model based on the target loss, and obtaining a trained generalization model when the preset model converges.
2. The generalized model based graph data prediction method according to claim 1, wherein the determining the core feature and the redundancy feature of each of the sample subgraphs according to the adjacency matrix and the fusion feature matrix includes:
Inputting any two sample nodes in the sample subgraph to a multi-layer perceptron based on a fusion feature matrix of the sample subgraph to obtain an edge mask value between any two sample nodes;
generating an edge mask matrix of the sample subgraph according to a plurality of edge mask values in the sample subgraph;
Generating core features corresponding to the sample subgraph based on the adjacency matrix, the fusion feature matrix and the edge mask matrix;
generating a corresponding complement matrix according to the edge mask matrix of the sample subgraph;
and obtaining the redundant features of the sample subgraph based on the adjacent matrix, the fusion feature matrix and the complement matrix.
3. The method for predicting map data based on a generalization model according to claim 1, wherein the obtaining sample map data for training a preset model comprises:
Acquiring basic graph data and a plurality of sub-patterns of different categories, and distributing the plurality of sub-patterns to any sample node of the basic graph data to obtain sample graph data;
determining a preset number of edges added to the sample graph data according to the number of nodes of the sample graph data, and adding edges to any two sample nodes in the sample graph data, wherein the number of edges added to the sample graph data is equal to the preset number.
4. The method for predicting graph data based on a generalization model according to claim 1, wherein the generating a corresponding sample subgraph for each sample node in the sample graph data according to a sample node connection structure of the sample graph data comprises:
Acquiring node density of the sample graph data, and determining a sample association level of each sample node based on the node density;
And aiming at each sample node serving as a central sample node, based on a sample node connection structure formed between the central sample node and the adjacent sample nodes, searching a sample subgraph of a target sample node connection structure of each central sample node in a sample association hierarchy.
5. The method for predicting map data based on a generalization model according to claim 1, further comprising, before obtaining the initial feature matrix and the adjacency matrix of each sample subgraph:
Acquiring a preset graph theory library;
acquiring preset characteristic indexes of the sample nodes, and generating node characteristic vectors of each sample node under the corresponding characteristic indexes through the graph theory library; the characteristic index comprises at least one of node identification, node degree, clustering coefficient, medium centrality and near centrality;
And generating a first feature matrix according to a plurality of node feature vectors corresponding to the plurality of sample nodes.
6. The method for predicting map data based on a generalization model according to claim 5, wherein said obtaining an initial feature matrix and an adjacency matrix of each sample subgraph comprises:
determining node identification of each sample node in each sample subgraph;
determining a node feature vector corresponding to each sample node from the first feature matrix based on the node identification;
Generating an initial feature matrix of the sample subgraph according to a plurality of node feature vectors corresponding to a plurality of sample nodes of the sample subgraph;
and generating an adjacent matrix corresponding to the feature matrix based on the connection relation between the sample nodes.
7. The method for predicting graph data based on a generalization model according to claim 5, wherein the sequentially inputting the initial feature matrix into a preset model for feature aggregation to obtain a fused feature matrix of each sample subgraph comprises:
determining a node degree matrix of each sample subgraph from the first feature matrix according to the node identification of each sample subgraph;
Acquiring an adjacent matrix of the sample subgraph, and normalizing the adjacent matrix through the node degree matrix to obtain a normalized adjacent matrix;
taking the central sample node as a fusion center, and carrying out layer-by-layer fusion of sample node characteristics based on the normalized adjacency matrix to obtain a central node characteristic vector after updating corresponding to the central sample node;
And updating the node characteristic vector of the central sample node of the initial characteristic matrix according to the central node characteristic vector to obtain a fusion characteristic matrix of each sample subgraph.
8. The generalized model based graph data prediction method according to claim 1, wherein the determining the target loss of the preset model based on the core feature and the redundancy feature includes:
determining a first penalty based on a difference between a predicted node class label output for a core feature of the sample subgraph relative to a sample node class label of the core feature;
Determining a second loss based on a preset model for predicting probability distribution uniformity when each sample node class label is predicted for the redundant features of the sample subgraph;
Determining a third loss based on a sample distance between the core feature and the redundant feature;
And constructing a target loss of the preset model based on the sum of the first loss, the second loss and the third loss.
9. A graph data prediction apparatus based on a generalization model, the apparatus comprising:
The system comprises an acquisition module, a storage module and a storage module, wherein the acquisition module is used for acquiring target graph data, and the target graph data comprises a plurality of target nodes and a node connection structure formed by the plurality of target nodes;
the searching module is used for searching a target subgraph of a target connection structure of each target node in a target association level based on a node connection structure formed by a plurality of target nodes in the target graph data;
The input module is used for inputting the target subgraph into the generalization model to obtain a node class label of a target node in the target subgraph; the generalization model carries out causal prediction learning between a sample node connection structure and a sample node class label in a sample sub-graph based on a training process of minimizing target loss, wherein the target loss is formed by a first loss, a second loss and a third loss, and the first loss is determined based on the difference between a predicted node class label output for core characteristics of the sample sub-graph and a sample node class label of the core characteristics; the second loss is determined based on probability distribution uniformity when the redundancy features of the sample subgraph are used for predicting the category labels of each sample node according to a preset model; the third loss is determined according to the sample distance between the core feature and the redundant feature, wherein the core feature is a feature obtained by enhancing and representing a sample node connection structure of a sample sub-graph according to an edge mask matrix obtained according to an attention mechanism; the redundant features are features obtained by weakening and representing a sample node connection structure of a sample sub-graph according to a complement matrix of the edge mask matrix; the sample nodes of the sample subgraph comprise a central sample node and a plurality of adjacent sample nodes; the generalization model is obtained by training in the following way: acquiring sample graph data for training a preset model, and generating a corresponding sample subgraph for each sample node in the sample graph data according to a sample node connection structure of the sample graph data; acquiring an initial feature matrix and an adjacent matrix of each sample sub-graph, and sequentially inputting the initial feature matrix into a preset model to perform feature aggregation to obtain a fusion feature matrix of each sample sub-graph; according to the adjacency matrix and the fusion feature matrix, determining core features and redundant features of each sample subgraph; determining a target loss of the preset model based on the core feature and the redundancy feature; and updating parameters of the preset model based on the target loss, and obtaining a trained generalization model when the preset model converges.
10. A computer device, characterized in that it comprises a memory storing a computer program and a processor implementing the generalization model based graph data prediction method according to any of claims 1 to 8 when said computer program is executed.
11. A computer-readable storage medium storing a computer program, wherein the computer program, when executed by a processor, implements the generalization model-based graph data prediction method of any one of claims 1 to 8.
CN202410649703.0A 2024-05-24 2024-05-24 Method, device, equipment and storage medium for predicting graph data based on generalization model Active CN118245638B (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202410649703.0A CN118245638B (en) 2024-05-24 2024-05-24 Method, device, equipment and storage medium for predicting graph data based on generalization model

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202410649703.0A CN118245638B (en) 2024-05-24 2024-05-24 Method, device, equipment and storage medium for predicting graph data based on generalization model

Publications (2)

Publication Number Publication Date
CN118245638A CN118245638A (en) 2024-06-25
CN118245638B true CN118245638B (en) 2024-08-27

Family

ID=91556756

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202410649703.0A Active CN118245638B (en) 2024-05-24 2024-05-24 Method, device, equipment and storage medium for predicting graph data based on generalization model

Country Status (1)

Country Link
CN (1) CN118245638B (en)

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2019001070A1 (en) * 2017-06-28 2019-01-03 浙江大学 Adjacency matrix-based connection information organization system, image feature extraction system, and image classification system and method
CN112231527A (en) * 2020-12-17 2021-01-15 北京百度网讯科技有限公司 Method and device for predicting label information of graph node and electronic equipment

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112214499B (en) * 2020-12-03 2021-03-19 腾讯科技(深圳)有限公司 Graph data processing method and device, computer equipment and storage medium
CN114389966B (en) * 2022-03-24 2022-06-21 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) Network traffic identification method and system based on graph neural network and stream space-time correlation

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2019001070A1 (en) * 2017-06-28 2019-01-03 浙江大学 Adjacency matrix-based connection information organization system, image feature extraction system, and image classification system and method
CN112231527A (en) * 2020-12-17 2021-01-15 北京百度网讯科技有限公司 Method and device for predicting label information of graph node and electronic equipment

Also Published As

Publication number Publication date
CN118245638A (en) 2024-06-25

Similar Documents

Publication Publication Date Title
Wu et al. AutoCTS: Automated correlated time series forecasting
CN111428587B (en) Crowd counting and density estimation method, device, storage medium and terminal
CN109543112A (en) A kind of sequence of recommendation method and device based on cyclic convolution neural network
CN111695046B (en) User portrait inference method and device based on spatio-temporal mobile data representation learning
CN117332033B (en) A method, device, computer equipment and storage medium for generating space-time trajectory
WO2024067373A1 (en) Data processing method and related apparatus
CN112612955B (en) Product push method and system based on deep learning
CN115564532A (en) Training method and device of sequence recommendation model
CN117077928A (en) Network appointment vehicle demand prediction method, device, equipment and storage medium
Qin et al. Memory attention enhanced graph convolution long short‐term memory network for traffic forecasting
CN116150511A (en) A Next Location Recommendation Method Based on Spatiotemporal Context and Category Preference
CN119128990A (en) Dynamic data adaptive desensitization method and device based on artificial intelligence
CN114338416B (en) Space-time multi-index prediction method and device and storage medium
Mann et al. Choice modelling with Gaussian processes in the social sciences: A case study of neighbourhood choice in Stockholm
CN118245638B (en) Method, device, equipment and storage medium for predicting graph data based on generalization model
Hao et al. Deep collaborative online learning resource recommendation based on attention mechanism
CN117217779A (en) Training method and device of prediction model and information prediction method and device
CN117609790A (en) Training behavior sequence prediction model, and method and device for predicting behavior sequence
CN117290599A (en) Method for generating prediction model, method for recommending article, device, equipment and medium
CN117474080A (en) An adversarial transfer learning method and device based on multiple discriminators
CN116720009A (en) Social robot detection method, device, equipment and storage medium
CN117011037A (en) Risk account identification method, apparatus, device, storage medium and program product
CN117010480A (en) Model training method, device, equipment, storage medium and program product
CN114708110A (en) Joint training method and device for continuous guarantee behavior prediction model and electronic equipment
CN114548382A (en) Migration training method, device, equipment, storage medium and program product

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant