Skip to content

Commit

Permalink
remove add in pad when padding with non-zero
Browse files Browse the repository at this point in the history
this is simpler before shapetracker is shared
  • Loading branch information
chenyuxyz committed Nov 16, 2024
1 parent 22da31b commit 6a10483
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ def pad(self, padding:Union[Sequence[sint], Sequence[Optional[Tuple[sint, sint]]
pX = ((0,0),)*(self.ndim - len(padding)//2) + tuple(zip(padding[-2::-2], padding[::-2])) if flat else padding
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
X, pX = self, cast(Tuple[Tuple[sint, sint]], tuple((0,0) if p is None else p for p in pX))
def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0, v)
def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(Tensor.ones_like(x), arg=px).where(F.Pad.apply(x, arg=px), v)
# early return for symbolic with positive pads (no need to max)
if all(resolve(p >= 0) for p in flatten(pX)): return _constant(X, pX, value)
pads, shrinks = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX), tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, self.shape))
Expand Down

0 comments on commit 6a10483

Please sign in to comment.