@@ -1045,6 +1045,36 @@ def test_comparisons(self, func, variation, unit, dtype):
1045
1045
1046
1046
assert expected == result
1047
1047
1048
+ @pytest .mark .xfail (reason = "blocked by `where`" )
1049
+ @pytest .mark .parametrize (
1050
+ "unit" ,
1051
+ (
1052
+ pytest .param (1 , id = "no_unit" ),
1053
+ pytest .param (unit_registry .dimensionless , id = "dimensionless" ),
1054
+ pytest .param (unit_registry .s , id = "incompatible_unit" ),
1055
+ pytest .param (unit_registry .cm , id = "compatible_unit" ),
1056
+ pytest .param (unit_registry .m , id = "identical_unit" ),
1057
+ ),
1058
+ )
1059
+ def test_broadcast_like (self , unit , dtype ):
1060
+ array1 = np .linspace (1 , 2 , 2 * 1 ).reshape (2 , 1 ).astype (dtype ) * unit_registry .Pa
1061
+ array2 = np .linspace (0 , 1 , 2 * 3 ).reshape (2 , 3 ).astype (dtype ) * unit_registry .Pa
1062
+
1063
+ x1 = np .arange (2 ) * unit_registry .m
1064
+ x2 = np .arange (2 ) * unit
1065
+ y1 = np .array ([0 ]) * unit_registry .m
1066
+ y2 = np .arange (3 ) * unit
1067
+
1068
+ arr1 = xr .DataArray (data = array1 , coords = {"x" : x1 , "y" : y1 }, dims = ("x" , "y" ))
1069
+ arr2 = xr .DataArray (data = array2 , coords = {"x" : x2 , "y" : y2 }, dims = ("x" , "y" ))
1070
+
1071
+ expected = attach_units (
1072
+ strip_units (arr1 ).broadcast_like (strip_units (arr2 )), extract_units (arr1 )
1073
+ )
1074
+ result = arr1 .broadcast_like (arr2 )
1075
+
1076
+ assert_equal_with_units (expected , result )
1077
+
1048
1078
@pytest .mark .parametrize (
1049
1079
"unit" ,
1050
1080
(
@@ -1303,6 +1333,49 @@ def test_squeeze(self, shape, dtype):
1303
1333
np .squeeze (array , axis = index ), data_array .squeeze (dim = name )
1304
1334
)
1305
1335
1336
+ @pytest .mark .xfail (
1337
+ reason = "indexes strip units and head / tail / thin only support integers"
1338
+ )
1339
+ @pytest .mark .parametrize (
1340
+ "unit,error" ,
1341
+ (
1342
+ pytest .param (1 , DimensionalityError , id = "no_unit" ),
1343
+ pytest .param (
1344
+ unit_registry .dimensionless , DimensionalityError , id = "dimensionless"
1345
+ ),
1346
+ pytest .param (unit_registry .s , DimensionalityError , id = "incompatible_unit" ),
1347
+ pytest .param (unit_registry .cm , None , id = "compatible_unit" ),
1348
+ pytest .param (unit_registry .m , None , id = "identical_unit" ),
1349
+ ),
1350
+ )
1351
+ @pytest .mark .parametrize (
1352
+ "func" ,
1353
+ (method ("head" , x = 7 , y = 3 ), method ("tail" , x = 7 , y = 3 ), method ("thin" , x = 7 , y = 3 )),
1354
+ ids = repr ,
1355
+ )
1356
+ def test_head_tail_thin (self , func , unit , error , dtype ):
1357
+ array = np .linspace (1 , 2 , 10 * 5 ).reshape (10 , 5 ) * unit_registry .degK
1358
+
1359
+ coords = {
1360
+ "x" : np .arange (10 ) * unit_registry .m ,
1361
+ "y" : np .arange (5 ) * unit_registry .m ,
1362
+ }
1363
+
1364
+ arr = xr .DataArray (data = array , coords = coords , dims = ("x" , "y" ))
1365
+
1366
+ kwargs = {name : value * unit for name , value in func .kwargs .items ()}
1367
+
1368
+ if error is not None :
1369
+ with pytest .raises (error ):
1370
+ func (arr , ** kwargs )
1371
+
1372
+ return
1373
+
1374
+ expected = attach_units (func (strip_units (arr )), extract_units (arr ))
1375
+ result = func (arr , ** kwargs )
1376
+
1377
+ assert_equal_with_units (expected , result )
1378
+
1306
1379
@pytest .mark .parametrize (
1307
1380
"unit,error" ,
1308
1381
(
@@ -2472,6 +2545,40 @@ def test_comparisons(self, func, variation, unit, dtype):
2472
2545
2473
2546
assert expected == result
2474
2547
2548
+ @pytest .mark .xfail (reason = "blocked by `where`" )
2549
+ @pytest .mark .parametrize (
2550
+ "unit" ,
2551
+ (
2552
+ pytest .param (1 , id = "no_unit" ),
2553
+ pytest .param (unit_registry .dimensionless , id = "dimensionless" ),
2554
+ pytest .param (unit_registry .s , id = "incompatible_unit" ),
2555
+ pytest .param (unit_registry .cm , id = "compatible_unit" ),
2556
+ pytest .param (unit_registry .m , id = "identical_unit" ),
2557
+ ),
2558
+ )
2559
+ def test_broadcast_like (self , unit , dtype ):
2560
+ array1 = np .linspace (1 , 2 , 2 * 1 ).reshape (2 , 1 ).astype (dtype ) * unit_registry .Pa
2561
+ array2 = np .linspace (0 , 1 , 2 * 3 ).reshape (2 , 3 ).astype (dtype ) * unit_registry .Pa
2562
+
2563
+ x1 = np .arange (2 ) * unit_registry .m
2564
+ x2 = np .arange (2 ) * unit
2565
+ y1 = np .array ([0 ]) * unit_registry .m
2566
+ y2 = np .arange (3 ) * unit
2567
+
2568
+ ds1 = xr .Dataset (
2569
+ data_vars = {"a" : (("x" , "y" ), array1 )}, coords = {"x" : x1 , "y" : y1 }
2570
+ )
2571
+ ds2 = xr .Dataset (
2572
+ data_vars = {"a" : (("x" , "y" ), array2 )}, coords = {"x" : x2 , "y" : y2 }
2573
+ )
2574
+
2575
+ expected = attach_units (
2576
+ strip_units (ds1 ).broadcast_like (strip_units (ds2 )), extract_units (ds1 )
2577
+ )
2578
+ result = ds1 .broadcast_like (ds2 )
2579
+
2580
+ assert_equal_with_units (expected , result )
2581
+
2475
2582
@pytest .mark .parametrize (
2476
2583
"unit" ,
2477
2584
(
0 commit comments