Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion skrample/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def exp[T: Sample](x: T) -> T:


def sigmoid[T: Sample](array: T) -> T:
return 1 / (1 + exp(-array)) # type: ignore
arrexp: T = exp(array)
return arrexp / (1 + arrexp) # type: ignore


def softmax[T: tuple[Sample, ...]](elems: T) -> T:
Expand Down
25 changes: 7 additions & 18 deletions skrample/pytorch/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def schedule_to_ramp(schedule: NDArray[np.float64]) -> NDArray[np.float64]:
return np.concatenate([[0], np.flip(schedule[:, 1])])
return np.concatenate([schedule[:, 1], [0]])


@dataclass(frozen=True)
Expand Down Expand Up @@ -212,13 +212,8 @@ def generate(self) -> torch.Tensor:
return noise / noise.std() # Scaled back to roughly unit variance


@dataclass(frozen=True)
class BrownianProps(TensorNoiseProps):
reverse: bool = False


@dataclass
class Brownian(TensorNoiseCommon[BrownianProps]):
class Brownian(TensorNoiseCommon[None]):
"""Uses torchsde.BrownianInterval to generate noise along a fixed timestep.
generate() will raise StopIteration at the end of the ramp."""

Expand All @@ -241,23 +236,17 @@ def __post_init__(self) -> None:
self._step: int = 0

# Basic sanitization to normalize 0->1
if self.ramp[0] > self.ramp[-1]:
self.ramp = -self.ramp
self.ramp -= self.ramp.min()
self.ramp /= self.ramp.max()
if self.ramp[0] > self.ramp[-1]:
self.ramp = np.flip(self.ramp)

def generate(self) -> torch.Tensor:
if self._step + 1 >= len(self.ramp):
raise StopIteration

if self.props.reverse:
# - 2 because you still get the next sequentiall
step = len(self.ramp) - self._step - 2
else:
step = self._step

sigma = self.ramp[step]
sigma_next = self.ramp[step + 1]
sigma = self.ramp[self._step]
sigma_next = self.ramp[self._step + 1]
self._step += 1

return self._tree(sigma, sigma_next) / abs(sigma_next - sigma) ** 0.5
Expand All @@ -267,7 +256,7 @@ def from_inputs(
cls,
shape: tuple[int, ...],
seed: torch.Generator,
props: BrownianProps = BrownianProps(),
props: None = None,
dtype: torch.dtype = torch.float32,
ramp: NDArray[np.float64] = np.linspace(0, 1, 1000, dtype=np.float64),
) -> Self:
Expand Down
5 changes: 4 additions & 1 deletion tests/miscellaneous.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ def test_bashforth() -> None:


def test_sigmoid() -> None:
assert abs(torch.sigmoid(torch.tensor(1.5, dtype=torch.float64)).item() - sigmoid(1.5)) < 1e-12
items = spowf(torch.linspace(-2, 2, 9, dtype=torch.float64), 2)
a = torch.sigmoid(items)
b = sigmoid(items)
assert torch.allclose(a, b, rtol=0, atol=1e-12), (a.tolist(), b.tolist())


def test_softmax() -> None:
Expand Down
Loading