sketchgraphs_models.graph.model.losses¶
This module implements the necessary code to compute losses from the output of the graph model.
Functions
-
sketchgraphs_models.graph.model.losses.
compute_average_losses
(losses, graph_counts)¶ Computes average losses from sum losses.
-
sketchgraphs_models.graph.model.losses.
compute_edge_losses
(edge_partner_offsets, edge_label_logits, partner_logits, edge_label, edge_partner)¶ Computes losses associated with edge prediction problems.
This function computes the two losses associated with the edge problem: -
edge_label
: this loss represents the loss for predicting the label of the edge -edge_partner
: this loss represents the loss for predicting the other vertex of the edge
-
sketchgraphs_models.graph.model.losses.
compute_feature_loss
(feature_logits: torch.Tensor, feature_targets: torch.Tensor, feature_dimensions: Dict[Any, int])¶ Computes losses on the given features.
- Parameters
feature_logits (torch.Tensor) – 2-d tensor of feature logits
feature_targets (torch.Tensor) – 1-d integer tensor of true feature labels
feature_dimensions (Dict[str, int]) – A list of integers representing the dimension of each feature.
- Returns
losses (torch.Tensor) – An array of losses on each feature
accuracies (torch.Tensor) – An array of accuracies on each feature
labels (torch.Tensor) – The provided array of targets
predictions (torch.Tensor) – An array of predicted labels according to arg-max predictions.
-
sketchgraphs_models.graph.model.losses.
compute_losses
(readout, batch, feature_dimensions, weights=None)¶ Computes losses for each component of the model, given the model output.
Note that this function computes losses using a sum reduction. To obtain equivalent average losses, see
compute_average_losses
.This function also returns labels and predictions for numerical edges’ numerical features.
- Parameters
readout (dict) – A dictionary containing the model output
batch (dict) – A dictionary containing the input data
feature_dimensions (dict) – A dictionary of list of ints describing the dimension of each feature for each target.
weights (dict, optional) – If not None, a mapping from TargetType to floats which describes the weighting of each prediction endpoint. TargetTypes which are not included are assumed to be weighted at 1.
- Returns
losses (dict) – Nested dictionary containing loss values for each component
accuracy (dict) – Nested dictionary containing accuracy for each component
efeat_labels (dict) – Dictionary containing labels for each numerical edge feature.
efeat_preds (dict) – Dictionary containing predictions for each numerical edge feature.
-
sketchgraphs_models.graph.model.losses.
compute_node_losses
(node_offsets, entity_logits, partner_logits, node_label)¶
-
sketchgraphs_models.graph.model.losses.
compute_subnode_losses
(subnode_offsets, partner_logits)¶
-
sketchgraphs_models.graph.model.losses.
merge_losses_and_accuracy
(losses, accuracy, updates, weight=None)¶
-
sketchgraphs_models.graph.model.losses.
segment_stop_accuracy
(partner_logits, segment_offsets, target_idx=None)¶ Computes the accuracy for stop prediction for partner logits.
- Parameters
partner_logits (torch.Tensor) – The un-normalized logits for the segments.
segment_offsets (torch.Tensor) – A tensor of shape
[num_segments + 1]
indicating the segment offsets.target_idx (torch.Tensor) – If not None, a tensor of shape
[num_segments]
representing the target offset to predict in each segment. Otherwise, this is assumed to be the last (implicit) entry in each segment.
- Returns
A boolean tensor of shape
[num_segments]
representing the accuracy at each segment.- Return type
-
sketchgraphs_models.graph.model.losses.
segment_stop_loss
(partner_logits, segment_offsets, target_idx=None)¶ Computes the loss for stop prediction for partner logits.
This function effectively corresponds to a cross-entropy softmax operation on each segment, where the cross-entropy is augmented with one last constant logit. If the target predicted is omitted, we assume that it is the last (implicit) entry.
- Parameters
partner_logits (torch.Tensor) – The un-normalized logits for the segments.
segment_offsets (torch.Tensor) – A tensor of shape
[num_segments + 1]
indicating the segment offsets.target_idx (torch.Tensor, optional) – If not None, a tensor of shape
[num_segments]
representing the target offset to predict in each segment. Otherwise, this is assumed to be the last (implicit) entry in each segment.
- Returns
A tensor of shape
[num_segments]
representing the cross-entropy loss at each segment.- Return type