Skip to content

Commit ed9948e

Browse files
committed
fix plotting with transposed nondim coords.
Fixes #3138
1 parent 652dd3c commit ed9948e

File tree

3 files changed

+30
-13
lines changed

3 files changed

+30
-13
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ Bug fixes
4444
- Sync with cftime by removing `dayofwk=-1` for cftime>=1.0.4.
4545
By `Anderson Banihirwe <https://github.com/andersy005>`_.
4646

47+
- Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`)
48+
By `Deepak Cherian <https://github.com/dcherian>`_.
4749

4850
Documentation
4951
~~~~~~~~~~~~~

xarray/plot/plot.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
import pandas as pd
1313

14+
from ..core.alignment import broadcast
1415
from .facetgrid import _easy_facetgrid
1516
from .utils import (
1617
_add_colorbar,
@@ -666,17 +667,6 @@ def newplotfunc(
666667
darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb
667668
)
668669

669-
# better to pass the ndarrays directly to plotting functions
670-
xval = darray[xlab].values
671-
yval = darray[ylab].values
672-
673-
# check if we need to broadcast one dimension
674-
if xval.ndim < yval.ndim:
675-
xval = np.broadcast_to(xval, yval.shape)
676-
677-
if yval.ndim < xval.ndim:
678-
yval = np.broadcast_to(yval, xval.shape)
679-
680670
# May need to transpose for correct x, y labels
681671
# xlab may be the name of a coord, we have to check for dim names
682672
if imshow_rgb:
@@ -690,8 +680,17 @@ def newplotfunc(
690680
elif darray[xlab].dims[-1] == darray.dims[0]:
691681
darray = darray.transpose(transpose_coords=True)
692682

693-
# Pass the data as a masked ndarray too
694-
zval = darray.to_masked_array(copy=False)
683+
# better to pass the ndarrays directly to plotting functions
684+
# Pass the data as a masked ndarray
685+
if darray[xlab].ndim == 1 and darray[ylab].ndim == 1:
686+
xval = darray[xlab].values
687+
yval = darray[ylab].values
688+
zval = darray.to_masked_array(copy=False)
689+
else:
690+
xval, yval, zval = map(
691+
lambda x: x.values, broadcast(darray[xlab], darray[ylab], darray)
692+
)
693+
zval = np.ma.masked_array(zval, mask=pd.isnull(zval), copy=False)
695694

696695
# Replace pd.Intervals if contained in xval or yval.
697696
xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__)

xarray/tests/test_plot.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2145,3 +2145,19 @@ def test_yticks_kwarg(self, da):
21452145
da.plot(yticks=np.arange(5))
21462146
expected = np.arange(5)
21472147
assert np.all(plt.gca().get_yticks() == expected)
2148+
2149+
2150+
@requires_matplotlib
2151+
@pytest.mark.parametrize("plotfunc", ["pcolormesh", "contourf", "contour"])
2152+
def test_plot_transposed_nondim_coord(plotfunc):
2153+
x = np.linspace(0, 10, 101)
2154+
h = np.linspace(3, 7, 101)
2155+
s = np.linspace(0, 1, 51)
2156+
z = s[:, np.newaxis] * h[np.newaxis, :]
2157+
da = xr.DataArray(
2158+
np.sin(x) * np.cos(z),
2159+
dims=["s", "x"],
2160+
coords={"x": x, "s": s, "z": (("s", "x"), z), "zt": (("x", "s"), z.T)},
2161+
)
2162+
getattr(da.plot, plotfunc)(x="x", y="zt")
2163+
getattr(da.plot, plotfunc)(x="zt", y="x")

0 commit comments

Comments
 (0)