Skip to content

Commit 577d3a7

Browse files
authored
fix plotting with transposed nondim coords. (#3441)
* make plotting work with transposed nondim coords. * Additional test. * Test to make sure transpose is right * Undo the transpose change and add test to make sure transposition is right. * fix whats-new merge. * proper fix. * fix whats-new * Fix whats-new
1 parent 308bb37 commit 577d3a7

File tree

3 files changed

+49
-7
lines changed

3 files changed

+49
-7
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ New Features
3434

3535
Bug fixes
3636
~~~~~~~~~
37+
- Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`, :pull:`3441`)
38+
By `Deepak Cherian <https://github.com/dcherian>`_.
3739

3840

3941
Documentation

xarray/plot/plot.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -672,10 +672,22 @@ def newplotfunc(
672672

673673
# check if we need to broadcast one dimension
674674
if xval.ndim < yval.ndim:
675-
xval = np.broadcast_to(xval, yval.shape)
675+
dims = darray[ylab].dims
676+
if xval.shape[0] == yval.shape[0]:
677+
xval = np.broadcast_to(xval[:, np.newaxis], yval.shape)
678+
else:
679+
xval = np.broadcast_to(xval[np.newaxis, :], yval.shape)
676680

677-
if yval.ndim < xval.ndim:
678-
yval = np.broadcast_to(yval, xval.shape)
681+
elif yval.ndim < xval.ndim:
682+
dims = darray[xlab].dims
683+
if yval.shape[0] == xval.shape[0]:
684+
yval = np.broadcast_to(yval[:, np.newaxis], xval.shape)
685+
else:
686+
yval = np.broadcast_to(yval[np.newaxis, :], xval.shape)
687+
elif xval.ndim == 2:
688+
dims = darray[xlab].dims
689+
else:
690+
dims = (darray[ylab].dims[0], darray[xlab].dims[0])
679691

680692
# May need to transpose for correct x, y labels
681693
# xlab may be the name of a coord, we have to check for dim names
@@ -685,10 +697,9 @@ def newplotfunc(
685697
# we transpose to (y, x, color) to make this work.
686698
yx_dims = (ylab, xlab)
687699
dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims)
688-
if dims != darray.dims:
689-
darray = darray.transpose(*dims, transpose_coords=True)
690-
elif darray[xlab].dims[-1] == darray.dims[0]:
691-
darray = darray.transpose(transpose_coords=True)
700+
701+
if dims != darray.dims:
702+
darray = darray.transpose(*dims, transpose_coords=True)
692703

693704
# Pass the data as a masked ndarray too
694705
zval = darray.to_masked_array(copy=False)

xarray/tests/test_plot.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ def test2d_1d_2d_coordinates_contourf(self):
265265
)
266266

267267
a.plot.contourf(x="time", y="depth")
268+
a.plot.contourf(x="depth", y="time")
268269

269270
def test3d(self):
270271
self.darray.plot()
@@ -2149,3 +2150,31 @@ def test_yticks_kwarg(self, da):
21492150
da.plot(yticks=np.arange(5))
21502151
expected = np.arange(5)
21512152
assert np.all(plt.gca().get_yticks() == expected)
2153+
2154+
2155+
@requires_matplotlib
2156+
@pytest.mark.parametrize("plotfunc", ["pcolormesh", "contourf", "contour"])
2157+
def test_plot_transposed_nondim_coord(plotfunc):
2158+
x = np.linspace(0, 10, 101)
2159+
h = np.linspace(3, 7, 101)
2160+
s = np.linspace(0, 1, 51)
2161+
z = s[:, np.newaxis] * h[np.newaxis, :]
2162+
da = xr.DataArray(
2163+
np.sin(x) * np.cos(z),
2164+
dims=["s", "x"],
2165+
coords={"x": x, "s": s, "z": (("s", "x"), z), "zt": (("x", "s"), z.T)},
2166+
)
2167+
getattr(da.plot, plotfunc)(x="x", y="zt")
2168+
getattr(da.plot, plotfunc)(x="zt", y="x")
2169+
2170+
2171+
@requires_matplotlib
2172+
@pytest.mark.parametrize("plotfunc", ["pcolormesh", "imshow"])
2173+
def test_plot_transposes_properly(plotfunc):
2174+
# test that we aren't mistakenly transposing when the 2 dimensions have equal sizes.
2175+
da = xr.DataArray([np.sin(2 * np.pi / 10 * np.arange(10))] * 10, dims=("y", "x"))
2176+
hdl = getattr(da.plot, plotfunc)(x="x", y="y")
2177+
# get_array doesn't work for contour, contourf. It returns the colormap intervals.
2178+
# pcolormesh returns 1D array but imshow returns a 2D array so it is necessary
2179+
# to ravel() on the LHS
2180+
assert np.all(hdl.get_array().ravel() == da.to_masked_array().ravel())

0 commit comments

Comments
 (0)