Skip to content

Implement indicators #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 148 additions & 31 deletions dash_3d_viewer/slicer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,47 @@
import numpy as np
from plotly.graph_objects import Figure, Image
from plotly.graph_objects import Figure, Image, Scatter
from dash import Dash
from dash.dependencies import Input, Output, State
from dash.dependencies import Input, Output, State, ALL
from dash_core_components import Graph, Slider, Store

from .utils import gen_random_id, img_array_to_uri
from .utils import img_array_to_uri, get_thumbnail_size_from_shape


class DashVolumeSlicer:
"""A slicer to show 3D image data in Dash."""
"""A slicer to show 3D image data in Dash.

def __init__(self, app, volume, axis=0, id=None):
Parameters:
app (dash.Dash): the Dash application instance.
volume (ndarray): the 3D numpy array to slice through.
axis (int): the dimension to slice in. Default 0.
scene_id (str): the scene that this slicer is part of. Slicers
that have the same scene-id show each-other's positions with
line indicators. By default this is a hash of ``id(volume)``.

This is a placeholder object, not a Dash component. The components
that make up the slicer can be accessed as attributes:

* ``graph``: the Graph object.
* ``slider``: the Slider object.
* ``stores``: a list of Store objects. Some are "public" values, others
used internally. Make sure to put them somewhere in the layout.

Each component is given a dict-id with the following keys:

* "context": a unique string id for this slicer instance.
* "scene": the scene_id.
* "axis": the int axis.
* "name": the name of the (sub) component.

TODO: iron out these details, list the stores that are public
"""

_global_slicer_counter = 0

def __init__(self, app, volume, axis=0, scene_id=None):
if not isinstance(app, Dash):
raise TypeError("Expect first arg to be a Dash app.")
self._app = app
# Check and store volume
if not (isinstance(volume, np.ndarray) and volume.ndim == 3):
raise TypeError("Expected volume to be a 3D numpy array")
Expand All @@ -22,29 +51,36 @@ def __init__(self, app, volume, axis=0, id=None):
raise ValueError("The given axis must be 0, 1, or 2.")
self._axis = int(axis)
# Check and store id
if id is None:
id = gen_random_id()
elif not isinstance(id, str):
raise TypeError("Id must be a string")
self._id = id
if scene_id is None:
scene_id = "volume_" + hex(id(volume))[2:]
elif not isinstance(scene_id, str):
raise TypeError("scene_id must be a string")
self.scene_id = scene_id
# Get unique id scoped to this slicer object
DashVolumeSlicer._global_slicer_counter += 1
self.context_id = "slicer_" + str(DashVolumeSlicer._global_slicer_counter)

# Get the slice size (width, height), and max index
# arr_shape = list(volume.shape)
# arr_shape.pop(self._axis)
# slice_size = list(reversed(arr_shape))
arr_shape = list(volume.shape)
arr_shape.pop(self._axis)
self._slice_size = tuple(reversed(arr_shape))
self._max_index = self._volume.shape[self._axis] - 1

# Prep low-res slices
thumbnail_size = get_thumbnail_size_from_shape(arr_shape, 32)
thumbnails = [
img_array_to_uri(self._slice(i), (32, 32))
img_array_to_uri(self._slice(i), thumbnail_size)
for i in range(self._max_index + 1)
]

