From 74a2fae8078f0e3e6e85f4088dfaa58120313461 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 16 May 2024 09:18:57 -0700 Subject: [PATCH] setitem with grad --- tinygrad/shape/shapetracker.py | 4 ++-- tinygrad/shape/view.py | 9 +++++++-- tinygrad/tensor.py | 15 ++++++++++++++- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 1569d85d61acd..8ff5150b0ab6b 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -24,8 +24,8 @@ def __add__(self, st:ShapeTracker) -> ShapeTracker: for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification return ret - def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]: - ret = tuple(v.invert(s) for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape])) + def invert(self, out_shape:Tuple[sint, ...], unsafe=False) -> Optional[ShapeTracker]: + ret = tuple(v.invert(s, unsafe) for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape])) return ShapeTracker(cast(Tuple[View, ...], ret)).reshape(out_shape) if all(x is not None for x in ret) else None @staticmethod diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index e2caece155dfd..28f7a93cb5b37 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -192,11 +192,16 @@ def __add__(self, vm1:View) -> Optional[View]: return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset) @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none - def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]: + def invert(self, out_shape:Tuple[sint, ...], unsafe=False) -> Optional[View]: ret = View.create(self.shape) if self.mask: ret = ret.shrink(self.mask) ret = ret.stride(tuple(-1 if x < 0 else 1 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides))) - return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1) + if prod(ret.shape) == prod(out_shape): return ret + if not unsafe: return None # doesn't support shrink, expand, or stride != (-1, 1) + # support shrink + offsets = un1d(self.shape, self.offset) + ret = ret.pad(tuple((offset, s-r-offset) for offset,s,r in zip(offsets,out_shape,ret.shape))) + return ret if prod(ret.shape) == prod(out_shape) else None @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def minify(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 71589842a51c1..7feceaf118156 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -822,7 +822,20 @@ def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None: assert all(lb.st.contiguous for lb in self.lazydata.lbs), "setitem target needs to be contiguous" if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor") if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype) - if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported") + + if self.requires_grad or v.requires_grad: + # requires_grad can be None though + v = v.cast(self.dtype) + st = self.__getitem__(indices).lazydata.st + st = st.invert(self.shape, unsafe=True) + assert st is not None, "not supported" + mask = Tensor(v.full_like(True, dtype=dtypes.bool).lazydata._view(st)) + inverted = Tensor(v.lazydata._view(st)) + print(f"{mask.numpy()=}") + print(f"{inverted.numpy()=}") + self.assign(mask.where(inverted, self)) + return + if isinstance(indices, (Tensor, list)) or (isinstance(indices, tuple) and any(isinstance(i, (Tensor, list)) for i in indices)): raise NotImplementedError("Advanced indexing setitem is not currently supported")