Skip to content

Commit

Permalink
[IR] Improve external data handling (#2020)
Browse files Browse the repository at this point in the history
1. Add an `external_data` option to `ir.save`. This will save
initializers as external tensors. It is robust against data loss when
overwriting, and is idempotent when the current model does not contain
external tensors already referencing the same path.
1. Expose `ir.external_data` module as a public module users can use to
manipulate external data.
    1. It defines the following methods
		```py
		[
		    "set_base_dir",
		    "unload_from_model",
		    "load_to_model",
		    "convert_tensors_to_external",
		    "convert_tensors_from_external",
		]
		```
I renamed `to_external_data` to `unload_from_model` for clarity.
**Reviewers please let me know if the naming sounds good.**
1. Support setting a threshold `size_threshold_bytes` to control which
tensors are offloaded.
1. Simplified torch_apis logic by leveraging to updated `ir.save`
method.
1. Updated the to_external_data function to always load data to memory,
iff the tensor references an external data file that is being written
to. This simplifies the logic and avoids creating and managing temporary
files.
1. Implemented a polyfill of the `zip()` function's strict mode to
support Python<=3.9

> [!NOTE]
> We **do not** need to add external data options to `ir.load`. The
external data is always loaded lazily in the IR. If users want to
transfer the data to memory at loading, they can use
`ir.external_data.load_to_model()`.

## Example usage

```py
ir.save(model, "model.onnx", external_data="model.onnx.data")
# Can save many times
ir.save(model, "model_copy.onnx", external_data="model_copy.onnx.data")
```

---------

Co-authored-by: Copilot <[email protected]>
  • Loading branch information
justinchuby and Copilot authored Jan 22, 2025
1 parent 0447822 commit b8d3179
Show file tree
Hide file tree
Showing 9 changed files with 670 additions and 434 deletions.
23 changes: 3 additions & 20 deletions onnxscript/_framework_apis/torch_2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from onnxscript import ir, optimizer, version_converter
from onnxscript.function_libs.torch_lib import registration
from onnxscript.ir import _external_data


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -68,32 +67,16 @@ def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike
"""Save the model with external data. The model is unchanged after saving."""

# TODO(#1835): Decide if we want to externalize large attributes as well
initializer_values = tuple(model.graph.initializers.values())
tensors = [v.const_value for v in initializer_values]
for tensor in tensors:
if tensor is None:
for value in model.graph.initializers.values():
if value.const_value is None:
raise ValueError(
"The model contains uninitialized initializer values. "
"Please make sure all initializer values are initialized."
)
destination_path = pathlib.Path(model_path)
base_dir = destination_path.parent
data_path = f"{destination_path.name}.data"

external_tensors = _external_data.convert_tensors_to_external(
tensors, # type: ignore[arg-type]
base_dir,
data_path,
)

# Replace the initializer values with external tensors and save the model
for initializer, external_tensor in zip(initializer_values, external_tensors):
initializer.const_value = external_tensor
ir.save(model, model_path)

# Restore the original initializer values so the model is unchanged
for initializer, tensor in zip(initializer_values, tensors):
initializer.const_value = tensor
ir.save(model, model_path, external_data=data_path)


def get_torchlib_ops() -> list[_OnnxFunctionMeta]:
Expand Down
5 changes: 3 additions & 2 deletions onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
__all__ = [
# Modules
"serde",
"traversal",
"convenience",
"external_data",
# IR classes
"Tensor",
"ExternalTensor",
Expand Down Expand Up @@ -72,13 +74,12 @@
"tensor",
# Pass infrastructure
"passes",
"traversal",
# IO
"load",
"save",
]

from onnxscript.ir import convenience, passes, serde, traversal
from onnxscript.ir import convenience, external_data, passes, serde, traversal
from onnxscript.ir._convenience import tensor
from onnxscript.ir._core import (
Attr,
Expand Down
28 changes: 27 additions & 1 deletion onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
import sys
import textwrap
import typing
from collections.abc import Hashable
from typing import (
AbstractSet,
Any,
Collection,
Generic,
Hashable,
Iterable,
Iterator,
NamedTuple,
Expand Down Expand Up @@ -516,6 +516,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
"_metadata_props",
"_offset",
"_shape",
"_valid",
"doc_string",
"name",
"raw",
Expand Down Expand Up @@ -568,6 +569,7 @@ def __init__(
self.raw: mmap.mmap | None = None
self._metadata_props = metadata_props
self._metadata: _metadata.MetadataStore | None = None
self._valid = True

@property
def base_dir(self) -> str | os.PathLike:
Expand Down Expand Up @@ -609,6 +611,7 @@ def shape(self) -> Shape:
return self._shape

def _load(self):
self._check_validity()
assert self._array is None, "Bug: The array should be loaded only once."
if self.size == 0:
# When the size is 0, mmap is impossible and meaningless
Expand Down Expand Up @@ -647,6 +650,7 @@ def _load(self):
self._array = self._array.reshape(shape)

def __array__(self, dtype: Any = None) -> np.ndarray:
self._check_validity()
if self._array is None:
self._load()
assert self._array is not None
Expand Down Expand Up @@ -675,6 +679,7 @@ def numpy(self) -> np.ndarray:
The data will be memory mapped into memory and will not taken up physical memory space.
"""
self._check_validity()
if self._array is None:
self._load()
assert self._array is not None
Expand All @@ -685,13 +690,34 @@ def tobytes(self) -> bytes:
This will load the tensor into memory.
"""
self._check_validity()
if self.raw is None:
self._load()
assert self.raw is not None
offset = self._offset or 0
length = self._length or self.nbytes
return self.raw[offset : offset + length]

def valid(self) -> bool:
"""Check if the tensor is valid.
The external tensor is valid if it has not been invalidated.
"""
return self._valid

def _check_validity(self) -> None:
if not self.valid():
raise ValueError(
f"The external tensor '{self!r}' is invalidated. The data may be corrupted or deleted."
)

def invalidate(self) -> None:
"""Invalidate the tensor.
The external tensor is invalidated when the data is known to be corrupted or deleted.
"""
self._valid = False

def release(self) -> None:
"""Delete all references to the memory buffer and close the memory-mapped file."""
self._array = None
Expand Down
Loading

0 comments on commit b8d3179

Please sign in to comment.