# Create a placeholder trace
# todo: can add "%{z[0]}", but that would be the scaled value ...
trace = Image(source="", hovertemplate="(%{x}, %{y})<extra></extra>")
image_trace = Image(
source="", dx=1, dy=1, hovertemplate="(%{x}, %{y})<extra></extra>"
)
scatter_trace = Scatter(x=[], y=[]) # placeholder
# Create the figure object
fig = Figure(data=[trace])
self._fig = fig = Figure(data=[image_trace, scatter_trace])
fig.update_layout(
template=None,
margin=dict(l=0, r=0, b=0, t=0, pad=4),
Expand All @@ -70,6 +106,7 @@ def __init__(self, app, volume, axis=0, id=None):
config={"scrollZoom": True},
)
# Create a slider object that the user can put in the layout (or not)
# todo: use tooltip to show current value?
self.slider = Slider(
id=self._subid("slider"),
min=0,
Expand All @@ -81,18 +118,29 @@ def __init__(self, app, volume, axis=0, id=None):
)
# Create the stores that we need (these must be present in the layout)
self.stores = [
Store(id=self._subid("slice-index"), data=volume.shape[self._axis] // 2),
Store(
id=self._subid("_slice-size"), data=self._slice_size + thumbnail_size
),
Store(id=self._subid("index"), data=volume.shape[self._axis] // 2),
Store(id=self._subid("_requested-slice-index"), data=0),
Store(id=self._subid("_slice-data"), data=""),
Store(id=self._subid("_slice-data-lowres"), data=thumbnails),
Store(id=self._subid("_indicators"), data=[]),
]

self._create_server_callbacks(app)
self._create_client_callbacks(app)
self._create_server_callbacks()
self._create_client_callbacks()

def _subid(self, subid):
def _subid(self, name):
"""Given a subid, get the full id including the slicer's prefix."""
return self._id + "-" + subid
# return self.context_id + "-" + name
# todo: is there a penalty for using a dict-id vs a string-id?
return {
"context": self.context_id,
"scene": self.scene_id,
"axis": self._axis,
"name": name,
}

def _slice(self, index):
"""Sample a slice from the volume."""
Expand All @@ -101,8 +149,9 @@ def _slice(self, index):
im = self._volume[tuple(indices)]
return (im.astype(np.float32) * (255 / im.max())).astype(np.uint8)

def _create_server_callbacks(self, app):
def _create_server_callbacks(self):
"""Create the callbacks that run server-side."""
app = self._app

@app.callback(
Output(self._subid("_slice-data"), "data"),
Expand All @@ -112,16 +161,17 @@ def upload_requested_slice(slice_index):
slice = self._slice(slice_index)
return [slice_index, img_array_to_uri(slice)]

def _create_client_callbacks(self, app):
def _create_client_callbacks(self):
"""Create the callbacks that run client-side."""
app = self._app

app.clientside_callback(
"""
function handle_slider_move(index) {
return index;
}
""",
Output(self._subid("slice-index"), "data"),
Output(self._subid("index"), "data"),
[Input(self._subid("slider"), "value")],
)

Expand All @@ -138,24 +188,24 @@ def _create_client_callbacks(self, app):
}
}
""".replace(
"{{ID}}", self._id
"{{ID}}", self.context_id
),
Output(self._subid("_requested-slice-index"), "data"),
[Input(self._subid("slice-index"), "data")],
[Input(self._subid("index"), "data")],
)

# app.clientside_callback("""
# function update_slider_pos(index) {
# return index;
# }
# """,
# [Output("slice-index", "data")],
# [Output("index", "data")],
# [State("slider", "value")],
# )

app.clientside_callback(
"""
function handle_incoming_slice(index, index_and_data, ori_figure, lowres) {
function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, slice_size) {
let new_index = index_and_data[0];
let new_data = index_and_data[1];
// Store data in cache
Expand All @@ -164,30 +214,97 @@ def _create_client_callbacks(self, app):
slice_cache[new_index] = new_data;
// Get the data we need *now*
let data = slice_cache[index];
let x0 = 0, y0 = 0, dx = 1, dy = 1;
//slice_cache[new_index] = undefined; // todo: disabled cache for now!
// Maybe we do not need an update
if (!data) {
data = lowres[index];
// Scale the image to take the exact same space as the full-res
// version. It's not correct, but it looks better ...
// slice_size = full_w, full_h, low_w, low_h
dx = slice_size[0] / slice_size[2];
dy = slice_size[1] / slice_size[3];
x0 = 0.5 * dx - 0.5;
y0 = 0.5 * dy - 0.5;
}
if (data == ori_figure.data[0].source) {
if (data == ori_figure.data[0].source && indicators.version == ori_figure.data[1].version) {
return window.dash_clientside.no_update;
}
// Otherwise, perform update
console.log("updating figure");
let figure = {...ori_figure};
figure.data[0].source = data;
figure.data[0].x0 = x0;
figure.data[0].y0 = y0;
figure.data[0].dx = dx;
figure.data[0].dy = dy;
figure.data[1] = indicators;
return figure;
}
""".replace(
"{{ID}}", self._id
"{{ID}}", self.context_id
),
Output(self._subid("graph"), "figure"),
[
Input(self._subid("slice-index"), "data"),
Input(self._subid("index"), "data"),
Input(self._subid("_slice-data"), "data"),
Input(self._subid("_indicators"), "data"),
],
[
State(self._subid("graph"), "figure"),
State(self._subid("_slice-data-lowres"), "data"),
State(self._subid("_slice-size"), "data"),
],
)

# Select the *other* axii
axii = [0, 1, 2]
axii.pop(self._axis)

# Create a callback to create a trace representing all slice-indices that:
# * corresponding to the same volume data
# * match any of the selected axii
app.clientside_callback(
"""
function handle_indicator(indices1, indices2, slice_size, current) {
let w = slice_size[0], h = slice_size[1];
let dx = w / 20, dy = h / 20;
let version = (current.version || 0) + 1;
let x = [], y = [];
for (let index of indices1) {
x.push(...[-dx, -1, null, w, w + dx, null]);
y.push(...[index, index, index, index, index, index]);
}
for (let index of indices2) {
x.push(...[index, index, index, index, index, index]);
y.push(...[-dy, -1, null, h, h + dy, null]);
}
return {
type: 'scatter',
mode: 'lines',
line: {color: '#ff00aa'},
x: x,
y: y,
hoverinfo: 'skip',
version: version
};
}
""",
Output(self._subid("_indicators"), "data"),
[
Input(
{
"scene": self.scene_id,
"context": ALL,
"name": "index",
"axis": axis,
},
"data",
)
for axis in axii
],
[
State(self._subid("_slice-size"), "data"),
State(self._subid("_indicators"), "data"),
],
)
9 changes: 9 additions & 0 deletions dash_3d_viewer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import random
import base64

import numpy as np
import PIL.Image
import skimage

Expand All @@ -23,3 +24,11 @@ def img_array_to_uri(img_array, new_size=None):
img_pil.save(f, format="PNG")
base64_str = base64.b64encode(f.getvalue()).decode()
return "data:image/png;base64," + base64_str


def get_thumbnail_size_from_shape(shape, base_size):
base_size = int(base_size)
img_array = np.zeros(shape, np.uint8)
img_pil = PIL.Image.fromarray(img_array)
img_pil.thumbnail((base_size, base_size))
return img_pil.size
64 changes: 64 additions & 0 deletions examples/slicer_with_1_plus_2_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
An example with two slicers at the same axis, and one on another axis.
This demonstrates how multiple indicators can be shown per axis.

Sharing the same scene_id is enough for the slicers to show each-others
position. If the same volume object is given, it works by default,
because the default scene_id is a hash of the volume object. Specifying
a scene_id provides slice position indicators even when slicing through
different volumes.
"""

import dash
import dash_html_components as html
from dash_3d_viewer import DashVolumeSlicer
import imageio


app = dash.Dash(__name__)

vol = imageio.volread("imageio:stent.npz")
slicer1 = DashVolumeSlicer(app, vol, axis=1, scene_id="myscene")
slicer2 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene")
slicer3 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene")

app.layout = html.Div(
style={
"display": "grid",
"grid-template-columns": "40% 40%",
},
children=[
html.Div(
[
html.H1("Coronal"),
slicer1.graph,
html.Br(),
slicer1.slider,
*slicer1.stores,
]
),
html.Div(
[
html.H1("Transversal 1"),
slicer2.graph,
html.Br(),
slicer2.slider,
*slicer2.stores,
]
),
html.Div(),
html.Div(
[
html.H1("Transversal 2"),
slicer3.graph,
html.Br(),
slicer3.slider,
*slicer3.stores,
]
),
],
)


if __name__ == "__main__":
app.run_server(debug=True)
Loading