Skip to content

Commit

Permalink
native upcasts (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener authored Sep 18, 2023
1 parent c2b235d commit 66e0f1c
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 107 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ jobs:
- name: Build wheels
uses: pypa/[email protected]
# (here: set these in pyproject.toml to the extent possible)
# env:
# CIBW_SOME_OPTION: value
env:
# CIBW_SOME_OPTION: value
CMAKE_BUILD_PARALLEL_LEVEL: 3
# ...
# with:
# package-dir: .
Expand Down
88 changes: 87 additions & 1 deletion gen_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,9 @@ def write_exposer(outf, meth, arg_names, doc_str):
# }}}


wrapped_isl_functions = set()


def write_wrappers(expf, wrapf, methods):
undoc = []

Expand Down Expand Up @@ -1448,7 +1451,7 @@ def write_wrappers(expf, wrapf, methods):
_, e, _ = sys.exc_info()
print(f"SKIP (sig not supported: {e}): {meth}")
else:
# print("WRAPPED:", meth)
wrapped_isl_functions.add(meth.name)
pass

print("SKIP ({} undocumented methods): {}".format(len(undoc), ", ".join(undoc)))
Expand All @@ -1463,6 +1466,56 @@ def write_wrappers(expf, wrapf, methods):
}


upcasts = {}


def add_upcasts(basic_class, special_class, fmap, expf):

def my_ismethod(method):
if method.name.endswith("_si") or method.name.endswith("_ui"):
return False

if method.name not in wrapped_isl_functions:
return False

if method.is_static:
return False

return True

expf.write(f"\n// {{{{{{ Upcasts from {basic_class} to {special_class}\n\n")

for special_method in fmap[special_class]:
if not my_ismethod(special_method):
continue

found = False

for basic_method in fmap[basic_class]:
if basic_method.name == special_method.name:
found = True
break

if found:
if not my_ismethod(basic_method):
continue

else:
if (basic_class in upcasts
and special_method.name in upcasts[basic_class]):
continue

upcasts.setdefault(basic_class, []).append(special_method.name)

doc_str = (f'"\\n\\nUpcast from :class:`{to_py_class(basic_class)}`'
+ f' to :class:`{to_py_class(special_class)}`\\n"')

expf.write(f'wrap_{basic_class}.def("{special_method.name}", '
f"isl::{special_class}_{special_method.name}, {doc_str});\n")

expf.write("\n// }}}\n\n")


def gen_wrapper(include_dirs, include_barvinok=False, isl_version=None):
fdata = FunctionData(["."] + include_dirs)
fdata.read_header("isl/ctx.h")
Expand Down Expand Up @@ -1512,6 +1565,39 @@ def gen_wrapper(include_dirs, include_barvinok=False, isl_version=None):
for cls in classes
for meth in fdata.classes_to_methods.get(cls, [])])

# {{{ add automatic 'self' upcasts

# note: automatic upcasts for method arguments are provided through
# 'implicitly_convertible'.

if part == "part1":
add_upcasts("aff", "pw_aff", fdata.classes_to_methods, expf)
add_upcasts("pw_aff", "union_pw_aff", fdata.classes_to_methods, expf)
add_upcasts("aff", "union_pw_aff", fdata.classes_to_methods, expf)

add_upcasts("space", "local_space", fdata.classes_to_methods, expf)

add_upcasts("multi_aff", "pw_multi_aff", fdata.classes_to_methods, expf)
add_upcasts("pw_multi_aff", "union_pw_multi_aff",
fdata.classes_to_methods, expf)
add_upcasts("multi_aff", "union_pw_multi_aff",
fdata.classes_to_methods, expf)

elif part == "part2":
add_upcasts("basic_set", "set", fdata.classes_to_methods, expf)
add_upcasts("set", "union_set", fdata.classes_to_methods, expf)
add_upcasts("basic_set", "union_set", fdata.classes_to_methods, expf)

add_upcasts("basic_map", "map", fdata.classes_to_methods, expf)
add_upcasts("map", "union_map", fdata.classes_to_methods, expf)
add_upcasts("basic_map", "union_map", fdata.classes_to_methods, expf)

