Nothing Special   »   [go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create custom Graph Extension Type #420

Open
dmadisetti opened this issue Feb 22, 2023 · 3 comments
Open

Create custom Graph Extension Type #420

dmadisetti opened this issue Feb 22, 2023 · 3 comments

Comments

@dmadisetti
Copy link

A "Graph" Type is even defined in the example:

This lets the user carry around graph information as a Tensorflow object, and even allow for ops on the object level. I think this would also allow for a batch of graphs, which would be great.

This would be a little nicer than carrying around multiple arrays for adjacency etc...

@dmadisetti
Copy link
Author
dmadisetti commented Feb 22, 2023

Nvm, for batching there's an explicit BatchableTypeSpec

https://github.com/tensorflow/tensorflow/blob/d5b57ca93e506df258271ea00fc29cf98383a374/tensorflow/python/framework/type_spec.py#L738-L751

but still, the custom extension type still makes sense

There is tf.experimental.BatchableExtensionType

Seems to work pretty straightforward? I'm using this for graphs generated by my pipeline.

@danielegrattarola
Copy link
Owner

Hey,

do you have a self-contained example of using this type with Spektral? Or does it require to re-write the layers?

Cheers

@dmadisetti
Copy link
Author
dmadisetti commented Feb 28, 2023

Likely would just be another data mode. This is what I have, but integration into the library would probably look a little different

class WrappedGCN(tf.keras.layers.Layer):
    def __init__(self, features, *args, **kwargs):
        super(WrappedGCN, self).__init__()
        self.features = features
        self.layer = GCNConv(features, *args, **kwargs)
    
    def hook(self, graph):
        features = graph.features.to_tensor()
        features = tf.reshape(features, (1, -1, features.shape[1]))
        adj = tf.cast(graph.adjacency.to_tensor(), tf.float32)
        adj = tf.reshape(adj, (1, adj.shape[0], adj.shape[1]))
        return tf.RaggedTensor.from_tensor(tf.squeeze(self.layer([features, adj])))

    def __call__(self, graph):
        if isinstance(graph, TensorGraph):
            features = tf.map_fn(self.hook, graph, tf.RaggedTensorSpec(
                shape=(None, self.features), dtype=tf.float32))
            return TensorGraph(
                features=features,
                adjacency=graph.adjacency)
        return self.layer(graph)

x0 = TensorGraph(adjacency=tf.ragged.stack(adjs),
                    features=tf.ragged.stack(features))

x1 = WrappedGCN(6)(x0)
x2 = WrappedGCN(6)(x1)

where

class TensorGraph(tf.experimental.BatchableExtensionType):
    """A collection of nodes with associated feature vectors."""
    features: tf.RaggedTensor
    adjacency: tf.RaggedTensor
    
    # TODO: Validation functions etc...

Could probably use Sparse instead of Ragged, but using Ragged here because I need the dense Adjs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants