sketchgraphs_models.autoconstraint.train.AutoconstraintHarness

class sketchgraphs_models.autoconstraint.train.AutoconstraintHarness(model, opt, config_train, config_eval, dist_config, scheduler=None, output_dir=None, profile_enabled=False, additional_model_information=None)
__init__(model, opt, config_train, config_eval, dist_config, scheduler=None, output_dir=None, profile_enabled=False, additional_model_information=None)

Creates a new harness for the given model.

Parameters
on_epoch_end(epoch, global_step)

This function is called at the end of each epoch.

single_step(batch, global_step)

Implements a single step of the model evaluation / training.

Parameters
  • batch (dict) – Input batch from the dataloader

  • global_step (int) – Global step for this batch

Returns

  • losses (dict) – Dictionary of computed losses

  • accuracy (dict) – Dictionary of computed accuracy