Skip to content

Commit 5fd5092

Browse files
committed
[Test] Str2StrEnv test
ghstack-source-id: 45a0e5f Pull Request resolved: #2725
1 parent cac93eb commit 5fd5092

File tree

3 files changed

+141
-2
lines changed

3 files changed

+141
-2
lines changed

test/mocking_classes.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,11 @@ def _step(
10681068
return tensordict
10691069

10701070

1071+
def get_random_string(min_size, max_size):
1072+
size = random.randint(min_size, max_size)
1073+
return "".join(random.choice(string.ascii_lowercase) for _ in range(size))
1074+
1075+
10711076
class CountingEnvWithString(CountingEnv):
10721077
def __init__(self, *args, **kwargs):
10731078
self.max_size = kwargs.pop("max_size", 30)
@@ -1083,8 +1088,7 @@ def __init__(self, *args, **kwargs):
10831088
)
10841089

10851090
def get_random_string(self):
1086-
size = random.randint(self.min_size, self.max_size)
1087-
return "".join(random.choice(string.ascii_lowercase) for _ in range(size))
1091+
return get_random_string(self.min_size, self.max_size)
10881092

10891093
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
10901094
res = super()._reset(tensordict, **kwargs)
@@ -2202,3 +2206,39 @@ def _step(
22022206

22032207
def _set_seed(self, seed):
22042208
...
2209+
2210+
2211+
class Str2StrEnv(EnvBase):
2212+
def __init__(self, min_size=4, max_size=10, **kwargs):
2213+
self.observation_spec = Composite(
2214+
observation=NonTensor(example_data="an observation!", shape=())
2215+
)
2216+
self.full_action_spec = Composite(
2217+
action=NonTensor(example_data="an action!", shape=())
2218+
)
2219+
self.reward_spec = Unbounded(shape=(1,), dtype=torch.float)
2220+
self.min_size = min_size
2221+
self.max_size = max_size
2222+
super().__init__(**kwargs)
2223+
2224+
def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2225+
assert isinstance(tensordict["action"], str)
2226+
out = tensordict.empty()
2227+
out.set("observation", self.get_random_string())
2228+
out.set("done", torch.zeros(1, dtype=torch.bool).bernoulli_(0.01))
2229+
out.set("reward", torch.zeros(1, dtype=torch.float).bernoulli_(0.01))
2230+
return out
2231+
2232+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2233+
out = tensordict.empty() if tensordict is not None else TensorDict()
2234+
out.set("observation", self.get_random_string())
2235+
out.set("done", torch.zeros(1, dtype=torch.bool).bernoulli_(0.01))
2236+
return out
2237+
2238+
def get_random_string(self):
2239+
return get_random_string(self.min_size, self.max_size)
2240+
2241+
def _set_seed(self, seed: Optional[int]):
2242+
random.seed(seed)
2243+
torch.manual_seed(0)
2244+
return seed

test/test_env.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
MultiKeyCountingEnv,
5858
MultiKeyCountingEnvPolicy,
5959
NestedCountingEnv,
60+
Str2StrEnv,
6061
)
6162
else:
6263
from _utils_internal import (
@@ -95,6 +96,7 @@
9596
MultiKeyCountingEnv,
9697
MultiKeyCountingEnvPolicy,
9798
NestedCountingEnv,
99+
Str2StrEnv,
98100
)
99101
from packaging import version
100102
from tensordict import (
@@ -133,6 +135,7 @@
133135
AutoResetTransform,
134136
Tokenizer,
135137
Transform,
138+
UnsqueezeTransform,
136139
)
137140
from torchrl.envs.utils import (
138141
_StepMDP,
@@ -174,6 +177,7 @@
174177
_has_chess = importlib.util.find_spec("chess") is not None
175178
_has_tv = importlib.util.find_spec("torchvision") is not None
176179
_has_cairosvg = importlib.util.find_spec("cairosvg") is not None
180+
_has_transformers = importlib.util.find_spec("transformers") is not None
177181
## TO BE FIXED: DiscreteActionProjection queries a randint on each worker, which leads to divergent results between
178182
## the serial and parallel batched envs
179183
# def _make_atari_env(atari_env):
@@ -2614,6 +2618,7 @@ def test_parallel(
26142618
NestedCountingEnv,
26152619
HeterogeneousCountingEnv,
26162620
MultiKeyCountingEnv,
2621+
Str2StrEnv,
26172622
],
26182623
)
26192624
def test_mocking_envs(envclass):
@@ -3441,6 +3446,96 @@ def test_partial_rest(self, batched):
34413446
assert s_["string"] == ["0", "6"]
34423447
assert s["next", "string"] == ["6", "6"]
34433448

