sketchgraphs_models.graph.model.GraphModel¶
-
class
sketchgraphs_models.graph.model.
GraphModel
(model_core, entity_label, entity_feature_readout, edge_post_embedding, edge_label, edge_feature_readout, edge_partner)¶ Main model class for the graph-based models.
-
model_core
¶ This module is the core of the network, and computes the main node and global embeddings in the graph.
-
entity_label
¶ This module computes the predictions for the label of the next node based on the global embedding.
- Type
-
entity_feature_readout
¶ This dictionary contains feature readouts for each entity types, indexed by entity type.
- Type
Dict[TargetType, numerical_features.NumericalFeatureReadout], optional
-
edge_post_embedding
¶ This module computes an edge post embedding based one the node embedding of the two entities forming the edge.
- Type
-
edge_label
¶ This module computes the edge label from the embedding of the edge in question and the global embedding for the graph.
- Type
-
edge_feature_readout
¶ This dictionary contains feature readout for each edge type, indexed by the edge type.
- Type
Dict[TartgetType, numerical_features.NumericalFeatureReadout], optional
-
edge_partner
¶ This module computes the other vertex of a new added edge (on vertex is always the last vertex).
- Type
-
__init__
(model_core, entity_label, entity_feature_readout, edge_post_embedding, edge_label, edge_feature_readout, edge_partner)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(data)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-