Skip to content

Commit 96ee8d5

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Create a new tensor arg for stride_per_key_per_rank to facilitate torch.export (#2950)
Summary: # Context * Currently torchrec IR serializer can't handle variable batch KJT use case. * To support VBE KJT, the `stride_per_key_per_rank` field needs to be flattened as a variable in the pytree flatten spec for a VBE KJT to be unflattened correctly by`torch.export`. * Currently `stride_per_key_per_rank` is a List. To flatten the `stride_per_key_per_rank` info as a variable we have to add a new tensor field for it. # Ref Differential Revision: D74207283
1 parent 2d0a0bf commit 96ee8d5

File tree

2 files changed

+65
-9
lines changed

2 files changed

+65
-9
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,6 +1779,7 @@ def __init__(
17791779
index_per_key: Optional[Dict[str, int]] = None,
17801780
jt_dict: Optional[Dict[str, JaggedTensor]] = None,
17811781
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
1782+
stride_per_key_per_rank_tensor: Optional[torch.Tensor] = None,
17821783
) -> None:
17831784
"""
17841785
This is the constructor for KeyedJaggedTensor is jit.scriptable and PT2 compatible.
@@ -1795,6 +1796,11 @@ def __init__(
17951796
self._stride_per_key_per_rank: Optional[List[List[int]]] = (
17961797
stride_per_key_per_rank
17971798
)
1799+
1800+
self._stride_per_key_per_rank_tensor: torch.Tensor = torch.empty(0)
1801+
if stride_per_key_per_rank_tensor is not None:
1802+
self._stride_per_key_per_rank_tensor = stride_per_key_per_rank_tensor
1803+
17981804
self._stride_per_key: Optional[List[int]] = stride_per_key
17991805
self._length_per_key: Optional[List[int]] = length_per_key
18001806
self._offset_per_key: Optional[List[int]] = offset_per_key
@@ -2184,7 +2190,7 @@ def stride_per_key(self) -> List[int]:
21842190
"""
21852191
stride_per_key = _maybe_compute_stride_per_key(
21862192
self._stride_per_key,
2187-
self._stride_per_key_per_rank,
2193+
self._stride_per_key_per_rank_optional,
21882194
self.stride(),
21892195
self._keys,
21902196
)
@@ -2199,7 +2205,27 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
21992205
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
22002206
"""
22012207
stride_per_key_per_rank = self._stride_per_key_per_rank
2202-
return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
2208+
2209+
if stride_per_key_per_rank is not None:
2210+
return stride_per_key_per_rank
2211+
2212+
if self._stride_per_key_per_rank_tensor.numel() > 0:
2213+
return self._stride_per_key_per_rank_tensor.tolist()
2214+
2215+
return []
2216+
2217+
@property
2218+
def _stride_per_key_per_rank_optional(self) -> Optional[List[List[int]]]:
2219+
if self._stride_per_key_per_rank is not None:
2220+
return self._stride_per_key_per_rank
2221+
2222+
if self._stride_per_key_per_rank_tensor.numel() > 0:
2223+
stride_per_key_per_rank: List[List[int]] = (
2224+
self._stride_per_key_per_rank_tensor.tolist()
2225+
)
2226+
return stride_per_key_per_rank
2227+
2228+
return None
22032229

22042230
def variable_stride_per_key(self) -> bool:
22052231
"""
@@ -2210,7 +2236,7 @@ def variable_stride_per_key(self) -> bool:
22102236
"""
22112237
if self._variable_stride_per_key is not None:
22122238
return self._variable_stride_per_key
2213-
return self._stride_per_key_per_rank is not None
2239+
return self._stride_per_key_per_rank_optional is not None
22142240

22152241
def inverse_indices(self) -> Tuple[List[str], torch.Tensor]:
22162242
"""
@@ -2375,6 +2401,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
23752401
index_per_key=self._index_per_key,
23762402
jt_dict=self._jt_dict,
23772403
inverse_indices=None,
2404+
stride_per_key_per_rank_tensor=None,
23782405
)
23792406
)
23802407
elif segment == 0:
@@ -2411,6 +2438,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
24112438
index_per_key=None,
24122439
jt_dict=None,
24132440
inverse_indices=None,
2441+
stride_per_key_per_rank_tensor=None,
24142442
)
24152443
)
24162444
else:
@@ -2457,6 +2485,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
24572485
index_per_key=None,
24582486
jt_dict=None,
24592487
inverse_indices=None,
2488+
stride_per_key_per_rank_tensor=None,
24602489
)
24612490
)
24622491
else:
@@ -2493,6 +2522,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
24932522
index_per_key=None,
24942523
jt_dict=None,
24952524
inverse_indices=None,
2525+
stride_per_key_per_rank_tensor=None,
24962526
)
24972527
)
24982528
start = end
@@ -2599,12 +2629,15 @@ def permute(
25992629
index_per_key=None,
26002630
jt_dict=None,
26012631
inverse_indices=None,
2632+
stride_per_key_per_rank_tensor=None,
26022633
)
26032634
return kjt
26042635

