Skip to content

Commit

Permalink
Enable custom device support in fsdp checkpoint (pytorch#107289)
Browse files Browse the repository at this point in the history
Fixes pytorch#104390
Enable custom device(privateuse1 backend) support in checkpointing by a dynamic abstract device module.
Pull Request resolved: pytorch#107289
Approved by: https://github.com/wz337
  • Loading branch information
dilililiwhy authored and pytorchmergebot committed Aug 25, 2023
1 parent b18e1b6 commit ff37f60
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 35 deletions.
11 changes: 11 additions & 0 deletions torch/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copyreg
import functools
import sys
import traceback
import warnings
Expand Down Expand Up @@ -839,3 +840,13 @@ def classproperty(func):
# Whether we are compiling with torch.compile or not
def is_compiling():
return False


@functools.lru_cache(2)
def _get_device_module(device_type: str):
device_module = getattr(torch, device_type, None)
if device_module is None:
raise RuntimeError(
f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'."
)
return device_module
19 changes: 11 additions & 8 deletions torch/distributed/checkpoint/_fsspec_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
from fsspec.core import url_to_fs
from torch import Tensor
from torch._utils import _get_device_module

from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
Expand Down Expand Up @@ -114,7 +115,7 @@ class _OverlappingCpuLoader(_TensorLoader):
def __init__(
self,
resolve_fun: Callable,
stream: Union[None, io.RawIOBase, torch._C._CudaStreamBase] = None,
stream: Union[None, io.RawIOBase, torch.Stream] = None,
inflight_threshhold: int = 1_000_000,
):
self.resolve_fun = resolve_fun
Expand All @@ -124,9 +125,11 @@ def __init__(
self.current_items: collections.deque = collections.deque()
self.idx = 0
self.started = False
self.stream = stream or torch.cuda.current_stream()
if self.stream != torch.cuda.current_stream():
self.stream.wait_stream(torch.cuda.current_stream())
self.device_type = stream.device_type if stream else torch.device("cuda").type
self.device_module = _get_device_module(self.device_type)
self.stream = stream or self.device_module.current_stream()
if self.stream != self.device_module.current_stream():
self.stream.wait_stream(self.device_module.current_stream())

@property
def _done(self):
Expand All @@ -143,15 +146,15 @@ def _drain(self):
return drained

def _refill(self):
with torch.cuda.stream(self.stream):
with self.device_module.stream(self.stream):
while (
not self._done
and self.in_flight_data < self.inflight_threshhold
):
_, obj = self.items[self.idx]
self.idx += 1
tensor = self.resolve_fun(obj).detach()
if tensor.is_cuda:
if tensor.device.type == self.device_type:
tensor = tensor.to(device="cpu", non_blocking=True)
elif tensor.device == torch.device("cpu"):
if tensor.storage().size() != tensor.numel():
Expand Down Expand Up @@ -232,7 +235,7 @@ def _split_by_size_and_type(


def _write_item(
stream: Optional[Union[io.RawIOBase, torch._C._CudaStreamBase]],
stream: Optional[Union[io.RawIOBase, torch.Stream]],
data: Union[io.BytesIO, torch.Tensor],
write_item: WriteItem,
storage_key: str,
Expand Down Expand Up @@ -294,7 +297,7 @@ def _write_files_from_queue(
)

for tensor, write_item in loader.values():
assert not tensor.is_cuda
assert tensor.is_cpu
write_results.append(
_write_item(stream, tensor, write_item, storage_key)
)
Expand Down
7 changes: 4 additions & 3 deletions torch/distributed/checkpoint/_sharded_tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
STATE_DICT_ITEM,
)

from .utils import _element_wise_add
from .utils import _element_wise_add, _normalize_device_info


# TODO: We need to refactor this code.
Expand Down Expand Up @@ -83,6 +83,7 @@ def rewrite_dict(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:

st_meta: ShardedTensorMetadata = copy.deepcopy(value.metadata())
other_rank = 0 if dist.get_rank() > 0 else 1
device_info = _normalize_device_info(inner_shard.tensor.device.type, 0)

# Remove the outer ST shard the inner ST covers
for i, shard_md in enumerate(st_meta.shards_metadata):
Expand All @@ -92,7 +93,7 @@ def rewrite_dict(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:

# Attribute other rank for the other shards
for shard_md in st_meta.shards_metadata:
shard_md.placement = _remote_device(f"rank:{other_rank}/cuda:0")
shard_md.placement = _remote_device(f"rank:{other_rank}/{device_info}")

# Add other inner shards from the inner tensor
for inner_md in inner_st.metadata().shards_metadata:
Expand All @@ -104,7 +105,7 @@ def rewrite_dict(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
inner_md.shard_offsets,
),
shard_sizes=inner_md.shard_sizes,
placement=f"rank:{other_rank}/cuda:0",
placement=f"rank:{other_rank}/{device_info}",
)
)

Expand Down
15 changes: 9 additions & 6 deletions torch/distributed/checkpoint/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .utils import _create_file_view

from torch.distributed._shard._utils import narrow_tensor_by_index
from torch._utils import _get_device_module

__all__ = [
"FileSystemWriter",
Expand Down Expand Up @@ -126,9 +127,11 @@ def __init__(self, resolve_fun, stream=None, inflight_threshhold=1_000_000):
self.current_items: collections.deque = collections.deque()
self.idx = 0
self.started = False
self.stream = stream or torch.cuda.current_stream()
if self.stream != torch.cuda.current_stream():
self.stream.wait_stream(torch.cuda.current_stream())
self.device_type = stream.device_type if stream else torch.device("cuda").type
self.device_module = _get_device_module(self.device_type)
self.stream = stream or self.device_module.current_stream()
if self.stream != self.device_module.current_stream():
self.stream.wait_stream(self.device_module.current_stream())

@property
def _done(self):
Expand All @@ -145,15 +148,15 @@ def _drain(self):
return drained

def _refill(self):
with torch.cuda.stream(self.stream):
with self.device_module.stream(self.stream):
while (
not self._done
and self.in_flight_data < self.inflight_threshhold
):
_, obj = self.items[self.idx]
self.idx += 1
tensor = self.resolve_fun(obj).detach()
if tensor.is_cuda:
if tensor.device.type == self.device_type:
tensor = tensor.to(device="cpu", non_blocking=True)
elif tensor.device == torch.device("cpu"):
if tensor.storage().size() != tensor.numel():
Expand Down Expand Up @@ -292,7 +295,7 @@ def _write_files_from_queue(
)

for tensor, write_item in loader.values():
assert not tensor.is_cuda
assert tensor.is_cpu
write_results.append(
_write_item(stream, tensor, write_item, storage_key)
)
Expand Down
41 changes: 24 additions & 17 deletions torch/distributed/checkpoint/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@
from torch.distributed.checkpoint.utils import (
_element_wise_add,
_element_wise_sub,
_normalize_device_info
)

from torch._utils import _get_device_module

STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]


Expand All @@ -49,23 +52,27 @@
]


def _gen_rank_device(global_rank: int) -> str:
if torch.cuda.is_available():
return f"cuda:{global_rank % torch.cuda.device_count()}"
def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str:
if device_type == "cpu":
return "cpu"
device_module = _get_device_module(device_type)
if device_module.is_available():
return _normalize_device_info(device_type, global_rank % device_module.device_count())
return "cpu"


def _create_colwise_spec(
pg: Optional[dist.ProcessGroup] = None,
) -> ChunkShardingSpec:
pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type
if pg is None:
placements = [
f"rank:{idx}/{_gen_rank_device(idx)}"
f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}"
for idx in range(dist.get_world_size())
]
else:
placements = [
f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx))}"
f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}"
for idx in range(pg.size())
]
return ChunkShardingSpec(
Expand All @@ -92,14 +99,14 @@ def _is_nested_tensor(val: torch.Tensor) -> bool:
return False


def _alloc_tensor(props: TensorProperties, size: Sequence[int]) -> torch.Tensor:
def _alloc_tensor(props: TensorProperties, size: Sequence[int], device_type: str = "cuda") -> torch.Tensor:
return torch.empty(
size=size,
dtype=props.dtype,
layout=props.layout,
requires_grad=props.requires_grad,
pin_memory=props.pin_memory,
device=cast(torch.device, torch.cuda.current_device()),
device=cast(torch.device, _get_device_module(device_type).current_device()),
)


Expand Down Expand Up @@ -255,15 +262,15 @@ def load_sharded_optimizer_state_dict(
metadata = storage_reader.read_metadata()

layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict)
dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type
device_module = _get_device_module(dp_pg_device_type)

if dp_pg is None:
sharding_spec = ChunkShardingSpec(
dim=0,
placements=[
f"rank:{i}/cuda:{i % torch.cuda.device_count()}"
for i in range(dist.get_world_size())
],
)
placements = []
for i in range(dist.get_world_size()):
device_info = _normalize_device_info(dp_pg_device_type, i % device_module.device_count())
placements.append(f"rank:{i}/{device_info}")
sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type]
else:
sharding_spec = _create_colwise_spec(dp_pg)

