diff --git a/skrample/common.py b/skrample/common.py index 2df6746..4f9adaa 100644 --- a/skrample/common.py +++ b/skrample/common.py @@ -113,7 +113,8 @@ def exp[T: Sample](x: T) -> T: def sigmoid[T: Sample](array: T) -> T: - return 1 / (1 + exp(-array)) # type: ignore + arrexp: T = exp(array) + return arrexp / (1 + arrexp) # type: ignore def softmax[T: tuple[Sample, ...]](elems: T) -> T: diff --git a/skrample/pytorch/noise.py b/skrample/pytorch/noise.py index 7da69e3..9e81e47 100644 --- a/skrample/pytorch/noise.py +++ b/skrample/pytorch/noise.py @@ -9,7 +9,7 @@ def schedule_to_ramp(schedule: NDArray[np.float64]) -> NDArray[np.float64]: - return np.concatenate([[0], np.flip(schedule[:, 1])]) + return np.concatenate([schedule[:, 1], [0]]) @dataclass(frozen=True) @@ -212,13 +212,8 @@ def generate(self) -> torch.Tensor: return noise / noise.std() # Scaled back to roughly unit variance -@dataclass(frozen=True) -class BrownianProps(TensorNoiseProps): - reverse: bool = False - - @dataclass -class Brownian(TensorNoiseCommon[BrownianProps]): +class Brownian(TensorNoiseCommon[None]): """Uses torchsde.BrownianInterval to generate noise along a fixed timestep. generate() will raise StopIteration at the end of the ramp.""" @@ -241,23 +236,17 @@ def __post_init__(self) -> None: self._step: int = 0 # Basic sanitization to normalize 0->1 + if self.ramp[0] > self.ramp[-1]: + self.ramp = -self.ramp self.ramp -= self.ramp.min() self.ramp /= self.ramp.max() - if self.ramp[0] > self.ramp[-1]: - self.ramp = np.flip(self.ramp) def generate(self) -> torch.Tensor: if self._step + 1 >= len(self.ramp): raise StopIteration - if self.props.reverse: - # - 2 because you still get the next sequentiall - step = len(self.ramp) - self._step - 2 - else: - step = self._step - - sigma = self.ramp[step] - sigma_next = self.ramp[step + 1] + sigma = self.ramp[self._step] + sigma_next = self.ramp[self._step + 1] self._step += 1 return self._tree(sigma, sigma_next) / abs(sigma_next - sigma) ** 0.5 @@ -267,7 +256,7 @@ def from_inputs( cls, shape: tuple[int, ...], seed: torch.Generator, - props: BrownianProps = BrownianProps(), + props: None = None, dtype: torch.dtype = torch.float32, ramp: NDArray[np.float64] = np.linspace(0, 1, 1000, dtype=np.float64), ) -> Self: diff --git a/tests/miscellaneous.py b/tests/miscellaneous.py index 14ce907..dc6d4ac 100644 --- a/tests/miscellaneous.py +++ b/tests/miscellaneous.py @@ -185,7 +185,10 @@ def test_bashforth() -> None: def test_sigmoid() -> None: - assert abs(torch.sigmoid(torch.tensor(1.5, dtype=torch.float64)).item() - sigmoid(1.5)) < 1e-12 + items = spowf(torch.linspace(-2, 2, 9, dtype=torch.float64), 2) + a = torch.sigmoid(items) + b = sigmoid(items) + assert torch.allclose(a, b, rtol=0, atol=1e-12), (a.tolist(), b.tolist()) def test_softmax() -> None: