Skip to content

Commit 83858f0

Browse files
authored
Implements dpctl.tensor.matrix_transpose (#1356)
* Implements matrix_transpose - Function wrapper for call to dpctl.tensor.usm_ndarray.mT attribute * Add arg validation tests for matrix_transpose * Added a test for matrix_transpose for coverage
1 parent 96293fd commit 83858f0

File tree

3 files changed

+98
-0
lines changed

3 files changed

+98
-0
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from dpctl.tensor._device import Device
6161
from dpctl.tensor._dlpack import from_dlpack
6262
from dpctl.tensor._indexing_functions import extract, nonzero, place, put, take
63+
from dpctl.tensor._linear_algebra_functions import matrix_transpose
6364
from dpctl.tensor._manipulation_functions import (
6465
broadcast_arrays,
6566
broadcast_to,
@@ -199,6 +200,7 @@
199200
"tril",
200201
"triu",
201202
"where",
203+
"matrix_transpose",
202204
"all",
203205
"any",
204206
"dtype",
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import dpctl.tensor as dpt
18+
19+
20+
def matrix_transpose(x):
21+
"""matrix_transpose(x)
22+
23+
Transposes the innermost two dimensions of `x`, where `x` is a
24+
2-dimensional matrix or a stack of 2-dimensional matrices.
25+
26+
To convert from a 1-dimensional array to a 2-dimensional column
27+
vector, use x[:, dpt.newaxis].
28+
29+
Args:
30+
x (usm_ndarray):
31+
Input array with shape (..., m, n).
32+
33+
Returns:
34+
usm_ndarray:
35+
Array with shape (..., n, m).
36+
"""
37+
38+
if not isinstance(x, dpt.usm_ndarray):
39+
raise TypeError(
40+
"Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x))
41+
)
42+
if x.ndim < 2:
43+
raise ValueError(
44+
"dpctl.tensor.matrix_transpose requires array to have"
45+
"at least 2 dimensions"
46+
)
47+
48+
return x.mT
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import pytest
18+
19+
import dpctl.tensor as dpt
20+
from dpctl.tests.helper import get_queue_or_skip
21+
22+
23+
def test_matrix_transpose():
24+
get_queue_or_skip()
25+
26+
X = dpt.reshape(dpt.arange(2 * 3, dtype="i4"), (2, 3))
27+
res = dpt.matrix_transpose(X)
28+
expected_res = X.mT
29+
30+
assert expected_res.shape == res.shape
31+
assert expected_res.flags["C"] == res.flags["C"]
32+
assert expected_res.flags["F"] == res.flags["F"]
33+
assert dpt.all(X.mT == res)
34+
35+
36+
def test_matrix_transpose_arg_validation():
37+
get_queue_or_skip()
38+
39+
X = dpt.empty(5, dtype="i4")
40+
with pytest.raises(ValueError):
41+
dpt.matrix_transpose(X)
42+
43+
X = dict()
44+
with pytest.raises(TypeError):
45+
dpt.matrix_transpose(X)
46+
47+
X = dpt.empty((5, 5), dtype="i4")
48+
assert isinstance(dpt.matrix_transpose(X), dpt.usm_ndarray)

0 commit comments

Comments
 (0)