|
13 | 13 | from collections import defaultdict |
14 | 14 |
|
15 | 15 | import torch |
16 | | -from functorch import make_fx |
17 | 16 | from graph_net.torch import utils |
18 | 17 |
|
19 | 18 |
|
@@ -376,62 +375,6 @@ def collect_op_stats_with_symbolic_trace(model, sample_inputs, device): |
376 | 375 | return meta_executor.is_complete, meta_executor.op_stats |
377 | 376 |
|
378 | 377 |
|
379 | | -def collect_op_stats_with_make_fx(model, sample_inputs): |
380 | | - # Use meta tensors as input to avoid actually running the model |
381 | | - meta_input_list = convert_real_to_meta(sample_inputs) |
382 | | - |
383 | | - try: |
384 | | - # Generate FX Graph, and automatically fill in meta information |
385 | | - fx_model = make_fx(model)(*meta_input_list) |
386 | | - except Exception: |
387 | | - print("Failed to execute make_fx") |
388 | | - return False, None |
389 | | - |
390 | | - is_complete = True |
391 | | - op_stats = {} |
392 | | - for node in fx_model.graph.nodes: |
393 | | - op_name = None |
394 | | - if node.op == "call_module": |
395 | | - # classname of module |
396 | | - submod = fx_model.get_submodule(node.target) |
397 | | - op_name = submod.__class__.__name__ |
398 | | - elif node.op == "call_function": |
399 | | - op_name = node.target.__name__ |
400 | | - elif node.op == "call_method": |
401 | | - op_name = node.target |
402 | | - elif node.op in ["placeholder", "output", "get_attr"]: |
403 | | - op_name = node.op |
404 | | - else: |
405 | | - assert False, f"node.op: {node.op}" |
406 | | - |
407 | | - dtype = None |
408 | | - if node.op not in ["placeholder", "output"]: |
409 | | - if "tensor_meta" in node.meta: |
410 | | - tensor_meta = node.meta["tensor_meta"] |
411 | | - dtype = tensor_meta.dtype |
412 | | - # print(f"node.op={node.op}, node.target={node.target}, dtype={tensor_meta.dtype}") |
413 | | - else: |
414 | | - print( |
415 | | - f"node.op={node.op}, node.target={node.target} has no tensor_meta!" |
416 | | - ) |
417 | | - is_complete = False |
418 | | - |
419 | | - op_name = ( |
420 | | - op_name.replace(".default", "") |
421 | | - .replace(".Tensor", "") |
422 | | - .replace(".Scalar", "") |
423 | | - ) |
424 | | - dtype_str = str(dtype).replace("torch.", "") |
425 | | - if op_stats.get(op_name, None) is None: |
426 | | - op_stats[op_name] = OpStat(op_name, {dtype_str: 1}, 1) |
427 | | - else: |
428 | | - op_stats[op_name].op_dtypes[dtype_str] = ( |
429 | | - op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1 |
430 | | - ) |
431 | | - op_stats[op_name].count = op_stats[op_name].count + 1 |
432 | | - return is_complete, op_stats |
433 | | - |
434 | | - |
435 | 378 | def collect_op_stats(model, sample_inputs, device): |
436 | 379 | is_complete_symbolic, op_stats_symbolic = collect_op_stats_with_symbolic_trace( |
437 | 380 | model, sample_inputs, device |
|
0 commit comments