Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use minimum spanning tree to simplify junctions #135

Merged
merged 2 commits into from
Sep 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 94 additions & 5 deletions skan/csr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from enum import Enum
import warnings

import numpy as np
import pandas as pd
from scipy import sparse, ndimage as ndi
Expand All @@ -10,6 +13,20 @@
from .summary_utils import find_main_branches


class JunctionModes(Enum):
"""Modes for cleaning up junctions in skeletons.

NONE: the junctions are left as is. In skan < 0.9, this is equavalent
to unique_junctions=False.
Centroid: junctions are consolidated into the centroid of the contributing nodes.
In skan < 0.9, this is equivalent to unique_junctions=True.
MST: junctions are replaced with the minimum spanning tree.
"""
NONE='none'
Centroid='centroid'
MST='mst'


## NBGraph and Numba-based implementation

csr_spec = [
Expand Down Expand Up @@ -321,10 +338,12 @@ class Skeleton:
"""
def __init__(self, skeleton_image, *, spacing=1, source_image=None,
_buffer_size_offset=None, keep_images=True,
unique_junctions=True):
junction_mode=JunctionModes.MST,
unique_junctions=None):
graph, coords = skeleton_to_csgraph(
skeleton_image,
spacing=spacing,
junction_mode=junction_mode,
unique_junctions=unique_junctions,
)
if np.issubdtype(skeleton_image.dtype, np.float_):
Expand Down Expand Up @@ -480,7 +499,10 @@ def prune_paths(self, indices) -> 'Skeleton':
# warning: slow
image_cp = np.copy(self.skeleton_image)
for i in indices:
coords_to_wipe = self.path_coordinates(i)
pixel_ids_to_wipe = self.path(i)
junctions = self.degrees[pixel_ids_to_wipe] > 2
pixel_ids_to_wipe = pixel_ids_to_wipe[~junctions]
coords_to_wipe = self.coordinates[pixel_ids_to_wipe]
coords_idxs = tuple(np.round(coords_to_wipe).astype(int).T)
image_cp[coords_idxs] = 0
# optional cleanup:
Expand Down Expand Up @@ -568,6 +590,49 @@ def _path_distance(graph, path):
return d


def _mst_junctions(csmat):
"""Replace clustered pixels with degree > 2 by their minimum spanning tree.

This function performs the operation in place.

Parameters
----------
csmat : NBGraph
The input graph.
pixel_indices : array of int
The raveled index in the image of every pixel represented in csmat.
spacing : float, or array-like of float, shape `len(shape)`, optional
The spacing between pixels in the source image along each dimension.

Returns
-------
final_graph : NBGraph
The output csmat.
"""
# make copy
# mask out all degree < 3 entries
# find MST
# replace edges not in MST with zeros
# use .eliminate_zeros() to get a new matrix
csc_graph = csmat.tocsc()
degrees = np.asarray(csmat.astype(bool).astype(int).sum(axis=0))
non_junction = np.flatnonzero(degrees < 3)
non_junction_column_start = csc_graph.indptr[non_junction]
non_junction_column_end = csc_graph.indptr[non_junction+1]
for start, end in zip(non_junction_column_start, non_junction_column_end):
csc_graph.data[start:end] = 0
csr_graph = csc_graph.tocsr()
non_junction_row_start = csr_graph.indptr[non_junction]
non_junction_row_end = csr_graph.indptr[non_junction+1]
for start, end in zip(non_junction_row_start, non_junction_row_end):
csr_graph.data[start:end] = 0
csr_graph.eliminate_zeros()
mst = csgraph.minimum_spanning_tree(csr_graph)
non_tree_edges = csr_graph - (mst + mst.T)
final_graph = csmat - non_tree_edges
return final_graph


def _uniquify_junctions(csmat, pixel_indices, junction_labels,
junction_centroids, *, spacing=1):
"""Replace clustered pixels with degree > 2 by a single "floating" pixel.
Expand Down Expand Up @@ -600,7 +665,8 @@ def _uniquify_junctions(csmat, pixel_indices, junction_labels,


def skeleton_to_csgraph(skel, *, spacing=1, value_is_height=False,
unique_junctions=True):
junction_mode=JunctionModes.MST,
unique_junctions=None):
"""Convert a skeleton image of thin lines to a graph of neighbor pixels.

Parameters
Expand All @@ -622,7 +688,14 @@ def skeleton_to_csgraph(skel, *, spacing=1, value_is_height=False,
considered to be a height measurement, and this height will be
incorporated into skeleton branch lengths. Used for analysis of
atomic force microscopy (AFM) images.
junction_mode : JunctionModes.{NONE,MST,CENTROID}
If NONE, junction pixels are not collapsed.
If MST, junction pixels are replaced by their minimum spanning tree,
resulting in a single junction pixel.
If CENTROID, junction pixels are collapsed to their centroid.
unique_junctions : bool, optional
**DEPRECATED**: Use junction_mode=JunctionModes.Centroid to get
behavior equivalent to ``unique_junctions=True``.
If True, adjacent junction nodes get collapsed into a single
conceptual node, with position at the centroid of all the connected
initial nodes.
Expand Down Expand Up @@ -662,7 +735,21 @@ def skeleton_to_csgraph(skel, *, spacing=1, value_is_height=False,
degree_image = ndi.convolve(skel.astype(int), degree_kernel,
mode='constant') * skel

if unique_junctions:
if unique_junctions is not None:
warnings.warn('unique junctions in deprecated, see junction_modes')
junction_mode = (
JunctionModes.Centroid if unique_junctions else JunctionModes.NONE
)

if not isinstance(junction_mode, JunctionModes):
try:
junction_mode = JunctionModes(junction_mode.lower())
except ValueError:
raise ValueError(f"{junction_mode} is an invalid junction_mode. Should be 'none', 'centroid', or 'mst'")
except AttributeError:
raise TypeError('junction_mode should be a string or a JunctionModes')

if junction_mode == JunctionModes.Centroid:
# group all connected junction nodes into "meganodes".
junctions = degree_image > 2
junction_ids = skelint[junctions]
Expand All @@ -679,9 +766,11 @@ def skeleton_to_csgraph(skel, *, spacing=1, value_is_height=False,
spacing=spacing)
graph = _pixel_graph(skelint, steps, distances, num_edges, height)

if unique_junctions:
if junction_mode == JunctionModes.Centroid:
_uniquify_junctions(graph, pixel_indices,
labeled_junctions, centroids, spacing=spacing)
elif junction_mode == JunctionModes.MST:
graph = _mst_junctions(graph)
return graph, pixel_indices


Expand Down
51 changes: 43 additions & 8 deletions skan/test/test_csr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os, sys
import numpy as np
from numpy.testing import assert_equal, assert_almost_equal
import pytest
from skan import csr
from skan.csr import JunctionModes

rundir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(rundir)
Expand All @@ -11,7 +13,7 @@


def test_tiny_cycle():
g, idxs = csr.skeleton_to_csgraph(tinycycle)
g, idxs = csr.skeleton_to_csgraph(tinycycle, junction_mode='centroid')
expected_indptr = [0, 0, 2, 4, 6, 8]
expected_indices = [2, 3, 1, 4, 1, 4, 2, 3]
expected_data = np.sqrt(2)
Expand All @@ -25,7 +27,7 @@ def test_tiny_cycle():


def test_skeleton1_stats():
g, idxs = csr.skeleton_to_csgraph(skeleton1)
g, idxs = csr.skeleton_to_csgraph(skeleton1, junction_mode='centroid')
stats = csr.branch_statistics(g)
assert_equal(stats.shape, (4, 4))
keys = map(tuple, stats[:, :2].astype(int))
Expand Down Expand Up @@ -60,20 +62,20 @@ def test_summarise_spacing():


def test_line():
g, idxs = csr.skeleton_to_csgraph(tinyline)
g, idxs = csr.skeleton_to_csgraph(tinyline, junction_mode='centroid')
assert_equal(np.ravel(idxs), [0, 1, 2, 3])
assert_equal(g.shape, (4, 4))
assert_equal(csr.branch_statistics(g), [[1, 3, 2, 0]])


def test_cycle_stats():
stats = csr.branch_statistics(csr.skeleton_to_csgraph(tinycycle)[0],
stats = csr.branch_statistics(csr.skeleton_to_csgraph(tinycycle, junction_mode='centroid')[0],
buffer_size_offset=1)
assert_almost_equal(stats, [[1, 1, 4*np.sqrt(2), 3]])


def test_3d_spacing():
g, idxs = csr.skeleton_to_csgraph(skeleton3d, spacing=[5, 1, 1])
g, idxs = csr.skeleton_to_csgraph(skeleton3d, spacing=[5, 1, 1], junction_mode='centroid')
stats = csr.branch_statistics(g)
assert_equal(stats.shape, (5, 4))
assert_almost_equal(stats[0], [1, 5, 10.467, 1], decimal=3)
Expand All @@ -82,7 +84,7 @@ def test_3d_spacing():

def test_topograph():
g, idxs = csr.skeleton_to_csgraph(topograph1d,
value_is_height=True)
value_is_height=True, junction_mode='centroid')
stats = csr.branch_statistics(g)
assert stats.shape == (1, 4)
assert_almost_equal(stats[0], [1, 3, 2 * np.sqrt(2), 0])
Expand All @@ -98,9 +100,9 @@ def test_topograph_summary():

def test_junction_multiplicity():
"""Test correct distances when a junction has more than one pixel."""
g, idxs = csr.skeleton_to_csgraph(skeleton0)
g, idxs = csr.skeleton_to_csgraph(skeleton0, junction_mode='centroid')
assert_almost_equal(g[3, 5], 2.0155644)
g, idxs = csr.skeleton_to_csgraph(skeleton0, unique_junctions=False)
g, idxs = csr.skeleton_to_csgraph(skeleton0, junction_mode='none')
assert_almost_equal(g[2, 3], 1.0)
assert_almost_equal(g[3, 6], np.sqrt(2))

Expand All @@ -124,3 +126,36 @@ def test_pixel_values():
def test_tip_junction_edges():
stats1 = csr.summarise(skeleton4)
assert stats1.shape[0] == 3 # ensure all three branches are counted


@pytest.mark.parametrize(
'mst_mode,none_mode',
[
('mst', 'none'),
('MST', 'NONE'),
('MsT', 'NoNe'),
(JunctionModes.MST, JunctionModes.NONE)
]
)
def test_mst_junctions(mst_mode, none_mode):
g, _ = csr.skeleton_to_csgraph(skeleton0, junction_mode=none_mode)
h = csr._mst_junctions(g)
hprime, _ = csr.skeleton_to_csgraph(skeleton0, junction_mode=mst_mode)

G = g.todense()
G[G > 1.1] = 0

np.testing.assert_equal(G, h.todense())
np.testing.assert_equal(G, hprime.todense())


def test_junction_mode_type_error():
with pytest.raises(TypeError):
"""Test that giving the wrong type of junction_mode raises a TypeError"""
g, _ = csr.skeleton_to_csgraph(skeleton0, junction_mode=4)


def test_junction_mode_value_error():
with pytest.raises(ValueError):
"""Test that giving an invalidjunction_mode raises a ValueError"""
g, _ = csr.skeleton_to_csgraph(skeleton0, junction_mode='not a mode')
11 changes: 1 addition & 10 deletions skan/test/test_prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,9 @@
from skan import Skeleton


@pytest.mark.parametrize('branch_num', [0])
def test_pruning(branch_num):
skeleton = Skeleton(skeleton0)
pruned = skeleton.prune_paths([branch_num])
print(pruned.skeleton_image.astype(int))
assert pruned.n_paths == 1


@pytest.mark.xfail
@pytest.mark.parametrize('branch_num', [0, 1, 2])
def test_pruning_comprehensive(branch_num):
skeleton = Skeleton(skeleton0)
pruned = skeleton.prune_paths([branch_num])
print(pruned.skeleton_image.astype(int))
assert pruned.n_paths == 1
assert pruned.n_paths == 1
26 changes: 13 additions & 13 deletions skan/test/test_skeleton_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@


def test_tiny_cycle():
skeleton = Skeleton(tinycycle)
skeleton = Skeleton(tinycycle, junction_mode='centroid')
assert skeleton.paths.shape == (1, 5)


def test_skeleton1_topo():
skeleton = Skeleton(skeleton1)
skeleton = Skeleton(skeleton1, junction_mode='centroid')
assert skeleton.paths.shape == (4, 21)
paths_list = skeleton.paths_list()
reference_paths = [
Expand All @@ -35,19 +35,19 @@ def test_skeleton1_topo():
def test_skeleton1_float():
image = np.zeros(skeleton1.shape, dtype=float)
image[skeleton1] = 1 + np.random.random(np.sum(skeleton1))
skeleton = Skeleton(image)
skeleton = Skeleton(image, junction_mode='centroid')
path, data = skeleton.path_with_data(0)
assert 1.0 < np.mean(data) < 2.0


def test_skeleton_coordinates():
skeleton = Skeleton(skeleton1)
skeleton = Skeleton(skeleton1, junction_mode='centroid')
last_path_coordinates = skeleton.path_coordinates(3)
assert_allclose(last_path_coordinates, [[3, 3], [4, 4], [4, 5], [4, 6]])


def test_path_length_caching():
skeleton = Skeleton(skeleton3d)
skeleton = Skeleton(skeleton3d, junction_mode='centroid')
t0 = process_time()
distances = skeleton.path_lengths()
t1 = process_time()
Expand All @@ -59,7 +59,7 @@ def test_path_length_caching():


def test_tip_junction_edges():
skeleton = Skeleton(skeleton4)
skeleton = Skeleton(skeleton4, junction_mode='centroid')
reference_paths = [[1, 2], [2, 4, 5], [2, 7]]
paths_list = skeleton.paths_list()
for path in reference_paths:
Expand All @@ -69,14 +69,14 @@ def test_tip_junction_edges():
def test_path_stdev():
image = np.zeros(skeleton1.shape, dtype=float)
image[skeleton1] = 1 + np.random.random(np.sum(skeleton1))
skeleton = Skeleton(image)
skeleton = Skeleton(image, junction_mode='centroid')
# longest_path should be 0, but could change.
longest_path = np.argmax(skeleton.path_lengths())
dev = skeleton.path_stdev()[longest_path]
assert 0.09 < dev < 0.44 # chance is < 1/10K that this will fail

# second check: first principles.
skeleton2 = Skeleton(image**2)
skeleton2 = Skeleton(image**2, junction_mode='centroid')
# (Var = StDev**2 = E(X**2) - (E(X))**2)
assert_allclose(skeleton.path_stdev()**2,
skeleton2.path_means() - skeleton.path_means()**2)
Expand All @@ -90,13 +90,13 @@ def test_junction_first():
before any of its adjacent branches. This turns out to be tricky to achieve
but not impossible in 2D.
"""
assert [1, 1] not in Skeleton(junction_first).paths_list()
assert [1, 1] not in Skeleton(junction_first, junction_mode='centroid').paths_list()


def test_skeleton_summarize():
image = np.zeros(skeleton2.shape, dtype=float)
image[skeleton2] = 1 + np.random.random(np.sum(skeleton2))
skeleton = Skeleton(image)
skeleton = Skeleton(image, junction_mode='centroid')
summary = summarize(skeleton)
assert set(summary['skeleton-id']) == {1, 2}
assert (np.all(summary['mean-pixel-value'] < 2)
Expand All @@ -115,7 +115,7 @@ def test_skeleton_label_image_strict():
This is expected to fail due to the current junction representation.
See: https://github.com/jni/skan/issues/133
"""
skeleton = Skeleton(skeleton4, unique_junctions=False)
skeleton = Skeleton(skeleton4, junction_mode='none')
label_image = np.asarray(skeleton)
expected = np.array([
[1, 0, 0, 0, 0],
Expand All @@ -133,10 +133,10 @@ def test_skeleton_label_image():
"""Simple test that the skeleton label image covers the same
pixels as the expected label image.
"""
skeleton = Skeleton(skeleton4, unique_junctions=False)
skeleton = Skeleton(skeleton4, junction_mode='none')
label_image = np.asarray(skeleton)
expected = np.array([
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[0, 1, 2, 2, 2],
[0, 3, 0, 0, 0],
[0, 3, 0, 0, 0],
Expand Down