Skip to content

Commit 8fc19e4

Browse files
committed
Do not restore constant_buffer, and use bytes instead of Buffer
Pull Request resolved: #15802 All constants are serialized in the segment (none in the Program). This PR: 1. Places constant data into the PTEFile class instead of restoring it into the Program. 2. Use List[bytes] instead of List[Buffer] for constant and mutable data. Buffer was initially used to maintain alignment; now, constants are serialized with alignment in the segment, and Buffer is not required. Note on the non-const tensor placeholder: No constants = no placeholder Constants = placeholder Update tests. After this, we can mark 'constant_buffer' as deprecated, as it's no longer being used in deserialization or emitter. Differential Revision: [D86913756](https://our.internmc.facebook.com/intern/diff/D86913756/) ghstack-source-id: 322868586
1 parent 082e69f commit 8fc19e4

File tree

9 files changed

+154
-126
lines changed

9 files changed

+154
-126
lines changed

devtools/bundled_program/test/test_bundle_data.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,16 @@ def test_bundled_program(self) -> None:
7070
method_test_case.expected_outputs,
7171
)
7272

73+
emitter_output = executorch_program._emitter_output
7374
self.assertEqual(
7475
bundled_program.serialize_to_schema().program,
7576
bytes(
7677
_serialize_pte_binary(
77-
pte_file=_PTEFile(program=executorch_program.executorch_program)
78+
pte_file=_PTEFile(
79+
program=executorch_program.executorch_program,
80+
constant_data=emitter_output.constant_data,
81+
mutable_data=emitter_output.mutable_data,
82+
)
7883
)
7984
),
8085
)
@@ -116,10 +121,18 @@ def test_bundled_program_from_pte(self) -> None:
116121
bundled_program_ioset.expected_outputs,
117122
method_test_case.expected_outputs,
118123
)
119-
124+
emitter_output = executorch_program._emitter_output
120125
self.assertEqual(
121126
bundled_program.serialize_to_schema().program,
122-
executorch_program.buffer,
127+
bytes(
128+
_serialize_pte_binary(
129+
pte_file=_PTEFile(
130+
program=executorch_program.executorch_program,
131+
constant_data=emitter_output.constant_data,
132+
mutable_data=emitter_output.mutable_data,
133+
)
134+
)
135+
),
123136
)
124137

125138
def test_bundled_miss_methods(self) -> None:

exir/_serialize/_program.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import math
1212
import re
1313

14-
from dataclasses import dataclass
14+
from dataclasses import dataclass, field
1515
from typing import ClassVar, Dict, List, Literal, Optional, Sequence, Tuple
1616

