@@ -177,14 +177,13 @@ def _build_function_memory_map(function_metadata):
177177 """
178178 device_max_workspace = dict ()
179179 main_func_metadata = function_metadata [MAIN_FUNC_NAME_STR ]
180- num_targets = len (main_func_metadata .workspace_sizes . items () )
180+ main_targets = dict (main_func_metadata .workspace_sizes ). keys ( )
181181 from tvm .driver import tvmc # pylint: disable=import-outside-toplevel
182182
183183 external_codegens = tvmc .composite_target .get_codegen_names ()
184184 func_entries = []
185185 target_local_entries = dict ()
186- for i in range (num_targets ):
187- main_target = main_func_metadata .workspace_sizes .items ()[i ][0 ]
186+ for main_target in main_targets :
188187 device_max_workspace [main_target ] = 0
189188 for func_name , finfo in function_metadata .items ():
190189 if func_name == MAIN_FUNC_NAME_STR :
@@ -197,22 +196,18 @@ def _build_function_memory_map(function_metadata):
197196 # 2. BYOC operator implementations do not currently export useful FunctionInfo.
198197 if func_name == MAIN_FUNC_NAME_STR or not finfo .tir_primfuncs :
199198 continue
200- assert (
201- len (finfo .constant_sizes .items ()) == num_targets
202- ), f"{ func_name } : found { finfo .constant_sizes !r} vs { num_targets } "
203- assert len (finfo .io_sizes .items ()) == num_targets
204- target = finfo .workspace_sizes .items ()[i ][0 ]
205- workspace_size = finfo .workspace_sizes .items ()[i ][1 ]
206- target_entry = {
207- "device" : int (target .kind .device_type ),
208- "workspace_size_bytes" : int (workspace_size ),
209- }
210- target_local_entries [func_name ].append (target_entry )
211- if workspace_size > device_max_workspace .get (target , 0 ):
212- device_max_workspace [target ] = workspace_size
213- # TODO(Mousius) - Remove this massive hack when Targets are unified
214- if target .kind .name in external_codegens :
215- device_max_workspace [main_target ] += int (workspace_size )
199+ if main_target in finfo .workspace_sizes .keys ():
200+ workspace_size = finfo .workspace_sizes [main_target ]
201+ target_entry = {
202+ "device" : int (main_target .kind .device_type ),
203+ "workspace_size_bytes" : int (workspace_size ),
204+ }
205+ target_local_entries [func_name ].append (target_entry )
206+ if workspace_size > device_max_workspace .get (main_target , 0 ):
207+ device_max_workspace [main_target ] = workspace_size
208+ # TODO(Mousius) - Remove this massive hack when Targets are unified
209+ if main_target .kind .name in external_codegens :
210+ device_max_workspace [main_target ] += int (workspace_size )
216211
217212 for func_name , target_entries_ in target_local_entries .items ():
218213 func_entry = {
@@ -222,15 +217,22 @@ def _build_function_memory_map(function_metadata):
222217 func_entries .append (func_entry )
223218
224219 target_main_entries = list ()
225- for i in range (num_targets ):
226- target = main_func_metadata .workspace_sizes .items ()[i ][0 ]
227- main_func_local_workspace = main_func_metadata .workspace_sizes .items ()[i ][1 ]
228- main_func_constants = main_func_metadata .constant_sizes .items ()[i ][1 ]
229- main_func_io = main_func_metadata .io_sizes .items ()[i ][1 ]
220+ for main_target in main_targets :
221+ main_func_local_workspace = main_func_metadata .workspace_sizes [main_target ]
222+ main_func_constants = (
223+ main_func_metadata .constant_sizes [main_target ]
224+ if main_target in main_func_metadata .constant_sizes .keys ()
225+ else 0
226+ )
227+ main_func_io = (
228+ main_func_metadata .io_sizes [main_target ]
229+ if main_target in main_func_metadata .io_sizes .keys ()
230+ else 0
231+ )
230232 target_main_entries .append (
231233 {
232- "device" : int (target .kind .device_type ),
233- "workspace_size_bytes" : int (device_max_workspace [target ])
234+ "device" : int (main_target .kind .device_type ),
235+ "workspace_size_bytes" : int (device_max_workspace [main_target ])
234236 + int (main_func_local_workspace ),
235237 "constants_size_bytes" : int (main_func_constants ),
236238 "io_size_bytes" : int (main_func_io ),
0 commit comments