sketchgraphs_models.nn.summary.ClassificationSummary¶
-
class
sketchgraphs_models.nn.summary.
ClassificationSummary
(num_outcomes=2, device=None)¶ Simple class to keep track of summaries of a classification problem.
-
__init__
(num_outcomes=2, device=None)¶ Initializes a new summary class with the given number of outcomes.
- Parameters
num_outcomes (int) – the number of possible outcomes of the classification problem.
device (torch.device, optional) – device on which to place the recorded statistics.
-
accuracy
()¶ Compute the accuracy of the recorded problem.
-
cohen_kappa
()¶ Computes the Cohen kappa measure of agreement.
-
confusion_matrix
()¶ Returns a
torch.Tensor
representing the confusion matrix.
-
marginal_labels
()¶ Computes the empirical marginal distribution of the true labels.
-
marginal_predicted
()¶ Computes the empirical marginal distribution of the predicted labels.
-
property
prediction_matrix
¶ Returns a
torch.Tensor
representing the prediction matrix.
-
record_statistics
(labels, predictions)¶ Records statistics for a batch of predictions.
- Parameters
labels (torch.Tensor) – an array of true labels in integer format. Each label must correspond to an integer in 0 to num_outcomes - 1 inclusive.
predictions (torch.Tensor) – an array of predicted labels. Must follow the same format as
labels
.
-
reset_statistics
()¶ Resets statistics recorded in this accumulator.
-
write_tensorboard
(writer, prefix='', global_step=None, **kwargs)¶ Write the accuracy and kappa metrics to a tensorboard writer.
- Parameters
writer (torch.utils.tensorboard.SummaryWriter) – The writer to which the metrics will be written
prefix (str, optional) – Optional prefix for the name under which the metrics will be written
global_step (int, optional) – Global step at which the metric is recorded
**kwargs – Further arguments to
torch.utils.tensorboard.SummaryWriter.add_scalar
.
-