3449+
@pytest.mark.skipif(not _has_transformers, reason="transformers required")
3450+
def test_str2str_env_tokenizer(self):
3451+
env = Str2StrEnv()
3452+
env.set_seed(0)
3453+
env = env.append_transform(
3454+
Tokenizer(
3455+
in_keys=["observation"],
3456+
out_keys=["obs_tokens"],
3457+
in_keys_inv=["action"],
3458+
out_keys_inv=["action_tokens"],
3459+
)
3460+
)
3461+
env.check_env_specs()
3462+
assert env._has_dynamic_specs
3463+
r = env.rollout(3, return_contiguous=False)
3464+
assert len(r) == 3
3465+
assert isinstance(r["observation"], list)
3466+
r = r.densify(layout=torch.jagged)
3467+
assert isinstance(r["observation"], list)
3468+
assert isinstance(r["obs_tokens"], torch.Tensor)
3469+
assert isinstance(r["action_tokens"], torch.Tensor)
3470+
3471+
@pytest.mark.skipif(not _has_transformers, reason="transformers required")
3472+
def test_str2str_env_tokenizer_catframes(self):
3473+
"""Tests that we can use Unsqueeze + CatFrames with tokenized strings of variable lengths."""
3474+
env = Str2StrEnv()
3475+
env.set_seed(0)
3476+
env = env.append_transform(
3477+
Tokenizer(
3478+
in_keys=["observation"],
3479+
out_keys=["obs_tokens"],
3480+
in_keys_inv=["action"],
3481+
out_keys_inv=["action_tokens"],
3482+
# We must use max_length otherwise we can't call cat
3483+
# Perhaps we could use NJT here?
3484+
max_length=10,
3485+
)
3486+
)
3487+
env = env.append_transform(
3488+
UnsqueezeTransform(
3489+
dim=-2, in_keys=["obs_tokens"], out_keys=["obs_tokens_cat"]
3490+
),
3491+
)
3492+
env = env.append_transform(CatFrames(N=4, dim=-2, in_keys=["obs_tokens_cat"]))
3493+
r = env.rollout(3)
3494+
assert r["obs_tokens_cat"].shape == (3, 4, 10)
3495+
3496+
@pytest.mark.skipif(not _has_transformers, reason="transformers required")
3497+
def test_str2str_rb_slicesampler(self):
3498+
"""Dedicated test for replay buffer sampling of trajectories with variable token length"""
3499+
from torchrl.data import LazyStackStorage, ReplayBuffer, SliceSampler
3500+
from torchrl.envs import TrajCounter
3501+
3502+
env = Str2StrEnv()
3503+
env.set_seed(0)
3504+
env = env.append_transform(
3505+
Tokenizer(
3506+
in_keys=["observation"],
3507+
out_keys=["obs_tokens"],
3508+
in_keys_inv=["action"],
3509+
out_keys_inv=["action_tokens"],
3510+
)
3511+
)
3512+
env = env.append_transform(StepCounter(max_steps=10))
3513+
env = env.append_transform(TrajCounter())
3514+
rb = ReplayBuffer(
3515+
storage=LazyStackStorage(100),
3516+
sampler=SliceSampler(slice_len=10, end_key=("next", "done")),
3517+
)
3518+
r0 = env.rollout(20, break_when_any_done=False)
3519+
rb.extend(r0)
3520+
has_0 = False
3521+
has_1 = False
3522+
for _ in range(100):
3523+
v0 = rb.sample(10)
3524+
assert (v0["step_count"].squeeze() == torch.arange(10)).all()
3525+
assert (v0["next", "step_count"].squeeze() == torch.arange(1, 11)).all()
3526+
try:
3527+
traj = v0["traj_count"].unique().item()
3528+
except Exception:
3529+
raise RuntimeError(
3530+
f"More than one traj found in single slice: {v0['traj_count']}"
3531+
)
3532+
has_0 |= traj == 0
3533+
has_1 |= traj == 1
3534+
if has_0 and has_1:
3535+
break
3536+
else:
3537+
raise RuntimeError("Failed to sample both trajs")
3538+
34443539

34453540
# fen strings for board positions generated with:
34463541
# https://lichess.org/editor
@@ -3676,6 +3771,7 @@ def test_reward(
36763771
assert td["reward"] == expected_reward
36773772
assert td["turn"] == (not expected_turn)
36783773

3774+
@pytest.mark.skipif(not _has_transformers, reason="transformers required")
36793775
def test_chess_tokenized(self):
36803776
env = ChessEnv(include_fen=True, stateful=True, include_san=True)
36813777
assert isinstance(env.observation_spec["fen"], NonTensor)

torchrl/data/tensor_specs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2450,6 +2450,9 @@ class NonTensor(TensorSpec):
24502450
24512451
:meth:`.rand` will return a :class:`~tensordict.NonTensorData` object with `None` data value.
24522452
(same will go for :meth:`.zero` and :meth:`.one`).
2453+
2454+
.. note:: The default shape of `NonTensor` is `(1,)`.
2455+
24532456
"""
24542457

24552458
example_data: Any = None

0 commit comments

Comments
 (0)