Shortcuts

Source code for torch.distributed.checkpoint.default_planner

# Copyright (c) Meta Platforms, Inc. and affiliates

import dataclasses
import io
import logging
import operator
from collections import ChainMap
from functools import reduce
from typing import Any, cast, Dict, List, Optional, Tuple, Union

import torch
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
from torch.distributed.checkpoint._nested_dict import (
    FLATTEN_MAPPING,
    flatten_state_dict,
)
from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
from torch.distributed.checkpoint._traverse import set_element
from torch.distributed.checkpoint.metadata import (
    BytesStorageMetadata,
    ChunkStorageMetadata,
    Metadata,
    MetadataIndex,
    STATE_DICT_TYPE,
    STORAGE_TYPES,
    TensorStorageMetadata,
)
from torch.distributed.checkpoint.planner import (
    LoadPlan,
    LoadPlanner,
    ReadItem,
    SavePlan,
    SavePlanner,
    WriteItem,
    WriteItemType,
)
from torch.distributed.checkpoint.planner_helpers import (
    _create_default_metadata_only_plan,
    _create_read_items,
    _create_write_items,
    _init_state_dict,
)
from torch.distributed.checkpoint.utils import find_state_dict_object

logger: logging.Logger = logging.getLogger(__name__)


__all__ = [
    "DefaultSavePlanner",
    "DefaultLoadPlanner",
    "create_default_local_load_plan",
    "create_default_global_load_plan",
    "create_default_local_save_plan",
    "create_default_global_save_plan",
]


