Skip to content

Commit

Permalink
updated denoise3d to add optional Gaussian filter postprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
tbepler committed Jul 20, 2020
1 parent efb9e4a commit 11f7ecb
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 37 deletions.
58 changes: 21 additions & 37 deletions topaz/commands/denoise3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import topaz.mrc as mrc
import topaz.cuda

from topaz.denoise import UDenoiseNet3D
from topaz.denoise import UDenoiseNet3D, GaussianDenoise3d

name = 'denoise3d'
help = 'denoise 3D volumes with various denoising algorithms'
Expand Down Expand Up @@ -59,6 +59,7 @@ def add_arguments(parser):


## denoising parameters
parser.add_argument('-g', '--gaussian', type=float, default=0, help='standard deviation of Gaussian filter postprocessing, 0 means no postprocessing (default: 0)')
parser.add_argument('-s', '--patch-size', type=int, default=96, help='denoises volumes in patches of this size. not used if <1 (default: 96)')
parser.add_argument('-p', '--patch-padding', type=int, default=48, help='padding around each patch to remove edge artifacts (default: 48)')

Expand Down Expand Up @@ -400,34 +401,8 @@ def train_model(even_path, odd_path, save_prefix, save_interval, device

# initialize the model
print('# initializing model...', file=log)
model = UDenoiseNet3D(base_width=base_kernel_width)

# set the device or devices
d = device
use_cuda = (d != -1) and torch.cuda.is_available()
num_devices = 1
if use_cuda:
device_count = torch.cuda.device_count()
try:
if d >= 0:
assert d < device_count
torch.cuda.set_device(d)
print('# using CUDA device:', d, file=log)
elif d == -2:
print('# using all available CUDA devices:', device_count, file=log)
model = nn.DataParallel(model)
num_devices = device_count
else:
raise ValueError
except (AssertionError, ValueError):
print('ERROR: Invalid device id or format', file=log)
sys.exit(1)
except Exception:
print('ERROR: Something went wrong with setting the compute device', file=log)
sys.exit(2)

if use_cuda:
model.cuda()
model_base = UDenoiseNet3D(base_width=base_kernel_width)
model,use_cuda,num_devices = set_device(model_base, device)

if cost_func == 'L2':
cost_func = nn.MSELoss()
Expand Down Expand Up @@ -520,7 +495,7 @@ def train_model(even_path, odd_path, save_prefix, save_interval, device
print("# ending time: {:02d}/{:02d}/{:04d} {:02d}h:{:02d}m:{:02d}s".format(now.month,now.day,now.year,now.hour,now.minute,now.second), file=log)
print("# total time:", time.strftime("%Hh:%Mm:%Ss", time.gmtime(end_time - start_time)), file=log)

return model, num_devices
return model_base, num_devices


def save_model(model, epoch, save_prefix, digits=3):
Expand All @@ -532,7 +507,7 @@ def save_model(model, epoch, save_prefix, digits=3):
torch.save(model, path)


def load_model(path, device, base_kernel_width=11):
def load_model(path, base_kernel_width=11):
from collections import OrderedDict
log = sys.stderr

Expand Down Expand Up @@ -570,7 +545,10 @@ def load_model(path, device, base_kernel_width=11):
model.load_state_dict(state)
model.eval()

return model


def set_device(model, device, log=sys.stderr):
# set the device or devices
d = device
use_cuda = (d != -1) and torch.cuda.is_available()
Expand Down Expand Up @@ -598,7 +576,7 @@ def load_model(path, device, base_kernel_width=11):
if use_cuda:
model.cuda()

return model, num_devices
return model, use_cuda, num_devices


class PatchDataset:
Expand Down Expand Up @@ -666,7 +644,6 @@ def denoise(model, path, outdir, patch_size=128, padding=128, batch_size=1

mu = tomo.mean()
std = tomo.std()

# denoise in patches
d = next(iter(model.parameters())).device
denoised = np.zeros_like(tomo)
Expand All @@ -690,7 +667,8 @@ def denoise(model, path, outdir, patch_size=128, padding=128, batch_size=1
x = x.unsqueeze(1) # batch x channel

# denoise
x = model(x).squeeze(1).cpu().numpy()
x = model(x)
x = x.squeeze(1).cpu().numpy()

# stitch into denoised volume
for b in range(len(x)):
Expand Down Expand Up @@ -751,8 +729,15 @@ def main(args):

if len(args.volumes) > 0: # tomograms to denoise!
if model is None: # need to load model
model,num_devices = load_model(args.model, args.device
, base_kernel_width=args.base_kernel_width)
model = load_model(args.model, base_kernel_width=args.base_kernel_width)

gaussian_sigma = args.gaussian
if gaussian_sigma > 0:
print('# apply Gaussian filter postprocessing with sigma={}'.format(gaussian_sigma), file=sys.stderr)
model = nn.Sequential(model, GaussianDenoise3d(gaussian_sigma))
model.eval()

model, use_cuda, num_devices = set_device(model, args.device)

#batch_size = args.batch_size
#batch_size *= num_devices
Expand All @@ -761,7 +746,6 @@ def main(args):
patch_size = args.patch_size
padding = args.patch_padding
print('# denoising with patch size={} and padding={}'.format(patch_size, padding), file=sys.stderr)

# denoise the volumes
total = len(args.volumes)
count = 0
Expand Down
40 changes: 40 additions & 0 deletions topaz/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ def gaussian_filter(sigma, s=11):
return f


def gaussian_filter_3d(sigma, s=11):
dim = s//2
xx,yy,zz = np.meshgrid(np.arange(-dim, dim+1), np.arange(-dim, dim+1), np.arange(-dim,dim+1))
d = xx**2 + yy**2 + zz**2
f = np.exp(-0.5*d/sigma**2)
return f


def inverse_filter(w):
F = np.fft.rfft2(np.fft.ifftshift(w))
F = np.fft.fftshift(np.fft.irfft2(1/F, s=w.shape))
Expand Down Expand Up @@ -325,6 +333,21 @@ def forward(self, x):
return self.filter(x)


class GaussianDenoise3d(nn.Module):
def __init__(self, sigma, scale=5):
super(GaussianDenoise3d, self).__init__()
width = 1 + 2*int(np.ceil(sigma*scale))
f = gaussian_filter_3d(sigma, s=width)
f /= f.sum()

self.filter = nn.Conv3d(1, 1, width, padding=width//2)
self.filter.weight.data[:] = torch.from_numpy(f).float()
self.filter.bias.data.zero_()

def forward(self, x):
return self.filter(x)


class InvGaussianFilter(nn.Module):
def __init__(self, sigma, scale=5):
super(InvGaussianFilter, self).__init__()
Expand Down Expand Up @@ -1365,6 +1388,23 @@ def gaussian(x, sigma=1, scale=5, use_cuda=False):
y = f(x).squeeze().cpu().numpy()
return y


def gaussian3d(x, sigma=1, scale=5, use_cuda=False):
"""
Apply Gaussian filter with sigma to volume. Truncates the kernel at scale times sigma pixels
"""

f = GaussianDenoise3d(sigma, scale=scale)
if use_cuda:
f.cuda()

with torch.no_grad():
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0)
if use_cuda:
x = x.cuda()
y = f(x).squeeze().cpu().numpy()
return y




0 comments on commit 11f7ecb

Please sign in to comment.