26052636
def flatten_lengths(self) -> "KeyedJaggedTensor":
26062637
stride_per_key_per_rank = (
2607-
self._stride_per_key_per_rank if self.variable_stride_per_key() else None
2638+
self._stride_per_key_per_rank_optional
2639+
if self.variable_stride_per_key()
2640+
else None
26082641
)
26092642
return KeyedJaggedTensor(
26102643
keys=self._keys,
@@ -2621,6 +2654,7 @@ def flatten_lengths(self) -> "KeyedJaggedTensor":
26212654
index_per_key=None,
26222655
jt_dict=None,
26232656
inverse_indices=None,
2657+
stride_per_key_per_rank_tensor=None,
26242658
)
26252659

26262660
def __getitem__(self, key: str) -> JaggedTensor:
@@ -2760,7 +2794,9 @@ def to(
27602794
lengths = self._lengths
27612795
offsets = self._offsets
27622796
stride_per_key_per_rank = (
2763-
self._stride_per_key_per_rank if self.variable_stride_per_key() else None
2797+
self._stride_per_key_per_rank_optional
2798+
if self.variable_stride_per_key()
2799+
else None
27642800
)
27652801
length_per_key = self._length_per_key
27662802
lengths_offset_per_key = self._lengths_offset_per_key
@@ -2805,6 +2841,7 @@ def to(
28052841
index_per_key=index_per_key,
28062842
jt_dict=jt_dict,
28072843
inverse_indices=inverse_indices,
2844+
stride_per_key_per_rank_tensor=None,
28082845
)
28092846

28102847
def __str__(self) -> str:
@@ -2836,7 +2873,9 @@ def pin_memory(self) -> "KeyedJaggedTensor":
28362873
lengths = self._lengths
28372874
offsets = self._offsets
28382875
stride_per_key_per_rank = (
2839-
self._stride_per_key_per_rank if self.variable_stride_per_key() else None
2876+
self._stride_per_key_per_rank_optional
2877+
if self.variable_stride_per_key()
2878+
else None
28402879
)
28412880
inverse_indices = self._inverse_indices
28422881
if inverse_indices is not None:
@@ -2857,6 +2896,7 @@ def pin_memory(self) -> "KeyedJaggedTensor":
28572896
index_per_key=self._index_per_key,
28582897
jt_dict=None,
28592898
inverse_indices=inverse_indices,
2899+
stride_per_key_per_rank_tensor=None,
28602900
)
28612901

28622902
def dist_labels(self) -> List[str]:

torchrec/sparse/tests/test_keyed_jagged_tensor.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,16 +1018,32 @@ def test_meta_device_compatibility(self) -> None:
10181018
)
10191019

10201020
def test_vbe_kjt_stride(self) -> None:
1021+
stride_per_key_per_rank = [[2], [1]]
10211022
inverse_indices = torch.tensor([[0, 1, 0], [0, 0, 0]])
1022-
kjt = KeyedJaggedTensor(
1023+
kjt_1 = KeyedJaggedTensor(
10231024
keys=["f1", "f2", "f3"],
10241025
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
10251026
lengths=torch.tensor([3, 3, 2]),
1026-
stride_per_key_per_rank=[[2], [1]],
1027+
stride_per_key_per_rank=stride_per_key_per_rank,
1028+
inverse_indices=(["f1", "f2"], inverse_indices),
1029+
)
1030+
kjt_2 = KeyedJaggedTensor(
1031+
keys=["f1", "f2", "f3"],
1032+
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
1033+
lengths=torch.tensor([3, 3, 2]),
1034+
stride_per_key_per_rank_tensor=torch.tensor(stride_per_key_per_rank),
10271035
inverse_indices=(["f1", "f2"], inverse_indices),
10281036
)
10291037

1030-
self.assertEqual(kjt.stride(), inverse_indices.shape[1])
1038+
self.assertEqual(kjt_1.stride(), inverse_indices.shape[1])
1039+
self.assertEqual(kjt_1.stride_per_key_per_rank(), stride_per_key_per_rank)
1040+
self.assertEqual(
1041+
kjt_1._stride_per_key_per_rank_optional, stride_per_key_per_rank
1042+
)
1043+
self.assertEqual(kjt_2.stride_per_key_per_rank(), stride_per_key_per_rank)
1044+
self.assertEqual(
1045+
kjt_2._stride_per_key_per_rank_optional, stride_per_key_per_rank
1046+
)
10311047

10321048

10331049
class TestKeyedJaggedTensorScripting(unittest.TestCase):

0 commit comments

Comments
 (0)