1717
from executorch.exir._serialize._cord import Cord
@@ -33,7 +33,6 @@
3333
from executorch.exir.schema import (
3434
BackendDelegateDataReference,
3535
BackendDelegateInlineData,
36-
Buffer,
3736
DataLocation,
3837
DataSegment,
3938
NamedData,
@@ -56,9 +55,10 @@ class PTEFile:
5655
"""
5756

5857
program: Program
59-
# TODO(lfq): add constant data (currently restored in the program)
60-
# TODO(lfq): update this to List[bytes]
61-
mutable_data: Optional[List[Buffer]] = None
58+
# Placeholder for non-const tensors.
59+
constant_data: List[bytes] = field(default_factory=lambda: [b""])
60+
# Placeholder for non-const tensors.
61+
mutable_data: List[bytes] = field(default_factory=lambda: [b""])
6262
named_data: Optional[NamedDataStoreOutput] = None
6363

6464

@@ -346,14 +346,14 @@ def _extract_delegate_segments(
346346

347347

348348
def _extract_constant_segment(
349-
constant_buffer: List[Buffer],
349+
constant_buffer: List[bytes],
350350
tensor_alignment: Optional[int] = None,
351351
) -> Tuple[Cord, List[int]]:
352352
"""Copies the tensors from the provided list into a Cord and tracks the offsets
353353
of each tensor.
354354
355355
Args:
356-
constant_buffer: list of Buffers from which to extract constants from. Not modified.
356+
constant_buffer: list of bytes from which to extract constants from. Not modified.
357357
tensor_alignment: Alignment in bytes. Each tensor in the cord will be padded to align
358358
with this value. Defaults to ALIGNMENT.
359359
@@ -365,8 +365,8 @@ def _extract_constant_segment(
365365
current_offset: int = 0
366366
for i in range(len(constant_buffer)):
367367
buffer = constant_buffer[i]
368-
constant_segment_data.append(buffer.storage)
369-
buffer_length = len(buffer.storage)
368+
constant_segment_data.append(buffer)
369+
buffer_length = len(buffer)
370370
pad_length = (
371371
padding_required(buffer_length, tensor_alignment)
372372
if tensor_alignment is not None
@@ -460,25 +460,24 @@ def serialize_pte_binary(
460460
# This may be constant data, delegate data or named data.
461461
segments: List[AlignedData] = []
462462

463-
constant_segment_data, constant_segment_offsets = _extract_constant_segment(
464-
program.constant_buffer, tensor_alignment=constant_tensor_alignment
465-
)
466-
467-
# If there are no constants, len(constant_segment_data) = 0. However, there may
468-
# be non-constants, in which case len(constant_segment_offsets) = 1, containing
469-
# the placeholder value 0. Ensure the placeholder value is put into
470-
# program.constant_segment.offsets.
471-
if len(constant_segment_offsets) > 0:
472-
# Update program.constant_segment with constant subsegment offset information.
473-
program.constant_segment = SubsegmentOffsets(
474-
segment_index=len(segments), offsets=constant_segment_offsets
463+
if len(pte_file.constant_data) > 1:
464+
constant_segment_data, constant_segment_offsets = _extract_constant_segment(
465+
pte_file.constant_data, tensor_alignment=constant_tensor_alignment
475466
)
476-
# Clear the constant buffer, as constant data will be stored in segments.
477-
program.constant_buffer = []
478-
# Add to the aggregate segments cord.
479-
segments.append(AlignedData(constant_segment_data))
480467

481-
if pte_file.mutable_data is not None:
468+
# If there are no constants, len(constant_segment_data) = 0. However, there may
469+
# be non-constants, in which case len(constant_segment_offsets) = 1, containing
470+
# the placeholder value 0. Ensure the placeholder value is put into
471+
# program.constant_segment.offsets.
472+
if len(constant_segment_offsets) > 0:
473+
# Update program.constant_segment with constant subsegment offset information.
474+
program.constant_segment = SubsegmentOffsets(
475+
segment_index=len(segments), offsets=constant_segment_offsets
476+
)
477+
# Add to the aggregate segments cord.
478+
segments.append(AlignedData(constant_segment_data))
479+
480+
if len(pte_file.mutable_data) > 1:
482481
mutable_segment_data, mutable_segment_offsets = _extract_constant_segment(
483482
pte_file.mutable_data,
484483
tensor_alignment=None, # data is copied at Method load so no need to align.
@@ -637,8 +636,9 @@ def _restore_segments(program: Program, segment_data: bytes) -> PTEFile:
637636
)
638637

639638
# Replace constants from constant_segment into constant_buffer.
639+
constant_data = None
640640
if program.constant_segment and len(program.constant_segment.offsets) > 0:
641-
constant_buffers: List[Buffer] = []
641+
constant_buffers: List[bytes] = []
642642
constant_segment = segments[program.constant_segment.segment_index]
643643
for i in range(len(program.constant_segment.offsets)):
644644
start_offset = program.constant_segment.offsets[i]
@@ -649,17 +649,15 @@ def _restore_segments(program: Program, segment_data: bytes) -> PTEFile:
649649
if i < len(program.constant_segment.offsets) - 1
650650
else len(constant_segment)
651651
)
652-
constant_buffers.append(
653-
Buffer(storage=constant_segment[start_offset:end_offset])
654-
)
655-
program.constant_buffer = constant_buffers
652+
constant_buffers.append(constant_segment[start_offset:end_offset])
653+
constant_data = constant_buffers
656654
program.constant_segment.segment_index = 0
657655
program.constant_segment.offsets = []
658656

659657
# Extract mutable segments.
660658
mutable_data = None
661659
if program.mutable_data_segments and len(program.mutable_data_segments.offsets) > 0:
662-
mutable_buffers: List[Buffer] = []
660+
mutable_buffers: List[bytes] = []
663661
mutable_segment = segments[program.mutable_segment.segment_index]
664662
for i in range(len(program.mutable_segments.offsets)):
665663
start_offset = program.mutable_segment.offsets[i]
@@ -670,9 +668,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> PTEFile:
670668
if i < len(program.mutable_segment.offsets) - 1
671669
else len(mutable_segment)
672670
)
673-
mutable_buffers.append(
674-
Buffer(storage=mutable_segment[start_offset:end_offset])
675-
)
671+
mutable_buffers.append(mutable_segment[start_offset:end_offset])
676672
mutable_data = mutable_buffers
677673
program.mutable_segment.segment_index = 0
678674
program.mutable_segment.offsets = []
@@ -699,7 +695,12 @@ def _restore_segments(program: Program, segment_data: bytes) -> PTEFile:
699695
named_data = named_data_store.get_named_data_store_output()
700696
program.named_data = []
701697
program.segments = []
702-
return PTEFile(program=program, mutable_data=mutable_data, named_data=named_data)
698+
return PTEFile(
699+
program=program,
700+
constant_data=constant_data,
701+
mutable_data=mutable_data,
702+
named_data=named_data,
703+
)
703704

704705

705706
def deserialize_pte_binary(program_data: bytes) -> PTEFile:

exir/_serialize/_serialize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def serialize_for_executorch(
4949
pte: Cord = serialize_pte_binary(
5050
pte_file=PTEFile(
5151
program=emitter_output.program,
52+
constant_data=emitter_output.constant_data,
5253
mutable_data=emitter_output.mutable_data,
5354
named_data=pte_named_data,
5455
),

exir/_serialize/test/test_program.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,6 @@
4949
CONSTANT_TENSOR_ALIGNMENT: int = 16
5050

5151

52-
def add_constant_data(program: Program, blobs: Sequence[bytes]) -> None:
53-
"""Adds the provided constant data blobs to the program."""
54-
for blob in blobs:
55-
program.constant_buffer.append(Buffer(storage=blob))
56-
57-
5852
def add_delegate_data(
5953
program: Program, plan: ExecutionPlan, blobs: Sequence[bytes]
6054
) -> None:
@@ -169,12 +163,14 @@ def constant_segment_with_tensor_alignment(
169163
self.gen_blob_data(constant_tensor_alignment, b"\x30\x33\x03"),
170164
self.gen_blob_data(constant_tensor_alignment + 1, b"\x40\x44\x04"),
171165
)
172-
add_constant_data(program, blobs)
173166

174167
# Extract blobs into constant segment during serialization.
175168
pte_data = bytes(
176169
serialize_pte_binary(
177-
PTEFile(program=program),
170+
PTEFile(
171+
program=program,
172+
constant_data=blobs,
173+
),
178174
segment_alignment=SEGMENT_ALIGNMENT,
179175
constant_tensor_alignment=constant_tensor_alignment,
180176
)
@@ -289,9 +285,7 @@ def constant_segment_with_tensor_alignment(
289285
# during serialization.
290286
self.assertEqual(deserialized.program.execution_plan, program.execution_plan)
291287
# Number of constant tensors should be the same.
292-
self.assertEqual(
293-
len(deserialized.program.constant_buffer), len(program.constant_buffer)
294-
)
288+
self.assertEqual(len(deserialized.constant_data), len(blobs))
295289
self.assertEqual(deserialized.mutable_data, None)
296290
self.assertEqual(deserialized.named_data, None)
297291

@@ -647,12 +641,13 @@ def test_round_trip_with_segments(self) -> None:
647641

648642
def test_no_constants(self) -> None:
649643
program = get_test_program()
650-
# Insert placeholder for non-const tensors.
651-
add_constant_data(program, [b""])
652644

653645
pte_data = bytes(
654646
serialize_pte_binary(
655-
PTEFile(program=program),
647+
PTEFile(
648+
program=program,
649+
constant_data=[b""], # placeholder for non-const tensors.
650+
),
656651
extract_delegate_segments=True,
657652
segment_alignment=SEGMENT_ALIGNMENT,
658653
constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT,
@@ -667,10 +662,9 @@ def test_no_constants(self) -> None:
667662
# Constant buffer should be empty.
668663
self.assertEqual(len(flatbuffer_program.constant_buffer), 0)
669664

670-
# Constant segment should contain the placeholder.
665+
# Constant segment also empty
671666
self.assertEqual(flatbuffer_program.constant_segment.segment_index, 0)
672-
self.assertEqual(len(flatbuffer_program.constant_segment.offsets), 1)
673-
self.assertEqual(flatbuffer_program.constant_segment.offsets[0], 0)
667+
self.assertEqual(len(flatbuffer_program.constant_segment.offsets), 0)
674668

675669
def test_unused_inline_delegate_blobs_with_segments(self) -> None:
676670
# Create a program with some delegate data blobs.
@@ -736,7 +730,6 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
736730
self.gen_blob_data(SEGMENT_ALIGNMENT // 2, b"\x30\x33\x03"),
737731
self.gen_blob_data(SEGMENT_ALIGNMENT + 1, b"\x40\x44\x04"),
738732
)
739-
add_constant_data(program, constant_blobs)
740733
add_delegate_data(program, program.execution_plan[0], delegate_blobs)
741734

742735
# Create named data segment.
@@ -755,7 +748,9 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
755748
# Extract the blobs into segments during serialization.
756749
pte_data = bytes(
757750
serialize_pte_binary(
758-
PTEFile(program=program, named_data=named_data),
751+
PTEFile(
752+
program=program, constant_data=constant_blobs, named_data=named_data
753+
),
759754
extract_delegate_segments=True,
760755
segment_alignment=SEGMENT_ALIGNMENT,
761756
constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT,
@@ -933,9 +928,7 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
933928
# during serialization.
934929
self.assertEqual(deserialized.program.execution_plan, program.execution_plan)
935930
# Number of constant tensors should be the same.
936-
self.assertEqual(
937-
len(deserialized.program.constant_buffer), len(program.constant_buffer)
938-
)
931+
self.assertEqual(len(deserialized.constant_data), len(constant_blobs))
939932
self.assertEqual(deserialized.mutable_data, None)
940933
self._check_named_data_store_output(deserialized.named_data, named_data)
941934

exir/emit/_emit_program.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8-
from dataclasses import dataclass
8+
from dataclasses import dataclass, field
99
from typing import Any, Dict, List, Optional, Union
1010

1111
import torch
@@ -50,14 +50,17 @@ class EmitterOutput:
5050
# generated by each instruction.
5151
instruction_id_to_num_outs_map: Dict[str, Dict[int, int]]
5252

53-
mutable_data: Optional[List[Buffer]]
53+
# Constant data stored in the PTE file.
54+
constant_data: List[bytes] = field(default_factory=list)
55+
# Mutable data stored in the PTE file.
56+
mutable_data: List[bytes] = field(default_factory=list)
5457

5558
# Constants are optionally stored in external files.
5659
# Aggregate unique external constants into one buffer.
57-
external_constant_buffer: List[bytes]
60+
external_constant_buffer: Optional[List[bytes]] = None
5861
# Each constant_tag groups a set of constants together.
5962
# {constant_tag: {fqn: index into external_constant_buffer}}
60-
external_constant_map: Optional[Dict[str, Dict[str, int]]]
63+
external_constant_map: Optional[Dict[str, Dict[str, int]]] = None
6164

6265

6366
def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule:
@@ -198,18 +201,23 @@ def emit_program(
198201
program=Program(
199202
version=EXECUTORCH_SCHEMA_VERSION,
200203
execution_plan=plans,
201-
constant_buffer=program_state.constant_buffer,
204+
constant_buffer=[], # Do not add constants here anymore.
202205
backend_delegate_data=program_state.backend_delegate_data,
203206
# Segments may be added at serialization time.
204207
segments=[],
205208
# Subsegment offsets may be added at serialization time.
206209
constant_segment=SubsegmentOffsets(segment_index=0, offsets=[]),
207210
mutable_data_segments=None, # Will be filled in during serialization
208211
),
212+
constant_data=(
213+
program_state.constant_buffer
214+
if len(program_state.constant_buffer) > 1
215+
else []
216+
),
209217
mutable_data=(
210218
program_state.mutable_buffer
211219
if len(program_state.mutable_buffer) > 1
212-
else None
220+
else []
213221
),
214222
external_constant_buffer=program_state.external_constant_buffer,
215223
external_constant_map=program_state.external_constant_map,

0 commit comments

Comments
 (0)