-
-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Add facet_col and animation_frame argument to imshow #2746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
afb5c4d
8be8ca0
d236bc2
c8e852e
12cec34
ab427ae
7a3a9f4
b689a2f
fbb3f65
882810f
ba65990
cf644e5
72674b7
bd42385
fc2375b
a431fad
b652039
59c6622
c7285a3
91c066e
ac5aa1f
36b9f98
8cdc6af
cf1c2b9
5d1d8d8
c27f88a
502fdfd
135b01b
6ac3e36
a5a2252
77cb5cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -6,7 +6,7 @@ jupyter: | |||||
extension: .md | ||||||
format_name: markdown | ||||||
format_version: '1.2' | ||||||
jupytext_version: 1.4.2 | ||||||
jupytext_version: 1.3.0 | ||||||
kernelspec: | ||||||
display_name: Python 3 | ||||||
language: python | ||||||
|
@@ -20,7 +20,7 @@ jupyter: | |||||
name: python | ||||||
nbconvert_exporter: python | ||||||
pygments_lexer: ipython3 | ||||||
version: 3.7.7 | ||||||
version: 3.7.3 | ||||||
plotly: | ||||||
description: How to display image data in Python with Plotly. | ||||||
display_as: scientific | ||||||
|
@@ -399,9 +399,73 @@ for compression_level in range(0, 9): | |||||
fig.show() | ||||||
``` | ||||||
|
||||||
### Exploring 3-D images and timeseries with `facet_col` | ||||||
|
||||||
*Introduced in plotly 4.11* | ||||||
|
||||||
For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by representing its different planes as facets. The `facet_col` argument specifies along which axes the image is sliced through to make the facets. With `facet_col_wrap` , one can set the maximum number of columns. For image datasets passed as xarrays, it is also possible to give an axis name as a string for `facet_col`. | ||||||
|
||||||
It is recommended to use `binary_string=True` for facetted plots of images in order to keep a small figure size and a short rendering time. | ||||||
|
||||||
See the [tutorial on facet plots](/python/facet-plots/) for more information on creating and styling facet plots. | ||||||
|
||||||
```python | ||||||
import plotly.express as px | ||||||
from skimage import io | ||||||
from skimage.data import image_fetcher | ||||||
path = image_fetcher.fetch('data/cells.tif') | ||||||
data = io.imread(path) | ||||||
mkcor marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
img = data[25:40] | ||||||
fig = px.imshow(img, facet_col=0, binary_string=True, facet_col_wrap=5, height=700) | ||||||
fig.show() | ||||||
``` | ||||||
|
||||||
```python | ||||||
import plotly.express as px | ||||||
from skimage import io | ||||||
from skimage.data import image_fetcher | ||||||
path = image_fetcher.fetch('data/cells.tif') | ||||||
data = io.imread(path) | ||||||
img = data[25:40] | ||||||
fig = px.imshow(img, facet_col=0, binary_string=True, facet_col_wrap=5) | ||||||
# To have square facets one needs to unmatch axes | ||||||
fig.update_xaxes(matches=None) | ||||||
fig.update_yaxes(matches=None) | ||||||
fig.show() | ||||||
``` | ||||||
|
||||||
### Exploring 3-D images and timeseries with `animation_frame` | ||||||
|
||||||
*Introduced in plotly 4.11* | ||||||
|
||||||
For three-dimensional image datasets, obtained for example by MRI or CT in medical imaging, one can explore the dataset by sliding through its different planes in an animation. The `animation_frame` argument of `px.imshow` sets the axis along which the 3-D image is sliced in the animation. | ||||||
|
||||||
```python | ||||||
import plotly.express as px | ||||||
from skimage import io | ||||||
from skimage.data import image_fetcher | ||||||
path = image_fetcher.fetch('data/cells.tif') | ||||||
data = io.imread(path) | ||||||
img = data[25:40] | ||||||
fig = px.imshow(img, animation_frame=0, binary_string=True) | ||||||
fig.show() | ||||||
``` | ||||||
|
||||||
### Animations of xarray datasets | ||||||
|
||||||
*Introduced in plotly 4.11* | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
For xarray datasets, one can pass either an axis number or an axis name to `animation_frame`. Axis names and coordinates are automatically used for the labels, ticks and animation controls of the figure. | ||||||
|
||||||
```python | ||||||
import plotly.express as px | ||||||
import xarray as xr | ||||||
# Load xarray from dataset included in the xarray tutorial | ||||||
ds = xr.tutorial.open_dataset('air_temperature').air[:20] | ||||||
fig = px.imshow(ds, animation_frame='time', zmin=220, zmax=300, color_continuous_scale='RdBu_r') | ||||||
fig.show() | ||||||
``` | ||||||
|
||||||
#### Reference | ||||||
<<<<<<< HEAD | ||||||
See https://plotly.com/python/reference/#image for more information and chart attribute options! | ||||||
======= | ||||||
|
||||||
See https://plotly.com/python/reference/image/ for more information and chart attribute options! | ||||||
>>>>>>> doc-prod |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,3 +28,4 @@ pyarrow | |
cufflinks==0.17.3 | ||
kaleido | ||
umap-learn | ||
pooch |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,6 +1,6 @@ | ||||||
import plotly.graph_objs as go | ||||||
from _plotly_utils.basevalidators import ColorscaleValidator | ||||||
from ._core import apply_default_cascade | ||||||
from ._core import apply_default_cascade, init_figure, configure_animation_controls | ||||||
from io import BytesIO | ||||||
import base64 | ||||||
from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types | ||||||
|
@@ -133,6 +133,9 @@ def imshow( | |||||
labels={}, | ||||||
x=None, | ||||||
y=None, | ||||||
animation_frame=None, | ||||||
facet_col=None, | ||||||
facet_col_wrap=None, | ||||||
nicolaskruchten marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
color_continuous_scale=None, | ||||||
color_continuous_midpoint=None, | ||||||
range_color=None, | ||||||
|
@@ -186,6 +189,14 @@ def imshow( | |||||
their lengths must match the lengths of the second and first dimensions of the | ||||||
img argument. They are auto-populated if the input is an xarray. | ||||||
|
||||||
facet_col: int, optional (default None) | ||||||
axis number along which the image array is slices to create a facetted plot. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I'm not entirely positive about my suggestion to remove 'number' in 'axis number'... From https://numpy.org/doc/stable/glossary.html, it looks like conventional terminology would be just 'axis' (as in 'axis 0' and 'axis 1'); I was tempted by 'axis index' but this would be confusing with dataframes (as in, 'index' vs 'columns'). Maybe 'axis number' is conventional terminology after all? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. axis is fine indeed. It could also be "axis position" (as in the docstring of np.moveaxis). We can also ask other opinions, what do you think about the terminology @nicolaskruchten ? |
||||||
|
||||||
facet_col_wrap: int | ||||||
Maximum number of facet columns. Wraps the column variable at this width, | ||||||
so that the column facets span multiple rows. | ||||||
Ignored if `facet_col` is None. | ||||||
|
||||||
color_continuous_scale : str or list of str | ||||||
colormap used to map scalar data to colors (for a 2D image). This parameter is | ||||||
not used for RGB or RGBA images. If a string is provided, it should be the name | ||||||
|
@@ -277,15 +288,38 @@ def imshow( | |||||
args = locals() | ||||||
apply_default_cascade(args) | ||||||
labels = labels.copy() | ||||||
nslices = 1 | ||||||
if facet_col is not None: | ||||||
if isinstance(facet_col, str): | ||||||
facet_col = img.dims.index(facet_col) | ||||||
nslices = img.shape[facet_col] | ||||||
ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices | ||||||
nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols | ||||||
else: | ||||||
nrows = 1 | ||||||
ncols = 1 | ||||||
if animation_frame is not None: | ||||||
if isinstance(animation_frame, str): | ||||||
animation_frame = img.dims.index(animation_frame) | ||||||
nslices = img.shape[animation_frame] | ||||||
slice_through = (facet_col is not None) or (animation_frame is not None) | ||||||
slice_label = None | ||||||
slices = range(nslices) | ||||||
# ----- Define x and y, set labels if img is an xarray ------------------- | ||||||
if xarray_imported and isinstance(img, xarray.DataArray): | ||||||
if binary_string: | ||||||
raise ValueError( | ||||||
"It is not possible to use binary image strings for xarrays." | ||||||
"Please pass your data as a numpy array instead using" | ||||||
"`img.values`" | ||||||
) | ||||||
y_label, x_label = img.dims[0], img.dims[1] | ||||||
# if binary_string: | ||||||
# raise ValueError( | ||||||
# "It is not possible to use binary image strings for xarrays." | ||||||
# "Please pass your data as a numpy array instead using" | ||||||
# "`img.values`" | ||||||
# ) | ||||||
dims = list(img.dims) | ||||||
if slice_through: | ||||||
slice_index = facet_col if facet_col is not None else animation_frame | ||||||
slices = img.coords[img.dims[slice_index]].values | ||||||
_ = dims.pop(slice_index) | ||||||
slice_label = img.dims[slice_index] | ||||||
y_label, x_label = dims[0], dims[1] | ||||||
# np.datetime64 is not handled correctly by go.Heatmap | ||||||
for ax in [x_label, y_label]: | ||||||
if np.issubdtype(img.coords[ax].dtype, np.datetime64): | ||||||
|
@@ -300,6 +334,8 @@ def imshow( | |||||
labels["x"] = x_label | ||||||
if labels.get("y", None) is None: | ||||||
labels["y"] = y_label | ||||||
if labels.get("slice", None) is None: | ||||||
labels["slice"] = slice_label | ||||||
if labels.get("color", None) is None: | ||||||
labels["color"] = xarray.plot.utils.label_from_attrs(img) | ||||||
labels["color"] = labels["color"].replace("\n", "<br>") | ||||||
|
@@ -334,10 +370,22 @@ def imshow( | |||||
|
||||||
# --------------- Starting from here img is always a numpy array -------- | ||||||
img = np.asanyarray(img) | ||||||
if facet_col is not None: | ||||||
img = np.moveaxis(img, facet_col, 0) | ||||||
facet_col = True | ||||||
if animation_frame is not None: | ||||||
img = np.moveaxis(img, animation_frame, 0) | ||||||
animation_frame = True | ||||||
args["animation_frame"] = ( | ||||||
"slice" if labels.get("slice") is None else labels["slice"] | ||||||
) | ||||||
|
||||||
# Default behaviour of binary_string: True for RGB images, False for 2D | ||||||
if binary_string is None: | ||||||
binary_string = img.ndim >= 3 and not is_dataframe | ||||||
if slice_through: | ||||||
binary_string = img.ndim >= 4 and not is_dataframe | ||||||
else: | ||||||
binary_string = img.ndim >= 3 and not is_dataframe | ||||||
|
||||||
# Cast bools to uint8 (also one byte) | ||||||
if img.dtype == np.bool: | ||||||
|
@@ -349,7 +397,11 @@ def imshow( | |||||
|
||||||
# -------- Contrast rescaling: either minmax or infer ------------------ | ||||||
if contrast_rescaling is None: | ||||||
contrast_rescaling = "minmax" if img.ndim == 2 else "infer" | ||||||
contrast_rescaling = ( | ||||||
"minmax" | ||||||
if (img.ndim == 2 or (img.ndim == 3 and slice_through)) | ||||||
else "infer" | ||||||
) | ||||||
|
||||||
# We try to set zmin and zmax only if necessary, because traces have good defaults | ||||||
if contrast_rescaling == "minmax": | ||||||
|
@@ -366,18 +418,26 @@ def imshow( | |||||
zmin = 0 | ||||||
|
||||||
# For 2d data, use Heatmap trace, unless binary_string is True | ||||||
if img.ndim == 2 and not binary_string: | ||||||
if y is not None and img.shape[0] != len(y): | ||||||
if (img.ndim == 2 or (img.ndim == 3 and slice_through)) and not binary_string: | ||||||
y_index = 1 if slice_through else 0 | ||||||
if y is not None and img.shape[y_index] != len(y): | ||||||
raise ValueError( | ||||||
"The length of the y vector must match the length of the first " | ||||||
+ "dimension of the img matrix." | ||||||
) | ||||||
if x is not None and img.shape[1] != len(x): | ||||||
x_index = 2 if slice_through else 1 | ||||||
if x is not None and img.shape[x_index] != len(x): | ||||||
raise ValueError( | ||||||
"The length of the x vector must match the length of the second " | ||||||
+ "dimension of the img matrix." | ||||||
) | ||||||
trace = go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1") | ||||||
if slice_through: | ||||||
traces = [ | ||||||
go.Heatmap(x=x, y=y, z=img_slice, coloraxis="coloraxis1", name=str(i)) | ||||||
for i, img_slice in enumerate(img) | ||||||
] | ||||||
else: | ||||||
traces = [go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1")] | ||||||
autorange = True if origin == "lower" else "reversed" | ||||||
layout = dict(yaxis=dict(autorange=autorange)) | ||||||
if aspect == "equal": | ||||||
|
@@ -396,7 +456,9 @@ def imshow( | |||||
layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"]) | ||||||
|
||||||
# For 2D+RGB data, use Image trace | ||||||
elif img.ndim == 3 and img.shape[-1] in [3, 4] or (img.ndim == 2 and binary_string): | ||||||
elif ( | ||||||
img.ndim >= 3 and (img.shape[-1] in [3, 4] or slice_through and binary_string) | ||||||
) or (img.ndim == 2 and binary_string): | ||||||
rescale_image = True # to check whether image has been modified | ||||||
if zmin is not None and zmax is not None: | ||||||
zmin, zmax = ( | ||||||
|
@@ -407,40 +469,75 @@ def imshow( | |||||
if zmin is None and zmax is None: # no rescaling, faster | ||||||
img_rescaled = img | ||||||
rescale_image = False | ||||||
elif img.ndim == 2: | ||||||
elif img.ndim == 2 or (img.ndim == 3 and slice_through): | ||||||
img_rescaled = rescale_intensity( | ||||||
img, in_range=(zmin[0], zmax[0]), out_range=np.uint8 | ||||||
) | ||||||
else: | ||||||
img_rescaled = np.dstack( | ||||||
img_rescaled = np.stack( | ||||||
[ | ||||||
rescale_intensity( | ||||||
img[..., ch], | ||||||
in_range=(zmin[ch], zmax[ch]), | ||||||
out_range=np.uint8, | ||||||
) | ||||||
for ch in range(img.shape[-1]) | ||||||
] | ||||||
], | ||||||
axis=-1, | ||||||
) | ||||||
img_str = _array_to_b64str( | ||||||
img_rescaled, | ||||||
backend=binary_backend, | ||||||
compression=binary_compression_level, | ||||||
ext=binary_format, | ||||||
) | ||||||
trace = go.Image(source=img_str) | ||||||
if slice_through: | ||||||
img_str = [ | ||||||
_array_to_b64str( | ||||||
img_rescaled_slice, | ||||||
backend=binary_backend, | ||||||
compression=binary_compression_level, | ||||||
ext=binary_format, | ||||||
) | ||||||
for img_rescaled_slice in img_rescaled | ||||||
] | ||||||
|
||||||
else: | ||||||
img_str = [ | ||||||
_array_to_b64str( | ||||||
img_rescaled, | ||||||
backend=binary_backend, | ||||||
compression=binary_compression_level, | ||||||
ext=binary_format, | ||||||
) | ||||||
] | ||||||
traces = [ | ||||||
go.Image(source=img_str_slice, name=str(i)) | ||||||
for i, img_str_slice in enumerate(img_str) | ||||||
] | ||||||
else: | ||||||
colormodel = "rgb" if img.shape[-1] == 3 else "rgba256" | ||||||
trace = go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel) | ||||||
if slice_through: | ||||||
traces = [ | ||||||
go.Image(z=img_slice, zmin=zmin, zmax=zmax, colormodel=colormodel) | ||||||
for img_slice in img | ||||||
] | ||||||
else: | ||||||
traces = [go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)] | ||||||
layout = {} | ||||||
if origin == "lower": | ||||||
layout["yaxis"] = dict(autorange=True) | ||||||
else: | ||||||
raise ValueError( | ||||||
"px.imshow only accepts 2D single-channel, RGB or RGBA images. " | ||||||
"An image of shape %s was provided" % str(img.shape) | ||||||
"An image of shape %s was provided" | ||||||
emmanuelle marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
"Alternatively, 3-D single or multichannel datasets can be" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 3- or 4-D ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
"visualized using the `facet_col` or `animation_frame` arguments." | ||||||
% str(img.shape) | ||||||
) | ||||||
|
||||||
# Now build figure | ||||||
col_labels = [] | ||||||
if facet_col is not None: | ||||||
slice_label = "slice" if labels.get("slice") is None else labels["slice"] | ||||||
if slices is None: | ||||||
slices = range(nslices) | ||||||
col_labels = ["%s = %d" % (slice_label, i) for i in slices] | ||||||
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, []) | ||||||
layout_patch = dict() | ||||||
for attr_name in ["height", "width"]: | ||||||
if args[attr_name]: | ||||||
|
@@ -449,7 +546,16 @@ def imshow( | |||||
layout_patch["title_text"] = args["title"] | ||||||
elif args["template"].layout.margin.t is None: | ||||||
layout_patch["margin"] = {"t": 60} | ||||||
fig = go.Figure(data=trace, layout=layout) | ||||||
|
||||||
frame_list = [] | ||||||
for index, (slice_index, trace) in enumerate(zip(slices, traces)): | ||||||
if facet_col or index == 0: | ||||||
fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1) | ||||||
if animation_frame: | ||||||
frame_list.append(dict(data=trace, layout=layout, name=str(slice_index))) | ||||||
if animation_frame: | ||||||
fig.frames = frame_list | ||||||
fig.update_layout(layout) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's a bit odd to have |
||||||
fig.update_layout(layout_patch) | ||||||
# Hover name, z or color | ||||||
if binary_string and rescale_image and not np.all(img == img_rescaled): | ||||||
|
@@ -479,5 +585,6 @@ def imshow( | |||||
fig.update_xaxes(title_text=labels["x"]) | ||||||
if labels["y"]: | ||||||
fig.update_yaxes(title_text=labels["y"]) | ||||||
fig.update_layout(template=args["template"], overwrite=True) | ||||||
configure_animation_controls(args, go.Image, fig) | ||||||
# fig.update_layout(template=args["template"], overwrite=True) | ||||||
return fig |
Uh oh!
There was an error while loading. Please reload this page.