Skip to content

Commit d09abf0

Browse files
kshitij12345facebook-github-bot
authored andcommitted
OpInfo: narrow (#58082)
Summary: Reference: #54261 Pull Request resolved: #58082 Reviewed By: agolynski Differential Revision: D28379371 Pulled By: mruberry fbshipit-source-id: 484e560b1e6ceba234e497585ed308a27cd8b7a0
1 parent 9148f19 commit d09abf0

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

test/test_torch.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7967,8 +7967,6 @@ def tmp(dtype, device):
79677967
('ndimension', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
79687968
('nelement', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
79697969
('numel', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
7970-
('narrow', '', _small_3d, lambda t, d: [1, 3, 2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
7971-
('narrow', 'neg_dim', _small_3d, lambda t, d: [-1, 3, 2], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
79727970
('nonzero', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False),
79737971
('norm', '', _small_3d, lambda t, d: [], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False),
79747972
('norm', '3_norm', _small_3d, lambda t, d: [3], 1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes(), _cpu_types, False),

torch/testing/_internal/common_methods_invocations.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1831,6 +1831,24 @@ def sample_repeat_tile(op_info, device, dtype, requires_grad, **kwargs):
18311831

18321832
return samples
18331833

1834+
1835+
def sample_inputs_narrow(op_info, device, dtype, requires_grad, **kwargs):
1836+
shapes_and_args = (
1837+
((S, S, S), (1, 2, 2)),
1838+
((S, S, S), (-1, 2, 2)),
1839+
((S, S, S), (1, 0, 0)),
1840+
((S, S, S), (-1, 0, 0)),
1841+
)
1842+
1843+
def generator():
1844+
for shape, args in shapes_and_args:
1845+
tensor = make_tensor(shape, device, dtype, low=None, high=None,
1846+
requires_grad=requires_grad)
1847+
yield SampleInput(tensor, args=args)
1848+
1849+
return list(generator())
1850+
1851+
18341852
def sample_unsqueeze(op_info, device, dtype, requires_grad, **kwargs):
18351853
shapes_and_axes = [
18361854
((3, 4, 5), 0),
@@ -4974,6 +4992,10 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
49744992
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
49754993
supports_autograd=False,
49764994
sample_inputs_func=sample_inputs_comparison_ops),
4995+
OpInfo('narrow',
4996+
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
4997+
supports_out=False,
4998+
sample_inputs_func=sample_inputs_narrow),
49774999
UnaryUfuncInfo('neg',
49785000
aliases=('negative', ),
49795001
ref=np.negative,
@@ -6347,8 +6369,6 @@ def method_tests():
63476369
('fill_', (S, S, S), (1,), 'number'),
63486370
('fill_', (), (1,), 'number_scalar'),
63496371
('fill_', (S, S, S), ((),), 'variable'),
6350-
('narrow', (S, S, S), (1, 2, 2), 'dim', (), [0]),
6351-
('narrow', (S, S, S), (1, 0, 0), 'empty_dim', (), [0]),
63526372
('squeeze', (S, 1, S, 1), NO_ARGS, '', (True,)),
63536373
('squeeze', (1, 1, 1, 1), NO_ARGS, 'input_sizes_are_ones', (True,)),
63546374
('squeeze', (S, 1, S, 1), (1,), '1_dim', (True,), [0]),

0 commit comments

Comments
 (0)