@@ -108,34 +108,62 @@ def preprocess(
108108 options : dict [str , typing .Any ] = {
109109 # Do not link against the full PyTorch/libtorch library
110110 "aot_inductor.link_libtorch" : False ,
111- # Package model constants and other generated files directly in the shared object (.so) file
112- "aot_inductor.package_constants_in_so" : True ,
111+ # Separate weight constants from the .so file
112+ "aot_inductor.package" : True ,
113+ "aot_inductor.package_constants_in_so" : False ,
114+ # Store weight constants on disk in a binary blob
115+ "aot_inductor.package_constants_on_disk_format" : "binary_blob" ,
113116 # Enable maximum automatic tuning for optimal performance
114117 "max_autotune" : True ,
115118 # "aot_inductor.debug_compile": True,
116119 # "aot_inductor.force_mmap_weights": False,
117120 }
118121
119122 with collect_unsupported_fallback_kernels ():
120- so_path = torch ._inductor .aot_compile (edge_program_module , tuple (user_input_placeholders ), options = options ) # type: ignore[arg-type]
123+ paths = torch ._inductor .aot_compile (edge_program_module , tuple (user_input_placeholders ), options = options ) # type: ignore[arg-type]
121124 if len (missing_fallback_kernels ) > 0 :
122125 formatted_kernels = "\n - " .join (sorted (missing_fallback_kernels ))
123126 raise RuntimeError (
124127 f"Missing fallback kernels ({ len (missing_fallback_kernels )} total):\n - { formatted_kernels } \n "
125128 "Please add them to the AOTI backend."
126129 )
127130
131+ # Extract the .so and .blob paths from the returned list
132+ so_path = None
133+ blob_path = None
134+ for path in paths :
135+ if path .endswith (".wrapper.so" ):
136+ so_path = path
137+ elif path .endswith (".wrapper_weights.blob" ):
138+ blob_path = path
139+
140+ if so_path is None or blob_path is None :
141+ raise RuntimeError (
142+ f"Could not find required files in compiled paths, got { paths } "
143+ )
144+
128145 # pyre-ignorep[6]: Incompatible parameter type
129146 with open (so_path , "rb" ) as f :
130147 so_data = f .read ()
131148
132149 named_data_store = NamedDataStore ()
133150 method_name = MetalBackend .method_name_from_compile_specs (compile_specs )
151+
152+ # Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file.
153+ named_data_store .add_named_data (method_name + "_so_blob" , so_data , 1 , None )
154+
155+ # Add weights blob to named data store
156+ with open (blob_path , "rb" ) as f :
157+ blob_data = f .read ()
158+
134159 named_data_store .add_named_data (
135- method_name + "_so_blob " , so_data , 1 , "aoti_metal_blob"
160+ method_name + "_weights_blob " , blob_data , 1 , "aoti_metal_blob"
136161 )
137162
138- # Clean up the generated so file; it has been packaged into the NamdeDataStore
163+ # Clean up the weights blob file
164+ os .remove (blob_path )
165+
166+ # Clean up the generated so file; it has been packaged into the NamedDataStore
139167 # pyre-ignorep[6]: Incompatible parameter type
140168 os .remove (so_path )
141169
0 commit comments