|
9 | 9 | from collections import namedtuple |
10 | 10 | from contextlib import contextmanager |
11 | 11 | from types import MethodType |
12 | | -from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Union |
| 12 | +from typing import Any, Callable, cast, List, Optional, Tuple |
13 | 13 |
|
14 | 14 | import torch |
15 | 15 | from executorch.exir.capture._config import CaptureConfig |
16 | 16 | from executorch.exir.error import ExportError, ExportErrorType, InternalError |
17 | | -from executorch.exir.program import ExirExportedProgram, MultiMethodExirExportedProgram |
| 17 | +from executorch.exir.program import ExirExportedProgram |
18 | 18 | from executorch.exir.program._program import _transform, HackedUpExportedProgramDONOTUSE |
19 | 19 | from executorch.exir.tracer import ( |
20 | 20 | _default_decomposition_table, |
@@ -360,137 +360,6 @@ def convert_to_fake(x): |
360 | 360 | return ExirExportedProgram(ep, False) |
361 | 361 |
|
362 | 362 |
|
363 | | -@compatibility(is_backward_compatible=False) |
364 | | -def capture_multiple( |
365 | | - m: Union[torch.nn.Module, Callable[..., Any]], |
366 | | - args: Union[Dict[str, Tuple[Value, ...]], Tuple[Value, ...]], |
367 | | - config: Optional[CaptureConfig] = None, |
368 | | - prim_getters: Optional[Set[str]] = None, |
369 | | - dynamic_shapes: Optional[Union[Dict[str, Any], List[Any]]] = None, |
370 | | -): |
371 | | - """ |
372 | | - capture_multiple traces either an nn.Module or just a callable with PyTorch |
373 | | - operations inside and produce a single MultiMethodExirExportedProgram that |
374 | | - can potentially have multiple entry points. When multiple entry points |
375 | | - are traced, each of them is stored separately in the resulting |
376 | | - MultiMethodExirExportedProgram without sharing state. |
377 | | -
|
378 | | - Args: |
379 | | - m: the `nn.Module` or callable to trace. |
380 | | -
|
381 | | - args: Tracing example inputs. |
382 | | -
|
383 | | - When `m` is an nn.Module, `args` can be |
384 | | - 1) A dictionary that maps names of method to their tracing example inputs. |
385 | | - in this case, all specified methods will be captured. |
386 | | - 2) A tuple. In this case, `forward` method of `m` will be captured. It is |
387 | | - equivalent to passing {"forward", tuple-type-args} |
388 | | -
|
389 | | - When `m` is a non-Module callable, `args` must be a Tuple containing |
390 | | - tracing example inputs. |
391 | | -
|
392 | | - config: A CaptureConfig object that specifies how to interpret the |
393 | | - program being captured. |
394 | | -
|
395 | | - prim_getters: A set of primitive getter functions to capture the return values of |
396 | | -
|
397 | | - dynamic_shapes: Input dynamic shapes. |
398 | | -
|
399 | | - When `m` is an nn.Module, `dynamic_shapes` is a dictionary that maps names of method |
400 | | - to their input dynamic shapes. |
401 | | -
|
402 | | - When `m` is a non-Module callable, `dynamic_shapes` is a list of input dynamic shapes. |
403 | | -
|
404 | | - Returns: |
405 | | - A MultiMethodExirExportedProgram. |
406 | | -
|
407 | | - if `m` is an nn.Module, returned program would have multiple |
408 | | - captured methods, each corresponding to one entry in args dictionary. |
409 | | -
|
410 | | - if `m` is a non-Module callable, returned program would have a single |
411 | | - captured method named `forward`. |
412 | | -
|
413 | | - Raises: |
414 | | - AssertionError if given method name do not reference a valid method |
415 | | - on the given nn.Module. |
416 | | - """ |
417 | | - warnings.warn( |
418 | | - "This function is now deprecated, please use `torch.export and exir.to_edge` instead.", |
419 | | - DeprecationWarning, |
420 | | - stacklevel=1, |
421 | | - ) |
422 | | - # Normalize m and args. |
423 | | - compile_specs = [] |
424 | | - prim_getter_cache: Optional[Dict[str, Any]] = None |
425 | | - if isinstance(m, torch.nn.Module): |
426 | | - if dynamic_shapes is not None: |
427 | | - assert isinstance( |
428 | | - dynamic_shapes, dict |
429 | | - ), f"Expected a dict for dynamic_shapes, got {type(dynamic_shapes)}" |
430 | | - |
431 | | - if isinstance(args, tuple): |
432 | | - compile_specs.append( |
433 | | - CompileSpec( |
434 | | - "forward", |
435 | | - m.forward, |
436 | | - args, |
437 | | - ( |
438 | | - dynamic_shapes["forward"] |
439 | | - if dynamic_shapes and "forward" in dynamic_shapes |
440 | | - else None |
441 | | - ), |
442 | | - ) |
443 | | - ) |
444 | | - else: |
445 | | - assert isinstance( |
446 | | - args, dict |
447 | | - ), f"Expected a tuple or Dict[str, tuple], got {type(args)}" |
448 | | - for method_name, method_args in args.items(): |
449 | | - compile_specs.append( |
450 | | - CompileSpec( |
451 | | - method_name, |
452 | | - getattr(m, method_name), |
453 | | - method_args, |
454 | | - ( |
455 | | - dynamic_shapes[method_name] |
456 | | - if dynamic_shapes and method_name in dynamic_shapes |
457 | | - else None |
458 | | - ), |
459 | | - ) |
460 | | - ) |
461 | | - if prim_getters is not None: |
462 | | - prim_getter_cache = {} |
463 | | - for getter in prim_getters: |
464 | | - prim_getter_cache[getter] = getattr(m, getter)() |
465 | | - else: |
466 | | - # Reaching here means `m` is a non-Module callable. |
467 | | - assert isinstance( |
468 | | - m, Callable |
469 | | - ), f"Only nn.Module or callable allowed, got {type(m)}" |
470 | | - assert isinstance( |
471 | | - args, tuple |
472 | | - ), f"When tracing a non-Module callable, `args` must be a tuple of tracing inputs, but got {type(args)}" |
473 | | - assert ( |
474 | | - prim_getters is None |
475 | | - ), "Caller should not specify primitive getter functions when only providing a callable as input" |
476 | | - if dynamic_shapes is not None: |
477 | | - assert isinstance( |
478 | | - dynamic_shapes, list |
479 | | - ), f"Expected a list for constraints, got {type(dynamic_shapes)}" |
480 | | - compile_specs.append(CompileSpec("forward", m, args, dynamic_shapes)) |
481 | | - |
482 | | - method_name_to_prog = {} |
483 | | - for compile_spec in compile_specs: |
484 | | - method_name_to_prog[compile_spec.method_name] = capture( |
485 | | - compile_spec.callable, |
486 | | - compile_spec.args, |
487 | | - config, |
488 | | - compile_spec.dynamic_shapes, |
489 | | - ) |
490 | | - |
491 | | - return MultiMethodExirExportedProgram(method_name_to_prog, prim_getter_cache) |
492 | | - |
493 | | - |
494 | 363 | # This is to bootstrap the missing meta["val"] when 1. ph consists of scalar |
495 | 364 | # 2. meta["val"] is not properly set in dispatch_trace. |
496 | 365 | def _instantiate_missing_placeholder_val_with_real_inputs(gm, args): |
|
0 commit comments