Skip to content

Commit d6f9b69

Browse files
authored
add bool (apache#12)
2 parents 7d8ccad + f9af194 commit d6f9b69

File tree

3 files changed

+29
-13
lines changed

3 files changed

+29
-13
lines changed

frontend/guard_tracker.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from .pycode_generator import GraphFnCodegen, GuardFnCodegen
2424
from .fx_graph import FxGraph, get_frame_root, is_leaf_module, NodeArgs
2525
from .bytecode_analysis import livevars_analysis
26-
from .variables.tuple_ import TupleVar
2726
from .variables.base import Variable
2827

2928

@@ -35,6 +34,23 @@ class PartialVar:
3534
inplace_ref: Any = None # None if not inplace
3635

3736

37+
class Bool_():
38+
value: bool
39+
40+
def __init__(self, value: bool) -> None:
41+
self.value = value
42+
43+
#todo: overwirte xor, or, and, others uses super()
44+
def __and__(self, operator: bool) -> "Bool_":
45+
return Bool_(self.value and operator)
46+
47+
def __or__(self, operator: bool) -> "Bool_":
48+
return Bool_(self.value or operator)
49+
50+
def __not__(self) -> "Bool_":
51+
return Bool_(not self.value)
52+
53+
3854
class State:
3955
objects: ObjectTable
4056
start_pc: int
@@ -639,8 +655,14 @@ def process_last_inst(self) -> None:
639655
if self.state.num_new_refs == -1:
640656
self.state.num_new_refs = get_value_stack_size(self.frame)
641657
for i in range(self.state.num_new_refs):
642-
self.state.object_refs.append(
643-
get_value_stack_from_top(self.frame, i))
658+
obj = get_value_stack_from_top(self.frame, i)
659+
if isinstance(obj, bool):
660+
new_bool = Bool_(obj)
661+
var_bool = vs.ScalarVar(obj, True, False)
662+
self.state.objects.update_by_id(var_bool, id(new_bool))
663+
self.state.object_refs.append(new_bool)
664+
else:
665+
self.state.object_refs.append(obj)
644666
self.state.num_new_refs = 0
645667
for i, obj in enumerate(self.state.inplace_update_objs):
646668
assert not isinstance(obj, torch.Tensor)

frontend/object_table.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ def __init__(self) -> None:
1717
self.objs_no_id = []
1818

1919
def add(self, var: Variable, value: Any) -> None:
20-
if isinstance(value, bool):
21-
self.objs_no_id.append(var)
22-
elif id(value) in self.objs:
20+
if id(value) in self.objs:
2321
old_var = self.objs[id(value)]
2422
old_var.extract_code_at_start.extend(var.extract_code_at_start)
2523
old_var.need_guard_check |= var.need_guard_check
@@ -47,9 +45,7 @@ def get_all_with_id(self) -> list[Tuple[int, Variable]]:
4745
return list(self.objs.items())
4846

4947
def get(self, value: Any, allow_unexist_const: bool = False) -> Variable:
50-
if isinstance(value, bool):
51-
return ScalarVar(value, True, False)
52-
elif id(value) in self.objs:
48+
if id(value) in self.objs:
5349
return self.objs[id(value)]
5450
elif allow_unexist_const:
5551
if isinstance(value, get_args(CONST_TYPES)) or isinstance(
@@ -74,10 +70,7 @@ def get_or_make_var(self,
7470
need_guard_check: bool,
7571
fx_graph: Optional[FxGraph] = None,
7672
extract_code_at_start: list[StorePos] = []) -> Variable:
77-
if isinstance(value, bool):
78-
return ScalarVar(value, True, need_guard_check,
79-
extract_code_at_start)
80-
elif id(value) in self.objs:
73+
if id(value) in self.objs:
8174
return self.objs[id(value)]
8275
else:
8376
return make_var_from_value(value, need_guard_check,

frontend/variables/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
float: ScalarVar,
2020
int: ScalarVar,
2121
str: ScalarVar,
22+
bool: ScalarVar,
2223
torch.Tensor: TensorVar,
2324
NullObject: NullVar,
2425
type(None): NoneVar,

0 commit comments

Comments
 (0)