sketchgraphs_models.training.TrainingHarness¶
-
class
sketchgraphs_models.training.
TrainingHarness
(model, opt, config_train: sketchgraphs_models.training.TrainingConfig, config_eval: Optional[sketchgraphs_models.training.TrainingConfig] = None, dist_config: Optional[sketchgraphs_models.distributed_utils.DistributedTrainingInfo] = None)¶ This class implements the main training loop.
-
__init__
(model, opt, config_train: sketchgraphs_models.training.TrainingConfig, config_eval: Optional[sketchgraphs_models.training.TrainingConfig] = None, dist_config: Optional[sketchgraphs_models.distributed_utils.DistributedTrainingInfo] = None)¶ Creates a new harness for the given model.
- Parameters
model (torch.nn.Module) – The torch model to train.
opt (torch.optim.Optimizer) – The optimizer to use during training
config_train (TrainingConfig) – The configuration to use for training
config_eval (TrainingConfig, optional) – The configuration to use for evaluation
dist_config (DistributedTrainingInfo, optional) – The configuration used for distributed training
-
on_epoch_end
(epoch, global_step)¶ This function is called at the end of each epoch.
-
run_holdout_eval
(epoch, global_step)¶ Runs the holdout evaluation process.
-
abstract
single_step
(batch, global_step)¶ Implements a single step of the model evaluation / training.
-
train_epochs
(start_epoch=0, global_step=0)¶ Trains the model for a single iteration over the dataloader.
Note that usually, a single iteration over a dataloader represents a single epoch. However, because starting a new epoch is very expensive for the dataloader, we instead allow dataloaders to iterate over multiple epochs at a time.
-