FullyShardedDataParallel¶
- class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=None, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=False)[source]¶
A wrapper for sharding Module parameters across data parallel workers. This is inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. FullyShardedDataParallel is commonly shortened to FSDP.
Example:
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> torch.cuda.set_device(device_id) >>> sharded_module = FSDP(my_module) >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) >>> loss = x.sum() >>> loss.backward() >>> optim.step()
Warning
The optimizer must be initialized after the module has been wrapped, since FSDP will shard parameters in-place and this will break any previously initialized optimizers.
Warning
If the destination CUDA device has ID
dev_id
, either (1)module
should already be placed on that device, (2) the device should be set usingtorch.cuda.set_device(dev_id)
, or (3)dev_id
should be passed into thedevice_id
constructor argument. This FSDP instance’s compute device will be that destination device. For (1) and (3), the FSDP initialization always occurs on GPU. For (2), the FSDP initialization happens onmodule
‘s current device, which may be CPU.Warning
FSDP currently does not support gradient accumulation outside
no_sync()
when using CPU offloading. Trying to do so yields incorrect results since FSDP will use the newly-reduced gradient instead of accumulating with any existing gradient.Warning
Changing the original parameter variable names after construction will lead to undefined behavior.
Warning
Passing in sync_module_states=True flag requires module to be put on GPU, or to use
device_id
argument to specify a CUDA device that FSDP will move module to. This is becausesync_module_states=True
requires GPU communication.Warning
As of PyTorch 1.12, FSDP only offers limited support for shared parameters (for example, setting one
Linear
layer’s weight to another’s). In particular, modules that share parameters must be wrapped as part of the same FSDP unit. If enhanced shared parameter support is needed for your use case, please ping https://github.com/pytorch/pytorch/issues/77724Note
Inputs into FSDP
forward
function will be moved to compute device (same device FSDP module is on) before runningforward
, so user does not have to manually move inputs from CPU -> GPU.- Parameters:
module (nn.Module) – module to be wrapped with FSDP.
process_group (Optional[ProcessGroup]) – process group for sharding
sharding_strategy (Optional[ShardingStrategy]) – Config sharding algorithm, different sharding algorithm has trade off between memory saving and communication overhead.
FULL_SHARD
will be chosen if sharding_strategy is not specified.cpu_offload (Optional[CPUOffload]) – CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It can be enabled via passing in
cpu_offload=CPUOffload(offload_params=True)
. Note that this currently implicitly enables gradient offloading to CPU in order for params and grads to be on same device to work with optimizer. This API is subject to change. Default isNone
in which case there will be no offloading.auto_wrap_policy (Optional[Callable[[nn.Module, bool, int], bool]]) –
A callable specifying a policy to recursively wrap layers with FSDP. Note that this policy currently will only apply to child modules of the passed in module. The remainder modules are always wrapped in the returned FSDP root instance.
size_based_auto_wrap_policy
written intorch.distributed.fsdp.wrap
is an example ofauto_wrap_policy
callable, this policy wraps layers with the number of parameters larger than 100M.transformer_auto_wrap_policy
written intorch.distributed.fsdp.wrap
is an example ofauto_wrap_policy
callable for transformer-like model architectures. Users can supply the customizedauto_wrap_policy
callable that should accept following arguments:module: nn.Module
,recurse: bool
,unwrapped_params: int
, and return abool
specifying whether the passed inmodule`
should be wrapped (ifrecurse=False
) or whether we should recurse down the subgraph ofmodule
children (ifrecurse=True
). Extra customized arguments could be added to the customizedauto_wrap_policy
callable as well. It is a good practice to print out the sharded model and check whether the sharded model is what the application wants and then adjust accordingly.Example:
>>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> unwrapped_params: int, >>> # These are customizable for this policy function. >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return unwrapped_params >= min_num_params >>> # Configure a custom min_num_params >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=1e5)
backward_prefetch (Optional[BackwardPrefetch]) – This is an experimental feature that is subject to change in the the near future. It allows users to enable two different backward_prefetch algorithms to help backward communication and computation overlapping. Pros and cons of each algorithm is explained in the class
BackwardPrefetch
.mixed_precision (Optional[MixedPrecision]) – A
MixedPrecision
instance describing the mixed precision training config to be used.MixedPrecision
supports configuring parameter, buffer, and gradient communication dtype. Note that only floating point data is cast to the reduced precision. This allows users potential memory saving and training speedup while trading off accuracy during model training. IfNone
, no mixed precision is applied. Note that ifmixed_precision
is enabled for FSDP model that containsBatchNorm
withauto_wrap_policy
, FSDP will take care to disable mixed precision forBatchNorm
units by wrapping them separately in their own FSDP unit withmixed_precision=None
. This is done because severalBatchNorm
kernels do not implement reduced type support at the moment. If individually wrapping the model, users must take care to setmixed_precision=None
forBatchNorm
units. (Default:None
)ignored_modules (Optional[Iterable[torch.nn.Module]]) – Modules whose own parameters and child modules’ parameters and buffers are ignored by this instance. None of the modules directly in
ignored_modules
should beFullyShardedDataParallel
instances, and any child modules that are already-constructedFullyShardedDataParallel
instances will not be ignored if they are nested under this instance. This argument may be used to avoid sharding specific parameters at module granularity when using anauto_wrap_policy
or if parameters’ sharding is not managed by FSDP. (Default:None
)param_init_fn (Optional[Callable[[nn.Module], None]]) –
A
Callable[torch.nn.Module] -> None
that specifies how modules that are currently on the meta device should be initialized onto an actual device. Note that as of v1.12, we detect modules on the meta device viais_meta
check and apply a default initialization that callsreset_parameters
method on the passed innn.Module
ifparam_init_fn
is not specified, otherwise we runparam_init_fn
to initialize the passed innn.Module
. In particular, this means that ifis_meta=True
for any module parameters for modules that will be wrapped with FSDP andparam_init_fn
is not specified, we assume your module properly implements areset_paramters()
and will throw errors if not. Note that additionally, we offer support for modules initialized with torchdistX’s (https://github.com/pytorch/torchdistX)deferred_init
API. In this case, deferred modules would be initialized by a default initialization function that calls torchdistX’smaterialize_module
, or the passed inparam_init_fn
, if it is notNone
. The sameCallable
is applied to initialize all meta modules. Note that this initialization function is applied before doing any FSDP sharding logic.Example:
>>> module = MyModule(device="meta") >>> def my_init_fn(module): >>> # responsible for initializing a module, such as with reset_parameters >>> ... >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) >>> print(next(fsdp_model.parameters()).device) # current CUDA device >>> # With torchdistX >>> module = deferred_init.deferred_init(MyModule, device="cuda") >>> # Will initialize via deferred_init.materialize_module(). >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
device_id (Optional[Union[int, torch.device]]) – An
int
ortorch.device
describing the CUDA device the FSDP module should be moved to determining where initialization such as sharding takes place. If this argument is not specified andmodule
is on CPU, we issue a warning mentioning that this argument can be specified for faster initialization. If specified, resulting FSDP instances will reside on this device, including moving ignored modules’ parameters if needed. Note that ifdevice_id
is specified butmodule
is already on a different CUDA device, an error will be thrown. (Default:None
)sync_module_states (bool) – If
True
, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to ensure they are the same across all ranks after initialization. This helps ensure model parameters are the same across ranks before starting training, but adds communication overhead to__init__
, as at least one broadcast is triggered per individually wrapped FSDP unit. This can also help load checkpoints taken bystate_dict
and to be loaded byload_state_dict
in a memory efficient way. See documentation forFullStateDictConfig
for an example of this. (Default:False
)forward_prefetch (bool) – If
True
, then FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. This may improve communication and computation overlap for CPU bound workloads. This should only be used for static graph models since the forward order is fixed based on the first iteration’s execution. (Default:False
)limit_all_gathers (bool) – If
False
, then FSDP allows the CPU thread to schedule all-gathers without any extra synchronization. IfTrue
, then FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. Thisbool
only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number of CUDA malloc retries.
- apply(fn)[source]¶
Applies
fn
recursively to every submodule (as returned by.children()
) as well as self. Typical use includes initializing the parameters of a model (see also torch.nn.init).Compared to
torch.nn.Module.apply
, this version additionally gathers the full parameters before applyingfn
. It should not be called from within anothersummon_full_params
context.- Parameters:
fn (
Module
-> None) – function to be applied to each submodule- Returns:
self
- Return type:
- Module
- clip_grad_norm_(max_norm, norm_type=2.0)[source]¶
Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.
- Parameters:
- Returns:
Total norm of the parameters (viewed as a single vector).
- Return type:
None
Note
This is analogous to
torch.nn.utils.clip_grad_norm_
but handles the partitioning and multiple devices per rank under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads in the model, so calling it for FSDP models would lead to different scaling being applied per subset of model parameters.Warning
This needs to be called on all ranks, since synchronization primitives will be used.
- static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim_input=None, optim=None)[source]¶
The API is similar to
shard_full_optim_state_dict()
. The only difference is that the inputsharded_optim_state_dict
should be returned fromsharded_optim_state_dict()
. Therefore, there will be all-gather calls on each rank to gatherShardedTensor
s.- Parameters:
sharded_optim_state_dict (Dict[str, Any]) – Optimizer state dict corresponding to the unflattened parameters and holding the sharded optimizer state.
model (torch.nn.Module) – Refer to :meth:
shard_full_optim_state_dict
.
- Returns:
Refer to
shard_full_optim_state_dict()
.- Return type:
- forward(*args, **kwargs)[source]¶
Runs the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic.
- Return type:
- static fsdp_modules(module, root_only=False)[source]¶
Returns all nested FSDP instances, possibly including
module
itself and only including FSDP root modules ifroot_only=True
.- Parameters:
module (torch.nn.Module) – Root module, which may or may not be an
FSDP
module.root_only (bool) – Whether to return only FSDP root modules. (Default:
False
)
- Returns:
FSDP modules that are nested in the input
module
.- Return type:
List[FullyShardedDataParallel]
- static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)[source]¶
Consolidates the full optimizer state on rank 0 and returns it as a
dict
following the convention oftorch.optim.Optimizer.state_dict()
, i.e. with keys"state"
and"param_groups"
. The flattened parameters inFSDP
modules contained inmodel
are mapped back to their unflattened parameters.Warning
This needs to be called on all ranks since synchronization primitives are used. However, if
rank0_only=True
, then the state dict is only populated on rank 0, and all other ranks return an emptydict
.Warning
Unlike
torch.optim.Optimizer.state_dict()
, this method uses full parameter names as keys instead of parameter IDs.Note
Like in
torch.optim.Optimizer.state_dict()
, the tensors contained in the optimizer state dict are not cloned, so there may be aliasing surprises. For best practices, consider saving the returned optimizer state dict immediately, e.g. usingtorch.save()
.- Parameters:
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallel
instance) whose parameters were passed into the optimizeroptim
.optim (torch.optim.Optimizer) – Optimizer for
model
‘s parameters.optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer
optim
representing either alist
of parameter groups or an iterable of parameters; ifNone
, then this method assumes the input wasmodel.parameters()
. This argument is deprecated, and there is no need to pass it in anymore. (Default:None
)rank0_only (bool) – If
True
, saves the populateddict
only on rank 0; ifFalse
, saves it on all ranks. (Default:True
)group (dist.ProcessGroup) – Model’s process group or
None
if using the default process group. (Default:None
)
- Returns:
A
dict
containing the optimizer state formodel
‘s original unflattened parameters and including keys “state” and “param_groups” following the convention oftorch.optim.Optimizer.state_dict()
. Ifrank0_only=True
, then nonzero ranks return an emptydict
.- Return type:
Dict[str, Any]
- load_state_dict(state_dict, *args, **kwargs)[source]¶
The entry point of all three FSDP
load_state_dict
APIs. By default, callingload_state_dict
on an FSDP module will result in FSDP attempting to load a “full” state_dict, i.e. a state_dict consisting of full, unsharded, unflattened original module parameters. This requires FSDP to load the full parameter context on each rank which could result in GPU OOM. As a result,state_dict_type()
API is available to configure betweenload_state_dict
implementations. User can thus usewith self.state_dict_type(self, StateDictType.LOCAL_STATE_DICT)
context manager to load a local state dict checkpoint that will restore only local shards of the module. Currently, the only supported implementations areStateDictType.LOCAL_STATE_DICT
andStateDictType.FULL_STATE_DICT
(default). Please seestate_dict()
for documentation around creating an FSDP checkpoint.Example:
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> torch.cuda.set_device(device_id) >>> my_module = nn.Linear(...) >>> sharded_module = FSDP(my_module) >>> checkpoint = torch.load(PATH) >>> full_state_dict = checkpoint['full_state_dict'] >>> with FSDP.state_dict_type(sharded_module, StateDictType.FULL_STATE_DICT): >>> sharded_module.load_state_dict(full_state_dict) >>> full_dict.keys() >>> odict_keys(['weight', 'bias']) >>> # using local state dict >>> local_state_dict = checkpoint['local_state_dict'] >>> with FSDP.state_dict_type(sharded_module, StateDictType.LOCAL_STATE_DICT): >>> sharded_module.load_state_dict(local_state_dict) >>> local_dict.keys() >>> odict_keys(['flat_param', 'inner.flat_param'])
Warning
This needs to be called on all ranks, since synchronization primitives may be used.
- Return type:
- named_buffers(*args, **kwargs)[source]¶
Overrides
named_buffers()
to intercept buffer names and remove all occurrences of the FSDP-specific flattened buffer prefix when inside thesummon_full_params()
context manager.
- named_parameters(*args, **kwargs)[source]¶
Overrides
named_parameters()
to intercept parameter names and remove all occurrences of the FSDP-specific flattened parameter prefix when inside thesummon_full_params()
context manager.
- no_sync()[source]¶
A context manager to disable gradient synchronizations across FSDP instances. Within this context, gradients will be accumulated in module variables, which will later be synchronized in the first forward-backward pass after exiting the context. This should only be used on the root FSDP instance and will recursively apply to all children FSDP instances.
Note
This likely results in higher memory usage because FSDP will accumulate the full model gradients (instead of gradient shards) until the eventual sync.
Note
When used with CPU offloading, the gradients will not be offloaded to CPU when inside the context manager. Instead, they will only be offloaded right after the eventual sync.
- Return type:
- property params_with_grad: List[Parameter]¶
Recursively returns a list of all module parameters that have a gradient.
- register_comm_hook(state, hook)[source]¶
Registers a communication hook which is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates gradients across multiple workers. This hook can be used to implement several algorithms like GossipGrad and gradient compression which involve different communication strategies for parameter syncs while training with
FullyShardedDataParallel
.Warning
FSDP communication hook should be registered before running an initial forward pass and only once.
- Parameters:
state (object) –
Passed to the hook to maintain any state information during the training process. Examples include error feedback in gradient compression, peers to communicate with next in GossipGrad, etc. It is locally stored by each worker and shared by all the gradient tensors on the worker.
hook (Callable) – Callable, which has one of the following signatures: 1)
hook: Callable[torch.Tensor] -> None
: This function takes in a Python tensor, which represents the full, flattened, unsharded gradient with respect to all variables corresponding to the model this FSDP unit is wrapping (that are not wrapped by other FSDP sub-units). It then performs all necessary processing and returnsNone
; 2)hook: Callable[torch.Tensor, torch.Tensor] -> None
: This function takes in two Python tensors, the first one represents the full, flattened, unsharded gradient with respect to all variables corresponding to the model this FSDP unit is wrapping (that are not wrapped by other FSDP sub-units). The latter represents a pre-sized tensor to store a chunk of a sharded gradient after reduction. In both cases, callable performs all necessary processing and returnsNone
. Callables with signature 1 are expected to handle gradient communication for a NO_SHARD case. Callables with signature 2 are expected to handle gradient communication for sharded cases.
- static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)[source]¶
Re-keys the optimizer state dict
optim_state_dict
to use the key typeoptim_state_key_type
. This can be used to achieve compatibility between optimizer state dicts from models with FSDP instances and ones without.To re-key an FSDP full optimizer state dict (i.e. from
full_optim_state_dict()
) to use parameter IDs and be loadable to a non-wrapped model:>>> wrapped_model, wrapped_optim = ... >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) >>> nonwrapped_model, nonwrapped_optim = ... >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) >>> nonwrapped_optim.load_state_dict(rekeyed_osd)
To re-key a normal optimizer state dict from a non-wrapped model to be loadable to a wrapped model:
>>> nonwrapped_model, nonwrapped_optim = ... >>> osd = nonwrapped_optim.state_dict() >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) >>> wrapped_model, wrapped_optim = ... >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) >>> wrapped_optim.load_state_dict(sharded_osd)
- Returns:
The optimizer state dict re-keyed using the parameter keys specified by
optim_state_key_type
.- Return type:
Dict[str, Any]
- static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)[source]¶
Scatters the full optimizer state dict from rank 0 to all other ranks, returning the sharded optimizer state dict on each rank. The return value is the same as
shard_full_optim_state_dict()
, and on rank 0, the first argument should be the return value offull_optim_state_dict()
.Example:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 >>> # Define new model with possibly different world size >>> new_model, new_optim, new_group = ... >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) >>> new_optim.load_state_dict(sharded_osd)
Note
Both
shard_full_optim_state_dict()
andscatter_full_optim_state_dict()
may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost.- Parameters:
full_optim_state_dict (Optional[Dict[str, Any]]) – Optimizer state dict corresponding to the unflattened parameters and holding the full non-sharded optimizer state if on rank 0; the argument is ignored on nonzero ranks.
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallel
instance) whose parameters correspond to the optimizer state infull_optim_state_dict
.optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer representing either a
list
of parameter groups or an iterable of parameters; ifNone
, then this method assumes the input wasmodel.parameters()
. This argument is deprecated, and there is no need to pass it in anymore. (Default:None
)optim (Optional[torch.optim.Optimizer]) – Optimizer that will load the state dict returned by this method. This is the preferred argument to use over
optim_input
. (Default:None
)group (dist.ProcessGroup) – Model’s process group or
None
if using the default process group. (Default:None
)
- Returns:
The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank’s part of the optimizer state.
- Return type:
Dict[str, Any]
- static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)[source]¶
Shards the full optimizer state dict
full_optim_state_dict
by remapping the state to flattened parameters instead of unflattened parameters and restricting to only this rank’s part of the optimizer state. The first argument should be the return value offull_optim_state_dict()
.Example:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) >>> torch.save(full_osd, PATH) >>> # Define new model with possibly different world size >>> new_model, new_optim = ... >>> full_osd = torch.load(PATH) >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) >>> new_optim.load_state_dict(sharded_osd)
Note
Both
shard_full_optim_state_dict()
andscatter_full_optim_state_dict()
may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost.- Parameters:
full_optim_state_dict (Dict[str, Any]) – Optimizer state dict corresponding to the unflattened parameters and holding the full non-sharded optimizer state.
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallel
instance) whose parameters correspond to the optimizer state infull_optim_state_dict
.optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer representing either a
list
of parameter groups or an iterable of parameters; ifNone
, then this method assumes the input wasmodel.parameters()
. This argument is deprecated, and there is no need to pass it in anymore. (Default:None
)optim (Optional[torch.optim.Optimizer]) – Optimizer that will load the state dict returned by this method. This is the preferred argument to use over
optim_input
. (Default:None
)
- Returns:
The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank’s part of the optimizer state.
- Return type:
Dict[str, Any]
- static sharded_optim_state_dict(model, optim, optim_input=None, group=None)[source]¶
The API is similar to
full_optim_state_dict()
but this API chunks all non-zero-dimension states toShardedTensor
to save memory. This API should only be used when the modelstate_dict
is derived with the context managerwith state_dict_type(SHARDED_STATE_DICT):
.For the detailed usage, refer to
full_optim_state_dict()
.Warning
The returned state dict contains
ShardedTensor
and cannot be directly used by the regularoptim.load_state_dict
.
- state_dict(*args, **kwargs)[source]¶
This is the entry point of all three FSDP
state_dict
APIs: full, local, and sharded. For the full state dict (StateDictType.FULL_STATE_DICT
), FSDP attempts to unshard the model on all ranks, which may result in an OOM error if the full model cannot fit on a single GPU. In that case, users may pass in aFullStateDictConfig
to only save the checkpoint on rank 0 and/ or to offload it to CPU memory layer by layer, enabling much larger checkpoints. If the full model cannot fit in CPU memory, then users may instead take a local state dict (StateDictType.LOCAL_STATE_DICT
) that only saves the local shard of the model. The sharded state dict (StateDictType.SHARDED_STATE_DICT
) saves the model parameters asShardedTensor
s. Thestate_dict
type can be configured using thestate_dict_type()
context manager.Example:
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> torch.cuda.set_device(device_id) >>> my_module = nn.Linear(...) >>> sharded_module = FSDP(my_module) >>> full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) >>> with FSDP.state_dict_type(sharded_module, StateDictType.FULL_STATE_DICT, full_state_dict_config): >>> full_dict = sharded_module.state_dict() >>> full_dict.keys() >>> odict_keys(['weight', 'bias']) >>> # using local state dict >>> with FSDP.state_dict_type(sharded_module, StateDictType.LOCAL_STATE_DICT): >>> local_dict = sharded_module.state_dict() >>> local_dict.keys() >>> odict_keys(['flat_param', 'inner.flat_param'])
Warning
This needs to be called on all ranks, since synchronization primitives may be used.
- static state_dict_type(module, state_dict_type, state_dict_config=None)[source]¶
A context manager to set the
state_dict_type
of all the descendant FSDP modules of the target module. The target module does not have to be a FSDP module. If the target module is a FSDP module, itsstate_dict_type
will also be changed.Note
This API should be called for only the top-level (root) module.
Note
This API enables users to transparently use the conventional
state_dict
API to take model checkpoints in cases where the root FSDP module is wrapped by anothernn.Module
. For example, the following will ensurestate_dict
is called on all non-FSDP instances, while dispatching into local_state_dict implementation for FSDP:Example:
>>> model = DDP(FSDP(...)) >>> with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): >>> checkpoint = model.state_dict()
- Parameters:
module (torch.nn.Module) – Root module.
state_dict_type (StateDictType) – the desired
state_dict_type
to set.
- Return type:
- static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False)[source]¶
A context manager to expose full params for FSDP instances. Can be useful after forward/backward for a model to get the params for additional processing or checking. It can take a non-FSDP module and will summon full params for all contained FSDP modules as well as their children, depending on the
recurse
argument.Note
This can be used on inner FSDPs.
Note
This can not be used within a forward or backward pass. Nor can forward and backward be started from within this context.
Note
Parameters will revert to their local shards after the context manager exits, storage behavior is the same as forward.
Note
The full parameters can be modified, but only the portion corresponding to the local param shard will persist after the context manager exits (unless
writeback=False
, in which case changes will be discarded). In the case where FSDP does not shard the parameters, currently only whenworld_size == 1
, orNO_SHARD
config, the modification is persisted regardless ofwriteback
.Note
This method works on modules which are not FSDP themselves but may contain multiple independent FSDP units. In that case, the given arguments will apply to all contained FSDP units.
Warning
Note that
rank0_only=True
in conjunction withwriteback=True
is not currently supported and will raise an error. This is because model parameter shapes would be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited.Warning
Note that
offload_to_cpu
andrank0_only=False
will result in full parameters being redundantly copied to CPU memory for GPUs that reside on the same machine, which may incur the risk of CPU OOM. It is recommended to useoffload_to_cpu
withrank0_only=True
.- Parameters:
recurse (bool, Optional) – recursively summon all params for nested FSDP instances (default: True).
writeback (bool, Optional) – if
False
, modifications to params are discarded after the context manager exits; disabling this can be slightly more efficient (default: True)rank0_only (bool, Optional) – if
True
, full parameters are materialized on only global rank 0. This means that within the context, only rank 0 will have full parameters and the other ranks will have sharded parameters. Note that settingrank0_only=True
withwriteback=True
is not supported, as model parameter shapes will be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited.offload_to_cpu (bool, Optional) – If
True
, full parameters are offloaded to CPU. Note that this offloading currently only occurs if the parameter is sharded (which is only not the case for world_size = 1 orNO_SHARD
config). It is recommended to useoffload_to_cpu
withrank0_only=True
to avoid redundant copies of model parameters being offloaded to the same CPU memory.
- Return type: