Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
dbc3890
[IR] Improve external data handling
justinchuby Jan 17, 2025
3c9d315
Also load into memory
justinchuby Jan 17, 2025
fe696f5
external
justinchuby Jan 17, 2025
63553c5
f
justinchuby Jan 17, 2025
03b90ca
Simplify logic
justinchuby Jan 17, 2025
ad169e1
format
justinchuby Jan 17, 2025
903b206
comments
justinchuby Jan 17, 2025
ffe270f
docs
justinchuby Jan 17, 2025
e8726fc
sim
justinchuby Jan 17, 2025
68c8e41
Apply suggestions from code review
justinchuby Jan 17, 2025
d78f407
fix samefile call
justinchuby Jan 18, 2025
81458ec
wip
justinchuby Jan 18, 2025
273919d
address comments
justinchuby Jan 21, 2025
67d8ab9
Merge branch 'main' into justinchu/ir-save
justinchuby Jan 21, 2025
357c41b
test
justinchuby Jan 21, 2025
a0a5b58
test
justinchuby Jan 21, 2025
5e5a59c
Merge branch 'main' into justinchu/ir-save
justinchuby Jan 21, 2025
f601264
test
justinchuby Jan 21, 2025
e6fc2fe
update
justinchuby Jan 22, 2025
7dfdfdf
wip
justinchuby Jan 22, 2025
f5e0724
Handle right
justinchuby Jan 22, 2025
06f2eeb
polyfill
justinchuby Jan 22, 2025
458e0ab
Rename
justinchuby Jan 22, 2025
074afa9
Rename functions and expose external data module
justinchuby Jan 22, 2025
e37f3ba
Update onnxscript/ir/_polyfill.py
justinchuby Jan 22, 2025
b9f8e80
name
justinchuby Jan 22, 2025
58049f4
rename
justinchuby Jan 22, 2025
319e5cc
typing
justinchuby Jan 22, 2025
5dee1cd
mypy
justinchuby Jan 22, 2025
cb25b6e
Hashable
justinchuby Jan 22, 2025
b4f8c8c
naming
justinchuby Jan 22, 2025
33a8345
sort
justinchuby Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 3 additions & 20 deletions onnxscript/_framework_apis/torch_2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from onnxscript import ir, optimizer, version_converter
from onnxscript.function_libs.torch_lib import registration
from onnxscript.ir import _external_data


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -68,32 +67,16 @@ def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike
"""Save the model with external data. The model is unchanged after saving."""

# TODO(#1835): Decide if we want to externalize large attributes as well
initializer_values = tuple(model.graph.initializers.values())
tensors = [v.const_value for v in initializer_values]
for tensor in tensors:
if tensor is None:
for value in model.graph.initializers.values():
if value.const_value is None:
raise ValueError(
"The model contains uninitialized initializer values. "
"Please make sure all initializer values are initialized."
)
destination_path = pathlib.Path(model_path)
base_dir = destination_path.parent
data_path = f"{destination_path.name}.data"

external_tensors = _external_data.convert_tensors_to_external(
tensors, # type: ignore[arg-type]
base_dir,
data_path,
)

# Replace the initializer values with external tensors and save the model
for initializer, external_tensor in zip(initializer_values, external_tensors):
initializer.const_value = external_tensor
ir.save(model, model_path)

# Restore the original initializer values so the model is unchanged
for initializer, tensor in zip(initializer_values, tensors):
initializer.const_value = tensor
ir.save(model, model_path, external_data=data_path)


def get_torchlib_ops() -> list[_OnnxFunctionMeta]:
Expand Down
5 changes: 3 additions & 2 deletions onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
__all__ = [
# Modules
"serde",
"traversal",
"convenience",
"external_data",
# IR classes
"Tensor",
"ExternalTensor",
Expand Down Expand Up @@ -72,13 +74,12 @@
"tensor",
# Pass infrastructure
"passes",
"traversal",
# IO
"load",
"save",
]

from onnxscript.ir import convenience, passes, serde, traversal
from onnxscript.ir import convenience, external_data, passes, serde, traversal
from onnxscript.ir._convenience import tensor
from onnxscript.ir._core import (
Attr,
Expand Down
26 changes: 26 additions & 0 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
"_metadata_props",
"_offset",
"_shape",
"_valid",
"doc_string",
"name",
"raw",
Expand Down Expand Up @@ -568,6 +569,7 @@ def __init__(
self.raw: mmap.mmap | None = None
self._metadata_props = metadata_props
self._metadata: _metadata.MetadataStore | None = None
self._valid = True

@property
def base_dir(self) -> str | os.PathLike:
Expand Down Expand Up @@ -609,6 +611,7 @@ def shape(self) -> Shape:
return self._shape

def _load(self):
self._check_validity()
assert self._array is None, "Bug: The array should be loaded only once."
if self.size == 0:
# When the size is 0, mmap is impossible and meaningless
Expand Down Expand Up @@ -647,6 +650,7 @@ def _load(self):
self._array = self._array.reshape(shape)

def __array__(self, dtype: Any = None) -> np.ndarray:
self._check_validity()
if self._array is None:
self._load()
assert self._array is not None
Expand Down Expand Up @@ -675,6 +679,7 @@ def numpy(self) -> np.ndarray:

The data will be memory mapped into memory and will not taken up physical memory space.
"""
self._check_validity()
if self._array is None:
self._load()
assert self._array is not None
Expand All @@ -685,13 +690,34 @@ def tobytes(self) -> bytes:

This will load the tensor into memory.
"""
self._check_validity()
if self.raw is None:
self._load()
assert self.raw is not None
offset = self._offset or 0
length = self._length or self.nbytes
return self.raw[offset : offset + length]

def valid(self) -> bool:
"""Check if the tensor is valid.

The external tensor is valid if it has not been invalidated.
"""
return self._valid

def _check_validity(self) -> None:
if not self.valid():
raise ValueError(
f"The external tensor '{self!r}' is invalidated. The data may be corrupted or deleted."
)

def invalidate(self) -> None:
"""Invalidate the tensor.

The external tensor is invalidated when the data is known to be corrupted or deleted.
"""
self._valid = False

def release(self) -> None:
"""Delete all references to the memory buffer and close the memory-mapped file."""
self._array = None
Expand Down
Loading
Loading