sketchgraphs_models.training.TrainingHarness

class sketchgraphs_models.training.TrainingHarness(model, opt, config_train: sketchgraphs_models.training.TrainingConfig, config_eval: Optional[sketchgraphs_models.training.TrainingConfig] = None, dist_config: Optional[sketchgraphs_models.distributed_utils.DistributedTrainingInfo] = None)

This class implements the main training loop.

__init__(model, opt, config_train: sketchgraphs_models.training.TrainingConfig, config_eval: Optional[sketchgraphs_models.training.TrainingConfig] = None, dist_config: Optional[sketchgraphs_models.distributed_utils.DistributedTrainingInfo] = None)

Creates a new harness for the given model.

Parameters
on_epoch_end(epoch, global_step)

This function is called at the end of each epoch.

run_holdout_eval(epoch, global_step)

Runs the holdout evaluation process.

Parameters
  • epoch (int) – The current epoch of training

  • global_step (int) – The current global step of training

abstract single_step(batch, global_step)

Implements a single step of the model evaluation / training.

Parameters
  • batch (dict) – Input batch from the dataloader

  • global_step (int) – Global step for this batch

Returns

  • losses (dict) – Dictionary of computed losses

  • accuracy (dict) – Dictionary of computed accuracy

train_epochs(start_epoch=0, global_step=0)

Trains the model for a single iteration over the dataloader.

Note that usually, a single iteration over a dataloader represents a single epoch. However, because starting a new epoch is very expensive for the dataloader, we instead allow dataloaders to iterate over multiple epochs at a time.

Parameters
  • start_epoch (int) – The current epoch before training

  • global_step (int) – The current global step before training

Returns

  • epoch (int) – The current epoch after training

  • global_step (int) – The current global step after training