# TODO: Update docstrings for default_planner.py
[docs]class DefaultSavePlanner(SavePlanner): mappings: FLATTEN_MAPPING def __init__( self, flatten_state_dict: bool = True, flatten_sharded_tensors: bool = True, dedup_replicated_tensors: Optional[bool] = None, ) -> None: self.flatten_state_dict = flatten_state_dict self.flatten_sharded_tensors = flatten_sharded_tensors self.mappings = {} if dedup_replicated_tensors is not None: logger.warning( "DefaultSavePlanner's `dedup_replicated_tensors` argument is being " "deprecated, and no longer has any effect. Please remove this argument " "from your call." ) def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None: if self.flatten_state_dict: state_dict, self.mappings = flatten_state_dict(state_dict) if self.flatten_sharded_tensors: state_dict = _flatten_sharded_tensors(state_dict) self.state_dict = state_dict self.is_coordinator = is_coordinator def create_local_plan(self) -> SavePlan: plan = create_default_local_save_plan(self.state_dict, self.is_coordinator) if self.flatten_state_dict: plan = dataclasses.replace(plan, planner_data=self.mappings) self.plan = plan return self.plan def create_global_plan( self, all_plans: List[SavePlan] ) -> Tuple[List[SavePlan], Metadata]: all_plans = dedup_save_plans(all_plans) global_plan, metadata = create_default_global_save_plan(all_plans) if self.flatten_state_dict: # | does not work for Python 3.8 or older version. # merged_mappings = reduce( # lambda x, y: x | y, (p.planner_data for p in global_plan) # ) planner_data_dict = [p.planner_data for p in global_plan] merged_mappings = dict(ChainMap(*planner_data_dict)) metadata = dataclasses.replace(metadata, planner_data=merged_mappings) if not _validate_global_plan(global_plan, metadata): raise ValueError("Failed to validate global plan") self.global_plan = global_plan self.metadata = metadata return self.global_plan, self.metadata def finish_plan(self, new_plan: SavePlan) -> SavePlan: self.plan = new_plan return new_plan def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]: object = self.lookup_object(write_item.index) return self.transform_object(write_item, object)
[docs] def lookup_object(self, index: MetadataIndex) -> Any: """Extension from the planner interface to make it easy to extend the default planner.""" return find_state_dict_object(self.state_dict, index)
[docs] def transform_object(self, write_item: WriteItem, object: Any): """Extension from the planner interface to make it easy to extend the default planner.""" if write_item.type == WriteItemType.BYTE_IO: bytes = io.BytesIO() torch.save(object, bytes) object = bytes return object
[docs]class DefaultLoadPlanner(LoadPlanner): """ DefaultLoadPlanner that adds multiple features on top of LoadPlanner. In particular it adds the following: flatten_state_dict: Handle state_dict with nested dicts flatten_sharded_tensors: For FSDP in 2D parallel mode """ original_state_dict: STATE_DICT_TYPE mappings: FLATTEN_MAPPING def __init__( self, flatten_state_dict: bool = True, flatten_sharded_tensors: bool = True, ) -> None: self.flatten_state_dict = flatten_state_dict self.flatten_sharded_tensors = flatten_sharded_tensors self.original_state_dict = {} self.mappings = {} def set_up_planner( self, state_dict: STATE_DICT_TYPE, metadata: Metadata, is_coordinator: bool, ) -> None: _init_state_dict(state_dict) self.original_state_dict = state_dict if self.flatten_sharded_tensors: state_dict = _flatten_sharded_tensors(state_dict) if self.flatten_state_dict: state_dict, self.mappings = flatten_state_dict(state_dict) self.state_dict = state_dict self.metadata = metadata self.is_coordinator = is_coordinator def create_local_plan(self) -> LoadPlan: return create_default_local_load_plan(self.state_dict, self.metadata) def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: return create_default_global_load_plan(global_plan) def finish_plan(self, new_plan: LoadPlan) -> LoadPlan: return new_plan def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None: if self.flatten_state_dict: set_element( self.original_state_dict, self.mappings[read_item.dest_index.fqn], torch.load(value), ) else: self.state_dict[read_item.dest_index.fqn] = torch.load(value) def resolve_tensor(self, read_item: ReadItem): tensor = self.lookup_tensor(read_item.dest_index) return self.transform_tensor(read_item, tensor) def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: pass
[docs] def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: """Extension from the planner interface to make it easy to extend the default planner.""" return find_state_dict_object(self.state_dict, index)
[docs] def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor): """Extension from the planner interface to make it easy to extend the default planner.""" return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths)
class _EmptyStateDictLoadPlanner(DefaultLoadPlanner): """ Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata. Useful for loading in state_dict without first initializing a model, such as when converting a DCP checkpoint into a Torch save file. . N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner .. warning:: Because the entire state dict is initialized, It's recommended to only utilize this LoadPlanner on a single rank or process to avoid OOM. """ def __init__(self, keys=None, *args, **kwargs): self.keys = keys super().__init__(*args, **kwargs) def set_up_planner( self, state_dict: STATE_DICT_TYPE, metadata: Metadata, is_coordinator: bool, ) -> None: assert not state_dict # rebuild the state dict from the metadata for k, v in metadata.state_dict_metadata.items(): if self.keys and k not in self.keys: continue if isinstance(v, TensorStorageMetadata): v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment] if k in metadata.planner_data: set_element(state_dict, metadata.planner_data[k], v) else: state_dict[k] = v super().set_up_planner(state_dict, metadata, is_coordinator) def create_default_local_load_plan( state_dict: Dict[str, Any], metadata: Metadata, ) -> LoadPlan: requests = [] """ Create the ``LoadPlan`` used by DefaultLoadPlanner. It produces one read item per value in ``state_dict`` using the metadata in ``metadata``. The default behavior is to match key exactly between state_dict and metadata. It handles resharding by issuing multiple read requests against storage in order to match load requirements. """ for fqn, obj in state_dict.items(): md = metadata.state_dict_metadata[fqn] # Since DTensor supports submesh, adding extra check to ensure _create_read_items() # gets called only when the current rank is part of the mesh for the corresponding DTensor. if isinstance(obj, DTensor): if obj.device_mesh.get_coordinate() is not None: requests += _create_read_items(fqn, md, obj) else: requests += _create_read_items(fqn, md, obj) return LoadPlan(requests) def create_default_global_load_plan( all_plans: List[LoadPlan], ) -> List[LoadPlan]: """ Create global load plan used by DefaultLoadPlanner. The default load behavior involved no global coordination and this function currently doesn't change the local plans. """ return all_plans def create_default_local_save_plan( state_dict: Dict[str, Any], is_coordinator: bool ) -> SavePlan: """ Create the ``SavePlan`` used by DefaultSavePlanner. On non-coordinator ranks, this function ignores tensors and non-tensor objects, only producing writes for ShardedTensor objects. On the coordinator rank, produce writes for all values. """ requests = [] for fqn, obj in state_dict.items(): # Since DTensor supports submesh, adding extra check to ensure _create_write_items() # gets called only when the current rank is part of the mesh for the corresponding DTensor. if isinstance(obj, DTensor): if obj.device_mesh.get_coordinate() is not None: requests += _create_write_items(fqn, obj) elif isinstance(obj, (torch.Tensor)) or is_coordinator: requests += _create_write_items(fqn, obj) return SavePlan(requests) def create_default_global_save_plan( all_plans: List[SavePlan], rewrite_index_hints: bool = True, ) -> Tuple[List[SavePlan], Metadata]: """ Create the global plan and metadata used by DefaultSavePlanner. Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans. The only global planning change is to update index hints in all ``MetadataIndex`` objects if ``rewrite_index_hints`` is True. """ md: Dict[str, STORAGE_TYPES] = {} new_plans = [] for plan in all_plans: new_items = [] for item in plan.items: if not item.type == WriteItemType.SHARD: assert item.index.fqn not in md if item.type == WriteItemType.BYTE_IO: md[item.index.fqn] = BytesStorageMetadata() new_items.append(item) else: assert item.tensor_data is not None tensor_md = cast( TensorStorageMetadata, md.setdefault( item.index.fqn, TensorStorageMetadata( properties=item.tensor_data.properties, size=item.tensor_data.size, chunks=[], ), ), ) new_item = item if rewrite_index_hints: new_index = dataclasses.replace( item.index, index=len(tensor_md.chunks) ) new_item = dataclasses.replace(item, index=new_index) new_items.append(new_item) assert ( item.tensor_data.chunk is not None ), f""" Cannot create MD for tensor without bounds. FQN: {item.index.fqn} """ tensor_md.chunks.append(item.tensor_data.chunk) new_plans.append(dataclasses.replace(plan, items=new_items)) return (new_plans, Metadata(md)) def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata: """Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``.""" plan = _create_default_metadata_only_plan(state_dict) _, md = create_default_global_save_plan([plan]) return md def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool: """Check if two boxes overlap. Tuples are (offset, lengths).""" # For each dim of each shard, check if one shard resides on the other # end of second shard with respect to that dim. As an example for a 2D # shard, we would check if one shard is above or on the left of the # other shard. ndims = len(box0.offsets) for i in range(ndims): if box0.offsets[i] >= box1.offsets[i] + box1.sizes[i]: return False if box1.offsets[i] >= box0.offsets[i] + box0.sizes[i]: return False return True def _check_box_bounds( outer_box_size: torch.Size, inner_box: ChunkStorageMetadata ) -> bool: for i in range(len(outer_box_size)): if inner_box.offsets[i] < 0: return False if inner_box.sizes[i] < 0: return False if inner_box.offsets[i] + inner_box.sizes[i] > outer_box_size[i]: return False return True def _validate_global_plan(global_plan: List[SavePlan], metadata: Metadata) -> bool: all_good = True for key, value in metadata.state_dict_metadata.items(): if isinstance(value, BytesStorageMetadata): continue if len(value.size) == 0: continue chunks_volume = 0 for chunk_idx, chunk0 in enumerate(value.chunks): # Compute the volume if not _check_box_bounds(value.size, chunk0): logger.warning( """ key:%s has out of bounds chunk: tensor-size:%s chunk: %s """, key, value.size, chunk0, ) all_good = False chunks_volume += reduce(operator.mul, chunk0.sizes, 1) # Check for overlap for chunk1 in value.chunks[chunk_idx + 1 :]: if _check_box_overlap(chunk0, chunk1): logger.warning( "key:%s has overlapping chunks: %s %s", key, chunk0, chunk1 ) all_good = False # Check whether combined chunk cover the whole tensor tensor_volume = reduce(operator.mul, value.size, 1) if chunks_volume != tensor_volume: logger.warning( """ key:%s invalid fill tensor-volume: %s chunks-volume: %s """, key, tensor_volume, chunks_volume, ) all_good = False return all_good

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources