@@ -1779,6 +1779,7 @@ def __init__(
1779
1779
index_per_key : Optional [Dict [str , int ]] = None ,
1780
1780
jt_dict : Optional [Dict [str , JaggedTensor ]] = None ,
1781
1781
inverse_indices : Optional [Tuple [List [str ], torch .Tensor ]] = None ,
1782
+ stride_per_key_per_rank_tensor : Optional [torch .Tensor ] = None ,
1782
1783
) -> None :
1783
1784
"""
1784
1785
This is the constructor for KeyedJaggedTensor is jit.scriptable and PT2 compatible.
@@ -1795,6 +1796,11 @@ def __init__(
1795
1796
self ._stride_per_key_per_rank : Optional [List [List [int ]]] = (
1796
1797
stride_per_key_per_rank
1797
1798
)
1799
+
1800
+ self ._stride_per_key_per_rank_tensor : torch .Tensor = torch .empty (0 )
1801
+ if stride_per_key_per_rank_tensor is not None :
1802
+ self ._stride_per_key_per_rank_tensor = stride_per_key_per_rank_tensor
1803
+
1798
1804
self ._stride_per_key : Optional [List [int ]] = stride_per_key
1799
1805
self ._length_per_key : Optional [List [int ]] = length_per_key
1800
1806
self ._offset_per_key : Optional [List [int ]] = offset_per_key
@@ -2184,7 +2190,7 @@ def stride_per_key(self) -> List[int]:
2184
2190
"""
2185
2191
stride_per_key = _maybe_compute_stride_per_key (
2186
2192
self ._stride_per_key ,
2187
- self ._stride_per_key_per_rank ,
2193
+ self ._stride_per_key_per_rank_optional ,
2188
2194
self .stride (),
2189
2195
self ._keys ,
2190
2196
)
@@ -2199,7 +2205,27 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
2199
2205
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
2200
2206
"""
2201
2207
stride_per_key_per_rank = self ._stride_per_key_per_rank
2202
- return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
2208
+
2209
+ if stride_per_key_per_rank is not None :
2210
+ return stride_per_key_per_rank
2211
+
2212
+ if self ._stride_per_key_per_rank_tensor .numel () > 0 :
2213
+ return self ._stride_per_key_per_rank_tensor .tolist ()
2214
+
2215
+ return []
2216
+
2217
+ @property
2218
+ def _stride_per_key_per_rank_optional (self ) -> Optional [List [List [int ]]]:
2219
+ if self ._stride_per_key_per_rank is not None :
2220
+ return self ._stride_per_key_per_rank
2221
+
2222
+ if self ._stride_per_key_per_rank_tensor .numel () > 0 :
2223
+ stride_per_key_per_rank : List [List [int ]] = (
2224
+ self ._stride_per_key_per_rank_tensor .tolist ()
2225
+ )
2226
+ return stride_per_key_per_rank
2227
+
2228
+ return None
2203
2229
2204
2230
def variable_stride_per_key (self ) -> bool :
2205
2231
"""
@@ -2210,7 +2236,7 @@ def variable_stride_per_key(self) -> bool:
2210
2236
"""
2211
2237
if self ._variable_stride_per_key is not None :
2212
2238
return self ._variable_stride_per_key
2213
- return self ._stride_per_key_per_rank is not None
2239
+ return self ._stride_per_key_per_rank_optional is not None
2214
2240
2215
2241
def inverse_indices (self ) -> Tuple [List [str ], torch .Tensor ]:
2216
2242
"""
@@ -2375,6 +2401,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
2375
2401
index_per_key = self ._index_per_key ,
2376
2402
jt_dict = self ._jt_dict ,
2377
2403
inverse_indices = None ,
2404
+ stride_per_key_per_rank_tensor = None ,
2378
2405
)
2379
2406
)
2380
2407
elif segment == 0 :
@@ -2411,6 +2438,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
2411
2438
index_per_key = None ,
2412
2439
jt_dict = None ,
2413
2440
inverse_indices = None ,
2441
+ stride_per_key_per_rank_tensor = None ,
2414
2442
)
2415
2443
)
2416
2444
else :
@@ -2457,6 +2485,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
2457
2485
index_per_key = None ,
2458
2486
jt_dict = None ,
2459
2487
inverse_indices = None ,
2488
+ stride_per_key_per_rank_tensor = None ,
2460
2489
)
2461
2490
)
2462
2491
else :
@@ -2493,6 +2522,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
2493
2522
index_per_key = None ,
2494
2523
jt_dict = None ,
2495
2524
inverse_indices = None ,
2525
+ stride_per_key_per_rank_tensor = None ,
2496
2526
)
2497
2527
)
2498
2528
start = end
@@ -2599,12 +2629,15 @@ def permute(
2599
2629
index_per_key = None ,
2600
2630
jt_dict = None ,
2601
2631
inverse_indices = None ,
2632
+ stride_per_key_per_rank_tensor = None ,
2602
2633
)
2603
2634
return kjt
2604
2635
2605
2636
def flatten_lengths (self ) -> "KeyedJaggedTensor" :
2606
2637
stride_per_key_per_rank = (
2607
- self ._stride_per_key_per_rank if self .variable_stride_per_key () else None
2638
+ self ._stride_per_key_per_rank_optional
2639
+ if self .variable_stride_per_key ()
2640
+ else None
2608
2641
)
2609
2642
return KeyedJaggedTensor (
2610
2643
keys = self ._keys ,
@@ -2621,6 +2654,7 @@ def flatten_lengths(self) -> "KeyedJaggedTensor":
2621
2654
index_per_key = None ,
2622
2655
jt_dict = None ,
2623
2656
inverse_indices = None ,
2657
+ stride_per_key_per_rank_tensor = None ,
2624
2658
)
2625
2659
2626
2660
def __getitem__ (self , key : str ) -> JaggedTensor :
@@ -2760,7 +2794,9 @@ def to(
2760
2794
lengths = self ._lengths
2761
2795
offsets = self ._offsets
2762
2796
stride_per_key_per_rank = (
2763
- self ._stride_per_key_per_rank if self .variable_stride_per_key () else None
2797
+ self ._stride_per_key_per_rank_optional
2798
+ if self .variable_stride_per_key ()
2799
+ else None
2764
2800
)
2765
2801
length_per_key = self ._length_per_key
2766
2802
lengths_offset_per_key = self ._lengths_offset_per_key
@@ -2805,6 +2841,7 @@ def to(
2805
2841
index_per_key = index_per_key ,
2806
2842
jt_dict = jt_dict ,
2807
2843
inverse_indices = inverse_indices ,
2844
+ stride_per_key_per_rank_tensor = None ,
2808
2845
)
2809
2846
2810
2847
def __str__ (self ) -> str :
@@ -2836,7 +2873,9 @@ def pin_memory(self) -> "KeyedJaggedTensor":
2836
2873
lengths = self ._lengths
2837
2874
offsets = self ._offsets
2838
2875
stride_per_key_per_rank = (
2839
- self ._stride_per_key_per_rank if self .variable_stride_per_key () else None
2876
+ self ._stride_per_key_per_rank_optional
2877
+ if self .variable_stride_per_key ()
2878
+ else None
2840
2879
)
2841
2880
inverse_indices = self ._inverse_indices
2842
2881
if inverse_indices is not None :
@@ -2857,6 +2896,7 @@ def pin_memory(self) -> "KeyedJaggedTensor":
2857
2896
index_per_key = self ._index_per_key ,
2858
2897
jt_dict = None ,
2859
2898
inverse_indices = inverse_indices ,
2899
+ stride_per_key_per_rank_tensor = None ,
2860
2900
)
2861
2901
2862
2902
def dist_labels (self ) -> List [str ]:
0 commit comments