Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chan_vese: pass all constants to _fused_variance_kernel2 as device scalars #764

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
apply review suggestion
Co-authored-by: jakirkham <jakirkham@gmail.com>
grlee77 and jakirkham committed Aug 21, 2024
commit 2f6a5aa9cb5486a21fa3276bcd20fba3b835f0b2
10 changes: 5 additions & 5 deletions python/cucim/src/cucim/skimage/segmentation/_chan_vese.py
Original file line number Diff line number Diff line change
@@ -7,8 +7,6 @@
from .._shared.utils import _supported_float_type
from .._vendored import pad

_one = cp.asarray(1.0, dtype=cp.float32)


@cp.fuse()
def _fused_variance_kernel1(eta, x_start, x_mid, x_end, y_start, y_mid, y_end):
@@ -55,7 +53,7 @@ def _fused_hphi_hinv(phi):

@cp.fuse()
def _fused_variance_kernel2(
image, c1, c2, lam1, lam2, phi, K, dt, mu, delta_phi, Csum, one
image, c1, c2, lam1, lam2, phi, K, dt, mu, delta_phi, Csum
):
difference_term = image - c1
difference_term *= difference_term
@@ -67,7 +65,9 @@ def _fused_variance_kernel2(
difference_term += term2

new_phi = phi + (dt * delta_phi) * (mu * K + difference_term)
out = new_phi / (one + mu * dt * delta_phi * Csum)
out_denom = mu * dt * delta_phi * Csum
out_denom += out_denom.dtype.type(1)
out = new_phi / out_denom
return out


@@ -107,7 +107,7 @@ def _cv_calculate_variation(image, phi, mu, lambda1, lambda2, dt):
c1, c2 = _cv_calculate_averages(image, Hphi, Hinv)
delta_phi = _cv_delta(phi)
out = _fused_variance_kernel2(
image, c1, c2, lambda1, lambda2, phi, K, dt, mu, delta_phi, Csum, _one
image, c1, c2, lambda1, lambda2, phi, K, dt, mu, delta_phi, Csum
)
return out