sketchgraphs_models.nn.distributed.SingleDeviceDistributedParallel¶
-
class
sketchgraphs_models.nn.distributed.
SingleDeviceDistributedParallel
(module, device_id, find_unused_parameters=False)¶ This module implements a module similar to
DistributedDataParallel
, but it accepts inputs of any shape, and only supports a single device per instance.-
__init__
(module, device_id, find_unused_parameters=False)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(*inputs, **kwargs)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
load_state_dict
(state_dict, strict=True)¶ Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.- Parameters
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
- Returns
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type
NamedTuple
withmissing_keys
andunexpected_keys
fields
-
state_dict
(destination=None, prefix='', keep_vars=False)¶ Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names.
- Returns
a dictionary containing a whole state of the module
- Return type
Example:
>>> module.state_dict().keys() ['bias', 'weight']
-