Skip to content

Commit 0c86622

Browse files
committed
Use tuples for indexing
1 parent 550ba28 commit 0c86622

File tree

9 files changed

+222
-75
lines changed

9 files changed

+222
-75
lines changed

xarray/backends/common.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from xarray.core.dataset import Dataset
2424
from xarray.core.types import NestedSequence
25+
from xarray.namedarray._typing import _IndexerKey
2526

2627
# Create a logger object, but don't add any handlers. Leave that to user code.
2728
logger = logging.getLogger(__name__)
@@ -219,18 +220,18 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500
219220

220221

221222
class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed):
222-
__slots__ = ()
223+
__slots__ = ("indexing_support",)
223224

224225
def get_duck_array(self, dtype: np.typing.DTypeLike = None):
225-
key = indexing.BasicIndexer((slice(None),) * self.ndim)
226+
key = (slice(None),) * self.ndim
226227
return self[key] # type: ignore [index]
227228

228-
def _oindex_get(self, key: indexing.OuterIndexer):
229+
def _oindex_get(self, key: _IndexerKey) -> Any:
229230
raise NotImplementedError(
230231
f"{self.__class__.__name__}._oindex_get method should be overridden"
231232
)
232233

233-
def _vindex_get(self, key: indexing.VectorizedIndexer):
234+
def _vindex_get(self, key: _IndexerKey) -> Any:
234235
raise NotImplementedError(
235236
f"{self.__class__.__name__}._vindex_get method should be overridden"
236237
)

xarray/backends/h5netcdf_.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,33 @@
4444
from xarray.backends.common import AbstractDataStore
4545
from xarray.core.dataset import Dataset
4646
from xarray.core.datatree import DataTree
47+
from xarray.namedarray._typing import (
48+
_BasicIndexerKey,
49+
_OuterIndexerKey,
50+
_VectorizedIndexerKey,
51+
)
4752

4853

4954
class H5NetCDFArrayWrapper(BaseNetCDF4Array):
55+
indexing_support = indexing.IndexingSupport.OUTER_1VECTOR
56+
5057
def get_array(self, needs_lock=True):
5158
ds = self.datastore._acquire(needs_lock)
5259
return ds.variables[self.variable_name]
5360

