diff --git a/plotly/_subplots.py b/plotly/_subplots.py index a1bb4219c9..223324096b 100644 --- a/plotly/_subplots.py +++ b/plotly/_subplots.py @@ -754,8 +754,12 @@ def _check_hv_spacing(dimsize, spacing, name, dimvarname, dimname): ) grid_ref[r][c] = subplot_refs - _configure_shared_axes(layout, grid_ref, specs, "x", shared_xaxes, row_dir) - _configure_shared_axes(layout, grid_ref, specs, "y", shared_yaxes, row_dir) + _configure_shared_axes(layout, grid_ref, specs, "x", shared_xaxes, row_dir, False) + _configure_shared_axes(layout, grid_ref, specs, "y", shared_yaxes, row_dir, False) + if secondary_y: + _configure_shared_axes( + layout, grid_ref, specs, "y", shared_yaxes, row_dir, True + ) # Build inset reference # --------------------- @@ -887,7 +891,9 @@ def _check_hv_spacing(dimsize, spacing, name, dimvarname, dimname): return figure -def _configure_shared_axes(layout, grid_ref, specs, x_or_y, shared, row_dir): +def _configure_shared_axes( + layout, grid_ref, specs, x_or_y, shared, row_dir, secondary_y +): rows = len(grid_ref) cols = len(grid_ref[0]) @@ -898,6 +904,13 @@ def _configure_shared_axes(layout, grid_ref, specs, x_or_y, shared, row_dir): else: rows_iter = range(rows) + if secondary_y: + cols_iter = range(cols - 1, -1, -1) + axis_index = 1 + else: + cols_iter = range(cols) + axis_index = 0 + def update_axis_matches(first_axis_id, subplot_ref, spec, remove_label): if subplot_ref is None: return first_axis_id @@ -921,13 +934,15 @@ def update_axis_matches(first_axis_id, subplot_ref, spec, remove_label): return first_axis_id if shared == "columns" or (x_or_y == "x" and shared is True): - for c in range(cols): + for c in cols_iter: first_axis_id = None ok_to_remove_label = x_or_y == "x" for r in rows_iter: if not grid_ref[r][c]: continue - subplot_ref = grid_ref[r][c][0] + if axis_index >= len(grid_ref[r][c]): + continue + subplot_ref = grid_ref[r][c][axis_index] spec = specs[r][c] first_axis_id = update_axis_matches( first_axis_id, subplot_ref, spec, ok_to_remove_label @@ -937,10 +952,12 @@ def update_axis_matches(first_axis_id, subplot_ref, spec, remove_label): for r in rows_iter: first_axis_id = None ok_to_remove_label = x_or_y == "y" - for c in range(cols): + for c in cols_iter: if not grid_ref[r][c]: continue - subplot_ref = grid_ref[r][c][0] + if axis_index >= len(grid_ref[r][c]): + continue + subplot_ref = grid_ref[r][c][axis_index] spec = specs[r][c] first_axis_id = update_axis_matches( first_axis_id, subplot_ref, spec, ok_to_remove_label @@ -948,15 +965,17 @@ def update_axis_matches(first_axis_id, subplot_ref, spec, remove_label): elif shared == "all": first_axis_id = None - for c in range(cols): - for ri, r in enumerate(rows_iter): + for ri, r in enumerate(rows_iter): + for c in cols_iter: if not grid_ref[r][c]: continue - subplot_ref = grid_ref[r][c][0] + if axis_index >= len(grid_ref[r][c]): + continue + subplot_ref = grid_ref[r][c][axis_index] spec = specs[r][c] if x_or_y == "y": - ok_to_remove_label = c > 0 + ok_to_remove_label = c < cols - 1 if secondary_y else c > 0 else: ok_to_remove_label = ri > 0 if row_dir > 0 else r < rows - 1 diff --git a/tests/test_core/test_subplots/test_make_subplots.py b/tests/test_core/test_subplots/test_make_subplots.py index 868ee0e316..fecbec2f9a 100644 --- a/tests/test_core/test_subplots/test_make_subplots.py +++ b/tests/test_core/test_subplots/test_make_subplots.py @@ -1828,115 +1828,127 @@ def test_secondary_y_traces(self): self.assertEqual(fig.to_plotly_json(), expected.to_plotly_json()) def test_secondary_y_subplots(self): - fig = subplots.make_subplots( - rows=2, - cols=2, - specs=[ - [{"secondary_y": True}, {"secondary_y": True}], - [{"secondary_y": True}, {"secondary_y": True}], - ], - ) + for shared_y_axes in [False, True]: + fig = subplots.make_subplots( + rows=2, + cols=2, + shared_yaxes=shared_y_axes, + specs=[ + [{"secondary_y": True}, {"secondary_y": True}], + [{"secondary_y": True}, {"secondary_y": True}], + ], + ) - fig.add_scatter(y=[1, 3, 2], name="First", row=1, col=1) - fig.add_scatter(y=[2, 1, 3], name="Second", row=1, col=1, secondary_y=True) + fig.add_scatter(y=[1, 3, 2], name="First", row=1, col=1) + fig.add_scatter(y=[2, 1, 3], name="Second", row=1, col=1, secondary_y=True) - fig.add_scatter(y=[4, 3, 2], name="Third", row=1, col=2) - fig.add_scatter(y=[8, 1, 3], name="Forth", row=1, col=2, secondary_y=True) + fig.add_scatter(y=[4, 3, 2], name="Third", row=1, col=2) + fig.add_scatter(y=[8, 1, 3], name="Forth", row=1, col=2, secondary_y=True) - fig.add_scatter(y=[0, 2, 4], name="Fifth", row=2, col=1) - fig.add_scatter(y=[2, 1, 3], name="Sixth", row=2, col=1, secondary_y=True) + fig.add_scatter(y=[0, 2, 4], name="Fifth", row=2, col=1) + fig.add_scatter(y=[2, 1, 3], name="Sixth", row=2, col=1, secondary_y=True) - fig.add_scatter(y=[2, 4, 0], name="Fifth", row=2, col=2) - fig.add_scatter(y=[2, 3, 6], name="Sixth", row=2, col=2, secondary_y=True) + fig.add_scatter(y=[2, 4, 0], name="Fifth", row=2, col=2) + fig.add_scatter(y=[2, 3, 6], name="Sixth", row=2, col=2, secondary_y=True) - fig.update_traces(uid=None) + fig.update_traces(uid=None) - expected = Figure( - { - "data": [ - { - "name": "First", - "type": "scatter", - "xaxis": "x", - "y": [1, 3, 2], - "yaxis": "y", - }, - { - "name": "Second", - "type": "scatter", - "xaxis": "x", - "y": [2, 1, 3], - "yaxis": "y2", - }, - { - "name": "Third", - "type": "scatter", - "xaxis": "x2", - "y": [4, 3, 2], - "yaxis": "y3", - }, - { - "name": "Forth", - "type": "scatter", - "xaxis": "x2", - "y": [8, 1, 3], - "yaxis": "y4", - }, - { - "name": "Fifth", - "type": "scatter", - "xaxis": "x3", - "y": [0, 2, 4], - "yaxis": "y5", - }, - { - "name": "Sixth", - "type": "scatter", - "xaxis": "x3", - "y": [2, 1, 3], - "yaxis": "y6", - }, - { - "name": "Fifth", - "type": "scatter", - "xaxis": "x4", - "y": [2, 4, 0], - "yaxis": "y7", - }, - { - "name": "Sixth", - "type": "scatter", - "xaxis": "x4", - "y": [2, 3, 6], - "yaxis": "y8", - }, - ], - "layout": { - "xaxis": {"anchor": "y", "domain": [0.0, 0.37]}, - "xaxis2": { - "anchor": "y3", - "domain": [0.5700000000000001, 0.9400000000000001], - }, - "xaxis3": {"anchor": "y5", "domain": [0.0, 0.37]}, - "xaxis4": { - "anchor": "y7", - "domain": [0.5700000000000001, 0.9400000000000001], + expected = Figure( + { + "data": [ + { + "name": "First", + "type": "scatter", + "xaxis": "x", + "y": [1, 3, 2], + "yaxis": "y", + }, + { + "name": "Second", + "type": "scatter", + "xaxis": "x", + "y": [2, 1, 3], + "yaxis": "y2", + }, + { + "name": "Third", + "type": "scatter", + "xaxis": "x2", + "y": [4, 3, 2], + "yaxis": "y3", + }, + { + "name": "Forth", + "type": "scatter", + "xaxis": "x2", + "y": [8, 1, 3], + "yaxis": "y4", + }, + { + "name": "Fifth", + "type": "scatter", + "xaxis": "x3", + "y": [0, 2, 4], + "yaxis": "y5", + }, + { + "name": "Sixth", + "type": "scatter", + "xaxis": "x3", + "y": [2, 1, 3], + "yaxis": "y6", + }, + { + "name": "Fifth", + "type": "scatter", + "xaxis": "x4", + "y": [2, 4, 0], + "yaxis": "y7", + }, + { + "name": "Sixth", + "type": "scatter", + "xaxis": "x4", + "y": [2, 3, 6], + "yaxis": "y8", + }, + ], + "layout": { + "xaxis": {"anchor": "y", "domain": [0.0, 0.37]}, + "xaxis2": { + "anchor": "y3", + "domain": [0.5700000000000001, 0.9400000000000001], + }, + "xaxis3": {"anchor": "y5", "domain": [0.0, 0.37]}, + "xaxis4": { + "anchor": "y7", + "domain": [0.5700000000000001, 0.9400000000000001], + }, + "yaxis": {"anchor": "x", "domain": [0.575, 1.0]}, + "yaxis2": {"anchor": "x", "overlaying": "y", "side": "right"}, + "yaxis3": {"anchor": "x2", "domain": [0.575, 1.0]}, + "yaxis4": {"anchor": "x2", "overlaying": "y3", "side": "right"}, + "yaxis5": {"anchor": "x3", "domain": [0.0, 0.425]}, + "yaxis6": {"anchor": "x3", "overlaying": "y5", "side": "right"}, + "yaxis7": {"anchor": "x4", "domain": [0.0, 0.425]}, + "yaxis8": {"anchor": "x4", "overlaying": "y7", "side": "right"}, }, - "yaxis": {"anchor": "x", "domain": [0.575, 1.0]}, - "yaxis2": {"anchor": "x", "overlaying": "y", "side": "right"}, - "yaxis3": {"anchor": "x2", "domain": [0.575, 1.0]}, - "yaxis4": {"anchor": "x2", "overlaying": "y3", "side": "right"}, - "yaxis5": {"anchor": "x3", "domain": [0.0, 0.425]}, - "yaxis6": {"anchor": "x3", "overlaying": "y5", "side": "right"}, - "yaxis7": {"anchor": "x4", "domain": [0.0, 0.425]}, - "yaxis8": {"anchor": "x4", "overlaying": "y7", "side": "right"}, - }, - } - ) + } + ) - expected.update_traces(uid=None) + if shared_y_axes: + expected["layout"]["yaxis2"]["matches"] = "y4" + expected["layout"]["yaxis2"]["showticklabels"] = False + expected["layout"]["yaxis3"]["matches"] = "y" + expected["layout"]["yaxis3"]["showticklabels"] = False + expected["layout"]["yaxis6"]["matches"] = "y8" + expected["layout"]["yaxis6"]["showticklabels"] = False + expected["layout"]["yaxis7"]["matches"] = "y5" + expected["layout"]["yaxis7"]["showticklabels"] = False - self.assertEqual(fig.to_plotly_json(), expected.to_plotly_json()) + expected.update_traces(uid=None) + + self.assertEqual(fig.to_plotly_json(), expected.to_plotly_json()) def test_if_passed_figure(self): # assert it returns the same figure it was passed