@@ -806,18 +806,18 @@ def svd_flip(u, v, u_based_decision=True):
806
806
xp , _ = get_namespace (u )
807
807
808
808
if u_based_decision :
809
- # columns of u, rows of v
810
- max_abs_cols = xp .argmax (xp .abs (u ), axis = 0 )
811
- indices = xp .arange (u .shape [1 ]) * u .shape [0 ] + max_abs_cols
812
- signs = xp .sign (xp .take (xp .reshape (u , (- 1 ,)), indices ))
813
- u *= signs
809
+ # columns of u, rows of v, or equivalently rows of u.T and v
810
+ max_abs_u_cols = xp .argmax (xp .abs (u . T ), axis = 1 )
811
+ indices = max_abs_u_cols + xp .arange (u .T . shape [0 ]) * u .T . shape [1 ]
812
+ signs = xp .sign (xp .take (xp .reshape (u . T , (- 1 ,)), indices ))
813
+ u *= signs [ np . newaxis , :]
814
814
v *= signs [:, np .newaxis ]
815
815
else :
816
816
# rows of v, columns of u
817
- max_abs_rows = xp .argmax (xp .abs (v ), axis = 1 )
818
- indices = xp .arange (v .shape [0 ]) * v .shape [1 ] + max_abs_rows
817
+ max_abs_v_rows = xp .argmax (xp .abs (v ), axis = 1 )
818
+ indices = max_abs_v_rows + xp .arange (v .shape [0 ]) * v .shape [1 ]
819
819
signs = xp .sign (xp .take (xp .reshape (v , (- 1 ,)), indices ))
820
- u *= signs
820
+ u *= signs [ np . newaxis , :]
821
821
v *= signs [:, np .newaxis ]
822
822
return u , v
823
823
0 commit comments