Skip to content

Commit b32784a

Browse files
committed
Support for anisotropic data.
1 parent cebbc83 commit b32784a

File tree

6 files changed

+111
-46
lines changed

6 files changed

+111
-46
lines changed

dash_3d_viewer/slicer.py

Lines changed: 66 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dash.dependencies import Input, Output, State, ALL
55
from dash_core_components import Graph, Slider, Store
66

7-
from .utils import img_array_to_uri, get_thumbnail_size_from_shape
7+
from .utils import img_array_to_uri, get_thumbnail_size_from_shape, shape3d_to_size2d
88

99

1010
class DashVolumeSlicer:
@@ -13,6 +13,11 @@ class DashVolumeSlicer:
1313
Parameters:
1414
app (dash.Dash): the Dash application instance.
1515
volume (ndarray): the 3D numpy array to slice through.
16+
The dimensions are assumed to be in zyx order.
17+
spacing (tuple of floats): The voxel size for each dimension (zyx).
18+
The spacing and origin are applied to make the slice drawn in
19+
"scene space" rather than "voxel space".
20+
origin (tuple of floats): The offset for each dimension (zyx).
1621
axis (int): the dimension to slice in. Default 0.
1722
scene_id (str): the scene that this slicer is part of. Slicers
1823
that have the same scene-id show each-other's positions with
@@ -38,14 +43,21 @@ class DashVolumeSlicer:
3843

3944
_global_slicer_counter = 0
4045

41-
def __init__(self, app, volume, axis=0, scene_id=None):
46+
def __init__(
47+
self, app, volume, *, spacing=None, origin=None, axis=0, scene_id=None
48+
):
49+
# todo: also implement xyz dim order?
4250
if not isinstance(app, Dash):
4351
raise TypeError("Expect first arg to be a Dash app.")
4452
self._app = app
4553
# Check and store volume
4654
if not (isinstance(volume, np.ndarray) and volume.ndim == 3):
4755
raise TypeError("Expected volume to be a 3D numpy array")
4856
self._volume = volume
57+
spacing = (1, 1, 1) if spacing is None else spacing
58+
spacing = float(spacing[0]), float(spacing[1]), float(spacing[2])
59+
origin = (0, 0, 0) if origin is None else origin
60+
origin = float(origin[0]), float(origin[1]), float(origin[2])
4961
# Check and store axis
5062
if not (isinstance(axis, int) and 0 <= axis <= 2):
5163
raise ValueError("The given axis must be 0, 1, or 2.")
@@ -60,20 +72,26 @@ def __init__(self, app, volume, axis=0, scene_id=None):
6072
DashVolumeSlicer._global_slicer_counter += 1
6173
self.context_id = "slicer_" + str(DashVolumeSlicer._global_slicer_counter)
6274

63-
# Get the slice size (width, height), and max index
64-
arr_shape = list(volume.shape)
65-
arr_shape.pop(self._axis)
66-
self._slice_size = tuple(reversed(arr_shape))
67-
self._max_index = self._volume.shape[self._axis] - 1
75+
# Prepare slice info
76+
info = {
77+
"shape": tuple(volume.shape),
78+
"axis": self._axis,
79+
"size": shape3d_to_size2d(volume.shape, axis),
80+
"origin": shape3d_to_size2d(origin, axis),
81+
"spacing": shape3d_to_size2d(spacing, axis),
82+
}
6883

6984
# Prep low-res slices
70-
thumbnail_size = get_thumbnail_size_from_shape(arr_shape, 32)
85+
thumbnail_size = get_thumbnail_size_from_shape(
86+
(info["size"][1], info["size"][0]), 32
87+
)
7188
thumbnails = [
7289
img_array_to_uri(self._slice(i), thumbnail_size)
73-
for i in range(self._max_index + 1)
90+
for i in range(info["size"][2])
7491
]
92+
info["lowres_size"] = thumbnail_size
7593

