sketchgraphs_models.graph.model.GraphModel

class sketchgraphs_models.graph.model.GraphModel(model_core, entity_label, entity_feature_readout, edge_post_embedding, edge_label, edge_feature_readout, edge_partner)

Main model class for the graph-based models.

model_core

This module is the core of the network, and computes the main node and global embeddings in the graph.

Type

message_passing.GraphModelCore

entity_label

This module computes the predictions for the label of the next node based on the global embedding.

Type

torch.nn.Module

entity_feature_readout

This dictionary contains feature readouts for each entity types, indexed by entity type.

Type

Dict[TargetType, numerical_features.NumericalFeatureReadout], optional

edge_post_embedding

This module computes an edge post embedding based one the node embedding of the two entities forming the edge.

Type

torch.nn.Module

edge_label

This module computes the edge label from the embedding of the edge in question and the global embedding for the graph.

Type

torch.nn.Module

edge_feature_readout

This dictionary contains feature readout for each edge type, indexed by the edge type.

Type

Dict[TartgetType, numerical_features.NumericalFeatureReadout], optional

edge_partner

This module computes the other vertex of a new added edge (on vertex is always the last vertex).

Type

EdgePartnerNetwork

__init__(model_core, entity_label, entity_feature_readout, edge_post_embedding, edge_label, edge_feature_readout, edge_partner)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(data)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.