sketchgraphs_models.nn.summary.ClassificationSummary¶
-
class
sketchgraphs_models.nn.summary.ClassificationSummary(num_outcomes=2, device=None)¶ Simple class to keep track of summaries of a classification problem.
-
__init__(num_outcomes=2, device=None)¶ Initializes a new summary class with the given number of outcomes.
- Parameters
num_outcomes (int) – the number of possible outcomes of the classification problem.
device (torch.device, optional) – device on which to place the recorded statistics.
-
accuracy()¶ Compute the accuracy of the recorded problem.
-
cohen_kappa()¶ Computes the Cohen kappa measure of agreement.
-
confusion_matrix()¶ Returns a
torch.Tensorrepresenting the confusion matrix.
-
marginal_labels()¶ Computes the empirical marginal distribution of the true labels.
-
marginal_predicted()¶ Computes the empirical marginal distribution of the predicted labels.
-
property
prediction_matrix¶ Returns a
torch.Tensorrepresenting the prediction matrix.
-
record_statistics(labels, predictions)¶ Records statistics for a batch of predictions.
- Parameters
labels (torch.Tensor) – an array of true labels in integer format. Each label must correspond to an integer in 0 to num_outcomes - 1 inclusive.
predictions (torch.Tensor) – an array of predicted labels. Must follow the same format as
labels.
-
reset_statistics()¶ Resets statistics recorded in this accumulator.
-
write_tensorboard(writer, prefix='', global_step=None, **kwargs)¶ Write the accuracy and kappa metrics to a tensorboard writer.
- Parameters
writer (torch.utils.tensorboard.SummaryWriter) – The writer to which the metrics will be written
prefix (str, optional) – Optional prefix for the name under which the metrics will be written
global_step (int, optional) – Global step at which the metric is recorded
**kwargs – Further arguments to
torch.utils.tensorboard.SummaryWriter.add_scalar.
-