Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test removing merge mask stuff #450

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 56 additions & 54 deletions tinygrad/shape/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,60 +163,62 @@ def __add__(self, vm1:View) -> Optional[View]:
if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))

# Project vm1's offset and strides on to vm2.
origin = un1d(vm2.shape, vm1.offset)
terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
strides: List[sint] = [0] * len(vm1.shape)
for d1, st in enumerate(vm1.strides):
if st == 0: continue
for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
if (s1 := s1 - o) == 0: continue
terms[d2].append((d1, s1))
strides[d1] += s1 * vm2.strides[d2]

# Merge dimensions in vm2 if required.
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
extents: List[Tuple[sint, UOp]] = []
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
merged_term += sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
merged_size *= s
if not resolve(merged_term >= merged_size) and not resolve(merged_term < 0):
extents.append((merged_size, merged_term))
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
if resolve(merged_term != 0): return None
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
return (reshaped_vm2 := vm2.reshape(vm2_shape)) and reshaped_vm2 + vm1

if vm2.mask:
# Try to project vm2's mask on to vm1.
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
if not (t.vmin < b or t.vmax >= e): continue
if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
bad = True
continue
term = terms[d2]
if len(term) != 1:
if not term and newe: newe[0] = 0
else: bad = True
continue
d1, s1 = term[0]
if not isinstance(s1, int) or not isinstance(newe[d1], int):
bad = True
continue
newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)

# If any of vm1 was masked off, try again with that mask in place.
for b, e, s in zip(newb, newe, vm1.shape):
if b != 0 or e != s:
return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
# Otherwise if vm2's mask was violated, then cannot merge.
if bad: return None

return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
return None

# # Project vm1's offset and strides on to vm2.
# origin = un1d(vm2.shape, vm1.offset)
# terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
# strides: List[sint] = [0] * len(vm1.shape)
# for d1, st in enumerate(vm1.strides):
# if st == 0: continue
# for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
# if (s1 := s1 - o) == 0: continue
# terms[d2].append((d1, s1))
# strides[d1] += s1 * vm2.strides[d2]

# # Merge dimensions in vm2 if required.
# # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
# idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
# merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
# extents: List[Tuple[sint, UOp]] = []
# for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
# merged_term += sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
# merged_size *= s
# if not resolve(merged_term >= merged_size) and not resolve(merged_term < 0):
# extents.append((merged_size, merged_term))
# merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
# if resolve(merged_term != 0): return None
# if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
# return (reshaped_vm2 := vm2.reshape(vm2_shape)) and reshaped_vm2 + vm1

# if vm2.mask:
# # Try to project vm2's mask on to vm1.
# newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
# for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
# if not (t.vmin < b or t.vmax >= e): continue
# if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
# bad = True
# continue
# term = terms[d2]
# if len(term) != 1:
# if not term and newe: newe[0] = 0
# else: bad = True
# continue
# d1, s1 = term[0]
# if not isinstance(s1, int) or not isinstance(newe[d1], int):
# bad = True
# continue
# newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
# newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)

# # If any of vm1 was masked off, try again with that mask in place.
# for b, e, s in zip(newb, newe, vm1.shape):
# if b != 0 or e != s:
# return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
# # Otherwise if vm2's mask was violated, then cannot merge.
# if bad: return None

# return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)

@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]:
Expand Down
Loading