Skip to content

Commit c84e4ef

Browse files
committed
Fix stide related logic in call to xp.take with 1d args
1 parent cf86c45 commit c84e4ef

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

sklearn/utils/extmath.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -806,18 +806,18 @@ def svd_flip(u, v, u_based_decision=True):
806806
xp, _ = get_namespace(u)
807807

808808
if u_based_decision:
809-
# columns of u, rows of v
810-
max_abs_cols = xp.argmax(xp.abs(u), axis=0)
811-
indices = xp.arange(u.shape[1]) * u.shape[0] + max_abs_cols
812-
signs = xp.sign(xp.take(xp.reshape(u, (-1,)), indices))
813-
u *= signs
809+
# columns of u, rows of v, or equivalently rows of u.T and v
810+
max_abs_u_cols = xp.argmax(xp.abs(u.T), axis=1)
811+
indices = max_abs_u_cols + xp.arange(u.T.shape[0]) * u.T.shape[1]
812+
signs = xp.sign(xp.take(xp.reshape(u.T, (-1,)), indices))
813+
u *= signs[np.newaxis, :]
814814
v *= signs[:, np.newaxis]
815815
else:
816816
# rows of v, columns of u
817-
max_abs_rows = xp.argmax(xp.abs(v), axis=1)
818-
indices = xp.arange(v.shape[0]) * v.shape[1] + max_abs_rows
817+
max_abs_v_rows = xp.argmax(xp.abs(v), axis=1)
818+
indices = max_abs_v_rows + xp.arange(v.shape[0]) * v.shape[1]
819819
signs = xp.sign(xp.take(xp.reshape(v, (-1,)), indices))
820-
u *= signs
820+
u *= signs[np.newaxis, :]
821821
v *= signs[:, np.newaxis]
822822
return u, v
823823

0 commit comments

Comments
 (0)