Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

native upcasts #113

Merged
merged 12 commits into from
Sep 18, 2023
5 changes: 3 additions & 2 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ jobs:
- name: Build wheels
uses: pypa/[email protected]
# (here: set these in pyproject.toml to the extent possible)
# env:
# CIBW_SOME_OPTION: value
env:
# CIBW_SOME_OPTION: value
CMAKE_BUILD_PARALLEL_LEVEL: 3
# ...
# with:
# package-dir: .
Expand Down
88 changes: 87 additions & 1 deletion gen_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,9 @@ def write_exposer(outf, meth, arg_names, doc_str):
# }}}


wrapped_isl_functions = set()


def write_wrappers(expf, wrapf, methods):
undoc = []

Expand Down Expand Up @@ -1448,7 +1451,7 @@ def write_wrappers(expf, wrapf, methods):
_, e, _ = sys.exc_info()
print(f"SKIP (sig not supported: {e}): {meth}")
else:
# print("WRAPPED:", meth)
wrapped_isl_functions.add(meth.name)
pass

print("SKIP ({} undocumented methods): {}".format(len(undoc), ", ".join(undoc)))
Expand All @@ -1463,6 +1466,56 @@ def write_wrappers(expf, wrapf, methods):
}


upcasts = {}


def add_upcasts(basic_class, special_class, fmap, expf):

def my_ismethod(method):
if method.name.endswith("_si") or method.name.endswith("_ui"):
return False

if method.name not in wrapped_isl_functions:
return False

if method.is_static:
return False

return True

expf.write(f"\n// {{{{{{ Upcasts from {basic_class} to {special_class}\n\n")

for special_method in fmap[special_class]:
if not my_ismethod(special_method):
continue

found = False

for basic_method in fmap[basic_class]:
if basic_method.name == special_method.name:
found = True
break

if found:
if not my_ismethod(basic_method):
continue

else:
if (basic_class in upcasts
and special_method.name in upcasts[basic_class]):
continue

upcasts.setdefault(basic_class, []).append(special_method.name)

doc_str = (f'"\\n\\nUpcast from :class:`{to_py_class(basic_class)}`'
+ f' to :class:`{to_py_class(special_class)}`\\n"')

expf.write(f'wrap_{basic_class}.def("{special_method.name}", '
f"isl::{special_class}_{special_method.name}, {doc_str});\n")

expf.write("\n// }}}\n\n")


def gen_wrapper(include_dirs, include_barvinok=False, isl_version=None):
fdata = FunctionData(["."] + include_dirs)
fdata.read_header("isl/ctx.h")
Expand Down Expand Up @@ -1512,6 +1565,39 @@ def gen_wrapper(include_dirs, include_barvinok=False, isl_version=None):
for cls in classes
for meth in fdata.classes_to_methods.get(cls, [])])

# {{{ add automatic 'self' upcasts

# note: automatic upcasts for method arguments are provided through
# 'implicitly_convertible'.

if part == "part1":
add_upcasts("aff", "pw_aff", fdata.classes_to_methods, expf)
add_upcasts("pw_aff", "union_pw_aff", fdata.classes_to_methods, expf)
add_upcasts("aff", "union_pw_aff", fdata.classes_to_methods, expf)

add_upcasts("space", "local_space", fdata.classes_to_methods, expf)

add_upcasts("multi_aff", "pw_multi_aff", fdata.classes_to_methods, expf)
add_upcasts("pw_multi_aff", "union_pw_multi_aff",
fdata.classes_to_methods, expf)
add_upcasts("multi_aff", "union_pw_multi_aff",
fdata.classes_to_methods, expf)

elif part == "part2":
add_upcasts("basic_set", "set", fdata.classes_to_methods, expf)
add_upcasts("set", "union_set", fdata.classes_to_methods, expf)
add_upcasts("basic_set", "union_set", fdata.classes_to_methods, expf)

add_upcasts("basic_map", "map", fdata.classes_to_methods, expf)
add_upcasts("map", "union_map", fdata.classes_to_methods, expf)
add_upcasts("basic_map", "union_map", fdata.classes_to_methods, expf)

