Skip to content

Commit 179a155

Browse files
pytorchbotlucylq
andauthored
Introduce PTEFile class (#15864)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #15800 by @lucylq ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/lucylq/125/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/125/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/125/orig Differential Revision: [D86814175](https://our.internmc.facebook.com/intern/diff/D86814175/) @diff-train-skip-merge Co-authored-by: lucylq <[email protected]>
1 parent 47c08d9 commit 179a155

File tree

8 files changed

+227
-62
lines changed

8 files changed

+227
-62
lines changed

backends/cadence/runtime/runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def get_op_names(program: et_schema.Program, execution_plan_id: int = 0) -> set[
4545
op_names |= get_op_names(
4646
deserialize_pte_binary(
4747
program.backend_delegate_data[delegate.processed.index].data
48-
)
48+
).program
4949
)
5050
return op_names
5151

backends/qualcomm/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def dump_context_from_pte(pte_path) -> List[str]:
197197
with open(pte_path, "rb") as f:
198198
program_data = f.read()
199199

200-
program = deserialize_pte_binary(program_data)
200+
program = deserialize_pte_binary(program_data).program
201201

202202
ctx_path = os.path.dirname(pte_path)
203203
dummy_compiler_specs = generate_qnn_executorch_compiler_spec(

codegen/tools/gen_ops_def.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def get_operators(model_file: str) -> List[Operator]:
2323
print("Processing model file: ", model_file)
2424
with open(model_file, "rb") as f:
2525
flatbuffer = f.read()
26-
program = _deserialize_pte_binary(flatbuffer)
26+
program = _deserialize_pte_binary(flatbuffer).program
2727
print(f"Program loaded from model file: {model_file}")
2828
operators = program.execution_plan[0].operators
2929
return operators

examples/qualcomm/oss_scripts/llama/decoder_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def __init__( # noqa: C901
276276

277277
with open(pte_path, "rb") as f:
278278
program_data = f.read()
279-
program = deserialize_pte_binary(program_data)
279+
program = deserialize_pte_binary(program_data).program
280280

281281
# Retrieve vocab_size from get_metadata under static_llama that is passed to edge manager
282282
self.output_vocab_size = None

exir/_serialize/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88

99
from executorch.exir._serialize._program import (
1010
deserialize_pte_binary as _deserialize_pte_binary,
11+
PTEFile as _PTEFile,
1112
serialize_pte_binary as _serialize_pte_binary,
1213
)
1314

1415
# Internal APIs that should not be used outside of exir.
1516
__all__ = [
1617
"_deserialize_pte_binary",
1718
"_serialize_pte_binary",
19+
"_PTEFile",
1820
]

exir/_serialize/_program.py

Lines changed: 140 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
_program_flatbuffer_to_json,
2222
_program_json_to_flatbuffer,
2323
)
24-
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
24+
from executorch.exir._serialize._named_data_store import (
25+
NamedDataStore,
26+
NamedDataStoreOutput,
27+
)
2528

2629
from executorch.exir._serialize.data_serializer import DataEntry
2730

@@ -46,6 +49,19 @@
4649
_HEADER_BYTEORDER: Literal["little"] = "little"
4750

4851

52+
@dataclass
53+
class PTEFile:
54+
"""
55+
Wraps together the data required to serialize into a PTE file.
56+
"""
57+
58+
program: Program
59+
# TODO(lfq): add constant data (currently restored in the program)
60+
# TODO(lfq): update this to List[bytes]
61+
mutable_data: Optional[List[Buffer]] = None
62+
named_data: Optional[NamedDataStoreOutput] = None
63+
64+
4965
@dataclass
5066
class AlignedData:
5167
"""
@@ -575,7 +591,91 @@ def serialize_pte_binary(
575591
return pte_data
576592

577593

578-
def _restore_segments(program: Program, segment_data: bytes) -> Program:
594+
def _restore_delegates(program: Program, segments: List[bytes]) -> Program:
595+
"""Find and replace the Program's references to these segments, inlining
596+
the data.
597+
598+
Args:
599+
program: The Program holding non-inlined delegates. Modified in-place.
600+
segments: List of bytes containing the delegate data. Not modified.
601+
602+
Returns: The Program with delegates restored.
603+
"""
604+
for plan_index, plan in enumerate(program.execution_plan):
605+
for delegate_index, delegate in enumerate(plan.delegates):
606+
if delegate.processed.location == DataLocation.INLINE:
607+
continue
608+
assert delegate.processed.location == DataLocation.SEGMENT
609+
index = delegate.processed.index
610+
if index >= len(segments):
611+
raise ValueError(
612+
f"Plan {plan_index} delegate {delegate_index} "
613+
+ f"segment index {index} >= num segments {len(segments)}"
614+
)
615+
616+
data_index: int = len(program.backend_delegate_data)
617+
program.backend_delegate_data.append(
618+
BackendDelegateInlineData(data=segments[index])
619+
)
620+
delegate.processed = BackendDelegateDataReference(
621+
location=DataLocation.INLINE, index=data_index
622+
)
623+
return program
624+
625+
626+
def _restore_constant_segment(
627+
constant_segment: SubsegmentOffsets, segment_data: bytes
628+
) -> List[Buffer]:
629+
"""Convert constant and mutable tensors from a single byte-blob into a list of individual tensors.
630+
631+
Args:
632+
constant_segment: SubsegmentOffset with the offsets of each tensor.
633+
segment_data: byte data containing the tensors and padding. Not modified.
634+
635+
Returns:
636+
List[Buffer] containing each tensor in a separate object.
637+
"""
638+
buffers: List[Buffer] = []
639+
for i in range(len(constant_segment.offsets)):
640+
start_offset = constant_segment.offsets[i]
641+
# Note: this is the original end offset plus any padding between it and the next start offset
642+
end_offset = (
643+
constant_segment.offsets[i + 1]
644+
if i < len(constant_segment.offsets) - 1
645+
else len(segment_data)
646+
)
647+
buffers.append(Buffer(storage=segment_data[start_offset:end_offset]))
648+
return buffers
649+
650+
651+
def _restore_named_data(
652+
program: Program,
653+
segments: List[bytes],
654+
) -> NamedDataStoreOutput:
655+
"""Moves named data from `segments` and `program` into the
656+
NamedDataStoreOutput class.
657+
658+
Args:
659+
program: The Program holding named data references. Not modified.
660+
segments: The data containing the segments. Not modified.
661+
"""
662+
named_data_store = NamedDataStore()
663+
for entry in program.named_data:
664+
if entry.segment_index >= len(segments):
665+
raise ValueError(
666+
"Named data segment index "
667+
f"{entry.segment_index} >= num segments {len(segments)}"
668+
)
669+
named_data_store.add_named_data(
670+
key=entry.key,
671+
data=segments[entry.segment_index],
672+
alignment=1, # Deserialization does not preserve alignment.
673+
tensor_layout=None, # PTE file currently does not serialize this.
674+
)
675+
return named_data_store.get_named_data_store_output()
676+
677+
678+
def _restore_segments(program: Program, segment_data: bytes) -> PTEFile:
579679
"""Moves segments from `segment_data` into `program`.
580680
581681
This should recreate the original Program that the segments were extracted
@@ -589,7 +689,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
589689
the preceding data has been stripped off so that the first segment
590690
begins at offset zero.
591691
Returns:
592-
The Program with segments restored.
692+
PTEFile, containing the Program with delegate and constant segments restored, mutable data segment, and named data segment.
593693
"""
594694
# Extract the list of segment data blobs, which parallel program.segments.
595695
segments: List[bytes] = []
@@ -600,53 +700,51 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
600700
)
601701
segments.append(segment_data[segment.offset : segment.offset + segment.size])
602702

603-
# Find and replace the Program's references to these segments, inlining the
604-
# data.
605-
for plan_index, plan in enumerate(program.execution_plan):
606-
for delegate_index, delegate in enumerate(plan.delegates):
607-
if delegate.processed.location == DataLocation.INLINE:
608-
continue
609-
assert delegate.processed.location == DataLocation.SEGMENT
610-
index = delegate.processed.index
611-
if index >= len(segments):
612-
raise ValueError(
613-
f"Plan {plan_index} delegate {delegate_index} "
614-
+ f"segment index {index} >= num segments {len(segments)}"
615-
)
616-
617-
data_index: int = len(program.backend_delegate_data)
618-
program.backend_delegate_data.append(
619-
BackendDelegateInlineData(data=segments[index])
620-
)
621-
delegate.processed = BackendDelegateDataReference(
622-
location=DataLocation.INLINE, index=data_index
623-
)
703+
# Restore delegate segments that weren't inlined previously.
704+
program = _restore_delegates(program, segments)
624705

625706
# Replace constants from constant_segment into constant_buffer.
626707
if program.constant_segment and len(program.constant_segment.offsets) > 0:
627-
buffers: List[Buffer] = []
628-
constant_segment = segments[program.constant_segment.segment_index]
629-
for i in range(len(program.constant_segment.offsets)):
630-
start_offset = program.constant_segment.offsets[i]
631-
# Note: this is the original end offset plus any padding between
632-
# it and the next start offset.
633-
end_offset = (
634-
program.constant_segment.offsets[i + 1]
635-
if i < len(program.constant_segment.offsets) - 1
636-
else len(constant_segment)
708+
if program.constant_segment.segment_index >= len(segments):
709+
raise ValueError(
710+
f"Constant segment index {program.constant_segment.segment_index} >= num segments {len(segments)}"
637711
)
638-
buffers.append(Buffer(storage=constant_segment[start_offset:end_offset]))
639-
program.constant_buffer = buffers
712+
program.constant_buffer = _restore_constant_segment(
713+
program.constant_segment, segments[program.constant_segment.segment_index]
714+
)
640715
program.constant_segment.segment_index = 0
641716
program.constant_segment.offsets = []
642717

643-
# Clear out the segments list since the original Program didn't have one.
718+
# Extract mutable segments.
719+
mutable_data = None
720+
if program.mutable_data_segments and len(program.mutable_data_segments) > 0:
721+
if len(program.mutable_data_segments) > 1:
722+
raise ValueError("Can't handle more than 1 mutable data segment.")
723+
segment_index = program.mutable_data_segments[0].segment_index
724+
if segment_index >= len(segments):
725+
raise ValueError(
726+
f"Mutable data segment index {segment_index} >= num segments {len(segments)}"
727+
)
728+
mutable_data = _restore_constant_segment(
729+
program.mutable_data_segments[0],
730+
segments[segment_index],
731+
)
732+
program.mutable_data_segments = None
733+
734+
# Extract named data.
735+
named_data = None
736+
if program.named_data:
737+
named_data = _restore_named_data(program, segments)
738+
739+
# Clear named_data and segments, which are empty pre-serialization.
740+
program.named_data = []
644741
program.segments = []
645-
return program
646742

743+
return PTEFile(program=program, mutable_data=mutable_data, named_data=named_data)
647744

648-
def deserialize_pte_binary(program_data: bytes) -> Program:
649-
"""Returns a Program deserialized from the given runtime binary data."""
745+
746+
def deserialize_pte_binary(program_data: bytes) -> PTEFile:
747+
"""Returns a PTEFile deserialized from the given runtime binary data."""
650748
program_size = len(program_data)
651749
segment_base_offset = 0
652750

@@ -664,8 +762,8 @@ def deserialize_pte_binary(program_data: bytes) -> Program:
664762

665763
if segment_base_offset != 0:
666764
# Move segment data back into the Program.
667-
program = _restore_segments(
765+
return _restore_segments(
668766
program=program, segment_data=program_data[segment_base_offset:]
669767
)
670768

671-
return program
769+
return PTEFile(program=program, mutable_data=None, named_data=None)

0 commit comments

Comments
 (0)