diff --git a/dash_3d_viewer/slicer.py b/dash_3d_viewer/slicer.py
index 13417c6..0a0f456 100644
--- a/dash_3d_viewer/slicer.py
+++ b/dash_3d_viewer/slicer.py
@@ -1,18 +1,47 @@
import numpy as np
-from plotly.graph_objects import Figure, Image
+from plotly.graph_objects import Figure, Image, Scatter
from dash import Dash
-from dash.dependencies import Input, Output, State
+from dash.dependencies import Input, Output, State, ALL
from dash_core_components import Graph, Slider, Store
-from .utils import gen_random_id, img_array_to_uri
+from .utils import img_array_to_uri, get_thumbnail_size_from_shape
class DashVolumeSlicer:
- """A slicer to show 3D image data in Dash."""
+ """A slicer to show 3D image data in Dash.
- def __init__(self, app, volume, axis=0, id=None):
+ Parameters:
+ app (dash.Dash): the Dash application instance.
+ volume (ndarray): the 3D numpy array to slice through.
+ axis (int): the dimension to slice in. Default 0.
+ scene_id (str): the scene that this slicer is part of. Slicers
+ that have the same scene-id show each-other's positions with
+ line indicators. By default this is a hash of ``id(volume)``.
+
+ This is a placeholder object, not a Dash component. The components
+ that make up the slicer can be accessed as attributes:
+
+ * ``graph``: the Graph object.
+ * ``slider``: the Slider object.
+ * ``stores``: a list of Store objects. Some are "public" values, others
+ used internally. Make sure to put them somewhere in the layout.
+
+ Each component is given a dict-id with the following keys:
+
+ * "context": a unique string id for this slicer instance.
+ * "scene": the scene_id.
+ * "axis": the int axis.
+ * "name": the name of the (sub) component.
+
+ TODO: iron out these details, list the stores that are public
+ """
+
+ _global_slicer_counter = 0
+
+ def __init__(self, app, volume, axis=0, scene_id=None):
if not isinstance(app, Dash):
raise TypeError("Expect first arg to be a Dash app.")
+ self._app = app
# Check and store volume
if not (isinstance(volume, np.ndarray) and volume.ndim == 3):
raise TypeError("Expected volume to be a 3D numpy array")
@@ -22,29 +51,36 @@ def __init__(self, app, volume, axis=0, id=None):
raise ValueError("The given axis must be 0, 1, or 2.")
self._axis = int(axis)
# Check and store id
- if id is None:
- id = gen_random_id()
- elif not isinstance(id, str):
- raise TypeError("Id must be a string")
- self._id = id
+ if scene_id is None:
+ scene_id = "volume_" + hex(id(volume))[2:]
+ elif not isinstance(scene_id, str):
+ raise TypeError("scene_id must be a string")
+ self.scene_id = scene_id
+ # Get unique id scoped to this slicer object
+ DashVolumeSlicer._global_slicer_counter += 1
+ self.context_id = "slicer_" + str(DashVolumeSlicer._global_slicer_counter)
# Get the slice size (width, height), and max index
- # arr_shape = list(volume.shape)
- # arr_shape.pop(self._axis)
- # slice_size = list(reversed(arr_shape))
+ arr_shape = list(volume.shape)
+ arr_shape.pop(self._axis)
+ self._slice_size = tuple(reversed(arr_shape))
self._max_index = self._volume.shape[self._axis] - 1
# Prep low-res slices
+ thumbnail_size = get_thumbnail_size_from_shape(arr_shape, 32)
thumbnails = [
- img_array_to_uri(self._slice(i), (32, 32))
+ img_array_to_uri(self._slice(i), thumbnail_size)
for i in range(self._max_index + 1)
]
# Create a placeholder trace
# todo: can add "%{z[0]}", but that would be the scaled value ...
- trace = Image(source="", hovertemplate="(%{x}, %{y})")
+ image_trace = Image(
+ source="", dx=1, dy=1, hovertemplate="(%{x}, %{y})"
+ )
+ scatter_trace = Scatter(x=[], y=[]) # placeholder
# Create the figure object
- fig = Figure(data=[trace])
+ self._fig = fig = Figure(data=[image_trace, scatter_trace])
fig.update_layout(
template=None,
margin=dict(l=0, r=0, b=0, t=0, pad=4),
@@ -70,6 +106,7 @@ def __init__(self, app, volume, axis=0, id=None):
config={"scrollZoom": True},
)
# Create a slider object that the user can put in the layout (or not)
+ # todo: use tooltip to show current value?
self.slider = Slider(
id=self._subid("slider"),
min=0,
@@ -81,18 +118,29 @@ def __init__(self, app, volume, axis=0, id=None):
)
# Create the stores that we need (these must be present in the layout)
self.stores = [
- Store(id=self._subid("slice-index"), data=volume.shape[self._axis] // 2),
+ Store(
+ id=self._subid("_slice-size"), data=self._slice_size + thumbnail_size
+ ),
+ Store(id=self._subid("index"), data=volume.shape[self._axis] // 2),
Store(id=self._subid("_requested-slice-index"), data=0),
Store(id=self._subid("_slice-data"), data=""),
Store(id=self._subid("_slice-data-lowres"), data=thumbnails),
+ Store(id=self._subid("_indicators"), data=[]),
]
- self._create_server_callbacks(app)
- self._create_client_callbacks(app)
+ self._create_server_callbacks()
+ self._create_client_callbacks()
- def _subid(self, subid):
+ def _subid(self, name):
"""Given a subid, get the full id including the slicer's prefix."""
- return self._id + "-" + subid
+ # return self.context_id + "-" + name
+ # todo: is there a penalty for using a dict-id vs a string-id?
+ return {
+ "context": self.context_id,
+ "scene": self.scene_id,
+ "axis": self._axis,
+ "name": name,
+ }
def _slice(self, index):
"""Sample a slice from the volume."""
@@ -101,8 +149,9 @@ def _slice(self, index):
im = self._volume[tuple(indices)]
return (im.astype(np.float32) * (255 / im.max())).astype(np.uint8)
- def _create_server_callbacks(self, app):
+ def _create_server_callbacks(self):
"""Create the callbacks that run server-side."""
+ app = self._app
@app.callback(
Output(self._subid("_slice-data"), "data"),
@@ -112,8 +161,9 @@ def upload_requested_slice(slice_index):
slice = self._slice(slice_index)
return [slice_index, img_array_to_uri(slice)]
- def _create_client_callbacks(self, app):
+ def _create_client_callbacks(self):
"""Create the callbacks that run client-side."""
+ app = self._app
app.clientside_callback(
"""
@@ -121,7 +171,7 @@ def _create_client_callbacks(self, app):
return index;
}
""",
- Output(self._subid("slice-index"), "data"),
+ Output(self._subid("index"), "data"),
[Input(self._subid("slider"), "value")],
)
@@ -138,10 +188,10 @@ def _create_client_callbacks(self, app):
}
}
""".replace(
- "{{ID}}", self._id
+ "{{ID}}", self.context_id
),
Output(self._subid("_requested-slice-index"), "data"),
- [Input(self._subid("slice-index"), "data")],
+ [Input(self._subid("index"), "data")],
)
# app.clientside_callback("""
@@ -149,13 +199,13 @@ def _create_client_callbacks(self, app):
# return index;
# }
# """,
- # [Output("slice-index", "data")],
+ # [Output("index", "data")],
# [State("slider", "value")],
# )
app.clientside_callback(
"""
- function handle_incoming_slice(index, index_and_data, ori_figure, lowres) {
+ function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, slice_size) {
let new_index = index_and_data[0];
let new_data = index_and_data[1];
// Store data in cache
@@ -164,30 +214,97 @@ def _create_client_callbacks(self, app):
slice_cache[new_index] = new_data;
// Get the data we need *now*
let data = slice_cache[index];
+ let x0 = 0, y0 = 0, dx = 1, dy = 1;
//slice_cache[new_index] = undefined; // todo: disabled cache for now!
// Maybe we do not need an update
if (!data) {
data = lowres[index];
+ // Scale the image to take the exact same space as the full-res
+ // version. It's not correct, but it looks better ...
+ // slice_size = full_w, full_h, low_w, low_h
+ dx = slice_size[0] / slice_size[2];
+ dy = slice_size[1] / slice_size[3];
+ x0 = 0.5 * dx - 0.5;
+ y0 = 0.5 * dy - 0.5;
}
- if (data == ori_figure.data[0].source) {
+ if (data == ori_figure.data[0].source && indicators.version == ori_figure.data[1].version) {
return window.dash_clientside.no_update;
}
// Otherwise, perform update
console.log("updating figure");
let figure = {...ori_figure};
figure.data[0].source = data;
+ figure.data[0].x0 = x0;
+ figure.data[0].y0 = y0;
+ figure.data[0].dx = dx;
+ figure.data[0].dy = dy;
+ figure.data[1] = indicators;
return figure;
}
""".replace(
- "{{ID}}", self._id
+ "{{ID}}", self.context_id
),
Output(self._subid("graph"), "figure"),
[
- Input(self._subid("slice-index"), "data"),
+ Input(self._subid("index"), "data"),
Input(self._subid("_slice-data"), "data"),
+ Input(self._subid("_indicators"), "data"),
],
[
State(self._subid("graph"), "figure"),
State(self._subid("_slice-data-lowres"), "data"),
+ State(self._subid("_slice-size"), "data"),
+ ],
+ )
+
+ # Select the *other* axii
+ axii = [0, 1, 2]
+ axii.pop(self._axis)
+
+ # Create a callback to create a trace representing all slice-indices that:
+ # * corresponding to the same volume data
+ # * match any of the selected axii
+ app.clientside_callback(
+ """
+ function handle_indicator(indices1, indices2, slice_size, current) {
+ let w = slice_size[0], h = slice_size[1];
+ let dx = w / 20, dy = h / 20;
+ let version = (current.version || 0) + 1;
+ let x = [], y = [];
+ for (let index of indices1) {
+ x.push(...[-dx, -1, null, w, w + dx, null]);
+ y.push(...[index, index, index, index, index, index]);
+ }
+ for (let index of indices2) {
+ x.push(...[index, index, index, index, index, index]);
+ y.push(...[-dy, -1, null, h, h + dy, null]);
+ }
+ return {
+ type: 'scatter',
+ mode: 'lines',
+ line: {color: '#ff00aa'},
+ x: x,
+ y: y,
+ hoverinfo: 'skip',
+ version: version
+ };
+ }
+ """,
+ Output(self._subid("_indicators"), "data"),
+ [
+ Input(
+ {
+ "scene": self.scene_id,
+ "context": ALL,
+ "name": "index",
+ "axis": axis,
+ },
+ "data",
+ )
+ for axis in axii
+ ],
+ [
+ State(self._subid("_slice-size"), "data"),
+ State(self._subid("_indicators"), "data"),
],
)
diff --git a/dash_3d_viewer/utils.py b/dash_3d_viewer/utils.py
index 68ab52c..583e8b2 100644
--- a/dash_3d_viewer/utils.py
+++ b/dash_3d_viewer/utils.py
@@ -2,6 +2,7 @@
import random
import base64
+import numpy as np
import PIL.Image
import skimage
@@ -23,3 +24,11 @@ def img_array_to_uri(img_array, new_size=None):
img_pil.save(f, format="PNG")
base64_str = base64.b64encode(f.getvalue()).decode()
return "data:image/png;base64," + base64_str
+
+
+def get_thumbnail_size_from_shape(shape, base_size):
+ base_size = int(base_size)
+ img_array = np.zeros(shape, np.uint8)
+ img_pil = PIL.Image.fromarray(img_array)
+ img_pil.thumbnail((base_size, base_size))
+ return img_pil.size
diff --git a/examples/slicer_with_1_plus_2_views.py b/examples/slicer_with_1_plus_2_views.py
new file mode 100644
index 0000000..312adf3
--- /dev/null
+++ b/examples/slicer_with_1_plus_2_views.py
@@ -0,0 +1,64 @@
+"""
+An example with two slicers at the same axis, and one on another axis.
+This demonstrates how multiple indicators can be shown per axis.
+
+Sharing the same scene_id is enough for the slicers to show each-others
+position. If the same volume object is given, it works by default,
+because the default scene_id is a hash of the volume object. Specifying
+a scene_id provides slice position indicators even when slicing through
+different volumes.
+"""
+
+import dash
+import dash_html_components as html
+from dash_3d_viewer import DashVolumeSlicer
+import imageio
+
+
+app = dash.Dash(__name__)
+
+vol = imageio.volread("imageio:stent.npz")
+slicer1 = DashVolumeSlicer(app, vol, axis=1, scene_id="myscene")
+slicer2 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene")
+slicer3 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene")
+
+app.layout = html.Div(
+ style={
+ "display": "grid",
+ "grid-template-columns": "40% 40%",
+ },
+ children=[
+ html.Div(
+ [
+ html.H1("Coronal"),
+ slicer1.graph,
+ html.Br(),
+ slicer1.slider,
+ *slicer1.stores,
+ ]
+ ),
+ html.Div(
+ [
+ html.H1("Transversal 1"),
+ slicer2.graph,
+ html.Br(),
+ slicer2.slider,
+ *slicer2.stores,
+ ]
+ ),
+ html.Div(),
+ html.Div(
+ [
+ html.H1("Transversal 2"),
+ slicer3.graph,
+ html.Br(),
+ slicer3.slider,
+ *slicer3.stores,
+ ]
+ ),
+ ],
+)
+
+
+if __name__ == "__main__":
+ app.run_server(debug=True)
diff --git a/examples/slicer_with_2_views.py b/examples/slicer_with_2_views.py
index 7913e2d..73b829c 100644
--- a/examples/slicer_with_2_views.py
+++ b/examples/slicer_with_2_views.py
@@ -11,8 +11,8 @@
app = dash.Dash(__name__)
vol = imageio.volread("imageio:stent.npz")
-slicer1 = DashVolumeSlicer(app, vol, axis=1, id="slicer1")
-slicer2 = DashVolumeSlicer(app, vol, axis=2, id="slicer2")
+slicer1 = DashVolumeSlicer(app, vol, axis=1)
+slicer2 = DashVolumeSlicer(app, vol, axis=2)
app.layout = html.Div(
style={
diff --git a/examples/slicer_with_3_views.py b/examples/slicer_with_3_views.py
index 93e3906..54b1a3f 100644
--- a/examples/slicer_with_3_views.py
+++ b/examples/slicer_with_3_views.py
@@ -15,9 +15,9 @@
# Read volumes and create slicer objects
vol = imageio.volread("imageio:stent.npz")
-slicer1 = DashVolumeSlicer(app, vol, axis=0, id="slicer1")
-slicer2 = DashVolumeSlicer(app, vol, axis=1, id="slicer2")
-slicer3 = DashVolumeSlicer(app, vol, axis=2, id="slicer3")
+slicer1 = DashVolumeSlicer(app, vol, axis=0)
+slicer2 = DashVolumeSlicer(app, vol, axis=1)
+slicer3 = DashVolumeSlicer(app, vol, axis=2)
# Calculate isosurface and create a figure with a mesh object
verts, faces, _, _ = marching_cubes(vol, 300, step_size=2)