Skip to content

Support for anisotropic data. #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 80 additions & 36 deletions dash_3d_viewer/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,25 @@
from dash.dependencies import Input, Output, State, ALL
from dash_core_components import Graph, Slider, Store

from .utils import img_array_to_uri, get_thumbnail_size_from_shape
from .utils import img_array_to_uri, get_thumbnail_size_from_shape, shape3d_to_size2d


class DashVolumeSlicer:
"""A slicer to show 3D image data in Dash.

Parameters:
app (dash.Dash): the Dash application instance.
volume (ndarray): the 3D numpy array to slice through.
volume (ndarray): the 3D numpy array to slice through. The dimensions
are assumed to be in zyx order. If this is not the case, you can
use ``np.swapaxes`` to make it so.
spacing (tuple of floats): The distance between voxels for each dimension (zyx).
The spacing and origin are applied to make the slice drawn in
"scene space" rather than "voxel space".
origin (tuple of floats): The offset for each dimension (zyx).
axis (int): the dimension to slice in. Default 0.
reverse_y (bool): Whether to reverse the y-axis, so that the origin of
the slice is in the top-left, rather than bottom-left. Default True.
(This sets the figure's yaxes ``autorange`` to either "reversed" or True.)
Comment on lines +23 to +25
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also added this.

scene_id (str): the scene that this slicer is part of. Slicers
that have the same scene-id show each-other's positions with
line indicators. By default this is a hash of ``id(volume)``.
Expand All @@ -38,14 +47,29 @@ class DashVolumeSlicer:

_global_slicer_counter = 0

def __init__(self, app, volume, axis=0, scene_id=None):
def __init__(
self,
app,
volume,
*,
spacing=None,
origin=None,
axis=0,
reverse_y=True,
scene_id=None
):
# todo: also implement xyz dim order?
if not isinstance(app, Dash):
raise TypeError("Expect first arg to be a Dash app.")
self._app = app
# Check and store volume
if not (isinstance(volume, np.ndarray) and volume.ndim == 3):
raise TypeError("Expected volume to be a 3D numpy array")
self._volume = volume
spacing = (1, 1, 1) if spacing is None else spacing
spacing = float(spacing[0]), float(spacing[1]), float(spacing[2])
origin = (0, 0, 0) if origin is None else origin
origin = float(origin[0]), float(origin[1]), float(origin[2])
# Check and store axis
if not (isinstance(axis, int) and 0 <= axis <= 2):
raise ValueError("The given axis must be 0, 1, or 2.")
Expand All @@ -60,20 +84,26 @@ def __init__(self, app, volume, axis=0, scene_id=None):
DashVolumeSlicer._global_slicer_counter += 1
self.context_id = "slicer_" + str(DashVolumeSlicer._global_slicer_counter)

# Get the slice size (width, height), and max index
arr_shape = list(volume.shape)
arr_shape.pop(self._axis)
self._slice_size = tuple(reversed(arr_shape))
self._max_index = self._volume.shape[self._axis] - 1
# Prepare slice info
info = {
"shape": tuple(volume.shape),
"axis": self._axis,
"size": shape3d_to_size2d(volume.shape, axis),
"origin": shape3d_to_size2d(origin, axis),
"spacing": shape3d_to_size2d(spacing, axis),
}

# Prep low-res slices
thumbnail_size = get_thumbnail_size_from_shape(arr_shape, 32)
thumbnail_size = get_thumbnail_size_from_shape(
(info["size"][1], info["size"][0]), 32
)
thumbnails = [
img_array_to_uri(self._slice(i), thumbnail_size)
for i in range(self._max_index + 1)
for i in range(info["size"][2])
]
info["lowres_size"] = thumbnail_size

# Create a placeholder trace
# Create traces
# todo: can add "%{z[0]}", but that would be the scaled value ...
image_trace = Image(
source="", dx=1, dy=1, hovertemplate="(%{x}, %{y})<extra></extra>"
Expand All @@ -97,6 +127,7 @@ def __init__(self, app, volume, axis=0, scene_id=None):
scaleanchor="x",
showticklabels=False,
zeroline=False,
autorange="reversed" if reverse_y else True,
)
# Wrap the figure in a graph
# todo: or should the user provide this?
Expand All @@ -106,22 +137,20 @@ def __init__(self, app, volume, axis=0, scene_id=None):
config={"scrollZoom": True},
)
# Create a slider object that the user can put in the layout (or not)
# todo: use tooltip to show current value?
self.slider = Slider(
id=self._subid("slider"),
min=0,
max=self._max_index,
max=info["size"][2] - 1,
step=1,
value=self._max_index // 2,
value=info["size"][2] // 2,
tooltip={"always_visible": False, "placement": "left"},
updatemode="drag",
)
# Create the stores that we need (these must be present in the layout)
self.stores = [
Store(
id=self._subid("_slice-size"), data=self._slice_size + thumbnail_size
),
Store(id=self._subid("info"), data=info),
Store(id=self._subid("index"), data=volume.shape[self._axis] // 2),
Store(id=self._subid("position"), data=0),
Store(id=self._subid("_requested-slice-index"), data=0),
Store(id=self._subid("_slice-data"), data=""),
Store(id=self._subid("_slice-data-lowres"), data=thumbnails),
Expand Down Expand Up @@ -175,6 +204,17 @@ def _create_client_callbacks(self):
[Input(self._subid("slider"), "value")],
)

app.clientside_callback(
"""
function update_position(index, info) {
return info.origin[2] + index * info.spacing[2];
}
""",
Output(self._subid("position"), "data"),
[Input(self._subid("index"), "data")],
[State(self._subid("info"), "data")],
)

app.clientside_callback(
"""
function handle_slice_index(index) {
Expand Down Expand Up @@ -205,7 +245,7 @@ def _create_client_callbacks(self):

app.clientside_callback(
"""
function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, slice_size) {
function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, info) {
let new_index = index_and_data[0];
let new_data = index_and_data[1];
// Store data in cache
Expand All @@ -214,18 +254,18 @@ def _create_client_callbacks(self):
slice_cache[new_index] = new_data;
// Get the data we need *now*
let data = slice_cache[index];
let x0 = 0, y0 = 0, dx = 1, dy = 1;
let x0 = info.origin[0], y0 = info.origin[1];
let dx = info.spacing[0], dy = info.spacing[1];
//slice_cache[new_index] = undefined; // todo: disabled cache for now!
// Maybe we do not need an update
if (!data) {
data = lowres[index];
// Scale the image to take the exact same space as the full-res
// version. It's not correct, but it looks better ...
// slice_size = full_w, full_h, low_w, low_h
dx = slice_size[0] / slice_size[2];
dy = slice_size[1] / slice_size[3];
x0 = 0.5 * dx - 0.5;
y0 = 0.5 * dy - 0.5;
dx *= info.size[0] / info.lowres_size[0];
dy *= info.size[1] / info.lowres_size[1];
x0 += 0.5 * dx - 0.5 * info.spacing[0];
y0 += 0.5 * dy - 0.5 * info.spacing[1];
}
if (data == ori_figure.data[0].source && indicators.version == ori_figure.data[1].version) {
return window.dash_clientside.no_update;
Expand Down Expand Up @@ -253,7 +293,7 @@ def _create_client_callbacks(self):
[
State(self._subid("graph"), "figure"),
State(self._subid("_slice-data-lowres"), "data"),
State(self._subid("_slice-size"), "data"),
State(self._subid("info"), "data"),
],
)

Expand All @@ -266,18 +306,22 @@ def _create_client_callbacks(self):
# * match any of the selected axii
app.clientside_callback(
"""
function handle_indicator(indices1, indices2, slice_size, current) {
let w = slice_size[0], h = slice_size[1];
let dx = w / 20, dy = h / 20;
function handle_indicator(positions1, positions2, info, current) {
let x0 = info.origin[0], y0 = info.origin[1];
let x1 = x0 + info.size[0] * info.spacing[0], y1 = y0 + info.size[1] * info.spacing[1];
x0 = x0 - info.spacing[0], y0 = y0 - info.spacing[1];
let d = ((x1 - x0) + (y1 - y0)) * 0.5 * 0.05;
let version = (current.version || 0) + 1;
let x = [], y = [];
for (let index of indices1) {
x.push(...[-dx, -1, null, w, w + dx, null]);
y.push(...[index, index, index, index, index, index]);
for (let pos of positions1) {
// x relative to our slice, y in scene-coords
x.push(...[x0 - d, x0, null, x1, x1 + d, null]);
y.push(...[pos, pos, pos, pos, pos, pos]);
}
for (let index of indices2) {
x.push(...[index, index, index, index, index, index]);
y.push(...[-dy, -1, null, h, h + dy, null]);
for (let pos of positions2) {
// x in scene-coords, y relative to our slice
x.push(...[pos, pos, pos, pos, pos, pos]);
y.push(...[y0 - d, y0, null, y1, y1 + d, null]);
}
return {
type: 'scatter',
Expand All @@ -296,15 +340,15 @@ def _create_client_callbacks(self):
{
"scene": self.scene_id,
"context": ALL,
"name": "index",
"name": "position",
"axis": axis,
},
"data",
)
for axis in axii
],
[
State(self._subid("_slice-size"), "data"),
State(self._subid("info"), "data"),
State(self._subid("_indicators"), "data"),
],
)
11 changes: 11 additions & 0 deletions dash_3d_viewer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,14 @@ def get_thumbnail_size_from_shape(shape, base_size):
img_pil = PIL.Image.fromarray(img_array)
img_pil.thumbnail((base_size, base_size))
return img_pil.size


def shape3d_to_size2d(shape, axis):
"""Turn a 3d shape (z, y, x) into a local (x', y', z'),
where z' represents the dimension indicated by axis.
"""
shape = list(shape)
axis_value = shape.pop(axis)
size = list(reversed(shape))
size.append(axis_value)
return tuple(size)
31 changes: 22 additions & 9 deletions examples/slicer_with_1_plus_2_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
This demonstrates how multiple indicators can be shown per axis.

Sharing the same scene_id is enough for the slicers to show each-others
position. If the same volume object is given, it works by default,
position. If the same volume object would be given, it works by default,
because the default scene_id is a hash of the volume object. Specifying
a scene_id provides slice position indicators even when slicing through
different volumes.

Further, this example has one slider showing data with different spacing.
Note how the indicators represent the actual position in "scene coordinates".

"""

import dash
Expand All @@ -17,22 +21,33 @@

app = dash.Dash(__name__)

vol = imageio.volread("imageio:stent.npz")
slicer1 = DashVolumeSlicer(app, vol, axis=1, scene_id="myscene")
slicer2 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene")
slicer3 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene")
vol1 = imageio.volread("imageio:stent.npz")

vol2 = vol1[::3, ::2, :]
spacing = 3, 2, 1
ori = 110, 120, 140


slicer1 = DashVolumeSlicer(
app, vol1, axis=1, origin=ori, reverse_y=False, scene_id="scene1"
)
slicer2 = DashVolumeSlicer(
app, vol1, axis=0, origin=ori, reverse_y=False, scene_id="scene1"
)
slicer3 = DashVolumeSlicer(
app, vol2, axis=0, origin=ori, spacing=spacing, reverse_y=False, scene_id="scene1"
)

app.layout = html.Div(
style={
"display": "grid",
"grid-template-columns": "40% 40%",
"gridTemplateColumns": "40% 40%",
},
children=[
html.Div(
[
html.H1("Coronal"),
slicer1.graph,
html.Br(),
slicer1.slider,
*slicer1.stores,
]
Expand All @@ -41,7 +56,6 @@
[
html.H1("Transversal 1"),
slicer2.graph,
html.Br(),
slicer2.slider,
*slicer2.stores,
]
Expand All @@ -51,7 +65,6 @@
[
html.H1("Transversal 2"),
slicer3.graph,
html.Br(),
slicer3.slider,
*slicer3.stores,
]
Expand Down
2 changes: 1 addition & 1 deletion examples/slicer_with_2_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
app.layout = html.Div(
style={
"display": "grid",
"grid-template-columns": "40% 40%",
"gridTemplateColumns": "40% 40%",
},
children=[
html.Div(
Expand Down
8 changes: 4 additions & 4 deletions examples/slicer_with_3_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

# Read volumes and create slicer objects
vol = imageio.volread("imageio:stent.npz")
slicer1 = DashVolumeSlicer(app, vol, axis=0)
slicer2 = DashVolumeSlicer(app, vol, axis=1)
slicer3 = DashVolumeSlicer(app, vol, axis=2)
slicer1 = DashVolumeSlicer(app, vol, reverse_y=False, axis=0)
slicer2 = DashVolumeSlicer(app, vol, reverse_y=False, axis=1)
slicer3 = DashVolumeSlicer(app, vol, reverse_y=False, axis=2)

# Calculate isosurface and create a figure with a mesh object
verts, faces, _, _ = marching_cubes(vol, 300, step_size=2)
Expand All @@ -30,7 +30,7 @@
app.layout = html.Div(
style={
"display": "grid",
"grid-template-columns": "40% 40%",
"gridTemplateColumns": "40% 40%",
},
children=[
html.Div(
Expand Down
14 changes: 14 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dash_3d_viewer.utils import shape3d_to_size2d

from pytest import raises


def test_shape3d_to_size2d():
# shape -> z, y, x
# size -> x, y, out-of-plane
assert shape3d_to_size2d((12, 13, 14), 0) == (14, 13, 12)
assert shape3d_to_size2d((12, 13, 14), 1) == (14, 12, 13)
assert shape3d_to_size2d((12, 13, 14), 2) == (13, 12, 14)

with raises(IndexError):
shape3d_to_size2d((12, 13, 14), 3)