Skip to content

Commit

Permalink
setitem with grad
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz committed May 17, 2024
1 parent ca1df20 commit 74a2fae
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
4 changes: 2 additions & 2 deletions tinygrad/shape/shapetracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def __add__(self, st:ShapeTracker) -> ShapeTracker:
for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
return ret

def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]:
ret = tuple(v.invert(s) for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]))
def invert(self, out_shape:Tuple[sint, ...], unsafe=False) -> Optional[ShapeTracker]:
ret = tuple(v.invert(s, unsafe) for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]))
return ShapeTracker(cast(Tuple[View, ...], ret)).reshape(out_shape) if all(x is not None for x in ret) else None

@staticmethod
Expand Down
9 changes: 7 additions & 2 deletions tinygrad/shape/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,16 @@ def __add__(self, vm1:View) -> Optional[View]:
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)

@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]:
def invert(self, out_shape:Tuple[sint, ...], unsafe=False) -> Optional[View]:
ret = View.create(self.shape)
if self.mask: ret = ret.shrink(self.mask)
ret = ret.stride(tuple(-1 if x < 0 else 1 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
if prod(ret.shape) == prod(out_shape): return ret
if not unsafe: return None # doesn't support shrink, expand, or stride != (-1, 1)
# support shrink
offsets = un1d(self.shape, self.offset)
ret = ret.pad(tuple((offset, s-r-offset) for offset,s,r in zip(offsets,out_shape,ret.shape)))
return ret if prod(ret.shape) == prod(out_shape) else None

@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def minify(self):
Expand Down
15 changes: 14 additions & 1 deletion tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,20 @@ def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
assert all(lb.st.contiguous for lb in self.lazydata.lbs), "setitem target needs to be contiguous"
if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")

if self.requires_grad or v.requires_grad:
# requires_grad can be None though
v = v.cast(self.dtype)
st = self.__getitem__(indices).lazydata.st
st = st.invert(self.shape, unsafe=True)
assert st is not None, "not supported"
mask = Tensor(v.full_like(True, dtype=dtypes.bool).lazydata._view(st))
inverted = Tensor(v.lazydata._view(st))
print(f"{mask.numpy()=}")
print(f"{inverted.numpy()=}")
self.assign(mask.where(inverted, self))
return

if isinstance(indices, (Tensor, list)) or (isinstance(indices, tuple) and any(isinstance(i, (Tensor, list)) for i in indices)):
raise NotImplementedError("Advanced indexing setitem is not currently supported")

Expand Down

0 comments on commit 74a2fae

Please sign in to comment.