Expand All @@ -282,10 +289,10 @@ def load_sharded_optimizer_state_dict(

# value: TensorStorageMetadata
if value.size.numel() == 1:
state_dict[key] = _alloc_tensor(value.properties, value.size)
state_dict[key] = _alloc_tensor(value.properties, value.size, dp_pg_device_type)
elif dp_pg is None:
state_dict[key] = _shard_tensor(
_alloc_tensor(value.properties, value.size), sharding_spec
_alloc_tensor(value.properties, value.size, dp_pg_device_type), sharding_spec
)
else:
spec_key = key_path[2]
Expand All @@ -305,7 +312,7 @@ def load_sharded_optimizer_state_dict(
local_shards.append(
Shard(
tensor=_alloc_tensor(
value.properties, shard_md.shard_sizes
value.properties, shard_md.shard_sizes, dp_pg_device_type
),
metadata=shard_md,
)
Expand Down
11 changes: 11 additions & 0 deletions torch/distributed/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> List[int]:
def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]:
return [i_a - i_b for i_a, i_b in zip(a, b)]


class _ReaderView(io.IOBase):
def __init__(self, base_stream: io.IOBase, offset: int, len: int):
super().__init__()
Expand Down Expand Up @@ -386,6 +387,16 @@ def readinto(self, b):
def read(self, size=-1):
return self.base_stream.read(size)


def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase:
# FIXME (kumpera) torch.load fails if we wrap with io.BufferedReader
return _ReaderView(file, offset, length)


def _normalize_device_info(device_type: str, device_id: int) -> str:
"""
Device info normalization.
"""
if device_type == "cpu":
return "cpu"
return f"{device_type}:{device_id}"
2 changes: 1 addition & 1 deletion torch/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def to_map(obj):
with device_mod.stream(stream):
output = obj.to(target_device)
# synchronize with the copy stream
with torch.cuda.device(target_device.index):
with device_mod.device(target_device.index):
current_stream = device_mod.current_stream()
# Sync the current stream with the copy stream
current_stream.wait_stream(stream)
Expand Down

0 comments on commit ff37f60

Please sign in to comment.