Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def update_config_files(self, file_path: str, merge_name: str):

for i, path in enumerate(path_list):
if not os.path.exists(path):
logger.info(f"path {i+1}: {path} (not exist)")
logger.info(f"path {i + 1}: {path} (not exist)")
continue

df = pd.read_csv(path)
Expand Down Expand Up @@ -916,7 +916,7 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources):

def FinalFunc():
logger.info(
f"\033[32mfinish build [{md_name}], cost {time.perf_counter()-startTS:.1f}s \033[0m"
f"\033[32mfinish build [{md_name}], cost {time.perf_counter() - startTS:.1f}s \033[0m"
)

mp_lock(lockPath=lock_path, MainFunc=MainFunc, FinalFunc=FinalFunc)
Expand Down Expand Up @@ -1130,6 +1130,8 @@ def _ctypes_call(func, fc_name, md_name):

_cache = {}
_arg_checked = False
_sig = inspect.signature(func)
_hints = typing.get_type_hints(func)

def _ensure_loaded():
if _cache:
Expand Down Expand Up @@ -1168,8 +1170,7 @@ def _opt_sym(name, argtypes=(), restype=None):
err_getter = _opt_sym("aiter_get_last_error", restype=ctypes.c_char_p)
err_clear = _opt_sym("aiter_clear_last_error")

hints = typing.get_type_hints(func)
ret_hint = hints.get("return")
ret_hint = _hints.get("return")
ctypes_data_return = ctypes_status_mode and ret_hint is int

if ctypes_status_mode:
Expand All @@ -1183,8 +1184,8 @@ def _opt_sym(name, argtypes=(), restype=None):

argtypes = []
has_tensor = False
for pname in inspect.signature(func).parameters:
hint = hints.get(pname)
for pname in _sig.parameters:
hint = _hints.get(pname)
origin = typing.get_origin(hint)
type_args = typing.get_args(hint)
if hint is torch.Tensor:
Expand Down Expand Up @@ -1252,20 +1253,17 @@ def _check_args_before_convert(bound_args, hints):
elif hint is str:
if not isinstance(value, str):
raise TypeError(
f"{fc_name}: '{pname}' expects str, "
f"got {type(value).__name__}"
f"{fc_name}: '{pname}' expects str, got {type(value).__name__}"
)
elif hint is bool:
if not isinstance(value, (bool, int)):
raise TypeError(
f"{fc_name}: '{pname}' expects bool, "
f"got {type(value).__name__}"
f"{fc_name}: '{pname}' expects bool, got {type(value).__name__}"
)
elif hint is int:
if not isinstance(value, int):
raise TypeError(
f"{fc_name}: '{pname}' expects int, "
f"got {type(value).__name__}"
f"{fc_name}: '{pname}' expects int, got {type(value).__name__}"
)
elif hint is float:
if not isinstance(value, (float, int)):
Expand All @@ -1287,21 +1285,19 @@ def caller(*args, **kwargs):
from ..test_common import log_args

log_args(func, *args, **kwargs)
sig = inspect.signature(func)
bound = sig.bind(*args, **kwargs)
bound = _sig.bind(*args, **kwargs)
bound.apply_defaults()
hints = typing.get_type_hints(func)

if not _arg_checked:
_check_args_before_convert(bound.arguments, hints)
_check_args_before_convert(bound.arguments, _hints)
_arg_checked = True

c_args = []
aiter_refs = []
tensor_device = None

for pname, value in bound.arguments.items():
hint = hints.get(pname)
hint = _hints.get(pname)
origin = typing.get_origin(hint)
type_args = typing.get_args(hint)

Expand Down
Loading