Skip to content

Commit dcd0a6a

Browse files
authored
Implement indicators (#4)
* Implement indicators * rename volume-id -> scene-id, and clarify example
1 parent ed9deeb commit dcd0a6a

File tree

5 files changed

+226
-36
lines changed

5 files changed

+226
-36
lines changed

dash_3d_viewer/slicer.py

Lines changed: 148 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,47 @@
11
import numpy as np
2-
from plotly.graph_objects import Figure, Image
2+
from plotly.graph_objects import Figure, Image, Scatter
33
from dash import Dash
4-
from dash.dependencies import Input, Output, State
4+
from dash.dependencies import Input, Output, State, ALL
55
from dash_core_components import Graph, Slider, Store
66

7-
from .utils import gen_random_id, img_array_to_uri
7+
from .utils import img_array_to_uri, get_thumbnail_size_from_shape
88

99

1010
class DashVolumeSlicer:
11-
"""A slicer to show 3D image data in Dash."""
11+
"""A slicer to show 3D image data in Dash.
1212
13-
def __init__(self, app, volume, axis=0, id=None):
13+
Parameters:
14+
app (dash.Dash): the Dash application instance.
15+
volume (ndarray): the 3D numpy array to slice through.
16+
axis (int): the dimension to slice in. Default 0.
17+
scene_id (str): the scene that this slicer is part of. Slicers
18+
that have the same scene-id show each-other's positions with
19+
line indicators. By default this is a hash of ``id(volume)``.
20+
21+
This is a placeholder object, not a Dash component. The components
22+
that make up the slicer can be accessed as attributes:
23+
24+
* ``graph``: the Graph object.
25+
* ``slider``: the Slider object.
26+
* ``stores``: a list of Store objects. Some are "public" values, others
27+
used internally. Make sure to put them somewhere in the layout.
28+
29+
Each component is given a dict-id with the following keys:
30+
31+
* "context": a unique string id for this slicer instance.
32+
* "scene": the scene_id.
33+
* "axis": the int axis.
34+
* "name": the name of the (sub) component.
35+
36+
TODO: iron out these details, list the stores that are public
37+
"""
38+
39+
_global_slicer_counter = 0
40+
41+
def __init__(self, app, volume, axis=0, scene_id=None):
1442
if not isinstance(app, Dash):
1543
raise TypeError("Expect first arg to be a Dash app.")
44+
self._app = app
1645
# Check and store volume
1746
if not (isinstance(volume, np.ndarray) and volume.ndim == 3):
1847
raise TypeError("Expected volume to be a 3D numpy array")
@@ -22,29 +51,36 @@ def __init__(self, app, volume, axis=0, id=None):
2251
raise ValueError("The given axis must be 0, 1, or 2.")
2352
self._axis = int(axis)
2453
# Check and store id
25-
if id is None:
26-
id = gen_random_id()
27-
elif not isinstance(id, str):
28-
raise TypeError("Id must be a string")
29-
self._id = id
54+
if scene_id is None:
55+
scene_id = "volume_" + hex(id(volume))[2:]
56+
elif not isinstance(scene_id, str):
57+
raise TypeError("scene_id must be a string")
58+
self.scene_id = scene_id
59+
# Get unique id scoped to this slicer object
60+
DashVolumeSlicer._global_slicer_counter += 1
61+
self.context_id = "slicer_" + str(DashVolumeSlicer._global_slicer_counter)
3062

3163
# Get the slice size (width, height), and max index
32-
# arr_shape = list(volume.shape)
33-
# arr_shape.pop(self._axis)
34-
# slice_size = list(reversed(arr_shape))
64+
arr_shape = list(volume.shape)
65+
arr_shape.pop(self._axis)
66+
self._slice_size = tuple(reversed(arr_shape))
3567
self._max_index = self._volume.shape[self._axis] - 1
3668

3769
# Prep low-res slices
70+
thumbnail_size = get_thumbnail_size_from_shape(arr_shape, 32)
3871
thumbnails = [
39-
img_array_to_uri(self._slice(i), (32, 32))
72+
img_array_to_uri(self._slice(i), thumbnail_size)
4073
for i in range(self._max_index + 1)
4174
]
4275

4376
# Create a placeholder trace
4477
# todo: can add "%{z[0]}", but that would be the scaled value ...
45-
trace = Image(source="", hovertemplate="(%{x}, %{y})<extra></extra>")
78+
image_trace = Image(
79+
source="", dx=1, dy=1, hovertemplate="(%{x}, %{y})<extra></extra>"
80+
)
81+
scatter_trace = Scatter(x=[], y=[]) # placeholder
4682
# Create the figure object
47-
fig = Figure(data=[trace])
83+
self._fig = fig = Figure(data=[image_trace, scatter_trace])
4884
fig.update_layout(
4985
template=None,
5086
margin=dict(l=0, r=0, b=0, t=0, pad=4),
@@ -70,6 +106,7 @@ def __init__(self, app, volume, axis=0, id=None):
70106
config={"scrollZoom": True},
71107
)
72108
# Create a slider object that the user can put in the layout (or not)
109+
# todo: use tooltip to show current value?
73110
self.slider = Slider(
74111
id=self._subid("slider"),
75112
min=0,
@@ -81,18 +118,29 @@ def __init__(self, app, volume, axis=0, id=None):
81118
)
82119
# Create the stores that we need (these must be present in the layout)
83120
self.stores = [
84-
Store(id=self._subid("slice-index"), data=volume.shape[self._axis] // 2),
121+
Store(
122+
id=self._subid("_slice-size"), data=self._slice_size + thumbnail_size
123+
),
124+
Store(id=self._subid("index"), data=volume.shape[self._axis] // 2),
85125
Store(id=self._subid("_requested-slice-index"), data=0),
86126
Store(id=self._subid("_slice-data"), data=""),
87127
Store(id=self._subid("_slice-data-lowres"), data=thumbnails),
128+
Store(id=self._subid("_indicators"), data=[]),
88129
]
89130

90-
self._create_server_callbacks(app)
91-
self._create_client_callbacks(app)
131+
self._create_server_callbacks()
132+
self._create_client_callbacks()
92133

93-
def _subid(self, subid):
134+
def _subid(self, name):
94135
"""Given a subid, get the full id including the slicer's prefix."""
95-
return self._id + "-" + subid
136+
# return self.context_id + "-" + name
137+
# todo: is there a penalty for using a dict-id vs a string-id?
138+
return {
139+
"context": self.context_id,
140+
"scene": self.scene_id,
141+
"axis": self._axis,
142+
"name": name,
143+
}
96144

97145
def _slice(self, index):
98146
"""Sample a slice from the volume."""
@@ -101,8 +149,9 @@ def _slice(self, index):
101149
im = self._volume[tuple(indices)]
102150
return (im.astype(np.float32) * (255 / im.max())).astype(np.uint8)
103151

104-
def _create_server_callbacks(self, app):
152+
def _create_server_callbacks(self):
105153
"""Create the callbacks that run server-side."""
154+
app = self._app
106155

107156
@app.callback(
108157
Output(self._subid("_slice-data"), "data"),
@@ -112,16 +161,17 @@ def upload_requested_slice(slice_index):
112161
slice = self._slice(slice_index)
113162
return [slice_index, img_array_to_uri(slice)]
114163

115-
def _create_client_callbacks(self, app):
164+
def _create_client_callbacks(self):
116165
"""Create the callbacks that run client-side."""
166+
app = self._app
117167

118168
app.clientside_callback(
119169
"""
120170
function handle_slider_move(index) {
121171
return index;
122172
}
123173
""",
124-
Output(self._subid("slice-index"), "data"),
174+
Output(self._subid("index"), "data"),
125175
[Input(self._subid("slider"), "value")],
126176
)
127177

@@ -138,24 +188,24 @@ def _create_client_callbacks(self, app):
138188
}
139189
}
140190
""".replace(
141-
"{{ID}}", self._id
191+
"{{ID}}", self.context_id
142192
),
143193
Output(self._subid("_requested-slice-index"), "data"),
144-
[Input(self._subid("slice-index"), "data")],
194+
[Input(self._subid("index"), "data")],
145195
)
146196

147197
# app.clientside_callback("""
148198
# function update_slider_pos(index) {
149199
# return index;
150200
# }
151201
# """,
152-
# [Output("slice-index", "data")],
202+
# [Output("index", "data")],
153203
# [State("slider", "value")],
154204
# )
155205

156206
app.clientside_callback(
157207
"""
158-
function handle_incoming_slice(index, index_and_data, ori_figure, lowres) {
208+
function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, slice_size) {
159209
let new_index = index_and_data[0];
160210
let new_data = index_and_data[1];
161211
// Store data in cache
@@ -164,30 +214,97 @@ def _create_client_callbacks(self, app):
164214
slice_cache[new_index] = new_data;
165215
// Get the data we need *now*
166216
let data = slice_cache[index];
217+
let x0 = 0, y0 = 0, dx = 1, dy = 1;
167218
//slice_cache[new_index] = undefined; // todo: disabled cache for now!
168219
// Maybe we do not need an update
169220
if (!data) {
170221
data = lowres[index];
222+
// Scale the image to take the exact same space as the full-res
223+
// version. It's not correct, but it looks better ...
224+
// slice_size = full_w, full_h, low_w, low_h
225+
dx = slice_size[0] / slice_size[2];
226+
dy = slice_size[1] / slice_size[3];
227+
x0 = 0.5 * dx - 0.5;
228+
y0 = 0.5 * dy - 0.5;
171229
}
172-
if (data == ori_figure.data[0].source) {
230+
if (data == ori_figure.data[0].source && indicators.version == ori_figure.data[1].version) {
173231
return window.dash_clientside.no_update;
174232
}
175233
// Otherwise, perform update
176234
console.log("updating figure");
177235
let figure = {...ori_figure};
178236
figure.data[0].source = data;
237+
figure.data[0].x0 = x0;
238+
figure.data[0].y0 = y0;
239+
figure.data[0].dx = dx;
240+
figure.data[0].dy = dy;
241+
figure.data[1] = indicators;
179242
return figure;
180243
}
181244
""".replace(
182-
"{{ID}}", self._id
245+
"{{ID}}", self.context_id
183246
),
184247
Output(self._subid("graph"), "figure"),
185248
[
186-
Input(self._subid("slice-index"), "data"),
249+
Input(self._subid("index"), "data"),
187250
Input(self._subid("_slice-data"), "data"),
251+
Input(self._subid("_indicators"), "data"),
188252
],
189253
[
190254
State(self._subid("graph"), "figure"),
191255
State(self._subid("_slice-data-lowres"), "data"),
256+
State(self._subid("_slice-size"), "data"),
257+
],
258+
)
259+
260+
# Select the *other* axii
261+
axii = [0, 1, 2]
262+
axii.pop(self._axis)
263+
264+
# Create a callback to create a trace representing all slice-indices that:
265+
# * corresponding to the same volume data
266+
# * match any of the selected axii
267+
app.clientside_callback(
268+
"""
269+
function handle_indicator(indices1, indices2, slice_size, current) {
270+
let w = slice_size[0], h = slice_size[1];
271+
let dx = w / 20, dy = h / 20;
272+
let version = (current.version || 0) + 1;
273+
let x = [], y = [];
274+
for (let index of indices1) {
275+
x.push(...[-dx, -1, null, w, w + dx, null]);
276+
y.push(...[index, index, index, index, index, index]);
277+
}
278+
for (let index of indices2) {
279+
x.push(...[index, index, index, index, index, index]);
280+
y.push(...[-dy, -1, null, h, h + dy, null]);
281+
}
282+
return {
283+
type: 'scatter',
284+
mode: 'lines',
285+
line: {color: '#ff00aa'},
286+
x: x,
287+
y: y,
288+
hoverinfo: 'skip',
289+
version: version
290+
};
291+
}
292+
""",
293+
Output(self._subid("_indicators"), "data"),
294+
[
295+
Input(
296+
{
297+
"scene": self.scene_id,
298+
"context": ALL,
299+
"name": "index",
300+
"axis": axis,
301+
},
302+
"data",
303+
)
304+
for axis in axii
305+
],
306+
[
307+
State(self._subid("_slice-size"), "data"),
308+
State(self._subid("_indicators"), "data"),
192309
],
193310
)

dash_3d_viewer/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import random
33
import base64
44

5+
import numpy as np
56
import PIL.Image
67
import skimage
78

@@ -23,3 +24,11 @@ def img_array_to_uri(img_array, new_size=None):
2324
img_pil.save(f, format="PNG")
2425
base64_str = base64.b64encode(f.getvalue()).decode()
2526
return "data:image/png;base64," + base64_str
27+
28+
29+
def get_thumbnail_size_from_shape(shape, base_size):
30+
base_size = int(base_size)
31+
img_array = np.zeros(shape, np.uint8)
32+
img_pil = PIL.Image.fromarray(img_array)
33+
img_pil.thumbnail((base_size, base_size))
34+
return img_pil.size
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
An example with two slicers at the same axis, and one on another axis.
3+
This demonstrates how multiple indicators can be shown per axis.
4+
5+
Sharing the same scene_id is enough for the slicers to show each-others
6+
position. If the same volume object is given, it works by default,
7+
because the default scene_id is a hash of the volume object. Specifying
8+
a scene_id provides slice position indicators even when slicing through
9+
different volumes.
10+
"""
11+
12+
import dash
13+
import dash_html_components as html
14+
from dash_3d_viewer import DashVolumeSlicer
15+
import imageio
16+
17+
18+
app = dash.Dash(__name__)
19+
20+
vol = imageio.volread("imageio:stent.npz")
21+
slicer1 = DashVolumeSlicer(app, vol, axis=1, scene_id="myscene")
22+
slicer2 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene")
23+
slicer3 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene")
24+
25+
app.layout = html.Div(
26+
style={
27+
"display": "grid",
28+
"grid-template-columns": "40% 40%",
29+
},
30+
children=[
31+
html.Div(
32+
[
33+
html.H1("Coronal"),
34+
slicer1.graph,
35+
html.Br(),
36+
slicer1.slider,
37+
*slicer1.stores,
38+
]
39+
),
40+
html.Div(
41+
[
42+
html.H1("Transversal 1"),
43+
slicer2.graph,
44+
html.Br(),
45+
slicer2.slider,
46+
*slicer2.stores,
47+
]
48+
),
49+
html.Div(),
50+
html.Div(
51+
[
52+
html.H1("Transversal 2"),
53+
slicer3.graph,
54+
html.Br(),
55+
slicer3.slider,
56+
*slicer3.stores,
57+
]
58+
),
59+
],
60+
)
61+
62+
63+
if __name__ == "__main__":
64+
app.run_server(debug=True)

0 commit comments

Comments
 (0)