sketchgraphs_models.graph.model.losses

This module implements the necessary code to compute losses from the output of the graph model.

Functions

sketchgraphs_models.graph.model.losses.compute_average_losses(losses, graph_counts)

Computes average losses from sum losses.

Parameters
  • losses (dict) – Dictionary containing components for each loss

  • graph_counts (Mapping) – Mapping from target types to integers representing the number of aggregated examples for each target type.

Returns

A dictionary containing the average loss values

Return type

dict

sketchgraphs_models.graph.model.losses.compute_edge_losses(edge_partner_offsets, edge_label_logits, partner_logits, edge_label, edge_partner)

Computes losses associated with edge prediction problems.

This function computes the two losses associated with the edge problem: - edge_label: this loss represents the loss for predicting the label of the edge - edge_partner: this loss represents the loss for predicting the other vertex of the edge

sketchgraphs_models.graph.model.losses.compute_feature_loss(feature_logits: torch.Tensor, feature_targets: torch.Tensor, feature_dimensions: Dict[Any, int])

Computes losses on the given features.

Parameters
  • feature_logits (torch.Tensor) – 2-d tensor of feature logits

  • feature_targets (torch.Tensor) – 1-d integer tensor of true feature labels

  • feature_dimensions (Dict[str, int]) – A list of integers representing the dimension of each feature.

Returns

  • losses (torch.Tensor) – An array of losses on each feature

  • accuracies (torch.Tensor) – An array of accuracies on each feature

  • labels (torch.Tensor) – The provided array of targets

  • predictions (torch.Tensor) – An array of predicted labels according to arg-max predictions.

sketchgraphs_models.graph.model.losses.compute_losses(readout, batch, feature_dimensions, weights=None)

Computes losses for each component of the model, given the model output.

Note that this function computes losses using a sum reduction. To obtain equivalent average losses, see compute_average_losses.

This function also returns labels and predictions for numerical edges’ numerical features.

Parameters
  • readout (dict) – A dictionary containing the model output

  • batch (dict) – A dictionary containing the input data

  • feature_dimensions (dict) – A dictionary of list of ints describing the dimension of each feature for each target.

  • weights (dict, optional) – If not None, a mapping from TargetType to floats which describes the weighting of each prediction endpoint. TargetTypes which are not included are assumed to be weighted at 1.

Returns

  • losses (dict) – Nested dictionary containing loss values for each component

  • accuracy (dict) – Nested dictionary containing accuracy for each component

  • efeat_labels (dict) – Dictionary containing labels for each numerical edge feature.

  • efeat_preds (dict) – Dictionary containing predictions for each numerical edge feature.

sketchgraphs_models.graph.model.losses.compute_node_losses(node_offsets, entity_logits, partner_logits, node_label)
sketchgraphs_models.graph.model.losses.compute_subnode_losses(subnode_offsets, partner_logits)
sketchgraphs_models.graph.model.losses.merge_losses_and_accuracy(losses, accuracy, updates, weight=None)
sketchgraphs_models.graph.model.losses.segment_stop_accuracy(partner_logits, segment_offsets, target_idx=None)

Computes the accuracy for stop prediction for partner logits.

Parameters
  • partner_logits (torch.Tensor) – The un-normalized logits for the segments.

  • segment_offsets (torch.Tensor) – A tensor of shape [num_segments + 1] indicating the segment offsets.

  • target_idx (torch.Tensor) – If not None, a tensor of shape [num_segments] representing the target offset to predict in each segment. Otherwise, this is assumed to be the last (implicit) entry in each segment.

Returns

A boolean tensor of shape [num_segments] representing the accuracy at each segment.

Return type

torch.Tensor

sketchgraphs_models.graph.model.losses.segment_stop_loss(partner_logits, segment_offsets, target_idx=None)

Computes the loss for stop prediction for partner logits.

This function effectively corresponds to a cross-entropy softmax operation on each segment, where the cross-entropy is augmented with one last constant logit. If the target predicted is omitted, we assume that it is the last (implicit) entry.

Parameters
  • partner_logits (torch.Tensor) – The un-normalized logits for the segments.

  • segment_offsets (torch.Tensor) – A tensor of shape [num_segments + 1] indicating the segment offsets.

  • target_idx (torch.Tensor, optional) – If not None, a tensor of shape [num_segments] representing the target offset to predict in each segment. Otherwise, this is assumed to be the last (implicit) entry in each segment.

Returns

A tensor of shape [num_segments] representing the cross-entropy loss at each segment.

Return type

torch.Tensor