elif part == "part3":
# empty
pass

# }}}

expf.close()
wrapf.close()

Expand Down
106 changes: 2 additions & 104 deletions islpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ def set_get_basic_sets(self):
return result

Set.get_basic_sets = set_get_basic_sets
BasicSet.get_basic_sets = set_get_basic_sets

# }}}

Expand Down Expand Up @@ -688,6 +689,7 @@ def pw_get_aggregate_domain(self):
return result

PwAff.get_pieces = pwaff_get_pieces
Aff.get_pieces = pwaff_get_pieces
PwAff.get_aggregate_domain = pw_get_aggregate_domain

PwQPolynomial.get_pieces = pwqpolynomial_get_pieces
Expand Down Expand Up @@ -855,110 +857,6 @@ def val_to_python(self):

# }}}

# {{{ add automatic 'self' upcasts

# note: automatic upcasts for method arguments are provided through
# 'implicitly_convertible' on the C++ side of the wrapper.

def make_new_upcast_wrapper(method, upcast):
# This function provides a scope in which method and upcast
# are not changed from one iteration of the enclosing for
# loop to the next.

def wrapper(basic_instance, *args, **kwargs):
special_instance = upcast(basic_instance)
return method(special_instance, *args, **kwargs)

return wrapper

def make_existing_upcast_wrapper(basic_method, special_method, upcast):
# This function provides a scope in which method and upcast
# are not changed from one iteration of the enclosing for
# loop to the next.

def wrapper(basic_instance, *args, **kwargs):
try:
return basic_method(basic_instance, *args, **kwargs)
except TypeError:
pass

special_instance = upcast(basic_instance)
return special_method(special_instance, *args, **kwargs)

return wrapper

def add_upcasts(basic_class, special_class, upcast_method):
from functools import update_wrapper

def my_ismethod(class_, method_name):
if method_name.startswith("_"):
return False

method = getattr(class_, method_name)

if not callable(method):
return False

# Here we're desperately trying to filter out static methods,
# based on what seems to be a common feature.
if type(method).__name__ == "nb_func":
return False

return True

for method_name in dir(special_class):
special_method = getattr(special_class, method_name)

if not my_ismethod(special_class, method_name):
continue

if hasattr(basic_class, method_name):
# method already exists in basic class
basic_method = getattr(basic_class, method_name)

if not my_ismethod(basic_class, method_name):
continue

wrapper = make_existing_upcast_wrapper(
basic_method, special_method, upcast_method)
setattr(
basic_class, method_name,
update_wrapper(wrapper, basic_method))
else:
# method does not yet exists in basic class

wrapper = make_new_upcast_wrapper(special_method, upcast_method)
setattr(
basic_class, method_name,
update_wrapper(wrapper, special_method))

for args_triple in [
(BasicSet, Set, Set.from_basic_set),
(Set, UnionSet, UnionSet.from_set),
(BasicSet, UnionSet, lambda x: UnionSet.from_set(Set.from_basic_set(x))),

(BasicMap, Map, Map.from_basic_map),
(Map, UnionMap, UnionMap.from_map),
(BasicMap, UnionMap, lambda x: UnionMap.from_map(Map.from_basic_map(x))),

(Aff, PwAff, PwAff.from_aff),
(PwAff, UnionPwAff, UnionPwAff.from_pw_aff),
(Aff, UnionPwAff, UnionPwAff.from_aff),

(MultiAff, PwMultiAff, PwMultiAff.from_multi_aff),
(PwMultiAff, UnionPwMultiAff, UnionPwMultiAff.from_pw_multi_aff),
(MultiAff, UnionPwMultiAff, UnionPwMultiAff.from_multi_aff),

(Space, LocalSpace, LocalSpace.from_space),
]:
add_upcasts(*args_triple)

# }}}

# ORDERING DEPENDENCY: The availability of some of the 'is_equal'
# used by rich comparison below depends on the self upcasts created
# above.

# {{{ rich comparisons

def obj_eq(self, other):
Expand Down

0 comments on commit 66e0f1c

Please sign in to comment.