76-
# Create a placeholder trace
94+
# Create traces
7795
# todo: can add "%{z[0]}", but that would be the scaled value ...
7896
image_trace = Image(
7997
source="", dx=1, dy=1, hovertemplate="(%{x}, %{y})<extra></extra>"
@@ -106,22 +124,20 @@ def __init__(self, app, volume, axis=0, scene_id=None):
106124
config={"scrollZoom": True},
107125
)
108126
# Create a slider object that the user can put in the layout (or not)
109-
# todo: use tooltip to show current value?
110127
self.slider = Slider(
111128
id=self._subid("slider"),
112129
min=0,
113-
max=self._max_index,
130+
max=info["size"][2] - 1,
114131
step=1,
115-
value=self._max_index // 2,
132+
value=info["size"][2] // 2,
116133
tooltip={"always_visible": False, "placement": "left"},
117134
updatemode="drag",
118135
)
119136
# Create the stores that we need (these must be present in the layout)
120137
self.stores = [
121-
Store(
122-
id=self._subid("_slice-size"), data=self._slice_size + thumbnail_size
123-
),
138+
Store(id=self._subid("info"), data=info),
124139
Store(id=self._subid("index"), data=volume.shape[self._axis] // 2),
140+
Store(id=self._subid("position"), data=0),
125141
Store(id=self._subid("_requested-slice-index"), data=0),
126142
Store(id=self._subid("_slice-data"), data=""),
127143
Store(id=self._subid("_slice-data-lowres"), data=thumbnails),
@@ -175,6 +191,17 @@ def _create_client_callbacks(self):
175191
[Input(self._subid("slider"), "value")],
176192
)
177193

194+
app.clientside_callback(
195+
"""
196+
function update_position(index, info) {
197+
return info.origin[2] + index * info.spacing[2];
198+
}
199+
""",
200+
Output(self._subid("position"), "data"),
201+
[Input(self._subid("index"), "data")],
202+
[State(self._subid("info"), "data")],
203+
)
204+
178205
app.clientside_callback(
179206
"""
180207
function handle_slice_index(index) {
@@ -205,7 +232,7 @@ def _create_client_callbacks(self):
205232

206233
app.clientside_callback(
207234
"""
208-
function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, slice_size) {
235+
function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, info) {
209236
let new_index = index_and_data[0];
210237
let new_data = index_and_data[1];
211238
// Store data in cache
@@ -214,18 +241,18 @@ def _create_client_callbacks(self):
214241
slice_cache[new_index] = new_data;
215242
// Get the data we need *now*
216243
let data = slice_cache[index];
217-
let x0 = 0, y0 = 0, dx = 1, dy = 1;
244+
let x0 = info.origin[0], y0 = info.origin[1];
245+
let dx = info.spacing[0], dy = info.spacing[1];
218246
//slice_cache[new_index] = undefined; // todo: disabled cache for now!
219247
// Maybe we do not need an update
220248
if (!data) {
221249
data = lowres[index];
222250
// Scale the image to take the exact same space as the full-res
223251
// version. It's not correct, but it looks better ...
224-
// slice_size = full_w, full_h, low_w, low_h
225-
dx = slice_size[0] / slice_size[2];
226-
dy = slice_size[1] / slice_size[3];
227-
x0 = 0.5 * dx - 0.5;
228-
y0 = 0.5 * dy - 0.5;
252+
dx *= info.size[0] / info.lowres_size[0];
253+
dy *= info.size[1] / info.lowres_size[1];
254+
x0 += 0.5 * dx - 0.5 * info.spacing[0];
255+
y0 += 0.5 * dy - 0.5 * info.spacing[1];
229256
}
230257
if (data == ori_figure.data[0].source && indicators.version == ori_figure.data[1].version) {
231258
return window.dash_clientside.no_update;
@@ -253,7 +280,7 @@ def _create_client_callbacks(self):
253280
[
254281
State(self._subid("graph"), "figure"),
255282
State(self._subid("_slice-data-lowres"), "data"),
256-
State(self._subid("_slice-size"), "data"),
283+
State(self._subid("info"), "data"),
257284
],
258285
)
259286

@@ -266,18 +293,22 @@ def _create_client_callbacks(self):
266293
# * match any of the selected axii
267294
app.clientside_callback(
268295
"""
269-
function handle_indicator(indices1, indices2, slice_size, current) {
270-
let w = slice_size[0], h = slice_size[1];
271-
let dx = w / 20, dy = h / 20;
296+
function handle_indicator(positions1, positions2, info, current) {
297+
let x0 = info.origin[0], y0 = info.origin[1];
298+
let x1 = x0 + info.size[0] * info.spacing[0], y1 = y0 + info.size[1] * info.spacing[1];
299+
x0 = x0 - info.spacing[0], y0 = y0 - info.spacing[1];
300+
let d = ((x1 - x0) + (y1 - y0)) * 0.5 * 0.05;
272301
let version = (current.version || 0) + 1;
273302
let x = [], y = [];
274-
for (let index of indices1) {
275-
x.push(...[-dx, -1, null, w, w + dx, null]);
276-
y.push(...[index, index, index, index, index, index]);
303+
for (let pos of positions1) {
304+
// x relative to our slice, y in scene-coords
305+
x.push(...[x0 - d, x0, null, x1, x1 + d, null]);
306+
y.push(...[pos, pos, pos, pos, pos, pos]);
277307
}
278-
for (let index of indices2) {
279-
x.push(...[index, index, index, index, index, index]);
280-
y.push(...[-dy, -1, null, h, h + dy, null]);
308+
for (let pos of positions2) {
309+
// x in scene-coords, y relative to our slice
310+
x.push(...[pos, pos, pos, pos, pos, pos]);
311+
y.push(...[y0 - d, y0, null, y1, y1 + d, null]);
281312
}
282313
return {
283314
type: 'scatter',
@@ -296,15 +327,15 @@ def _create_client_callbacks(self):
296327
{
297328
"scene": self.scene_id,
298329
"context": ALL,
299-
"name": "index",
330+
"name": "position",
300331
"axis": axis,
301332
},
302333
"data",
303334
)
304335
for axis in axii
305336
],
306337
[
307-
State(self._subid("_slice-size"), "data"),
338+
State(self._subid("info"), "data"),
308339
State(self._subid("_indicators"), "data"),
309340
],
310341
)

dash_3d_viewer/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,14 @@ def get_thumbnail_size_from_shape(shape, base_size):
3232
img_pil = PIL.Image.fromarray(img_array)
3333
img_pil.thumbnail((base_size, base_size))
3434
return img_pil.size
35+
36+
37+
def shape3d_to_size2d(shape, axis):
38+
"""Turn a 3d shape (z, y, x) into a local (x', y', z'),
39+
where z' represents the dimension indicated by axis.
40+
"""
41+
shape = list(shape)
42+
axis_value = shape.pop(axis)
43+
size = list(reversed(shape))
44+
size.append(axis_value)
45+
return tuple(size)

examples/slicer_with_1_plus_2_views.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
This demonstrates how multiple indicators can be shown per axis.
44
55
Sharing the same scene_id is enough for the slicers to show each-others
6-
position. If the same volume object is given, it works by default,
6+
position. If the same volume object would be given, it works by default,
77
because the default scene_id is a hash of the volume object. Specifying
88
a scene_id provides slice position indicators even when slicing through
99
different volumes.
10+
11+
Further, this example has one slider showing data with different spacing.
12+
Note how the indicators represent the actual position in "scene coordinates".
13+
1014
"""
1115

1216
import dash
@@ -17,22 +21,29 @@
1721

1822
app = dash.Dash(__name__)
1923

20-
vol = imageio.volread("imageio:stent.npz")
21-
slicer1 = DashVolumeSlicer(app, vol, axis=1, scene_id="myscene")
22-
slicer2 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene")
23-
slicer3 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene")
24+
vol1 = imageio.volread("imageio:stent.npz")
25+
26+
vol2 = vol1[::3, ::2, :]
27+
spacing = 3, 2, 1
28+
origin = 110, 120, 140
29+
30+
31+
slicer1 = DashVolumeSlicer(app, vol1, axis=1, origin=origin, scene_id="myscene")
32+
slicer2 = DashVolumeSlicer(app, vol1, axis=0, origin=origin, scene_id="myscene")
33+
slicer3 = DashVolumeSlicer(
34+
app, vol2, axis=0, origin=origin, spacing=spacing, scene_id="myscene"
35+
)
2436

2537
app.layout = html.Div(
2638
style={
2739
"display": "grid",
28-
"grid-template-columns": "40% 40%",
40+
"gridTemplateColumns": "40% 40%",
2941
},
3042
children=[
3143
html.Div(
3244
[
3345
html.H1("Coronal"),
3446
slicer1.graph,
35-
html.Br(),
3647
slicer1.slider,
3748
*slicer1.stores,
3849
]
@@ -41,7 +52,6 @@
4152
[
4253
html.H1("Transversal 1"),
4354
slicer2.graph,
44-
html.Br(),
4555
slicer2.slider,
4656
*slicer2.stores,
4757
]
@@ -51,7 +61,6 @@
5161
[
5262
html.H1("Transversal 2"),
5363
slicer3.graph,
54-
html.Br(),
5564
slicer3.slider,
5665
*slicer3.stores,
5766
]

examples/slicer_with_2_views.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
app.layout = html.Div(
1818
style={
1919
"display": "grid",
20-
"grid-template-columns": "40% 40%",
20+
"gridTemplateColumns": "40% 40%",
2121
},
2222
children=[
2323
html.Div(

examples/slicer_with_3_views.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
app.layout = html.Div(
3131
style={
3232
"display": "grid",
33-
"grid-template-columns": "40% 40%",
33+
"gridTemplateColumns": "40% 40%",
3434
},
3535
children=[
3636
html.Div(

tests/test_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from dash_3d_viewer.utils import shape3d_to_size2d
2+
3+
from pytest import raises
4+
5+
6+
def test_shape3d_to_size2d():
7+
# shape -> z, y, x
8+
# size -> x, y, out-of-plane
9+
assert shape3d_to_size2d((12, 13, 14), 0) == (14, 13, 12)
10+
assert shape3d_to_size2d((12, 13, 14), 1) == (14, 12, 13)
11+
assert shape3d_to_size2d((12, 13, 14), 2) == (13, 12, 14)
12+
13+
with raises(IndexError):
14+
shape3d_to_size2d((12, 13, 14), 3)

0 commit comments

Comments
 (0)