elif part == "part3":
# empty
pass

# }}}

expf.close()
wrapf.close()

Expand Down
106 changes: 2 additions & 104 deletions islpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ def set_get_basic_sets(self):
return result

Set.get_basic_sets = set_get_basic_sets
BasicSet.get_basic_sets = set_get_basic_sets

# }}}

Expand Down Expand Up @@ -688,6 +689,7 @@ def pw_get_aggregate_domain(self):
return result

PwAff.get_pieces = pwaff_get_pieces
Aff.get_pieces = pwaff_get_pieces
PwAff.get_aggregate_domain = pw_get_aggregate_domain

PwQPolynomial.get_pieces = pwqpolynomial_get_pieces
Expand Down Expand Up @@ -855,110 +857,6 @@ def val_to_python(self):

# }}}

# {{{ add automatic 'self' upcasts

# note: automatic upcasts for method arguments are provided through
# 'implicitly_convertible' on the C++ side of the wrapper.

def make_new_upcast_wrapper(method, upcast):
# This function provides a scope in which method and upcast
# are not changed from one iteration of the enclosing for
# loop to the next.

def wrapper(basic_instance, *args, **kwargs):
special_instance = upcast(basic_instance)
return method(special_instance, *args, **kwargs)

return wrapper

def make_existing_upcast_wrapper(basic_method, special_method, upcast):
# This function provides a scope in which method and upcast
# are not changed from one iteration of the enclosing for
# loop to the next.

def wrapper(basic_instance, *args, **kwargs):
try:
return basic_method(basic_instance, *args, **kwargs)
except TypeError:
pass

special_instance = upcast(basic_instance)
return special_method(special_instance, *args, **kwargs)

return wrapper

def add_upcasts(basic_class, special_class, upcast_method):
from functools import update_wrapper

def my_ismethod(class_, method_name):
if method_name.startswith("_"):
return False

method = getattr(class_, method_name)

if not callable(method):
return False

# Here we're desperately trying to filter out static methods,
# based on what seems to be a common feature.
if type(method).__name__ == "nb_func":
return False

return True

for method_name in dir(special_class):
special_method = getattr(special_class, method_name)

if not my_ismethod(special_class, method_name):
continue

if hasattr(basic_class, method_name):
# method already exists in basic class
basic_method = getattr(basic_class, method_name)

if not my_ismethod(basic_class, method_name):
continue

wrapper = make_existing_upcast_wrapper(
basic_method, special_method, upcast_method)
setattr(
basic_class, method_name,
update_wrapper(wrapper, basic_method))
else:
# method does not yet exists in basic class

wrapper = make_new_upcast_wrapper(special_method, upcast_method)
setattr(
basic_class, method_name,
update_wrapper(wrapper, special_method))

for args_triple in [
(BasicSet, Set, Set.from_basic_set),
(Set, UnionSet, UnionSet.from_set),
(BasicSet, UnionSet, lambda x: UnionSet.from_set(Set.from_basic_set(x))),

(BasicMap, Map, Map.from_basic_map),
(Map, UnionMap, UnionMap.from_map),
(BasicMap, UnionMap, lambda x: UnionMap.from_map(Map.from_basic_map(x))),

(Aff, PwAff, PwAff.from_aff),
(PwAff, UnionPwAff, UnionPwAff.from_pw_aff),
(Aff, UnionPwAff, UnionPwAff.from_aff),

(MultiAff, PwMultiAff, PwMultiAff.from_multi_aff),
(PwMultiAff, UnionPwMultiAff, UnionPwMultiAff.from_pw_multi_aff),
(MultiAff, UnionPwMultiAff, UnionPwMultiAff.from_multi_aff),

(Space, LocalSpace, LocalSpace.from_space),
]:
add_upcasts(*args_triple)

# }}}

# ORDERING DEPENDENCY: The availability of some of the 'is_equal'
# used by rich comparison below depends on the self upcasts created
# above.

# {{{ rich comparisons

def obj_eq(self, other):
Expand Down