sketchgraphs_models.nn.summary.ClassificationSummary

class sketchgraphs_models.nn.summary.ClassificationSummary(num_outcomes=2, device=None)

Simple class to keep track of summaries of a classification problem.

__init__(num_outcomes=2, device=None)

Initializes a new summary class with the given number of outcomes.

Parameters
  • num_outcomes (int) – the number of possible outcomes of the classification problem.

  • device (torch.device, optional) – device on which to place the recorded statistics.

accuracy()

Compute the accuracy of the recorded problem.

cohen_kappa()

Computes the Cohen kappa measure of agreement.

confusion_matrix()

Returns a torch.Tensor representing the confusion matrix.

marginal_labels()

Computes the empirical marginal distribution of the true labels.

marginal_predicted()

Computes the empirical marginal distribution of the predicted labels.

property prediction_matrix

Returns a torch.Tensor representing the prediction matrix.

record_statistics(labels, predictions)

Records statistics for a batch of predictions.

Parameters
  • labels (torch.Tensor) – an array of true labels in integer format. Each label must correspond to an integer in 0 to num_outcomes - 1 inclusive.

  • predictions (torch.Tensor) – an array of predicted labels. Must follow the same format as labels.

reset_statistics()

Resets statistics recorded in this accumulator.

write_tensorboard(writer, prefix='', global_step=None, **kwargs)

Write the accuracy and kappa metrics to a tensorboard writer.

Parameters
  • writer (torch.utils.tensorboard.SummaryWriter) – The writer to which the metrics will be written

  • prefix (str, optional) – Optional prefix for the name under which the metrics will be written

  • global_step (int, optional) – Global step at which the metric is recorded

  • **kwargs – Further arguments to torch.utils.tensorboard.SummaryWriter.add_scalar.