sketchgraphs_models.nn.distributed.SingleDeviceDistributedParallel¶
-
class
sketchgraphs_models.nn.distributed.SingleDeviceDistributedParallel(module, device_id, find_unused_parameters=False)¶ This module implements a module similar to
DistributedDataParallel, but it accepts inputs of any shape, and only supports a single device per instance.-
__init__(module, device_id, find_unused_parameters=False)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward(*inputs, **kwargs)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
load_state_dict(state_dict, strict=True)¶ Copies parameters and buffers from
state_dictinto this module and its descendants. IfstrictisTrue, then the keys ofstate_dictmust exactly match the keys returned by this module’sstate_dict()function.- Parameters
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dictmatch the keys returned by this module’sstate_dict()function. Default:True
- Returns
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type
NamedTuplewithmissing_keysandunexpected_keysfields
-
state_dict(destination=None, prefix='', keep_vars=False)¶ Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names.
- Returns
a dictionary containing a whole state of the module
- Return type
Example:
>>> module.state_dict().keys() ['bias', 'weight']
-