sketchgraphs_models.autoconstraint.train.AutoconstraintHarness¶
-
class
sketchgraphs_models.autoconstraint.train.
AutoconstraintHarness
(model, opt, config_train, config_eval, dist_config, scheduler=None, output_dir=None, profile_enabled=False, additional_model_information=None)¶ -
__init__
(model, opt, config_train, config_eval, dist_config, scheduler=None, output_dir=None, profile_enabled=False, additional_model_information=None)¶ Creates a new harness for the given model.
- Parameters
model (torch.nn.Module) – The torch model to train.
opt (torch.optim.Optimizer) – The optimizer to use during training
config_train (TrainingConfig) – The configuration to use for training
config_eval (TrainingConfig, optional) – The configuration to use for evaluation
dist_config (DistributedTrainingInfo, optional) – The configuration used for distributed training
-
on_epoch_end
(epoch, global_step)¶ This function is called at the end of each epoch.
-
single_step
(batch, global_step)¶ Implements a single step of the model evaluation / training.
-