CN118245638B - Method, device, equipment and storage medium for predicting graph data based on generalization model - Google Patents
Method, device, equipment and storage medium for predicting graph data based on generalization model Download PDFInfo
- Publication number
- CN118245638B CN118245638B CN202410649703.0A CN202410649703A CN118245638B CN 118245638 B CN118245638 B CN 118245638B CN 202410649703 A CN202410649703 A CN 202410649703A CN 118245638 B CN118245638 B CN 118245638B
- Authority
- CN
- China
- Prior art keywords
- sample
- node
- target
- matrix
- feature
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 86
- 238000003860 storage Methods 0.000 title claims abstract description 24
- 238000012549 training Methods 0.000 claims abstract description 58
- 230000001364 causal effect Effects 0.000 claims abstract description 40
- 230000008569 process Effects 0.000 claims abstract description 29
- 239000011159 matrix material Substances 0.000 claims description 253
- 239000013598 vector Substances 0.000 claims description 79
- 230000004927 fusion Effects 0.000 claims description 41
- 238000009826 distribution Methods 0.000 claims description 35
- 230000000295 complement effect Effects 0.000 claims description 23
- 230000007246 mechanism Effects 0.000 claims description 12
- 238000004590 computer program Methods 0.000 claims description 11
- 230000002708 enhancing effect Effects 0.000 claims description 9
- 230000003313 weakening effect Effects 0.000 claims description 9
- 230000002776 aggregation Effects 0.000 claims description 8
- 238000004220 aggregation Methods 0.000 claims description 8
- 238000012545 processing Methods 0.000 abstract description 12
- 230000002829 reductive effect Effects 0.000 description 24
- 238000013528 artificial neural network Methods 0.000 description 16
- 238000010586 diagram Methods 0.000 description 16
- 238000004891 communication Methods 0.000 description 9
- 230000006870 function Effects 0.000 description 8
- 241000282326 Felis catus Species 0.000 description 6
- 241000283973 Oryctolagus cuniculus Species 0.000 description 6
- 230000004913 activation Effects 0.000 description 5
- 238000010845 search algorithm Methods 0.000 description 5
- 241000009328 Perro Species 0.000 description 4
- 230000009286 beneficial effect Effects 0.000 description 4
- 238000012360 testing method Methods 0.000 description 4
- 238000009827 uniform distribution Methods 0.000 description 4
- 102100029212 Putative tetratricopeptide repeat protein 41 Human genes 0.000 description 3
- 238000012512 characterization method Methods 0.000 description 3
- 239000002131 composite material Substances 0.000 description 3
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 230000003993 interaction Effects 0.000 description 3
- 238000010606 normalization Methods 0.000 description 3
- 238000000547 structure data Methods 0.000 description 3
- 238000012935 Averaging Methods 0.000 description 2
- 101100001674 Emericella variicolor andI gene Proteins 0.000 description 2
- 241001465754 Metazoa Species 0.000 description 2
- 241000282898 Sus scrofa Species 0.000 description 2
- 238000004364 calculation method Methods 0.000 description 2
- 239000012141 concentrate Substances 0.000 description 2
- 238000013136 deep learning model Methods 0.000 description 2
- 230000006872 improvement Effects 0.000 description 2
- 230000000670 limiting effect Effects 0.000 description 2
- 230000000873 masking effect Effects 0.000 description 2
- 230000006916 protein interaction Effects 0.000 description 2
- 238000005070 sampling Methods 0.000 description 2
- 238000004458 analytical method Methods 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 238000004422 calculation algorithm Methods 0.000 description 1
- 238000006243 chemical reaction Methods 0.000 description 1
- 238000012790 confirmation Methods 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 239000010432 diamond Substances 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 230000002349 favourable effect Effects 0.000 description 1
- 238000007499 fusion processing Methods 0.000 description 1
- 238000013178 mathematical model Methods 0.000 description 1
- 238000010295 mobile communication Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003012 network analysis Methods 0.000 description 1
- 230000006855 networking Effects 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 108090000623 proteins and genes Proteins 0.000 description 1
- 102000004169 proteins and genes Human genes 0.000 description 1
- 239000007787 solid Substances 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 238000003756 stirring Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 230000009466 transformation Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/90—Details of database functions independent of the retrieved data types
- G06F16/901—Indexing; Data structures therefor; Storage structures
- G06F16/9024—Graphs; Linked lists
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/90—Details of database functions independent of the retrieved data types
- G06F16/903—Querying
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/90—Details of database functions independent of the retrieved data types
- G06F16/906—Clustering; Classification
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/042—Knowledge-based neural networks; Logical representations of neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Databases & Information Systems (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Software Systems (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Evolutionary Computation (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Life Sciences & Earth Sciences (AREA)
- Health & Medical Sciences (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
The embodiment of the application provides a graph data prediction method, device and equipment based on a generalization model and a storage medium, belonging to the technical field of graph data processing. The method comprises the following steps: obtaining target graph data, wherein the target graph data comprises a plurality of target nodes and corresponding node connection structures; searching a target subgraph of a target connection structure of each target node in a target association level based on a node connection structure formed by a plurality of target nodes in target graph data; inputting the target subgraph into the generalization model to obtain a node class label of a target node in the target subgraph; the generalization model carries out causal prediction learning between a sample node connection structure and a sample node class label in a sample subgraph based on a training process of minimizing target loss, wherein the target loss is formed by a first loss, a second loss and a third loss. The method and the device can improve the accuracy of graph data prediction.
Description
Technical Field
The present application relates to the field of graph data processing technologies, and in particular, to a method, an apparatus, a device, and a storage medium for predicting graph data based on a generalization model.
Background
The graph neural network (Graph Neural Networks, GNNs) is a deep learning model for processing graph structure data. In a graph neural network, features of nodes are updated and aggregated by considering relationships between nodes (i.e., the presence of edges) to build relationships between nodes step by step and learn the structural representation of the graph.
In the related art, in the process of training the graph neural network, key features are extracted from input graph data for predicting the graph data on the premise of maximizing mutual information between the graph data and the real labels. However, the training mode enables the graph neural network to learn all statistical correlations between the input features and the labels in the graph data as much as possible, does not distinguish between causal effects and non-causal effects between the input features and the labels, enables the graph neural network to tend to access the non-causal features as shortcuts for graph data prediction, and reduces accuracy of graph data prediction after the graph neural network is trained.
Disclosure of Invention
The embodiment of the application mainly aims to provide a graph data prediction method, device and equipment based on a generalization model and a storage medium, which can improve the accuracy of graph data prediction.
To achieve the above object, a first aspect of an embodiment of the present application provides a graph data prediction method based on a generalization model, where the method includes:
obtaining target graph data, wherein the target graph data comprises a plurality of target nodes and a node connection structure formed by the plurality of target nodes;
Searching a target subgraph of a target connection structure of each target node in a target association level based on a node connection structure formed by a plurality of target nodes in the target graph data;
Inputting the target subgraph into a generalization model to obtain a node class label of a target node in the target subgraph;
The generalization model carries out causal prediction learning between a sample node connection structure and a sample node class label in a sample sub-graph based on a training process of minimizing target loss, wherein the target loss is formed by first loss, second loss and third loss, and the first loss is determined according to a first prediction probability of core features of the sample sub-graph under the corresponding sample node class label; the second loss is determined according to a second prediction probability of the redundancy feature of the sample subgraph under the corresponding sample node class label; the third loss is determined according to the sample distance between the core feature and the redundant feature, wherein the core feature is a feature obtained by enhancing and representing a sample node connection structure of a sample sub-graph according to an edge mask matrix obtained according to an attention mechanism; the redundancy feature is a feature obtained by weakening and representing a sample node connection structure of the sample subgraph according to a complement matrix of the edge mask matrix.
Accordingly, a second aspect of an embodiment of the present application proposes a graph data prediction apparatus based on a generalization model, the apparatus comprising:
The system comprises an acquisition module, a storage module and a storage module, wherein the acquisition module is used for acquiring target graph data, and the target graph data comprises a plurality of target nodes and a node connection structure formed by the plurality of target nodes;
the searching module is used for searching a target subgraph of a target connection structure of each target node in a target association level based on a node connection structure formed by a plurality of target nodes in the target graph data;
The input module is used for inputting the target subgraph into the generalization model to obtain a node class label of a target node in the target subgraph; the generalization model carries out causal prediction learning between a sample node connection structure and a sample node class label in a sample sub-graph based on a training process of minimizing target loss, wherein the target loss is formed by first loss, second loss and third loss, and the first loss is determined according to a first prediction probability of core features of the sample sub-graph under the corresponding sample node class label; the second loss is determined according to a second prediction probability of the redundancy feature of the sample subgraph under the corresponding sample node class label; the third loss is determined according to the sample distance between the core feature and the redundant feature, wherein the core feature is a feature obtained by enhancing and representing a sample node connection structure of a sample sub-graph according to an edge mask matrix obtained according to an attention mechanism; the redundancy feature is a feature obtained by weakening and representing a sample node connection structure of the sample subgraph according to a complement matrix of the edge mask matrix.
In some embodiments, the sample nodes of the sample subgraph include a central sample node and a plurality of adjacent sample nodes; the map data prediction device based on the generalization model further comprises a training module for:
Acquiring sample graph data for training a preset model, and generating a corresponding sample subgraph for each sample node in the sample graph data according to a sample node connection structure of the sample graph data;
acquiring an initial feature matrix and an adjacent matrix of each sample sub-graph, and sequentially inputting the initial feature matrix into a preset model to perform feature aggregation to obtain a fusion feature matrix of each sample sub-graph;
according to the adjacency matrix and the fusion feature matrix, determining core features and redundant features of each sample subgraph;
determining a target loss of the preset model based on the core feature and the redundancy feature;
And updating parameters of the preset model based on the target loss, and obtaining a trained generalization model when the preset model converges.
In some embodiments, the training module is further to:
Inputting any two sample nodes in the sample subgraph to a multi-layer perceptron based on a fusion feature matrix of the sample subgraph to obtain an edge mask value between any two sample nodes;
generating an edge mask matrix of the sample subgraph according to a plurality of edge mask values in the sample subgraph;
Generating core features corresponding to the sample subgraph based on the adjacency matrix, the fusion feature matrix and the edge mask matrix;
generating a corresponding complement matrix according to the edge mask matrix of the sample subgraph;
and obtaining the redundant features of the sample subgraph based on the adjacent matrix, the fusion feature matrix and the complement matrix.
In some embodiments, the training module is further to:
Acquiring basic graph data and a plurality of sub-patterns of different categories, and distributing the plurality of sub-patterns to any sample node of the basic graph data to obtain sample graph data;
determining a preset number of edges added to the sample graph data according to the number of nodes of the sample graph data, and adding edges to any two sample nodes in the sample graph data, wherein the number of edges added to the sample graph data is equal to the preset number.
In some embodiments, the training module is further to:
Acquiring node density of the sample graph data, and determining a sample association level of each sample node based on the node density;
And aiming at each sample node serving as a central sample node, based on a sample node connection structure formed between the central sample node and the adjacent sample nodes, searching a sample subgraph of a target sample node connection structure of each central sample node in a sample association hierarchy.
In some embodiments, the map data prediction apparatus based on the generalization model further includes a generating module for:
Acquiring a preset graph theory library;
acquiring preset characteristic indexes of the sample nodes, and generating node characteristic vectors of each sample node under the corresponding characteristic indexes through the graph theory library; the characteristic index comprises at least one of node identification, node degree, clustering coefficient, medium centrality and near centrality;
And generating a first feature matrix according to a plurality of node feature vectors corresponding to the plurality of sample nodes.
In some embodiments, the training module is further to:
determining node identification of each sample node in each sample subgraph;
determining a node feature vector corresponding to each sample node from the first feature matrix based on the node identification;
Generating an initial feature matrix of the sample subgraph according to a plurality of node feature vectors corresponding to a plurality of sample nodes of the sample subgraph;
and generating an adjacent matrix corresponding to the feature matrix based on the connection relation between the sample nodes.
In some embodiments, the training module is further to:
determining a node degree matrix of each sample subgraph from the first feature matrix according to the node identification of each sample subgraph;
Acquiring an adjacent matrix of the sample subgraph, and normalizing the adjacent matrix through the node degree matrix to obtain a normalized adjacent matrix;
taking the central sample node as a fusion center, and carrying out layer-by-layer fusion of sample node characteristics based on the normalized adjacency matrix to obtain a central node characteristic vector after updating corresponding to the central sample node;
And updating the node characteristic vector of the central sample node of the initial characteristic matrix according to the central node characteristic vector to obtain a fusion characteristic matrix of each sample subgraph.
In some embodiments, the training module is further to:
determining a first penalty based on a difference between a predicted node class label output for a core feature of the sample subgraph relative to a sample node class label of the core feature;
Determining a second loss based on a preset model for predicting probability distribution uniformity when each sample node class label is predicted for the redundant features of the sample subgraph;
Determining a third loss based on a sample distance between the core feature and the redundant feature;
And constructing a target loss of the preset model based on the sum of the first loss, the second loss and the third loss.
Accordingly, a third aspect of the embodiments of the present application proposes a computer device, where the computer device includes a memory and a processor, where the memory stores a computer program, and where the processor implements the method for predicting graph data based on a generalization model according to any one of the embodiments of the first aspect of the present application when the processor executes the computer program.
Accordingly, a fourth aspect of the embodiments of the present application proposes a computer readable storage medium storing a computer program, which when executed by a processor implements a method for predicting graph data based on a generalization model according to any one of the embodiments of the first aspect of the present application.
The embodiment of the application obtains target graph data, wherein the target graph data comprises a plurality of target nodes and a node connection structure formed by the plurality of target nodes; searching a target subgraph of a target connection structure of each target node in a target association level based on a node connection structure formed by a plurality of target nodes in target graph data; inputting the target subgraph into the generalization model to obtain a node class label of a target node in the target subgraph; the generalization model carries out causal prediction learning between a sample node connection structure and a sample node class label in a sample sub-graph based on a training process of minimizing target loss, wherein the target loss is formed by first loss, second loss and third loss, and the first loss is determined based on the difference between a predicted node class label output for core characteristics of the sample sub-graph and a sample node class label of the core characteristics; the second loss is determined based on probability distribution uniformity when the redundancy features of the sample subgraphs are used for predicting the class labels of the nodes of each sample according to a preset model; the third loss is determined according to the sample distance between the core feature and the redundant feature, wherein the core feature is a feature obtained by enhancing and representing the sample node connection structure of the sample subgraph according to an edge mask matrix obtained by an attention mechanism; the redundant features are features obtained by weakening and representing the sample node connection structure of the sample subgraph according to the complement matrix of the edge mask matrix. Therefore, the preset model can be trained by constructing the target loss, the preset model tends to predict through the core features in the training process, the preference of the redundant features is reduced, and the sample distance between the core features and the redundant features is reduced, so that the trained generalization model has stronger causal interpretation capability and generalization capability. And moreover, the edge mask matrix obtained through the attention mechanism enhances the sample connection structure for representing the sample subgraph, and the complement matrix of the edge mask matrix weakens the sample node connection structure for representing the sample subgraph, so that the trained generalized model is facilitated to better capture the important relation between target nodes, the interference of redundant information is reduced, and the preset model is more focused on the learning of core features. In the process of applying the generalization model, a target subgraph of a target node to be predicted can be obtained based on target graph data, so that the generalization model does not need to predict the whole target graph data, only needs to obtain a relevant target subgraph of the target node to be predicted, and is beneficial to analyzing the causal relationship among the target nodes more intensively by the generalization model, and the prediction efficiency and accuracy are improved. In summary, the method and the device can improve the accuracy of the trained generalization model in predicting the graph data.
Drawings
FIG. 1 is a schematic diagram of a graph data prediction system based on a generalization model according to an embodiment of the present application;
FIG. 2 is a flowchart of a graph data prediction method based on a generalization model according to an embodiment of the present application;
FIG. 3 is a flowchart of the steps for training a preset model provided by an embodiment of the present application;
FIG. 4 is a general flow chart of a generalized model-based graph data prediction method provided by an embodiment of the present application;
FIG. 5 is a schematic diagram of a functional module of a graph data prediction device based on a generalization model according to an embodiment of the present application;
Fig. 6 is a schematic hardware structure of a computer device according to an embodiment of the present application.
Detailed Description
The present application will be described in further detail with reference to the drawings and examples, in order to make the objects, technical solutions and advantages of the present application more apparent. It should be understood that the specific embodiments described herein are for purposes of illustration only and are not intended to limit the scope of the application.
It should be noted that although functional block division is performed in a device diagram and a logic sequence is shown in a flowchart, in some cases, the steps shown or described may be performed in a different order than the block division in the device, or in the flowchart. The terms first, second and the like in the description and in the claims and in the above-described figures, are used for distinguishing between similar elements and not necessarily for describing a particular sequential or chronological order.
Unless defined otherwise, all technical and scientific terms used herein have the same meaning as commonly understood by one of ordinary skill in the art to which this application belongs. The terminology used herein is for the purpose of describing embodiments of the application only and is not intended to be limiting of the application.
The graph neural network (Graph Neural Networks, GNNs) is a deep learning model for processing graph structure data. In the graph neural network, the characteristics of the nodes are updated and aggregated by considering the relation (namely the existence of edges) among the nodes so as to gradually construct the relation among the nodes, learn the structural representation of the graph and realize the prediction of the target node.
In the related art, in the process of training the graph neural network, key features are extracted from input graph data for predicting the graph data on the premise of maximizing mutual information between the graph data and the real labels. However, the training mode enables the graph neural network to learn all statistical correlations between the input features and the labels in the graph data as much as possible, does not distinguish between causal effects and non-causal effects between the input features and the labels, enables the graph neural network to tend to access the non-causal features as shortcuts for graph data prediction, and reduces accuracy of graph data prediction after the graph neural network is trained.
Based on the above, the embodiment of the application provides a graph data prediction method, device and equipment based on a generalization model and a storage medium, which can improve the accuracy of graph data prediction.
The method, the device, the equipment and the storage medium for predicting the graph data based on the generalization model provided by the embodiment of the application are specifically described by the following embodiment, and the graph data prediction system based on the generalization model in the embodiment of the application is described first.
In some embodiments, referring to fig. 1, a graph data prediction system based on a generalization model includes a terminal 11 and a server 12.
Specifically, the terminal 11 may be a mobile terminal device or a non-mobile terminal device. The mobile terminal equipment can be a mobile phone, a tablet personal computer, a notebook computer, a palm computer, a wearable device, an ultra mobile personal computer, a netbook, a personal digital assistant, a wireless hot spot device and the like; the non-mobile terminal device may be a personal computer or the like, and embodiments of the present application are not particularly limited. The technician can directly interact with the server 12 through the terminal 11, send instructions for training the preset model to the server 12, set the training times of the preset model, and the like, and display results after the training to obtain the generalization model.
Further, the server 12 may include a cloud computing server, a data center server, and the like, and is responsible for storing related data of the generalization model, and training a preset model after receiving a model training instruction of the terminal 11, to obtain the generalization model. The server 12 has high computing, storage and network performance and is capable of supporting requests from a plurality of terminals 11.
The graph data prediction method based on the generalization model in the embodiment of the application can be illustrated by the following embodiment.
In the embodiments of the present application, when related processing is required according to user information, user behavior data, user history data, user location information, and other data related to user identity or characteristics, permission or consent of the user is obtained first. Moreover, the collection, use, processing, etc. of such data would comply with relevant laws and regulations. In addition, when the embodiment of the application needs to acquire the sensitive personal information of the user, the independent permission or independent consent of the user is acquired through popup or jump to a confirmation page and the like, and after the independent permission or independent consent of the user is definitely acquired, the necessary relevant data of the user for enabling the embodiment of the application to normally operate is acquired.
In the embodiment of the present application, description will be made from the dimension of the generalization model-based graph data prediction apparatus, which may be integrated in a computer device in particular. Referring to fig. 2, fig. 2 is a flowchart illustrating steps of a method for predicting map data based on a generalization model according to an embodiment of the present application, where, in an example in which a device for predicting map data based on a generalization model is specifically integrated on a terminal or a server, a specific flow is as follows when a processor on the terminal or the server executes a program instruction corresponding to the method for predicting map data based on a generalization model:
Step 101, obtaining target graph data, wherein the target graph data comprises a plurality of target nodes and a node connection structure formed by the plurality of target nodes.
In order to effectively classify and predict the target graph data, the target graph data may be obtained, so that the subsequent generalization model can understand the node connection structure formed by the target nodes of the target graph data, thereby classifying more accurately.
The target graph data may be graph structure data composed of a plurality of target nodes and corresponding edges, such as a social network, a user-item relationship graph in a recommendation system, a protein interaction network in bioinformatics, and the like.
Wherein the target nodes may be entities in the target graph data, such as users in a social network or items in a recommendation system, etc., each having its specific attribute or category labels.
The node connection structure may be a relationship or connection between target nodes, for example, in a social network, when a friend relationship exists between two users, an edge exists between the two users, and the two users and the edge form the node connection structure. When a user has purchased an item, an edge may be used to connect the user node to the item node, indicating the user's purchase of the item, the user, edge, and item may form a node connection structure, and so on.
For example, the target map data to be identified may be obtained first, and the target map data may be preset or may be obtained in real time. For example, a data source of target graph data to be acquired may be determined, for example, when the target graph data is required to be acquired from a social network, a web crawler technology may be used to collect relationships between user data and users on a social media platform; when it is necessary to acquire bioinformatically related target map data, network data of protein interactions and the like may be acquired from a common bioinformatic data platform. Further, after the raw data is collected, the raw data may be preprocessed, including data cleansing, data formatting, and the like.
By acquiring the target graph data in the above manner, the target graph data can be conveniently classified by the generalization model later.
Step 102, based on a node connection structure formed by a plurality of target nodes in the target graph data, searching out a target subgraph of a target connection structure of each target node in a target association level.
It can be understood that when classifying the target nodes to be predicted, if the whole target graph data is input into the generalization model, a great amount of computing resources are consumed, and the information related to the classification task of the target nodes to be predicted can be more accurately captured by searching the target subgraph of the target connection structure of each target node in the target association hierarchy, so that the classification accuracy is improved and the computing resources are saved.
The target association hierarchy may be a hierarchy determined by performing K-hop sampling in the target graph data through breadth-first search, where the K-hop sampling may help the generalization model determine local area information of the corresponding target node in the target graph data. For example, when the target node to be predicted needs to be predicted, the number of steps of the target node from the center can be determined through the density of the target graph data, and the determined number of steps is used as a target association level. Further, when the distribution of the target nodes in the target graph data is sparse, a larger target association level can be determined, and when the distribution of the target nodes in the target graph data is dense, a smaller target association level can be determined. Illustratively, the target association hierarchy may be level 2, level 3, etc., as embodiments of the application are not particularly limited.
The target connection structure is a node connection structure formed by a target node to be predicted and other connected target nodes in a target association level.
The target subgraph may be a local graph structure obtained from a target node to be predicted by a breadth-first search algorithm in target graph data. Each target subgraph includes a target node to be predicted as a central node, and all other target nodes connected to the central node within a target association hierarchy.
Specifically, the target node to be predicted can be used as a central starting point, other target nodes related to the target node to be predicted are searched in the target association hierarchy through breadth-first search, and a target subgraph corresponding to the target node to be predicted is generated through a node connection structure of the target node to be predicted and the other target nodes.
For example, when the target graph data is a social network graph, a new user node needs to be classified, and assuming that the target node to be predicted is a user node a, starting from the node a, other user nodes directly connected to the node a are determined in the target graph data using a breadth-first search algorithm. If the target association level is determined to be 3 according to the node density of the target graph data, the node connection structure of the social network graph can be obtained, other user nodes directly connected with the node a are searched in the 3 target association levels taking the user node a as a starting point, and if the user node B, C, D directly connected with the user node a exists, the node connection structure formed by the user node a, the user node B, the user node C and the user node D can be used as a target connection structure, and the target connection structure can be used as a target subgraph.
By the method, the information related to the classification task of the target node can be more accurately captured by searching the target subgraph corresponding to the target node to be predicted, and the method is also beneficial to reducing the calculated amount of the subsequent generalization model in the prediction of the target subgraph.
Step 103, inputting the target subgraph into the generalization model to obtain node class labels of target nodes in the target subgraph;
The generalization model carries out causal prediction learning between a sample node connection structure and a sample node class label in a sample sub-graph based on a training process of minimizing target loss, wherein the target loss is formed by first loss, second loss and third loss, and the first loss is determined based on the difference between a predicted node class label output for core characteristics of the sample sub-graph and a sample node class label of the core characteristics; the second loss is determined based on probability distribution uniformity when the redundancy features of the sample subgraphs are used for predicting the class labels of the nodes of each sample according to a preset model; the third loss is determined according to the sample distance between the core feature and the redundant feature, wherein the core feature is a feature obtained by enhancing and representing the sample node connection structure of the sample subgraph according to an edge mask matrix obtained by an attention mechanism; the redundant features are features obtained by weakening and representing the sample node connection structure of the sample subgraph according to the complement matrix of the edge mask matrix.
In some embodiments, in order to improve accuracy of target node prediction in the target graph data, the preset model may be trained in advance, so that the preset model continuously focuses on real causal relationships between causal feature learning sample nodes and real labels, and focus on shortcut features (i.e., redundant features) is reduced, so that the trained generalization model has higher generalization performance.
The generalization model may be a model with good generalization capability for predicting node class labels of the target nodes in the target subgraph, or may be a model for predicting classes of the whole target graph data. The generalization model is obtained by training a preset model in advance, and can keep good generalization performance when facing new tasks and data. By way of example, the generalization model may be a graph neural network (Graph Neural Networks, GNNs), a graph annotation network (Graph Attention Networks, GATs), a graph convolution network (Graph Convolutional Networks, GCNs), or the like, suitable for processing target graph data and making predictions of target nodes.
The node class label may be a specific class to which each target node belongs. For example, in a social network, a target node may represent a user, and a node category label may represent an interest or group of the user. In some embodiments, when the entire target graph data is to be predicted, the node class label represents the class of the entire target graph data, e.g., in bioinformatics, the target graph data may represent an interaction network of proteins, while the node class label may represent a functional type of this network, and so on.
The sample subgraph can be a local sample graph structure obtained by a breadth-first search algorithm from sample nodes to be predicted in sample graph data. Each sample subgraph includes a sample node to be predicted as a sample center node, and all other sample nodes connected to the sample center node within a sample association hierarchy.
The sample node connection structure may be a relationship or connection between sample nodes, for example, in a social network, when a friend relationship exists between two users, an edge exists between the two users, and the two users and the friend relationship form the node connection structure. When a user has purchased an item, an edge may be used to connect the user node to the item node, indicating the user's purchase of the item, the user, edge, and item may form a node connection structure, and so on.
The target loss may be a loss composed of a first loss, a second loss and a third loss, and is used for indicating adjustment of parameters of the preset model so as to train the preset model and obtain the generalization model. Specifically, the first loss may be a cross entropy loss, the second loss may be an entropy loss, the third loss may be a contrast loss, and the first loss, the second loss and the third loss form a target loss, and model parameters are updated by using the target loss, so that the generalization model can learn core features more intensively, and the generalization model is prevented from paying excessive attention to redundant features, thereby improving the quality of overall feature characterization.
The first loss may be a cross entropy loss, and is used for evaluating a difference between a probability distribution of each prediction node class label of the core feature prediction by the preset model and a sample node class label of the core feature. Specifically, a code may be set for each sample node class label, each sample unique node class label is represented by a binary vector, and only the real sample node class label corresponding to the core feature is set to be 1, and the other sample node class labels are set to be 0. Specifically, the first lossThe method comprises the following steps:
wherein, Is a one-hot encoding of the sample node class labels,Is the first predicted probability of the model for the v-th sample node class label. When the predicted node class label is consistent with the real sample node class label corresponding to the core feature, the first loss is minimum, so that the preset model can learn the causal relationship of the data better by minimizing the first loss, and more accurate prediction can be continuously made.
The predicted node class labels are a plurality of node class labels obtained by predicting core features by a preset model. For example, for an animal image, the predicted node class labels may be cat, dog, rabbit, each of which corresponds to a predicted probability.
The second loss may be entropy loss, which is used for measuring probability distribution uniformity when the preset model predicts the class labels of the nodes of each sample aiming at the redundant features of the sample subgraph, and the more uniform the probability distribution, the smaller the classification preference of the preset model for the redundant features is indicated, so that the preset model can be effectively prevented from learning shortcut features. In particular, the second lossThe method comprises the following steps:
wherein, Representing a uniform probability of classification labels for each sample node when uniformly distributed, in general=1/Number of categories,Is a second prediction probability of the preset model to the v sample node class label.
Specifically, when the redundant features are input into the preset model, the output second prediction probabilities are close to uniform distribution, which indicates that the redundant features do not help to classify the preset model. Exemplary, if the sample class labels are 4, cat, rabbit, dog, pig, respectively, thenIf the redundancy feature is input to the preset model, and the prediction probability of the cat, the rabbit, the dog and the pig is 1/4, the second loss of the preset model is minimum, at this time, the preset model has no classification preference on the redundancy feature, and the output probability distribution uniformity is high. Therefore, the preset model needs to be continuously trained, so that the prediction probability of the preset model for the redundant features on the class labels of all sample nodes is close to uniform distribution.
The probability distribution uniformity may be a second preset probability under each sample node class label of the redundant feature prediction output by the preset model after the redundant feature is input to the preset model. The more uniform the respective second preset probability distribution, the higher the probability distribution uniformity, for example, the probability distribution uniformity for the second preset probabilities of 1/3, 1/3 is greater than 2/3, 1/6.
The third loss may be a contrast loss, specifically, the core feature may be used as a positive sample, the redundant feature may be used as a negative sample, and the contrast loss may excite the preset model to be able to correctly distinguish the positive sample and the negative sample, so that the distance between the positive sample and the positive sample is closer, and the distance between the positive sample and the negative sample is further. Further, a third lossThe method comprises the following steps:
wherein, Is a sigmoid function, typically used to map the input between (0, 1), Q is a superparameter, C andIs a different example of a core feature, representing positive sample pairs, C andAll represent positive samples, i.e. core features; c and CRepresenting the positive and negative pairs of samples,Representing a negative sample, i.e. a redundant feature; The representation of the expected value may be obtained by averaging a first predictive probability of the core feature for all positive sample pairs and a plurality of second predictive probability numbers of the redundant features for all negative sample pairs by a predetermined model. Through the third loss training preset model, the preset model can be excited to correctly distinguish the positive sample and the negative sample, so that the distance between the positive sample and the positive sample is closer, and the distance between the positive sample and the negative sample is further.
The sample distance may be a distance between two core features or a distance between a core feature and a redundant feature, and the sample distance may be calculated by means of euclidean distance or cosine similarity.
The core features can be features with causal relation or causal interpretation capability in the sample subgraph, and the core features have causal information of sample nodes to be predicted and sample node category labels.
The edge mask matrix may be a matrix formed by passing node feature vectors of any two sample nodes in the sample subgraph through edge mask values of the multi-layer perceptron and edge mask values of all sample nodes in the sample subgraph.
The redundancy feature may be a shortcut feature between the core feature and the prediction result, and the redundancy feature does not have an additional contribution to predicting the sample node and the sample node class label, and does not have a causal interpretation characteristic. Moreover, the redundant features can enable the preset model to tend to learn shortcut features to make decisions, so that accuracy of the preset model in predicting sample graph data is reduced, and performance of the generalization model obtained through training in out-of-distribution test data is reduced, namely the generalization performance is reduced. Illustratively, the redundancy feature may be a noise feature.
Where complement matrix = 1-edge mask matrix, the complement matrix may be used to represent redundant edges in the sample subgraph.
It can be understood that the target subgraph is input into the trained generalization model, the generalization model can rapidly extract core features in the target subgraph, which are favorable for classifying target nodes to be predicted, and predict the target nodes to be predicted by combining the core features, so that the efficiency and the accuracy of predicting the target nodes are improved.
For example, when the interest class of the user a in the social network diagram needs to be predicted, the target connection structure corresponding to the user a may be determined through the node connection structure of the target diagram data as the target sub-graph of the user a, for example, the target sub-graph may be a target sub-graph including the user a and friends thereof, and then the target sub-graph is input into the generalization model, so as to obtain the specific interest class of the user a. In this process, the generalization model can extract useful information from the social network structure of user a and make accurate classification decisions accordingly.
The embodiment of the application obtains target graph data, wherein the target graph data comprises a plurality of target nodes and a node connection structure formed by the plurality of target nodes; searching a target subgraph of a target connection structure of each target node in a target association level based on a node connection structure formed by a plurality of target nodes in target graph data; inputting the target subgraph into the generalization model to obtain a node class label of a target node in the target subgraph; the generalization model carries out causal prediction learning between a sample node connection structure and a sample node class label in a sample sub-graph based on a training process of minimizing target loss, wherein the target loss is formed by first loss, second loss and third loss, and the first loss is determined based on the difference between a predicted node class label output for core characteristics of the sample sub-graph and a sample node class label of the core characteristics; the second loss is determined based on probability distribution uniformity when the redundancy features of the sample subgraphs are used for predicting the class labels of the nodes of each sample according to a preset model; the third loss is determined according to the sample distance between the core feature and the redundant feature, wherein the core feature is a feature obtained by enhancing and representing the sample node connection structure of the sample subgraph according to an edge mask matrix obtained by an attention mechanism; the redundant features are features obtained by weakening and representing the sample node connection structure of the sample subgraph according to the complement matrix of the edge mask matrix. Therefore, the preset model can be trained by constructing the target loss, the preset model tends to predict through the core features in the training process, the preference of the redundant features is reduced, and the sample distance between the core features and the redundant features is reduced, so that the trained generalization model has stronger causal interpretation capability and generalization capability. And moreover, the edge mask matrix obtained through the attention mechanism enhances the sample connection structure for representing the sample subgraph, and the complement matrix of the edge mask matrix weakens the sample node connection structure for representing the sample subgraph, so that the trained generalized model is facilitated to better capture the important relation between target nodes, the interference of redundant information is reduced, and the preset model is more focused on the learning of core features. In the process of applying the generalization model, a target subgraph of a target node to be predicted can be obtained based on target graph data, so that the generalization model does not need to predict the whole target graph data, only needs to obtain a relevant target subgraph of the target node to be predicted, and is beneficial to analyzing the causal relationship among the target nodes more intensively by the generalization model, and the prediction efficiency and accuracy are improved. In summary, the method and the device can improve the accuracy of the trained generalization model in predicting the graph data.
Referring to fig. 3, in some embodiments, sample nodes of a sample sub-graph may include a center sample node and neighboring sample nodes adjacent to the center sample node. In order to improve accuracy of the generalization model in predicting the target graph data, the preset model may be trained to enable the generalization model to learn causal features (i.e., core features) and avoid learning shortcut features (redundant features), and specifically, the generalization model is obtained by training in steps 201 to 205:
Step 201, obtaining sample graph data for training a preset model, and generating a corresponding sample subgraph for each sample node in the sample graph data according to a sample node connection structure of the sample graph data.
The sample graph data can be a synthetic graph generated in a random mode, different types of sample nodes can be obtained by randomly adding graphs with different shapes into each node of the base graph, and edges are added to any two sample nodes, so that a sample node connection structure formed by the sample nodes and the edges can be used as sample graph data. In some embodiments, existing graph data may also be obtained as sample graph data.
The sample subgraph can be a local sample graph structure obtained by a breadth-first search algorithm from sample nodes to be predicted in sample graph data, so that the preset model has better capability of processing local features, and each computing node can train the preset model although independently processing the sample subgraph, and the training efficiency of the preset model is improved and the performance of the preset model is optimized by synchronously updating and sharing parameters of the preset model. Each sample subgraph includes a sample node to be predicted as a sample center node, and all other sample nodes connected to the sample center node within a sample association hierarchy.
For example, assume there is a social networking graph, where each sample node represents a user and each edge represents a relationship (e.g., a friendship) between two users. When a sample subgraph of a user A needs to be generated according to the social network graph, other users connected with the user A can be determined through breadth-first search, and the generated sample subgraph comprises users A, B and C on the premise that the user A is directly connected with users B and C, wherein the user A serves as a central node of the sample subgraph. Furthermore, the range of the breadth-first search can be set to avoid overlarge sample subgraphs, and the preset model cannot concentrate on core features.
By the method, the sample subgraph of each sample node can be generated, so that the preset model is trained, the classification capacity of the preset model to the sample node to be predicted is improved, and the prediction accuracy is improved.
It can be understood that, in order to improve the generalization capability of the generalization model, sample graph data can be randomly generated, and a synthetic graph with a diversified structure can be generated according to the mode of the randomly generated sample graph data, so that the preset model can more comprehensively understand and predict different types of graph structures, and the robustness of the trained generalization model is improved. For example, "acquire sample map data for training a preset model" in step 201 includes:
(201.a1) acquiring basic diagram data and a plurality of sub-patterns of different categories, and distributing the plurality of sub-patterns to any sample node of the basic diagram data to obtain sample diagram data;
(201. A2) determining a preset number of edges to which the sample graph data is added for the number of nodes of the sample graph data, and adding edges to any two sample nodes in the sample graph data, wherein the number of edges to be added to the sample graph data is equal to the preset number.
Wherein the base graph data may be a scaleless network model graph, which is a mathematical model for modeling real world network structures, which may be built from a single node, which may be added to the network at a constant rate over time.
The sub-patterns may be patterns of any shape, such as houses, circles, grids, and other patterns, and the corresponding nodes may be given a specific category by adding different sub-patterns to the nodes of the base graph data.
Wherein an edge is an element that connects each node in the base graph data.
The preset number may be determined according to the number of nodes of the sample graph data, for example, the preset number may be 10%, 15% or the like of the number of nodes, or may be a specific number, for example, 20, 25 or the like.
For example, a composite graph composed of a base graph and four shapes can be generated, the network structure of the composite graph is thatWhereinIs a collection of nodes that are configured to be connected,Is a collection of edges. Specifically, the composite graph may be formed by combining a basic graph non-scale network model graph and a plurality of sub-patterns of four shapes, which may be houses, circles, diamonds, grids or other shapes. Then, the sub patterns with four shapes are randomly attached to one node of the basic graph, and 10% of random edges are added to further disturb the basic graph, so that sample graph data are generated.
Further, for all sample nodes in the sample graph data, a target number of sample nodes can be selected by a random selection mode, and label masking is performed, for example, 15% of the number of sample nodes are selected for label masking.
By the method, sample graph data of various structures can be generated, so that a preset model can more comprehensively understand and predict different types of graph structures, the generalization performance of the generalization model in a real scene is improved, and good performance of the generalization model can be still maintained in the face of a certain degree of noise and incomplete information by adding random edges to a synthetic graph to stir and randomly selecting sample nodes to mask labels to obtain the sample graph data.
In some embodiments, in order to effectively control the computational complexity, so that the training process of the preset model is more efficient, for each sample node, a sample sub-graph corresponding to the sample node may be generated, so that the sample sub-graph is trained by the subsequent preset model, and the generalization capability of the preset model is improved. For example, "generate a corresponding sample subgraph for each sample node in the sample graph data according to the sample node connection structure of the sample graph data" in 101.1, includes:
(201. B1) obtaining node densities of the sample graph data and determining a sample association hierarchy for each sample node based on the node densities;
(201. B2) for each sample node as a central sample node, finding out a sample subgraph of a target sample node connection structure of each central sample node in the sample association hierarchy based on a sample node connection structure formed between the central sample node and the adjacent sample nodes.
The node density may be the degree of compactness between each sample node in the sample graph data. The node density may be calculated by a ratio of the number of sample nodes to the area of the sample graph data, or may be calculated by other methods, which are not limited herein.
When the density of the sample association levels is high, which means that the connection between the sample nodes is tight, the sample association levels can be the number of levels searched by a breadth-first search algorithm from the sample nodes to be predicted. When the node density of the sample graph data is high, a small sample association level can be selected to limit the searching range, reduce the calculation amount and better capture the interrelationship among the sample nodes. In contrast, when the node density of the sample graph data is smaller, which means that the connection between the sample nodes is relatively sparse, in order to fully consider the relevance between the sample nodes in the graph, a larger sample relevance level needs to be selected, and the potential relation between the sample nodes can be better captured by expanding the searching range. Thus, the sample association hierarchy may also be a sample search step.
The adjacent sample nodes can be sample nodes which are determined to be directly connected with the sample nodes to be predicted through connectivity or are connected with each other through intermediate nodes in the sample graph data.
By the method, the proper sample subgraph can be selected from the sample graph data efficiently, and on the premise of ensuring that the selected sample subgraph can be representative, the computing resource is saved as much as possible.
It should be noted that, the first feature matrix of the sample graph data may be generated, so that when the sample subgraph needs to be trained subsequently, the first feature matrix may be directly used to search for the initial feature matrix of the sample subgraph, so before the initial feature matrix and the adjacent matrix of each sample subgraph are obtained, that is, before step 202, the method further includes:
(A1) Acquiring a preset graph theory library;
(A2) Acquiring preset characteristic indexes of sample nodes, and generating node characteristic vectors of each sample node under the corresponding characteristic indexes through a graph theory library; the characteristic index comprises at least one of node identification, node degree, clustering coefficient, medium centrality and near centrality;
(A3) And generating a first feature matrix according to a plurality of node feature vectors corresponding to the plurality of sample nodes.
The graph theory library can be a plurality of open source graph theory libraries, such as NetworkX library, igraph library, SNAP library and the like. Taking NetworkX libraries as an example, networkX library is an open-source Python package for creating and manipulating the structure, dynamics and functions of complex networks. The NetworkX library provides a comprehensive set of tools for graph theory and network analysis, suitable for studying various types of networks, including social networks, bioinformatics networks, transportation networks, and the like. The NetworkX library provides rich functions and methods to study the structural characteristics of these graphs, including node identification, searching and traversing the graphs, computing node degrees, clustering coefficients, betting centrality, near centrality, and so on.
The preset characteristic index can be at least one of node identification, node degree, clustering coefficient, medium centrality and near centrality. It will be appreciated that the node identification may also be generated at the time of generating the sample graph data. Specifically, the node degree may be the number of edges connecting the sample node and the adjacent sample node, the clustering coefficient may be an index for measuring the connection tightness degree between the current sample node and the adjacent sample node, the median centrality may be used for measuring the importance degree of the current sample node connecting other sample nodes in the sample data graph, and the proximity centrality measures the distance degree between the current sample node and other sample nodes.
Furthermore, the sample association level of each sample node can be determined through the clustering coefficient, the sample node with high clustering coefficient can select a smaller sample association level, and the sample node with low clustering coefficient can select a larger sample association level.
The node feature vector may be a node feature vector generated in a graph theory library for each preset feature index. The node feature vector may include feature values of feature indicators such as node identification, node degree, cluster coefficient, median centrality, and near centrality.
The first feature matrix may be composed of node feature vectors of a preset feature index.
For example, if the clustering coefficients, the median centrality, the near centrality, and a node feature vector for each sample node are to be randomly characterized, then the node feature vector for sample node 1 may be [0.75,0.33,0.5,0.42], with each number representing a value corresponding to a particular feature index. Further, when the feature index further includes a node identifier and a node degree, the node feature vector of the sample node 1 may be [1,2,0.75,0.33,0.5,0.42], and the node feature vector of the sample node 2 may be [2,4,0.85,0.54,0.6,0.59]. Further, when the sample graph data includes the sample node 1 and the sample node 2, the first feature matrix is formed by the node feature vectors of the sample node 1 and the sample node 2, each row in the first feature matrix represents one sample node, and each column represents the feature value of one feature index.
By the method, the first feature matrix of the sample graph data can be generated, so that the initial feature matrix of each sample sub-graph is generated according to the first feature matrix conveniently, the feature matrix does not need to be generated for each sample sub-graph independently, and the data processing efficiency is improved.
Step 202, obtaining an initial feature matrix and an adjacent matrix of each sample sub-graph, and sequentially inputting the initial feature matrix into a preset model for feature aggregation to obtain a fusion feature matrix of each sample sub-graph.
In some embodiments, the feature expression capability of the center node can be enhanced by fusing the information of the center sample node and the adjacent sample nodes of the center sample node vector in the sample subgraph, so that the center node can capture more abundant local structure and context information, and the core feature can be conveniently extracted later.
Wherein the initial feature matrix may be a feature matrix of each sample sub-graph, the initial feature matrix comprising vectors describing node features of each sample node in the sample sub-graph. Specifically, the initial feature matrix may search a corresponding node feature vector from the first feature matrix corresponding to the sample graph data according to the node identifier of each sample node, and generate the initial feature matrix according to the node feature vectors of the plurality of sample nodes.
The adjacency matrix can be a matrix for representing the connection relation of each sample node in the sample subgraph, and the adjacency matrix is a two-dimensional matrix, wherein elements represent whether edges exist between the sample nodes. For example, for a sample subgraph with n nodes, the adjacency matrix is an n×n matrix, where the elements are used to record the connection between the sample nodes: elements in the adjacency matrix may be represented by 0 and 1, where 1 represents the presence of an edge and 0 represents the absence of an edge.
The preset model can be a trained neural network model, and the generalization model can be obtained after the preset model is trained by training the preset model and updating model parameters.
The fusion feature matrix may be a feature matrix obtained by updating a feature vector of a central sample node in the initial feature matrix. The fusion feature matrix can be obtained by carrying out feature conversion and aggregation on the information of each sample node and the adjacent sample nodes by a preset model.
By the method, the important features of each sample node can be fused for each sample sub-graph to obtain the fused feature matrix, so that the prediction accuracy of the preset model on the sample sub-graph is higher, and the follow-up classification task is facilitated.
In some embodiments, the node feature vector of each sample sub-graph may be searched in the first feature matrix of the sample sub-graph data by using the node identification of each sample node in the sample sub-graph, and an initial feature matrix of each sample sub-graph is generated, so as to improve the generation efficiency of the initial feature matrix, and the exemplary "obtain initial feature matrix and adjacency matrix of each sample sub-graph" in (101.2) includes:
(202.a1) in each sample subgraph, determining a node identity for each sample node;
(202.a2) determining a node feature vector corresponding to each sample node from the first feature matrix based on the node identification;
(202.a3) generating an initial feature matrix of the sample subgraph according to a plurality of node feature vectors corresponding to a plurality of sample nodes of the sample subgraph;
(202.a4) generating an adjacency matrix corresponding to the feature matrix based on the connection relation between the sample nodes.
Wherein the node identification may be a unique identification or index value for each sample node in each sample graph data and sample subgraph. By means of the node identification, the node feature vector of the corresponding sample subgraph can be rapidly located in the first feature matrix of the sample graph data.
Illustratively, it is assumed that a sample subgraph includes a plurality of sample nodes, each having a unique user identification, such as user123, user456, and so on. Then, according to the user123 and the user456, node feature vectors of corresponding sample nodes, such as [ user123,2,0.75,0.33,0.5,0.42], [ user456,4,0.85,0.54,0.6,0.59], [ user789,3,0.66,0.24,0.6,0.5], and the like, are searched from the sample graph data, that is, the node feature vector of the user123 is [ user123,2,0.75,0.33,0.5,0.42], the node feature vector of the user456 is [ user456,4,0.85,0.54,0.6,0.59], and an initial feature matrix of the sample subgraph is formed according to the [ user123,2,0.75,0.33,0.5,0.42] and the [ user456,4,0.85,0.54,0.6,0.59 ].
For example, when there is a connection relationship between sample nodes, i.e., when there is an edge, a corresponding adjacency matrix may be generated. For example, when the sample subgraph is a social network graph, two users are friends, then there is a connection relationship between the two users, and the corresponding element of the two users in the adjacency matrix is 1, otherwise, is 0.
It can be understood that the original feature matrix is obtained by directly utilizing the existing first feature matrix, so that the original feature matrix in the sample subgraph can be prevented from being recalculated, the efficiency and the accuracy of generating the original feature matrix are obviously improved, and the core feature and the redundant feature can be obtained later.
It should be noted that, the central sample node is used as the core of the sample sub-graph, that is, the classification object of the sample sub-graph, and the quality of the features directly affects the analysis of the sample sub-graph by the preset model, so that feature fusion can be performed on the sample sub-graph, so that the central sample node can absorb and integrate the information of the connected adjacent sample nodes, thereby obtaining a more comprehensive and more accurate feature representation of the central sample node, for example, "sequentially inputting the initial feature matrix into the preset model to perform feature aggregation, to obtain the fused feature matrix of each sample sub-graph" in step 202 includes:
(202. B1) determining a node degree matrix of the sample subgraph from the first feature matrix according to the node identification of each sample subgraph;
(202. B2) obtaining an adjacent matrix of the sample subgraph, and normalizing the adjacent matrix through the node degree matrix to obtain a normalized adjacent matrix;
(202. B3) taking the central sample node as a fusion center, and carrying out layer-by-layer fusion of sample node characteristics based on the normalized adjacency matrix to obtain a central node characteristic vector after updating corresponding to the central sample node;
(202. B4) updating the node feature vector of the central sample node of the initial feature matrix according to the central node feature vector to obtain a fusion feature matrix of each sample subgraph.
The node degree matrix may be a diagonal matrix of the sample subgraph, and each element on a diagonal line of the node degree matrix represents the node degree of a corresponding node in the graph. The node degree of each sample node can be obtained from the first feature matrix through the node identification of each sample node in the sample subgraph. Illustratively, when the node feature vector of each sample node is composed of the node identifier, the node degree, the cluster coefficient, the median centrality, the near centrality and a random feature, such as [ user123,2,0.75,0.33,0.5,0.42], when the node identifier of the sample node is user123, the node degree of the sample node can be found to be 2 by finding the corresponding node feature vector from the sample graph data to be [ user123,2,0.75,0.33,0.5,0.42 ]. The node degree matrix of the sample subgraph can be generated through a plurality of node degrees corresponding to a plurality of sample nodes in the sample subgraph.
The central node feature vector may be a node feature vector of a central sample node in each sample sub-graph, the central sample node is a classification object of the sample sub-graph, and the preset model may classify the central sample node based on the central node feature vector.
Specifically, the fusion process of the central node feature vector may include adjacency matrix normalization, feature aggregation, linear transformation and nonlinear activation, multi-layer iteration and readout. For example, a normalized laplace matrix may be calculated first to reduce the influence of a sample node with a large degree in feature propagation, implement feature normalization, and maintain feature scale stability in the propagation process, then, in each layer, the sample node feature vector is multiplied by a normalized adjacency matrix, so that each sample node can receive information of its neighboring sample nodes, and then, the aggregated features are linearly transformed by a weight matrix, and an activation function is applied to update the representation of the sample nodes. Further, through multiple layers of iterations, the central sample node gradually fuses the information of the more distant neighboring sample nodes. In each layer, the representation of the center node is updated based on the node feature vectors of its neighboring sample nodes. After the iteration of all layers is completed, the node feature vector feature of the last layer of the central sample node can be directly used as the central node feature vector of the central sample node, or the node feature vectors of all adjacent sample nodes can be pooled to obtain the representation of the whole sample subgraph, and the representation is used as the central node feature vector of the central sample node.
Illustratively, the subgraph information may be fused by way of a graph convolutional neural network (Graph Convolutional Networks, GCN), so as to obtain a central node feature vector of the central sample node. The specific formula is as follows:
wherein, A node feature vector representing the sample node v of the l+1 layer; g (·) represents an activation function, such as ReLU or tanh, for introducing nonlinearity; Representing normalized node degree matrix Is the inverse square root of (a),Is a matrix of degrees of nodes and,Each element of (2)Representing the node degree of the sample node i; a represents an adjacency matrix of the sample subgraph and is used for representing connection relations among all sample nodes in the sample subgraph; a node feature vector representing a sample node v of the first layer; w represents a weight matrix of the first layer, and W is a model parameter obtained after training and learning of a preset model.
For each sample subgraph, the process of fusing subgraph information includes normalizing the adjacency matrix by the node degree matrix, i.eThe influence of the sample nodes with large node degree in the feature propagation is reduced, the normalization of the features is realized, and the stability of the feature scale in the propagation process is maintained.
Further, the node characteristic vector of the sample node v at the first layer can be obtainedAnd the normalized adjacency matrix such that in layer l each sample node is able to receive information of its neighboring sample nodes. Further, the node characteristic vector of the sample node v at the first layer can be obtainedMultiplying the normalized adjacency matrix, multiplying the normalized adjacency matrix by a weight matrix W and applying an activation function g (·) to convert the node feature adjacency of each current sample node from the original space to a new feature space, and introducing nonlinear activation. Thus, the feature vector after the sample subgraph update can be obtained。
Further, after the iteration of all layers is completed, the node feature vector of the last layer of the central sample node can be directly used as the representation of the central sample node, or the node feature vector of all sample nodes can be pooled to obtain the representation of the whole sample subgraph, and the representation is used as the global representation of the central sample node.
Further, a central node feature vector of the central sample node can be updated and obtained, and the central node feature vector is used for replacing the node feature vector of the central sample node in the initial feature matrix, so that a fusion feature matrix of each sample subgraph is obtained.
Through the mode, the node characteristic vectors of the sample subgraph can be normalized by utilizing the node degree matrix and the adjacent matrix, the influence of the sample nodes with larger node degree in the sample subgraph on characteristic propagation can be reduced, the dimensional stability in the characteristic propagation process is kept, each sample node can receive information from adjacent sample nodes more fairly, and the information in the sample subgraph can be effectively aggregated by fusing the node characteristic vectors of the adjacent sample nodes of the central sample node layer by layer, so that the characterization capability of the central sample node is enhanced.
And 203, determining the core features and the redundant features of each sample subgraph according to the adjacency matrix and the fusion feature matrix.
In some embodiments, the core features and the redundant features of the sample subgraph can be calculated through the adjacency matrix and the fusion feature matrix, so that the preset model can be trained through the core features and the redundant features, the preset model is more focused on the causal features, and the attention degree to the redundant features is reduced.
In some embodiments, after the central node feature vector of each sample sub-graph is updated, the central node feature vector of the central sample node may be updated in the first feature matrix of the sample graph data according to the node identification of the central sample node. And after updating each central node feature vector in the first feature matrix, updating corresponding node feature vectors in all sample subgraphs synchronously. In this way, the representation of the node feature vector of the sample node in each sample sub-graph can be enriched as much as possible. Or only the central node feature vector of the central sample node in each sample sub-graph can be updated, so that the preset model can pay more attention to the core features in the training process.
For example, the node feature vectors of any two sample nodes in the sample subgraph can be input into the multi-layer perceptron to obtain corresponding edge mask values, and an edge mask matrix is generated based on all edge mask values corresponding to the sample subgraph. Further, the edge mask matrix and the adjacent matrix of the corresponding sample subgraph can be multiplied to adjust the weight of each edge, and then the edge mask matrix and the adjacent matrix of the corresponding sample subgraph are combined with the fusion feature matrix to obtain a core structure diagram, and the core network diagram is processed by a preset model to obtain core features. It will be appreciated that by means of the edge mask matrix, the weights of the edges present in each of the adjacency matrices can be adjusted so that the important edges between sample nodes are more prominent.
For example, the complement matrix of the edge mask matrix may be multiplied by the adjacent matrix of the corresponding sample subgraph to adjust the weights of the edges, and then combined with the fusion feature matrix to obtain a redundant network diagram, and the redundant structure diagram is processed by a preset model to obtain core features. It will be appreciated that since the edge mask matrix contains the weight values of the edges between the sample nodes, the redundancy matrix of the edge mask matrix contains the redundancy portions of the edges between the sample nodes.
The core features and the redundant features which can be calculated in the mode can facilitate the follow-up training of the preset model, so that the preset model is more focused on the causal features, and the attention to the shortcut features is reduced.
In some embodiments, to accurately obtain the core feature and the redundant feature, the weights of the adjacency matrix may be adjusted by the edge mask matrix to quantify the importance of the edges of the two sample nodes, and step 203 may include, for example:
(203.1) inputting any two sample nodes in the sample subgraph into the multi-layer perceptron based on the fusion feature matrix of the sample subgraph to obtain an edge mask value between any two sample nodes;
(203.2) generating an edge mask matrix for the sample subgraph from the plurality of edge mask values in the sample subgraph;
(203.3) generating core features corresponding to the sample subgraph based on the adjacency matrix, the fusion feature matrix and the edge mask matrix;
(203.4) generating a corresponding complement matrix according to the edge mask matrix of the sample subgraph;
And (203.5) obtaining redundant features of the sample subgraph based on the adjacency matrix, the fusion feature matrix and the complement matrix.
The edge mask value may be a scalar value obtained by calculation of the multi-layer perceptron for each edge in the sample subgraph. The edge mask value represents the importance of an edge between two sample nodes. Specifically, the edge mask value may be calculated by inputting any two sample nodes in the sample subgraph into the multi-layer perceptron, for example, the edge mask value of an edge may be 0, 0.1, 0.6, and so on.
The edge mask matrix may be a matrix generated by sequentially inputting node feature vectors of any two sample nodes of the sample subgraph into the multi-layer perceptron until all sample nodes are input completely, according to all edge mask values.
In particular, can be in a sample subgraphAfter any two sample nodes are paired in pairs, the node characteristic vectors of the two sample nodes are input into a multi-layer perceptron (Multilayer Perceptron, MLP) to obtain an edge mask value between 0 and 1, and after the edge mask values among all sample nodes in the sample are obtained, a sample subgraph can be formedIs of an edge mask matrix。
Further, an edge mask matrix may be usedSum sample subgraphThe adjacent matrix A is multiplied element by element and then combined with the fusion feature matrix X to obtain a core structure diagram。
Further, the complement matrix isComplement matrixSum sample subgraphThe adjacent matrix A is multiplied element by element and then combined with the fusion characteristic matrix X to obtain a redundant structure diagram。
Specifically, the core structure diagram and the redundancy structure diagram can be processed by a preset model to respectively obtain the core characteristicsAnd redundancy features。
By the method, the structural information of the sample subgraph can be better captured, the prediction accuracy and interpretation of the preset model on the sample subgraph can be improved as much as possible, and by generating the redundant features, the negative influence of redundant data such as noise data on the performance of the preset model can be reduced, so that the attention degree of the preset model on the core features can be improved subsequently, and the dependence on the redundant features can be reduced.
Step 204, determining a target loss of the preset model based on the core feature and the redundancy feature.
The generalization model is a model obtained by training a preset model, and stopping training after the preset model converges or reaches the preset training times.
The target loss may be a loss composed of a first loss, a second loss and a third loss, and is used for indicating adjustment of parameters of the preset model so as to train the preset model and obtain the generalization model. Specifically, the first loss may be a cross entropy loss, the second loss may be an entropy loss, the third loss may be a contrast loss, and the first loss, the second loss and the third loss form a target loss, and model parameters are updated by using the target loss, so that the generalization model can learn core features more intensively, and the generalization model is prevented from paying excessive attention to redundant features, thereby improving the quality of overall feature characterization.
It can be appreciated that the core feature and the redundant feature can be input into the preset model, the first loss is determined according to the difference between the prediction node class label output by the core feature of the sample sub-graph and the sample node class label of the core feature, the second loss is determined according to the probability distribution uniformity when the sample node class label is predicted according to the redundant feature of the sample sub-graph by the preset model, and the third loss is determined according to the sample distance between the core feature and the redundant feature, so that the generalization model can concentrate attention on the causal feature, namely the core feature, the dependence on the unstable redundant feature is reduced, the risk of overfitting is reduced, and meanwhile more reliable decision is made.
In some embodiments, in order to make the generalization model better capture the intrinsic structure of the data and thus have better generalization capability in the face of test data with different distributions, the preset model may be trained based on the core features and the redundancy features so that the preset model has better classification performance, for example, step 204 may include:
(204.1) determining a first penalty based on a difference between the predicted node class label output for the core feature of the sample subgraph relative to the sample node class label of the core feature;
(204.2) determining a second loss based on a probability distribution uniformity when predicting each sample node class label for the redundant features of the sample subgraph by a preset model;
(204.3) determining a third loss based on the sample distance between the core feature and the redundant feature;
(204.4) constructing a target loss of the preset model based on the sum of the first loss, the second loss and the third loss.
The first loss may be a cross entropy loss, and is used for evaluating a difference between a probability distribution of each prediction node class label of the core feature prediction by the preset model and a sample node class label of the core feature. Specifically, a single thermal code may be set for each sample node class label, each sample node class label is represented by a binary vector, and only the actual sample node class label corresponding to the core feature is set to be 1, and the other sample node class labels are set to be 0. Specifically, the first loss may be:
wherein, Is a one-hot encoding of the sample node class labels,Is the first predicted probability of the model for the v-th sample node class label. When the predicted node class label is consistent with the real sample node class label corresponding to the core feature, the first loss is minimum, so that the preset model can learn the causal relationship of the data better by minimizing the first loss, and more accurate prediction can be continuously made.
The predicted node class labels are a plurality of node class labels obtained by predicting core features by a preset model. For example, for an animal image, the predicted node class labels may be cat, dog, rabbit, each of which corresponds to a predicted probability.
The second loss may be entropy loss, which is used for measuring probability distribution uniformity when the preset model predicts the class labels of the nodes of each sample aiming at the redundant features of the sample subgraph, and the more uniform the probability distribution, the smaller the classification preference of the preset model for the redundant features is indicated, so that the preset model can be effectively prevented from learning shortcut features. Specifically, the second loss may be:
wherein, Representing a uniform probability of classification labels for each sample node when uniformly distributed, in general=1/Number of categories,Is a second prediction probability of the preset model to the v sample node class label.
Specifically, when the redundant features are input into the preset model, the output second prediction probabilities are close to uniform distribution, which indicates that the redundant features do not help to classify the preset model. Exemplary, if the sample class labels are 4, cat, rabbit, dog, pig, respectively, thenIf the redundancy feature 1 is input to the preset model, and the prediction probabilities of the cat, the rabbit, the dog and the pig are 1/4, the second loss of the preset model is also minimum, at this time, the preset model has no classification preference for the redundancy feature, and the output probability distribution uniformity is high. Therefore, the preset model needs to be continuously trained, so that the prediction probability of the preset model for the redundant features on the class labels of all sample nodes is close to uniform distribution.
The probability distribution uniformity may be a second preset probability under each sample node class label of the redundant feature prediction output by the preset model after the redundant feature is input to the preset model. The more uniform the respective second preset probability distribution, the higher the probability distribution uniformity, for example, the probability distribution uniformity for the second preset probabilities of 1/3, 1/3 is greater than 2/3, 1/6.
The third loss may be a contrast loss, specifically, the core feature may be used as a positive sample, the redundant feature may be used as a negative sample, and the contrast loss may excite the preset model to be able to correctly distinguish the positive sample and the negative sample, so that the distance between the positive sample and the positive sample is closer, and the distance between the positive sample and the negative sample is further. Further, the third loss may be:
wherein, Is a sigmoid function, typically used to map the input between (0, 1), Q is a superparameter, C andIs a different example of a core feature, representing positive sample pairs, C andAll represent positive samples, i.e. core features; c and CRepresenting the positive and negative pairs of samples,Representing a negative sample, i.e. a redundant feature; The representation of the expected value may be obtained by averaging a first predictive probability of the core feature for all positive sample pairs and a plurality of second predictive probability numbers of the redundant features for all negative sample pairs by a predetermined model. Through the third loss training preset model, the preset model can be excited to correctly distinguish the positive sample and the negative sample, so that the distance between the positive sample and the positive sample is closer, and the distance between the positive sample and the negative sample is further.
The sample distance may be a distance between two core features or a distance between a core feature and a redundant feature, and the sample distance may be calculated by means of euclidean distance, cosine similarity, and the like.
The core features can be features with causal relation or causal interpretation capability in the sample subgraph, and the core features have causal information of sample nodes to be predicted and sample node category labels.
The redundancy feature may be a shortcut feature between the core feature and the prediction result, and has no additional contribution to predicting the sample node and the sample node class label, and no causal interpretation feature. In addition, the redundant features can enable the preset model to tend to learn shortcut features to make decisions, and accuracy of the preset model in predicting sample graph data is reduced, so that performance of the generalized model obtained through training in out-of-distribution test data is reduced. Illustratively, the redundancy feature may be a noise feature.
For example, if the sample graph data is a social network graph, the preset model needs to classify users in the social network, and the labels include "students", "workers", "free professionals", and the like. The sample subgraph is a record of the user's (i.e., sample node to be predicted) activity on the social network, such as posts, photos, interactions, etc. The core features may include professional information of the user, educational background, and the like. After the core features are input into the preset model, the preset model correspondingly outputs first prediction probabilities that the user is a student, a worker and a free professional. Further, the cross entropy loss may be used as a first loss to calculate the difference between the probability distribution under each predicted node class label and the actual class label (represented by the one-hot code) output by the preset model. For example, if the actual sample class label is student, then the one-hot code is [1, 0]; if the prediction probability is [0.7,0.2,0.1], the first loss calculates the cross entropy loss between the two, and adjusts the parameters of the preset model to encourage the preset model to conduct AND test on the sample nodes to be predicted in the sample subgraph, so that the prediction accuracy of the student class is improved.
For example, the redundant feature may be weather information of the geographic location of the user, although this appears to be relevant to the classification task, in practice the redundant feature may be biased and not actually helpful in classifying professional information of the user, and thus, entropy loss may be calculated as a second loss, when the output of the redundant feature by the preset model tends to be evenly distributed, indicating that the preset model does not learn any biased classification decisions from the redundant features, thereby encouraging the preset model to focus mainly on core features that are substantially helpful in classification.
Further, in the training process of the preset model, the preset model can distinguish relevant features (i.e. core features) and irrelevant features (i.e. redundant features) through comparison loss, learn to approach the core features (positive samples) to each other as much as possible, and meanwhile keep the redundant features (negative samples) away from the core features. Further, by comparing the loss as a third loss, the similarity between the positive and negative sample pairs is calculated and punished, so that the learned embedding space can better distinguish the relevant and irrelevant information.
It will be appreciated that by combining the first, second and third losses to form the target loss, a balanced performance improvement of the pre-set model in different aspects can be ensured, and optimization of the target loss can help the pre-set model to disregard redundant features that are not helpful for classification while learning the causal characteristics of the sample, and better distinguish core features related to the predictive task.
In step 205, parameters of the preset model are updated based on the target loss, and when the preset model converges, a trained generalization model is obtained.
Further, in the training process of the preset model, parameters of the preset model may be continuously updated according to the target loss, for example, through a gradient descent algorithm. When the loss of the preset model is not significantly reduced, or a certain termination condition (for example, the preset training times or the preset model convergence is reached), training of the preset model can be stopped, and the generalization model is obtained.
In summary, by training the pre-set model in combination with multiple loss functions, the pre-set model is more likely to learn causal features that are decisive for classification than simply relying on shortcuts that are easy to learn but may lead to overfitting. Therefore, a generalization model with better generalization performance and higher prediction accuracy can be obtained.
Referring to fig. 4, in some embodiments, the general flow of the present application is described in conjunction with fig. 4. For example, sample graph data may be generated by means of random generation, and for each sample node in the sample graph data, a sample subgraph is generated within a sample association hierarchy. Further, feature fusion is carried out aiming at the central sample node of each sample sub-graph as a fusion center, a central node feature vector of the central sample node is generated, and an initial feature matrix of the corresponding sample sub-graph is updated according to the central node feature vector to obtain a fusion feature matrix, so that the attention degree to the central sample node, namely the sample node to be predicted, is improved.
Further, any two sample nodes in the sample subgraph are processed by a multi-layer perceptron to generate an edge mask value, then an edge mask matrix is generated according to the edge mask values of all the sample nodes, the edge mask matrix is multiplied with an adjacent matrix element by element, and then a fusion feature matrix is combined to obtain core features. Further, the redundant features are obtained by multiplying the complementary set matrix of the edge mask matrix with the adjacent matrix element by element and combining the fusion feature matrix, so that the importance of the edge can be quantized.
Further, a first loss may be determined based on a difference between a predicted node class label output for the core feature of the sample subgraph relative to a sample node class label of the core feature, a second loss may be determined based on a probability distribution uniformity of a preset model when predicting each sample node class label for the redundancy feature of the sample subgraph, a third loss may be determined based on a sample distance between the core feature and the redundancy feature, and a target loss of the preset model may be formed based on a sum of the first loss, the second loss, and the third loss. Therefore, parameters of the preset model can be updated based on the target loss, when the preset model converges, a trained generalization model with good generalization capability is obtained, and then accurate prediction and evaluation of data outside a distribution range can be achieved through the generalization model.
The embodiment of the application obtains target graph data, wherein the target graph data comprises a plurality of target nodes and a node connection structure formed by the plurality of target nodes; searching a target subgraph of a target connection structure of each target node in a target association level based on a node connection structure formed by a plurality of target nodes in target graph data; inputting the target subgraph into the generalization model to obtain a node class label of a target node in the target subgraph; the generalization model carries out causal prediction learning between a sample node connection structure and a sample node class label in a sample sub-graph based on a training process of minimizing target loss, wherein the target loss is formed by first loss, second loss and third loss, and the first loss is determined based on the difference between a predicted node class label output for core characteristics of the sample sub-graph and a sample node class label of the core characteristics; the second loss is determined based on probability distribution uniformity when the redundancy features of the sample subgraphs are used for predicting the class labels of the nodes of each sample according to a preset model; the third loss is determined according to the sample distance between the core feature and the redundant feature, wherein the core feature is a feature obtained by enhancing and representing the sample node connection structure of the sample subgraph according to an edge mask matrix obtained by an attention mechanism; the redundant features are features obtained by weakening and representing the sample node connection structure of the sample subgraph according to the complement matrix of the edge mask matrix. Therefore, the preset model can be trained by constructing the target loss, the preset model tends to predict through the core features in the training process, the preference of the redundant features is reduced, and the sample distance between the core features and the redundant features is reduced, so that the trained generalization model has stronger causal interpretation capability and generalization capability. And moreover, the edge mask matrix obtained through the attention mechanism enhances the sample connection structure for representing the sample subgraph, and the complement matrix of the edge mask matrix weakens the sample node connection structure for representing the sample subgraph, so that the trained generalized model is facilitated to better capture the important relation between target nodes, the interference of redundant information is reduced, and the preset model is more focused on the learning of core features. In the process of applying the generalization model, a target subgraph of a target node to be predicted can be obtained based on target graph data, so that the generalization model does not need to predict the whole target graph data, only needs to obtain a relevant target subgraph of the target node to be predicted, and is beneficial to analyzing the causal relationship among the target nodes more intensively by the generalization model, and the prediction efficiency and accuracy are improved. In summary, the method and the device can improve the accuracy of the trained generalization model in predicting the graph data.
Referring to fig. 5, an embodiment of the present application further provides a graph data prediction device based on a generalization model, which can implement the graph data prediction method based on a generalization model, where the graph data prediction device based on a generalization model includes:
An obtaining module 51, configured to obtain target graph data, where the target graph data includes a plurality of target nodes and a node connection structure formed by the plurality of target nodes;
The searching module 52 is configured to search a target subgraph of a target connection structure of each target node in the target association hierarchy based on a node connection structure formed between a plurality of target nodes in the target graph data;
The input module 53 is configured to input the target subgraph to the generalization model, and obtain a node class label of a target node in the target subgraph; the generalization model carries out causal prediction learning between a sample node connection structure and a sample node class label in a sample sub-graph based on a training process of minimizing target loss, wherein the target loss is formed by first loss, second loss and third loss, and the first loss is determined based on the difference between a predicted node class label output for core characteristics of the sample sub-graph and a sample node class label of the core characteristics; the second loss is determined based on probability distribution uniformity when the redundancy features of the sample subgraphs are used for predicting the class labels of the nodes of each sample according to a preset model; the third loss is determined according to the sample distance between the core feature and the redundant feature, wherein the core feature is a feature obtained by enhancing and representing the sample node connection structure of the sample subgraph according to an edge mask matrix obtained by an attention mechanism; the redundant features are features obtained by weakening and representing the sample node connection structure of the sample subgraph according to the complement matrix of the edge mask matrix.
The specific implementation manner of the map data prediction apparatus based on the generalization model is basically the same as the specific embodiment of the map data prediction method based on the generalization model, and will not be described herein. On the premise of meeting the requirements of the embodiment of the application, other functional modules can be arranged in the map data prediction device based on the generalization model so as to realize the map data prediction method based on the generalization model in the embodiment.
The embodiment of the application also provides computer equipment, which comprises a memory and a processor, wherein the memory stores a computer program, and the processor realizes the graph data prediction method based on the generalization model when executing the computer program. The computer equipment can be any intelligent terminal including a tablet personal computer, a vehicle-mounted computer and the like.
Referring to fig. 6, fig. 6 illustrates a hardware structure of a computer device according to another embodiment, where the computer device includes:
The processor 61 may be implemented by a general-purpose CPU (central processing unit), a microprocessor, an application-specific integrated circuit (ApplicationSpecificIntegratedCircuit, ASIC), or one or more integrated circuits, etc. for executing related programs to implement the technical solution provided by the embodiments of the present application;
the memory 62 may be implemented in the form of read-only memory (ReadOnlyMemory, ROM), static storage, dynamic storage, or random access memory (RandomAccessMemory, RAM). The memory 62 may store an operating system and other application programs, and when the technical solution provided in the embodiments of the present disclosure is implemented by software or firmware, relevant program codes are stored in the memory 62, and the processor 61 invokes the graph data prediction method based on the generalization model to execute the embodiments of the present disclosure;
An input/output interface 63 for implementing information input and output;
The communication interface 64 is configured to implement communication interaction between the device and other devices, and may implement communication in a wired manner (such as USB, network cable, etc.), or may implement communication in a wireless manner (such as mobile network, WIFI, bluetooth, etc.);
a bus 66 for transferring information between the various components of the device (e.g., processor 61, memory 62, input/output interface 63, and communication interface 64);
Wherein the processor 61, the memory 62, the input/output interface 63 and the communication interface 64 are in communication connection with each other inside the device via a bus 66.
The embodiment of the application also provides a computer readable storage medium, wherein the computer readable storage medium stores a computer program, and the computer program realizes the graph data prediction method based on the generalization model when being executed by a processor.
The memory, as a non-transitory computer readable storage medium, may be used to store non-transitory software programs as well as non-transitory computer executable programs. In addition, the memory may include high-speed random access memory, and may also include non-transitory memory, such as at least one magnetic disk storage device, flash memory device, or other non-transitory solid state storage device. In some embodiments, the memory optionally includes memory remotely located relative to the processor, the remote memory being connectable to the processor through a network. Examples of such networks include, but are not limited to, the internet, intranets, local area networks, mobile communication networks, and combinations thereof.
The embodiments described in the embodiments of the present application are for more clearly describing the technical solutions of the embodiments of the present application, and do not constitute a limitation on the technical solutions provided by the embodiments of the present application, and those skilled in the art can know that, with the evolution of technology and the appearance of new application scenarios, the technical solutions provided by the embodiments of the present application are equally applicable to similar technical problems.
It will be appreciated by persons skilled in the art that the embodiments of the application are not limited by the illustrations, and that more or fewer steps than those shown may be included, or certain steps may be combined, or different steps may be included.
The above described apparatus embodiments are merely illustrative, wherein the units illustrated as separate components may or may not be physically separate, i.e. may be located in one place, or may be distributed over a plurality of network elements. Some or all of the modules may be selected according to actual needs to achieve the purpose of the solution of this embodiment.
Those of ordinary skill in the art will appreciate that all or some of the steps of the methods, systems, functional modules/units in the devices disclosed above may be implemented as software, firmware, hardware, and suitable combinations thereof.
The terms "first," "second," "third," "fourth," and the like in the description of the application and in the above figures, if any, are used for distinguishing between similar objects and not necessarily for describing a particular sequential or chronological order. It is to be understood that the data so used may be interchanged where appropriate such that the embodiments of the application described herein may be implemented in sequences other than those illustrated or otherwise described herein. Furthermore, the terms "comprises," "comprising," and "having," and any variations thereof, are intended to cover a non-exclusive inclusion, such that a process, method, system, article, or apparatus that comprises a list of steps or elements is not necessarily limited to those steps or elements expressly listed but may include other steps or elements not expressly listed or inherent to such process, method, article, or apparatus.
It should be understood that in the present application, "at least one (item)" and "a plurality" means one or more, and "a plurality" means two or more. "and/or" for describing the association relationship of the association object, the representation may have three relationships, for example, "a and/or B" may represent: only a, only B and both a and B are present, wherein a, B may be singular or plural. The character "/" generally indicates that the context-dependent object is an "or" relationship. "at least one of" or the like means any combination of these items, including any combination of single item(s) or plural items(s). For example, at least one (one) of a, b or c may represent: a, b, c, "a and b", "a and c", "b and c", or "a and b and c", wherein a, b, c may be single or plural.
In the several embodiments provided by the present application, it should be understood that the disclosed systems and methods may be implemented in other ways. For example, the system embodiments described above are merely illustrative, e.g., the division of the above elements is merely a logical functional division, and there may be additional divisions in actual implementation, e.g., multiple elements or components may be combined or integrated into another system, or some features may be omitted, or not performed. Alternatively, the coupling or direct coupling or communication connection shown or discussed with each other may be an indirect coupling or communication connection via some interfaces, devices or units, which may be in electrical, mechanical or other form.
The units described above as separate components may or may not be physically separate, and components shown as units may or may not be physical units, may be located in one place, or may be distributed over a plurality of network units. Some or all of the units may be selected according to actual needs to achieve the purpose of the solution of this embodiment.
In addition, each functional unit in the embodiments of the present application may be integrated in one processing unit, or each unit may exist alone physically, or two or more units may be integrated in one unit. The integrated units may be implemented in hardware or in software functional units.
The integrated units, if implemented in the form of software functional units and sold or used as stand-alone products, may be stored in a computer readable storage medium. Based on such understanding, the technical solution of the present application may be embodied in essence or a part contributing to the prior art or all or part of the technical solution in the form of a software product stored in a storage medium, including multiple instructions to cause a computer device (which may be a personal computer, a server, or a network device, etc.) to perform all or part of the steps of the method of the various embodiments of the present application. And the aforementioned storage medium includes: a U-disk, a removable hard disk, a Read-Only Memory (ROM), a random access Memory (Random Access Memory RAM), a magnetic disk, or an optical disk, or other various media capable of storing a program.
The preferred embodiments of the present application have been described above with reference to the accompanying drawings, and are not thereby limiting the scope of the claims of the embodiments of the present application. Any modifications, equivalent substitutions and improvements made by those skilled in the art without departing from the scope and spirit of the embodiments of the present application shall fall within the scope of the claims of the embodiments of the present application.
Claims (11)
1. A graph data prediction method based on a generalization model, the method comprising:
obtaining target graph data, wherein the target graph data comprises a plurality of target nodes and a node connection structure formed by the plurality of target nodes;
Searching a target subgraph of a target connection structure of each target node in a target association level based on a node connection structure formed by a plurality of target nodes in the target graph data;
Inputting the target subgraph into a generalization model to obtain a node class label of a target node in the target subgraph;
The generalization model carries out causal prediction learning between a sample node connection structure and a sample node class label in a sample sub-graph based on a training process of minimizing target loss, wherein the target loss is formed by a first loss, a second loss and a third loss, and the first loss is determined based on the difference between a predicted node class label output for core characteristics of the sample sub-graph and a sample node class label of the core characteristics; the second loss is determined based on probability distribution uniformity when the redundancy features of the sample subgraph are used for predicting the category labels of each sample node according to a preset model; the third loss is determined according to the sample distance between the core feature and the redundant feature, wherein the core feature is a feature obtained by enhancing and representing a sample node connection structure of a sample sub-graph according to an edge mask matrix obtained according to an attention mechanism; the redundant features are features obtained by weakening and representing a sample node connection structure of a sample sub-graph according to a complement matrix of the edge mask matrix;
The sample nodes of the sample subgraph comprise a central sample node and a plurality of adjacent sample nodes; the generalization model is obtained by training in the following way:
Acquiring sample graph data for training a preset model, and generating a corresponding sample subgraph for each sample node in the sample graph data according to a sample node connection structure of the sample graph data;
acquiring an initial feature matrix and an adjacent matrix of each sample sub-graph, and sequentially inputting the initial feature matrix into a preset model to perform feature aggregation to obtain a fusion feature matrix of each sample sub-graph;
according to the adjacency matrix and the fusion feature matrix, determining core features and redundant features of each sample subgraph;
determining a target loss of the preset model based on the core feature and the redundancy feature;
And updating parameters of the preset model based on the target loss, and obtaining a trained generalization model when the preset model converges.
2. The generalized model based graph data prediction method according to claim 1, wherein the determining the core feature and the redundancy feature of each of the sample subgraphs according to the adjacency matrix and the fusion feature matrix includes:
Inputting any two sample nodes in the sample subgraph to a multi-layer perceptron based on a fusion feature matrix of the sample subgraph to obtain an edge mask value between any two sample nodes;
generating an edge mask matrix of the sample subgraph according to a plurality of edge mask values in the sample subgraph;
Generating core features corresponding to the sample subgraph based on the adjacency matrix, the fusion feature matrix and the edge mask matrix;
generating a corresponding complement matrix according to the edge mask matrix of the sample subgraph;
and obtaining the redundant features of the sample subgraph based on the adjacent matrix, the fusion feature matrix and the complement matrix.
3. The method for predicting map data based on a generalization model according to claim 1, wherein the obtaining sample map data for training a preset model comprises:
Acquiring basic graph data and a plurality of sub-patterns of different categories, and distributing the plurality of sub-patterns to any sample node of the basic graph data to obtain sample graph data;
determining a preset number of edges added to the sample graph data according to the number of nodes of the sample graph data, and adding edges to any two sample nodes in the sample graph data, wherein the number of edges added to the sample graph data is equal to the preset number.
4. The method for predicting graph data based on a generalization model according to claim 1, wherein the generating a corresponding sample subgraph for each sample node in the sample graph data according to a sample node connection structure of the sample graph data comprises:
Acquiring node density of the sample graph data, and determining a sample association level of each sample node based on the node density;
And aiming at each sample node serving as a central sample node, based on a sample node connection structure formed between the central sample node and the adjacent sample nodes, searching a sample subgraph of a target sample node connection structure of each central sample node in a sample association hierarchy.
5. The method for predicting map data based on a generalization model according to claim 1, further comprising, before obtaining the initial feature matrix and the adjacency matrix of each sample subgraph:
Acquiring a preset graph theory library;
acquiring preset characteristic indexes of the sample nodes, and generating node characteristic vectors of each sample node under the corresponding characteristic indexes through the graph theory library; the characteristic index comprises at least one of node identification, node degree, clustering coefficient, medium centrality and near centrality;
And generating a first feature matrix according to a plurality of node feature vectors corresponding to the plurality of sample nodes.
6. The method for predicting map data based on a generalization model according to claim 5, wherein said obtaining an initial feature matrix and an adjacency matrix of each sample subgraph comprises:
determining node identification of each sample node in each sample subgraph;
determining a node feature vector corresponding to each sample node from the first feature matrix based on the node identification;
Generating an initial feature matrix of the sample subgraph according to a plurality of node feature vectors corresponding to a plurality of sample nodes of the sample subgraph;
and generating an adjacent matrix corresponding to the feature matrix based on the connection relation between the sample nodes.
7. The method for predicting graph data based on a generalization model according to claim 5, wherein the sequentially inputting the initial feature matrix into a preset model for feature aggregation to obtain a fused feature matrix of each sample subgraph comprises:
determining a node degree matrix of each sample subgraph from the first feature matrix according to the node identification of each sample subgraph;
Acquiring an adjacent matrix of the sample subgraph, and normalizing the adjacent matrix through the node degree matrix to obtain a normalized adjacent matrix;
taking the central sample node as a fusion center, and carrying out layer-by-layer fusion of sample node characteristics based on the normalized adjacency matrix to obtain a central node characteristic vector after updating corresponding to the central sample node;
And updating the node characteristic vector of the central sample node of the initial characteristic matrix according to the central node characteristic vector to obtain a fusion characteristic matrix of each sample subgraph.
8. The generalized model based graph data prediction method according to claim 1, wherein the determining the target loss of the preset model based on the core feature and the redundancy feature includes:
determining a first penalty based on a difference between a predicted node class label output for a core feature of the sample subgraph relative to a sample node class label of the core feature;
Determining a second loss based on a preset model for predicting probability distribution uniformity when each sample node class label is predicted for the redundant features of the sample subgraph;
Determining a third loss based on a sample distance between the core feature and the redundant feature;
And constructing a target loss of the preset model based on the sum of the first loss, the second loss and the third loss.
9. A graph data prediction apparatus based on a generalization model, the apparatus comprising:
The system comprises an acquisition module, a storage module and a storage module, wherein the acquisition module is used for acquiring target graph data, and the target graph data comprises a plurality of target nodes and a node connection structure formed by the plurality of target nodes;
the searching module is used for searching a target subgraph of a target connection structure of each target node in a target association level based on a node connection structure formed by a plurality of target nodes in the target graph data;
The input module is used for inputting the target subgraph into the generalization model to obtain a node class label of a target node in the target subgraph; the generalization model carries out causal prediction learning between a sample node connection structure and a sample node class label in a sample sub-graph based on a training process of minimizing target loss, wherein the target loss is formed by a first loss, a second loss and a third loss, and the first loss is determined based on the difference between a predicted node class label output for core characteristics of the sample sub-graph and a sample node class label of the core characteristics; the second loss is determined based on probability distribution uniformity when the redundancy features of the sample subgraph are used for predicting the category labels of each sample node according to a preset model; the third loss is determined according to the sample distance between the core feature and the redundant feature, wherein the core feature is a feature obtained by enhancing and representing a sample node connection structure of a sample sub-graph according to an edge mask matrix obtained according to an attention mechanism; the redundant features are features obtained by weakening and representing a sample node connection structure of a sample sub-graph according to a complement matrix of the edge mask matrix; the sample nodes of the sample subgraph comprise a central sample node and a plurality of adjacent sample nodes; the generalization model is obtained by training in the following way: acquiring sample graph data for training a preset model, and generating a corresponding sample subgraph for each sample node in the sample graph data according to a sample node connection structure of the sample graph data; acquiring an initial feature matrix and an adjacent matrix of each sample sub-graph, and sequentially inputting the initial feature matrix into a preset model to perform feature aggregation to obtain a fusion feature matrix of each sample sub-graph; according to the adjacency matrix and the fusion feature matrix, determining core features and redundant features of each sample subgraph; determining a target loss of the preset model based on the core feature and the redundancy feature; and updating parameters of the preset model based on the target loss, and obtaining a trained generalization model when the preset model converges.
10. A computer device, characterized in that it comprises a memory storing a computer program and a processor implementing the generalization model based graph data prediction method according to any of claims 1 to 8 when said computer program is executed.
11. A computer-readable storage medium storing a computer program, wherein the computer program, when executed by a processor, implements the generalization model-based graph data prediction method of any one of claims 1 to 8.
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410649703.0A CN118245638B (en) | 2024-05-24 | 2024-05-24 | Method, device, equipment and storage medium for predicting graph data based on generalization model |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410649703.0A CN118245638B (en) | 2024-05-24 | 2024-05-24 | Method, device, equipment and storage medium for predicting graph data based on generalization model |
Publications (2)
Publication Number | Publication Date |
---|---|
CN118245638A CN118245638A (en) | 2024-06-25 |
CN118245638B true CN118245638B (en) | 2024-08-27 |
Family
ID=91556756
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202410649703.0A Active CN118245638B (en) | 2024-05-24 | 2024-05-24 | Method, device, equipment and storage medium for predicting graph data based on generalization model |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN118245638B (en) |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2019001070A1 (en) * | 2017-06-28 | 2019-01-03 | 浙江大学 | Adjacency matrix-based connection information organization system, image feature extraction system, and image classification system and method |
CN112231527A (en) * | 2020-12-17 | 2021-01-15 | 北京百度网讯科技有限公司 | Method and device for predicting label information of graph node and electronic equipment |
Family Cites Families (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112214499B (en) * | 2020-12-03 | 2021-03-19 | 腾讯科技(深圳)有限公司 | Graph data processing method and device, computer equipment and storage medium |
CN114389966B (en) * | 2022-03-24 | 2022-06-21 | 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) | Network traffic identification method and system based on graph neural network and stream space-time correlation |
-
2024
- 2024-05-24 CN CN202410649703.0A patent/CN118245638B/en active Active
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2019001070A1 (en) * | 2017-06-28 | 2019-01-03 | 浙江大学 | Adjacency matrix-based connection information organization system, image feature extraction system, and image classification system and method |
CN112231527A (en) * | 2020-12-17 | 2021-01-15 | 北京百度网讯科技有限公司 | Method and device for predicting label information of graph node and electronic equipment |
Also Published As
Publication number | Publication date |
---|---|
CN118245638A (en) | 2024-06-25 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109145828B (en) | Method and apparatus for generating video category detection model | |
CN110008977B (en) | Clustering model construction method and device | |
CN114491263B (en) | Recommendation model training method and device, recommendation method and device | |
CN114418189A (en) | Water quality grade prediction method, system, terminal device and storage medium | |
CN117271899A (en) | Interest point recommendation method based on space-time perception | |
CN115062779A (en) | Event prediction method and device based on dynamic knowledge graph | |
CN118245822B (en) | Similarity set forecast optimization method, device, equipment and medium | |
CN114897085A (en) | Clustering method based on closed subgraph link prediction and computer equipment | |
CN114154564A (en) | Method and device for determining relevance based on heterogeneous graph, electronic equipment and storage medium | |
CN114330090A (en) | Defect detection method and device, computer equipment and storage medium | |
CN116909534B (en) | Operator flow generating method, operator flow generating device and storage medium | |
CN118245638B (en) | Method, device, equipment and storage medium for predicting graph data based on generalization model | |
CN113705293A (en) | Image scene recognition method, device, equipment and readable storage medium | |
CN114900435B (en) | Connection relation prediction method and related equipment | |
Jiang et al. | LibCity: A Unified Library Towards Efficient and Comprehensive Urban Spatial-Temporal Prediction | |
CN115730248A (en) | Machine account detection method, system, equipment and storage medium | |
WO2021115269A1 (en) | User cluster prediction method, apparatus, computer device, and storage medium | |
CN114882364A (en) | Data processing method, server and storage medium | |
CN116050508B (en) | Neural network training method and device | |
CN114510638B (en) | Information processing method, apparatus, device, storage medium, and program product | |
Su et al. | Automatic Completion of Underground Utility Topologies Using Graph Convolutional Networks | |
CN117011037A (en) | Risk account identification method, apparatus, device, storage medium and program product | |
CN116757475A (en) | Urban public transport user full life cycle loss risk assessment method and system | |
CN117454008A (en) | Information recommendation method, device, electronic equipment and storage medium | |
CN116866195A (en) | Model training method, device and equipment |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |