Skip to content

Commit

Permalink
Fix: fix device array handling in NoiseModels (#288)
Browse files Browse the repository at this point in the history
related to #285, missed commit

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: federico-carrara <[email protected]>
Co-authored-by: Joran Deschamps <[email protected]>
  • Loading branch information
4 people authored Dec 5, 2024
1 parent b4fa28f commit 9241452
Showing 1 changed file with 29 additions and 20 deletions.
49 changes: 29 additions & 20 deletions src/careamics/models/lvae/noise_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Optional

import numpy as np
Expand Down Expand Up @@ -140,7 +141,7 @@ def train_gm_noise_model(
# TODO any training params ? Different channels ?
noise_model = GaussianMixtureNoiseModel(model_config)
# TODO revisit config unpacking
noise_model.train_nm(signal, observation)
noise_model.fit(signal, observation)
return noise_model


Expand Down Expand Up @@ -280,6 +281,7 @@ def __init__(self, config: GaussianMixtureNMConfig):

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if config.path is None:
self.mode = "train"
# TODO this is (probably) to train a nm. We leave it for later refactoring
weight = config.weight
n_gaussian = config.n_gaussian
Expand All @@ -304,14 +306,12 @@ def __init__(self, config: GaussianMixtureNMConfig):
# TODO refactor to train on CPU!
else:
params = np.load(config.path)
# self.device = kwargs.get('device')
self.mode = "inference" # TODO better name?

self.min_signal = torch.Tensor(params["min_signal"])
self.max_signal = torch.Tensor(params["max_signal"])

self.weight = torch.nn.Parameter(
torch.Tensor(params["trained_weight"]), requires_grad=False
)
self.weight = torch.Tensor(params["trained_weight"])
self.min_sigma = params["min_sigma"].item()
self.n_gaussian = self.weight.shape[0] // 3 # TODO why // 3 ?
self.n_coeff = self.weight.shape[1]
Expand All @@ -321,12 +321,6 @@ def __init__(self, config: GaussianMixtureNMConfig):

print(f"[{self.__class__.__name__}] min_sigma: {self.min_sigma}")

def set_tolerance(self, tol):
"""Sets the tolerance for the likelihood evaluation."""
print("Setting tolerance to: ", tol)
self.tol = torch.Tensor([tol]).to(self.device)
# self.maxval = 0

def polynomialRegressor(self, weightParams, signals):
"""Combines `weightParams` and signal `signals` to regress for the gaussian parameter values.
Expand All @@ -337,6 +331,7 @@ def polynomialRegressor(self, weightParams, signals):
signals : torch.cuda.FloatTensor
Signals
Returns
-------
value : torch.cuda.FloatTensor
Expand All @@ -360,13 +355,13 @@ def normalDens(self, x, m_=0.0, std_=None):
Mean
std_: torch.cuda.FloatTensor
Standard-deviation
Returns
-------
tmp: torch.cuda.FloatTensor
Normal probability density of `x` given `m_` and `std_`
"""

tmp = -((x - m_) ** 2)
tmp = tmp / (2.0 * std_ * std_)
tmp = torch.exp(tmp)
Expand All @@ -383,12 +378,21 @@ def likelihood(self, observations, signals):
Noisy observations
signals : torch.cuda.FloatTensor
Underlying signals
Returns
-------
value :p + self.tol
Likelihood of observations given the signals and the GMM noise model
"""
if self.mode != "train":
signals = signals.cpu()
observations = observations.cpu()
self.weight = self.weight.to(signals.device)
self.min_signal = self.min_signal.to(signals.device)
self.max_signal = self.max_signal.to(signals.device)
self.tol = self.tol.to(signals.device)

gaussianParameters = self.getGaussianParameters(signals)
p = 0
for gaussian in range(self.n_gaussian):
Expand All @@ -409,11 +413,11 @@ def getGaussianParameters(self, signals):
----------
signals : torch.cuda.FloatTensor
Underlying signals
Returns
-------
noiseModel: list of torch.cuda.FloatTensor
Contains a list of `mu`, `sigma` and `alpha` for the `signals`
"""
noiseModel = []
mu = []
Expand All @@ -429,14 +433,10 @@ def getGaussianParameters(self, signals):
sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma)
sigma.append(torch.sqrt(sigmaTemp))

# expval = torch.exp(
# torch.clamp(
# self.polynomialRegressor(self.weight[2 * kernels + num, :], signals) + self.tol, MAX_ALPHA_W))
expval = torch.exp(
self.polynomialRegressor(self.weight[2 * kernels + num, :], signals)
+ self.tol
)
# self.maxval = max(self.maxval, expval.max().item())
alpha.append(expval)

sum_alpha = 0
Expand All @@ -452,7 +452,6 @@ def getGaussianParameters(self, signals):
for ker in range(kernels):
sum_means = alpha[ker] * mu[ker] + sum_means

mu_shifted = []
# subtracting the alpha weighted average of the means from the means
# ensures that the GMM has the inclination to have the mean=signals.
# its like a residual conection. I don't understand why we need to learn the mean?
Expand Down Expand Up @@ -504,7 +503,7 @@ def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip):
]
return fastShuffle(sig_obs_pairs, 2)

def train_nm(
def fit(
self,
signal,
observation,
Expand Down Expand Up @@ -589,8 +588,18 @@ def train_nm(
print("===================\n")

def save(self, path: str, name: str):
"""Save the trained parameters on the noise model.
Parameters
----------
path : str
Path to save the trained parameters.
name : str
File name to save the trained parameters.
"""
os.makedirs(path, exist_ok=True)
np.savez(
path + name,
os.path.join(path, name),
trained_weight=self.trained_weight,
min_signal=self.min_signal,
max_signal=self.max_signal,
Expand Down

0 comments on commit 9241452

Please sign in to comment.