54-
def _oindex_get(self, key: indexing.OuterIndexer):
55-
return indexing.explicit_indexing_adapter(
56-
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
61+
def _oindex_get(self, key: _OuterIndexerKey) -> Any:
62+
return indexing.outer_indexing_adapter(
63+
key, self.shape, self.indexing_support, self._getitem
5764
)
5865

59-
def _vindex_get(self, key: indexing.VectorizedIndexer):
60-
return indexing.explicit_indexing_adapter(
61-
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
66+
def _vindex_get(self, key: _VectorizedIndexerKey) -> Any:
67+
return indexing.vectorized_indexing_adapter(
68+
key, self.shape, self.indexing_support, self._getitem
6269
)
6370

64-
def __getitem__(self, key: indexing.BasicIndexer):
65-
return indexing.explicit_indexing_adapter(
66-
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
71+
def __getitem__(self, key: _BasicIndexerKey) -> Any:
72+
return indexing.basic_indexing_adapter(
73+
key, self.shape, self.indexing_support, self._getitem
6774
)
6875

6976
def _getitem(self, key):

xarray/backends/netCDF4_.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@
4949
from xarray.backends.common import AbstractDataStore
5050
from xarray.core.dataset import Dataset
5151
from xarray.core.datatree import DataTree
52+
from xarray.namedarray._typing import (
53+
_BasicIndexerKey,
54+
_OuterIndexerKey,
55+
_VectorizedIndexerKey,
56+
)
5257

5358
# This lookup table maps from dtype.byteorder to a readable endian
5459
# string used by netCDF4.
@@ -89,7 +94,7 @@ def get_array(self, needs_lock=True):
8994

9095

9196
class NetCDF4ArrayWrapper(BaseNetCDF4Array):
92-
__slots__ = ()
97+
indexing_support = indexing.IndexingSupport.OUTER
9398

9499
def get_array(self, needs_lock=True):
95100
ds = self.datastore._acquire(needs_lock)
@@ -100,19 +105,19 @@ def get_array(self, needs_lock=True):
100105
variable.set_auto_chartostring(False)
101106
return variable
102107

103-
def _oindex_get(self, key: indexing.OuterIndexer):
104-
return indexing.explicit_indexing_adapter(
105-
key, self.shape, indexing.IndexingSupport.OUTER, self._getitem
108+
def _oindex_get(self, key: _OuterIndexerKey):
109+
return indexing.outer_indexing_adapter(
110+
key, self.shape, self.indexing_support, self._getitem
106111
)
107112

108-
def _vindex_get(self, key: indexing.VectorizedIndexer):
109-
return indexing.explicit_indexing_adapter(
110-
key, self.shape, indexing.IndexingSupport.OUTER, self._getitem
113+
def _vindex_get(self, key: _VectorizedIndexerKey):
114+
return indexing.vectorized_indexing_adapter(
115+
key, self.shape, self.indexing_support, self._getitem
111116
)
112117

113-
def __getitem__(self, key: indexing.BasicIndexer):
114-
return indexing.explicit_indexing_adapter(
115-
key, self.shape, indexing.IndexingSupport.OUTER, self._getitem
118+
def __getitem__(self, key: _BasicIndexerKey):
119+
return indexing.basic_indexing_adapter(
120+
key, self.shape, self.indexing_support, self._getitem
116121
)
117122

118123
def _getitem(self, key):

xarray/backends/pydap_.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,17 @@
2929
from io import BufferedIOBase
3030

3131
from xarray.core.dataset import Dataset
32+
from xarray.namedarray._typing import (
33+
_BasicIndexerKey,
34+
_OuterIndexerKey,
35+
_VectorizedIndexerKey,
36+
)
3237

3338

3439
class PydapArrayWrapper(BackendArray):
35-
def __init__(self, array):
40+
indexing_support = indexing.IndexingSupport.BASIC
41+
42+
def __init__(self, array) -> None:
3643
self.array = array
3744

3845
@property
@@ -43,19 +50,19 @@ def shape(self) -> tuple[int, ...]:
4350
def dtype(self):
4451
return self.array.dtype
4552

46-
def _oindex_get(self, key: indexing.OuterIndexer):
47-
return indexing.explicit_indexing_adapter(
48-
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
53+
def _oindex_get(self, key: _OuterIndexerKey) -> Any:
54+
return indexing.outer_indexing_adapter(
55+
key, self.shape, self.indexing_support, self._getitem
4956
)
5057

51-
def _vindex_get(self, key: indexing.VectorizedIndexer):
52-
return indexing.explicit_indexing_adapter(
53-
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
58+
def _vindex_get(self, key: _VectorizedIndexerKey) -> Any:
59+
return indexing.vectorized_indexing_adapter(
60+
key, self.shape, self.indexing_support, self._getitem
5461
)
5562

56-
def __getitem__(self, key: indexing.BasicIndexer):
57-
return indexing.explicit_indexing_adapter(
58-
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
63+
def __getitem__(self, key: _BasicIndexerKey) -> Any:
64+
return indexing.basic_indexing_adapter(
65+
key, self.shape, self.indexing_support, self._getitem
5966
)
6067

6168
def _getitem(self, key):

xarray/backends/scipy_.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@
3838

3939
from xarray.backends.common import AbstractDataStore
4040
from xarray.core.dataset import Dataset
41+
from xarray.namedarray._typing import (
42+
_BasicIndexerKey,
43+
_OuterIndexerKey,
44+
_VectorizedIndexerKey,
45+
)
4146

4247

4348
HAS_NUMPY_2_0 = module_available("numpy", minversion="2.0.0.dev0")
@@ -56,6 +61,8 @@ def _decode_attrs(d):
5661

5762

5863
class ScipyArrayWrapper(BackendArray):
64+
indexing_support = indexing.IndexingSupport.OUTER_1VECTOR
65+
5966
def __init__(self, variable_name, datastore):
6067
self.datastore = datastore
6168
self.variable_name = variable_name
@@ -85,25 +92,25 @@ def _getitem(self, key):
8592
data = self.get_variable(needs_lock=False).data
8693
return data[key]
8794

88-
def _vindex_get(self, key: indexing.VectorizedIndexer):
89-
data = indexing.explicit_indexing_adapter(
90-
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
95+
def _vindex_get(self, key: _VectorizedIndexerKey) -> Any:
96+
data = indexing.vectorized_indexing_adapter(
97+
key, self.shape, self.indexing_support, self._getitem
9198
)
9299
return self._finalize_result(data)
93100

94-
def _oindex_get(self, key: indexing.OuterIndexer):
95-
data = indexing.explicit_indexing_adapter(
96-
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
101+
def _oindex_get(self, key: _OuterIndexerKey) -> Any:
102+
data = indexing.outer_indexing_adapter(
103+
key, self.shape, self.indexing_support, self._getitem
97104
)
98105
return self._finalize_result(data)
99106

100-
def __getitem__(self, key):
101-
data = indexing.explicit_indexing_adapter(
102-
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
107+
def __getitem__(self, key: _BasicIndexerKey) -> Any:
108+
data = indexing.basic_indexing_adapter(
109+
key, self.shape, self.indexing_support, self._getitem
103110
)
104111
return self._finalize_result(data)
105112

106-
def __setitem__(self, key, value):
113+
def __setitem__(self, key, value) -> None:
107114
with self.datastore.lock:
108115
data = self.get_variable(needs_lock=False)
109116
try:

xarray/backends/zarr.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@
4343
from xarray.backends.common import AbstractDataStore
4444
from xarray.core.dataset import Dataset
4545
from xarray.core.datatree import DataTree
46+
from xarray.namedarray._typing import (
47+
_BasicIndexerKey,
48+
_OuterIndexerKey,
49+
_VectorizedIndexerKey,
50+
)
4651

4752

4853
def _get_mappers(*, storage_options, store, chunk_store):
@@ -182,7 +187,7 @@ def encode_zarr_attr_value(value):
182187

183188

184189
class ZarrArrayWrapper(BackendArray):
185-
__slots__ = ("_array", "dtype", "shape")
190+
indexing_support = indexing.IndexingSupport.VECTORIZED
186191

187192
def __init__(self, zarr_array):
188193
# some callers attempt to evaluate an array if an `array` property exists on the object.
@@ -205,37 +210,28 @@ def __init__(self, zarr_array):
205210
def get_array(self):
206211
return self._array
207212

208-
def _oindex_get(self, key: indexing.OuterIndexer):
213+
def _oindex_get(self, key: _OuterIndexerKey) -> Any:
209214
def raw_indexing_method(key):
210215
return self._array.oindex[key]
211216

212-
return indexing.explicit_indexing_adapter(
213-
key,
214-
self._array.shape,
215-
indexing.IndexingSupport.VECTORIZED,
216-
raw_indexing_method,
217+
return indexing.outer_indexing_adapter(
218+
key, self._array.shape, self.indexing_support, raw_indexing_method
217219
)
218220

219-
def _vindex_get(self, key: indexing.VectorizedIndexer):
221+
def _vindex_get(self, key: _VectorizedIndexerKey) -> Any:
220222
def raw_indexing_method(key):
221223
return self._array.vindex[key]
222224

223-
return indexing.explicit_indexing_adapter(
224-
key,
225-
self._array.shape,
226-
indexing.IndexingSupport.VECTORIZED,
227-
raw_indexing_method,
225+
return indexing.vectorized_indexing_adapter(
226+
key, self._array.shape, self.indexing_support, raw_indexing_method
228227
)
229228

230-
def __getitem__(self, key: indexing.BasicIndexer):
229+
def __getitem__(self, key: _BasicIndexerKey) -> Any:
231230
def raw_indexing_method(key):
232231
return self._array[key]
233232

234-
return indexing.explicit_indexing_adapter(
235-
key,
236-
self._array.shape,
237-
indexing.IndexingSupport.VECTORIZED,
238-
raw_indexing_method,
233+
return indexing.basic_indexing_adapter(
234+
key, self._array.shape, self.indexing_support, raw_indexing_method
239235
)
240236

241237
# if self.ndim == 0:

xarray/coding/strings.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,8 @@ def _oindex_get(self, key: indexing.OuterIndexer):
256256
return _numpy_char_to_bytes(self.array.oindex[key])
257257

258258
def __getitem__(self, key: _IndexerKey):
259-
from xarray.core.indexing import BasicIndexer
260-
261259
# require slicing the last dimension completely
262260
indexer = indexing.expanded_indexer(key, self.array.ndim)
263261
if indexer[-1] != slice(None):
264262
raise IndexError("too many indices")
265-
return _numpy_char_to_bytes(self.array[BasicIndexer(indexer)])
263+
return _numpy_char_to_bytes(self.array[indexer])

0 commit comments

Comments
 (0)