diff --git a/dash_3d_viewer/slicer.py b/dash_3d_viewer/slicer.py index 13417c6..0a0f456 100644 --- a/dash_3d_viewer/slicer.py +++ b/dash_3d_viewer/slicer.py @@ -1,18 +1,47 @@ import numpy as np -from plotly.graph_objects import Figure, Image +from plotly.graph_objects import Figure, Image, Scatter from dash import Dash -from dash.dependencies import Input, Output, State +from dash.dependencies import Input, Output, State, ALL from dash_core_components import Graph, Slider, Store -from .utils import gen_random_id, img_array_to_uri +from .utils import img_array_to_uri, get_thumbnail_size_from_shape class DashVolumeSlicer: - """A slicer to show 3D image data in Dash.""" + """A slicer to show 3D image data in Dash. - def __init__(self, app, volume, axis=0, id=None): + Parameters: + app (dash.Dash): the Dash application instance. + volume (ndarray): the 3D numpy array to slice through. + axis (int): the dimension to slice in. Default 0. + scene_id (str): the scene that this slicer is part of. Slicers + that have the same scene-id show each-other's positions with + line indicators. By default this is a hash of ``id(volume)``. + + This is a placeholder object, not a Dash component. The components + that make up the slicer can be accessed as attributes: + + * ``graph``: the Graph object. + * ``slider``: the Slider object. + * ``stores``: a list of Store objects. Some are "public" values, others + used internally. Make sure to put them somewhere in the layout. + + Each component is given a dict-id with the following keys: + + * "context": a unique string id for this slicer instance. + * "scene": the scene_id. + * "axis": the int axis. + * "name": the name of the (sub) component. + + TODO: iron out these details, list the stores that are public + """ + + _global_slicer_counter = 0 + + def __init__(self, app, volume, axis=0, scene_id=None): if not isinstance(app, Dash): raise TypeError("Expect first arg to be a Dash app.") + self._app = app # Check and store volume if not (isinstance(volume, np.ndarray) and volume.ndim == 3): raise TypeError("Expected volume to be a 3D numpy array") @@ -22,29 +51,36 @@ def __init__(self, app, volume, axis=0, id=None): raise ValueError("The given axis must be 0, 1, or 2.") self._axis = int(axis) # Check and store id - if id is None: - id = gen_random_id() - elif not isinstance(id, str): - raise TypeError("Id must be a string") - self._id = id + if scene_id is None: + scene_id = "volume_" + hex(id(volume))[2:] + elif not isinstance(scene_id, str): + raise TypeError("scene_id must be a string") + self.scene_id = scene_id + # Get unique id scoped to this slicer object + DashVolumeSlicer._global_slicer_counter += 1 + self.context_id = "slicer_" + str(DashVolumeSlicer._global_slicer_counter) # Get the slice size (width, height), and max index - # arr_shape = list(volume.shape) - # arr_shape.pop(self._axis) - # slice_size = list(reversed(arr_shape)) + arr_shape = list(volume.shape) + arr_shape.pop(self._axis) + self._slice_size = tuple(reversed(arr_shape)) self._max_index = self._volume.shape[self._axis] - 1 # Prep low-res slices + thumbnail_size = get_thumbnail_size_from_shape(arr_shape, 32) thumbnails = [ - img_array_to_uri(self._slice(i), (32, 32)) + img_array_to_uri(self._slice(i), thumbnail_size) for i in range(self._max_index + 1) ] # Create a placeholder trace # todo: can add "%{z[0]}", but that would be the scaled value ... - trace = Image(source="", hovertemplate="(%{x}, %{y})") + image_trace = Image( + source="", dx=1, dy=1, hovertemplate="(%{x}, %{y})" + ) + scatter_trace = Scatter(x=[], y=[]) # placeholder # Create the figure object - fig = Figure(data=[trace]) + self._fig = fig = Figure(data=[image_trace, scatter_trace]) fig.update_layout( template=None, margin=dict(l=0, r=0, b=0, t=0, pad=4), @@ -70,6 +106,7 @@ def __init__(self, app, volume, axis=0, id=None): config={"scrollZoom": True}, ) # Create a slider object that the user can put in the layout (or not) + # todo: use tooltip to show current value? self.slider = Slider( id=self._subid("slider"), min=0, @@ -81,18 +118,29 @@ def __init__(self, app, volume, axis=0, id=None): ) # Create the stores that we need (these must be present in the layout) self.stores = [ - Store(id=self._subid("slice-index"), data=volume.shape[self._axis] // 2), + Store( + id=self._subid("_slice-size"), data=self._slice_size + thumbnail_size + ), + Store(id=self._subid("index"), data=volume.shape[self._axis] // 2), Store(id=self._subid("_requested-slice-index"), data=0), Store(id=self._subid("_slice-data"), data=""), Store(id=self._subid("_slice-data-lowres"), data=thumbnails), + Store(id=self._subid("_indicators"), data=[]), ] - self._create_server_callbacks(app) - self._create_client_callbacks(app) + self._create_server_callbacks() + self._create_client_callbacks() - def _subid(self, subid): + def _subid(self, name): """Given a subid, get the full id including the slicer's prefix.""" - return self._id + "-" + subid + # return self.context_id + "-" + name + # todo: is there a penalty for using a dict-id vs a string-id? + return { + "context": self.context_id, + "scene": self.scene_id, + "axis": self._axis, + "name": name, + } def _slice(self, index): """Sample a slice from the volume.""" @@ -101,8 +149,9 @@ def _slice(self, index): im = self._volume[tuple(indices)] return (im.astype(np.float32) * (255 / im.max())).astype(np.uint8) - def _create_server_callbacks(self, app): + def _create_server_callbacks(self): """Create the callbacks that run server-side.""" + app = self._app @app.callback( Output(self._subid("_slice-data"), "data"), @@ -112,8 +161,9 @@ def upload_requested_slice(slice_index): slice = self._slice(slice_index) return [slice_index, img_array_to_uri(slice)] - def _create_client_callbacks(self, app): + def _create_client_callbacks(self): """Create the callbacks that run client-side.""" + app = self._app app.clientside_callback( """ @@ -121,7 +171,7 @@ def _create_client_callbacks(self, app): return index; } """, - Output(self._subid("slice-index"), "data"), + Output(self._subid("index"), "data"), [Input(self._subid("slider"), "value")], ) @@ -138,10 +188,10 @@ def _create_client_callbacks(self, app): } } """.replace( - "{{ID}}", self._id + "{{ID}}", self.context_id ), Output(self._subid("_requested-slice-index"), "data"), - [Input(self._subid("slice-index"), "data")], + [Input(self._subid("index"), "data")], ) # app.clientside_callback(""" @@ -149,13 +199,13 @@ def _create_client_callbacks(self, app): # return index; # } # """, - # [Output("slice-index", "data")], + # [Output("index", "data")], # [State("slider", "value")], # ) app.clientside_callback( """ - function handle_incoming_slice(index, index_and_data, ori_figure, lowres) { + function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, slice_size) { let new_index = index_and_data[0]; let new_data = index_and_data[1]; // Store data in cache @@ -164,30 +214,97 @@ def _create_client_callbacks(self, app): slice_cache[new_index] = new_data; // Get the data we need *now* let data = slice_cache[index]; + let x0 = 0, y0 = 0, dx = 1, dy = 1; //slice_cache[new_index] = undefined; // todo: disabled cache for now! // Maybe we do not need an update if (!data) { data = lowres[index]; + // Scale the image to take the exact same space as the full-res + // version. It's not correct, but it looks better ... + // slice_size = full_w, full_h, low_w, low_h + dx = slice_size[0] / slice_size[2]; + dy = slice_size[1] / slice_size[3]; + x0 = 0.5 * dx - 0.5; + y0 = 0.5 * dy - 0.5; } - if (data == ori_figure.data[0].source) { + if (data == ori_figure.data[0].source && indicators.version == ori_figure.data[1].version) { return window.dash_clientside.no_update; } // Otherwise, perform update console.log("updating figure"); let figure = {...ori_figure}; figure.data[0].source = data; + figure.data[0].x0 = x0; + figure.data[0].y0 = y0; + figure.data[0].dx = dx; + figure.data[0].dy = dy; + figure.data[1] = indicators; return figure; } """.replace( - "{{ID}}", self._id + "{{ID}}", self.context_id ), Output(self._subid("graph"), "figure"), [ - Input(self._subid("slice-index"), "data"), + Input(self._subid("index"), "data"), Input(self._subid("_slice-data"), "data"), + Input(self._subid("_indicators"), "data"), ], [ State(self._subid("graph"), "figure"), State(self._subid("_slice-data-lowres"), "data"), + State(self._subid("_slice-size"), "data"), + ], + ) + + # Select the *other* axii + axii = [0, 1, 2] + axii.pop(self._axis) + + # Create a callback to create a trace representing all slice-indices that: + # * corresponding to the same volume data + # * match any of the selected axii + app.clientside_callback( + """ + function handle_indicator(indices1, indices2, slice_size, current) { + let w = slice_size[0], h = slice_size[1]; + let dx = w / 20, dy = h / 20; + let version = (current.version || 0) + 1; + let x = [], y = []; + for (let index of indices1) { + x.push(...[-dx, -1, null, w, w + dx, null]); + y.push(...[index, index, index, index, index, index]); + } + for (let index of indices2) { + x.push(...[index, index, index, index, index, index]); + y.push(...[-dy, -1, null, h, h + dy, null]); + } + return { + type: 'scatter', + mode: 'lines', + line: {color: '#ff00aa'}, + x: x, + y: y, + hoverinfo: 'skip', + version: version + }; + } + """, + Output(self._subid("_indicators"), "data"), + [ + Input( + { + "scene": self.scene_id, + "context": ALL, + "name": "index", + "axis": axis, + }, + "data", + ) + for axis in axii + ], + [ + State(self._subid("_slice-size"), "data"), + State(self._subid("_indicators"), "data"), ], ) diff --git a/dash_3d_viewer/utils.py b/dash_3d_viewer/utils.py index 68ab52c..583e8b2 100644 --- a/dash_3d_viewer/utils.py +++ b/dash_3d_viewer/utils.py @@ -2,6 +2,7 @@ import random import base64 +import numpy as np import PIL.Image import skimage @@ -23,3 +24,11 @@ def img_array_to_uri(img_array, new_size=None): img_pil.save(f, format="PNG") base64_str = base64.b64encode(f.getvalue()).decode() return "data:image/png;base64," + base64_str + + +def get_thumbnail_size_from_shape(shape, base_size): + base_size = int(base_size) + img_array = np.zeros(shape, np.uint8) + img_pil = PIL.Image.fromarray(img_array) + img_pil.thumbnail((base_size, base_size)) + return img_pil.size diff --git a/examples/slicer_with_1_plus_2_views.py b/examples/slicer_with_1_plus_2_views.py new file mode 100644 index 0000000..312adf3 --- /dev/null +++ b/examples/slicer_with_1_plus_2_views.py @@ -0,0 +1,64 @@ +""" +An example with two slicers at the same axis, and one on another axis. +This demonstrates how multiple indicators can be shown per axis. + +Sharing the same scene_id is enough for the slicers to show each-others +position. If the same volume object is given, it works by default, +because the default scene_id is a hash of the volume object. Specifying +a scene_id provides slice position indicators even when slicing through +different volumes. +""" + +import dash +import dash_html_components as html +from dash_3d_viewer import DashVolumeSlicer +import imageio + + +app = dash.Dash(__name__) + +vol = imageio.volread("imageio:stent.npz") +slicer1 = DashVolumeSlicer(app, vol, axis=1, scene_id="myscene") +slicer2 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene") +slicer3 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene") + +app.layout = html.Div( + style={ + "display": "grid", + "grid-template-columns": "40% 40%", + }, + children=[ + html.Div( + [ + html.H1("Coronal"), + slicer1.graph, + html.Br(), + slicer1.slider, + *slicer1.stores, + ] + ), + html.Div( + [ + html.H1("Transversal 1"), + slicer2.graph, + html.Br(), + slicer2.slider, + *slicer2.stores, + ] + ), + html.Div(), + html.Div( + [ + html.H1("Transversal 2"), + slicer3.graph, + html.Br(), + slicer3.slider, + *slicer3.stores, + ] + ), + ], +) + + +if __name__ == "__main__": + app.run_server(debug=True) diff --git a/examples/slicer_with_2_views.py b/examples/slicer_with_2_views.py index 7913e2d..73b829c 100644 --- a/examples/slicer_with_2_views.py +++ b/examples/slicer_with_2_views.py @@ -11,8 +11,8 @@ app = dash.Dash(__name__) vol = imageio.volread("imageio:stent.npz") -slicer1 = DashVolumeSlicer(app, vol, axis=1, id="slicer1") -slicer2 = DashVolumeSlicer(app, vol, axis=2, id="slicer2") +slicer1 = DashVolumeSlicer(app, vol, axis=1) +slicer2 = DashVolumeSlicer(app, vol, axis=2) app.layout = html.Div( style={ diff --git a/examples/slicer_with_3_views.py b/examples/slicer_with_3_views.py index 93e3906..54b1a3f 100644 --- a/examples/slicer_with_3_views.py +++ b/examples/slicer_with_3_views.py @@ -15,9 +15,9 @@ # Read volumes and create slicer objects vol = imageio.volread("imageio:stent.npz") -slicer1 = DashVolumeSlicer(app, vol, axis=0, id="slicer1") -slicer2 = DashVolumeSlicer(app, vol, axis=1, id="slicer2") -slicer3 = DashVolumeSlicer(app, vol, axis=2, id="slicer3") +slicer1 = DashVolumeSlicer(app, vol, axis=0) +slicer2 = DashVolumeSlicer(app, vol, axis=1) +slicer3 = DashVolumeSlicer(app, vol, axis=2) # Calculate isosurface and create a figure with a mesh object verts, faces, _, _ = marching_cubes(vol, 300, step_size=2)