2121 _program_flatbuffer_to_json ,
2222 _program_json_to_flatbuffer ,
2323)
24- from executorch .exir ._serialize ._named_data_store import NamedDataStoreOutput
24+ from executorch .exir ._serialize ._named_data_store import (
25+ NamedDataStore ,
26+ NamedDataStoreOutput ,
27+ )
2528
2629from executorch .exir ._serialize .data_serializer import DataEntry
2730
4649_HEADER_BYTEORDER : Literal ["little" ] = "little"
4750
4851
52+ @dataclass
53+ class PTEFile :
54+ """
55+ Wraps together the data required to serialize into a PTE file.
56+ """
57+
58+ program : Program
59+ # TODO(lfq): add constant data (currently restored in the program)
60+ # TODO(lfq): update this to List[bytes]
61+ mutable_data : Optional [List [Buffer ]] = None
62+ named_data : Optional [NamedDataStoreOutput ] = None
63+
64+
4965@dataclass
5066class AlignedData :
5167 """
@@ -575,7 +591,91 @@ def serialize_pte_binary(
575591 return pte_data
576592
577593
578- def _restore_segments (program : Program , segment_data : bytes ) -> Program :
594+ def _restore_delegates (program : Program , segments : List [bytes ]) -> Program :
595+ """Find and replace the Program's references to these segments, inlining
596+ the data.
597+
598+ Args:
599+ program: The Program holding non-inlined delegates. Modified in-place.
600+ segments: List of bytes containing the delegate data. Not modified.
601+
602+ Returns: The Program with delegates restored.
603+ """
604+ for plan_index , plan in enumerate (program .execution_plan ):
605+ for delegate_index , delegate in enumerate (plan .delegates ):
606+ if delegate .processed .location == DataLocation .INLINE :
607+ continue
608+ assert delegate .processed .location == DataLocation .SEGMENT
609+ index = delegate .processed .index
610+ if index >= len (segments ):
611+ raise ValueError (
612+ f"Plan { plan_index } delegate { delegate_index } "
613+ + f"segment index { index } >= num segments { len (segments )} "
614+ )
615+
616+ data_index : int = len (program .backend_delegate_data )
617+ program .backend_delegate_data .append (
618+ BackendDelegateInlineData (data = segments [index ])
619+ )
620+ delegate .processed = BackendDelegateDataReference (
621+ location = DataLocation .INLINE , index = data_index
622+ )
623+ return program
624+
625+
626+ def _restore_constant_segment (
627+ constant_segment : SubsegmentOffsets , segment_data : bytes
628+ ) -> List [Buffer ]:
629+ """Convert constant and mutable tensors from a single byte-blob into a list of individual tensors.
630+
631+ Args:
632+ constant_segment: SubsegmentOffset with the offsets of each tensor.
633+ segment_data: byte data containing the tensors and padding. Not modified.
634+
635+ Returns:
636+ List[Buffer] containing each tensor in a separate object.
637+ """
638+ buffers : List [Buffer ] = []
639+ for i in range (len (constant_segment .offsets )):
640+ start_offset = constant_segment .offsets [i ]
641+ # Note: this is the original end offset plus any padding between it and the next start offset
642+ end_offset = (
643+ constant_segment .offsets [i + 1 ]
644+ if i < len (constant_segment .offsets ) - 1
645+ else len (segment_data )
646+ )
647+ buffers .append (Buffer (storage = segment_data [start_offset :end_offset ]))
648+ return buffers
649+
650+
651+ def _restore_named_data (
652+ program : Program ,
653+ segments : List [bytes ],
654+ ) -> NamedDataStoreOutput :
655+ """Moves named data from `segments` and `program` into the
656+ NamedDataStoreOutput class.
657+
658+ Args:
659+ program: The Program holding named data references. Not modified.
660+ segments: The data containing the segments. Not modified.
661+ """
662+ named_data_store = NamedDataStore ()
663+ for entry in program .named_data :
664+ if entry .segment_index >= len (segments ):
665+ raise ValueError (
666+ "Named data segment index "
667+ f"{ entry .segment_index } >= num segments { len (segments )} "
668+ )
669+ named_data_store .add_named_data (
670+ key = entry .key ,
671+ data = segments [entry .segment_index ],
672+ alignment = 1 , # Deserialization does not preserve alignment.
673+ tensor_layout = None , # PTE file currently does not serialize this.
674+ )
675+ return named_data_store .get_named_data_store_output ()
676+
677+
678+ def _restore_segments (program : Program , segment_data : bytes ) -> PTEFile :
579679 """Moves segments from `segment_data` into `program`.
580680
581681 This should recreate the original Program that the segments were extracted
@@ -589,7 +689,7 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
589689 the preceding data has been stripped off so that the first segment
590690 begins at offset zero.
591691 Returns:
592- The Program with segments restored.
692+ PTEFile, containing the Program with delegate and constant segments restored, mutable data segment, and named data segment .
593693 """
594694 # Extract the list of segment data blobs, which parallel program.segments.
595695 segments : List [bytes ] = []
@@ -600,53 +700,51 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program:
600700 )
601701 segments .append (segment_data [segment .offset : segment .offset + segment .size ])
602702
603- # Find and replace the Program's references to these segments, inlining the
604- # data.
605- for plan_index , plan in enumerate (program .execution_plan ):
606- for delegate_index , delegate in enumerate (plan .delegates ):
607- if delegate .processed .location == DataLocation .INLINE :
608- continue
609- assert delegate .processed .location == DataLocation .SEGMENT
610- index = delegate .processed .index
611- if index >= len (segments ):
612- raise ValueError (
613- f"Plan { plan_index } delegate { delegate_index } "
614- + f"segment index { index } >= num segments { len (segments )} "
615- )
616-
617- data_index : int = len (program .backend_delegate_data )
618- program .backend_delegate_data .append (
619- BackendDelegateInlineData (data = segments [index ])
620- )
621- delegate .processed = BackendDelegateDataReference (
622- location = DataLocation .INLINE , index = data_index
623- )
703+ # Restore delegate segments that weren't inlined previously.
704+ program = _restore_delegates (program , segments )
624705
625706 # Replace constants from constant_segment into constant_buffer.
626707 if program .constant_segment and len (program .constant_segment .offsets ) > 0 :
627- buffers : List [Buffer ] = []
628- constant_segment = segments [program .constant_segment .segment_index ]
629- for i in range (len (program .constant_segment .offsets )):
630- start_offset = program .constant_segment .offsets [i ]
631- # Note: this is the original end offset plus any padding between
632- # it and the next start offset.
633- end_offset = (
634- program .constant_segment .offsets [i + 1 ]
635- if i < len (program .constant_segment .offsets ) - 1
636- else len (constant_segment )
708+ if program .constant_segment .segment_index >= len (segments ):
709+ raise ValueError (
710+ f"Constant segment index { program .constant_segment .segment_index } >= num segments { len (segments )} "
637711 )
638- buffers .append (Buffer (storage = constant_segment [start_offset :end_offset ]))
639- program .constant_buffer = buffers
712+ program .constant_buffer = _restore_constant_segment (
713+ program .constant_segment , segments [program .constant_segment .segment_index ]
714+ )
640715 program .constant_segment .segment_index = 0
641716 program .constant_segment .offsets = []
642717
643- # Clear out the segments list since the original Program didn't have one.
718+ # Extract mutable segments.
719+ mutable_data = None
720+ if program .mutable_data_segments and len (program .mutable_data_segments ) > 0 :
721+ if len (program .mutable_data_segments ) > 1 :
722+ raise ValueError ("Can't handle more than 1 mutable data segment." )
723+ segment_index = program .mutable_data_segments [0 ].segment_index
724+ if segment_index >= len (segments ):
725+ raise ValueError (
726+ f"Mutable data segment index { segment_index } >= num segments { len (segments )} "
727+ )
728+ mutable_data = _restore_constant_segment (
729+ program .mutable_data_segments [0 ],
730+ segments [segment_index ],
731+ )
732+ program .mutable_data_segments = None
733+
734+ # Extract named data.
735+ named_data = None
736+ if program .named_data :
737+ named_data = _restore_named_data (program , segments )
738+
739+ # Clear named_data and segments, which are empty pre-serialization.
740+ program .named_data = []
644741 program .segments = []
645- return program
646742
743+ return PTEFile (program = program , mutable_data = mutable_data , named_data = named_data )
647744
648- def deserialize_pte_binary (program_data : bytes ) -> Program :
649- """Returns a Program deserialized from the given runtime binary data."""
745+
746+ def deserialize_pte_binary (program_data : bytes ) -> PTEFile :
747+ """Returns a PTEFile deserialized from the given runtime binary data."""
650748 program_size = len (program_data )
651749 segment_base_offset = 0
652750
@@ -664,8 +762,8 @@ def deserialize_pte_binary(program_data: bytes) -> Program:
664762
665763 if segment_base_offset != 0 :
666764 # Move segment data back into the Program.
667- program = _restore_segments (
765+ return _restore_segments (
668766 program = program , segment_data = program_data [segment_base_offset :]
669767 )
670768
671- return program
769+ return PTEFile ( program = program , mutable_data = None , named_data = None )
0 commit comments