|
57 | 57 | MultiKeyCountingEnv,
|
58 | 58 | MultiKeyCountingEnvPolicy,
|
59 | 59 | NestedCountingEnv,
|
| 60 | + Str2StrEnv, |
60 | 61 | )
|
61 | 62 | else:
|
62 | 63 | from _utils_internal import (
|
|
95 | 96 | MultiKeyCountingEnv,
|
96 | 97 | MultiKeyCountingEnvPolicy,
|
97 | 98 | NestedCountingEnv,
|
| 99 | + Str2StrEnv, |
98 | 100 | )
|
99 | 101 | from packaging import version
|
100 | 102 | from tensordict import (
|
|
133 | 135 | AutoResetTransform,
|
134 | 136 | Tokenizer,
|
135 | 137 | Transform,
|
| 138 | + UnsqueezeTransform, |
136 | 139 | )
|
137 | 140 | from torchrl.envs.utils import (
|
138 | 141 | _StepMDP,
|
|
174 | 177 | _has_chess = importlib.util.find_spec("chess") is not None
|
175 | 178 | _has_tv = importlib.util.find_spec("torchvision") is not None
|
176 | 179 | _has_cairosvg = importlib.util.find_spec("cairosvg") is not None
|
| 180 | +_has_transformers = importlib.util.find_spec("transformers") is not None |
177 | 181 | ## TO BE FIXED: DiscreteActionProjection queries a randint on each worker, which leads to divergent results between
|
178 | 182 | ## the serial and parallel batched envs
|
179 | 183 | # def _make_atari_env(atari_env):
|
@@ -2614,6 +2618,7 @@ def test_parallel(
|
2614 | 2618 | NestedCountingEnv,
|
2615 | 2619 | HeterogeneousCountingEnv,
|
2616 | 2620 | MultiKeyCountingEnv,
|
| 2621 | + Str2StrEnv, |
2617 | 2622 | ],
|
2618 | 2623 | )
|
2619 | 2624 | def test_mocking_envs(envclass):
|
@@ -3441,6 +3446,96 @@ def test_partial_rest(self, batched):
|
3441 | 3446 | assert s_["string"] == ["0", "6"]
|
3442 | 3447 | assert s["next", "string"] == ["6", "6"]
|
3443 | 3448 |
|
| 3449 | + @pytest.mark.skipif(not _has_transformers, reason="transformers required") |
| 3450 | + def test_str2str_env_tokenizer(self): |
| 3451 | + env = Str2StrEnv() |
| 3452 | + env.set_seed(0) |
| 3453 | + env = env.append_transform( |
| 3454 | + Tokenizer( |
| 3455 | + in_keys=["observation"], |
| 3456 | + out_keys=["obs_tokens"], |
| 3457 | + in_keys_inv=["action"], |
| 3458 | + out_keys_inv=["action_tokens"], |
| 3459 | + ) |
| 3460 | + ) |
| 3461 | + env.check_env_specs() |
| 3462 | + assert env._has_dynamic_specs |
| 3463 | + r = env.rollout(3, return_contiguous=False) |
| 3464 | + assert len(r) == 3 |
| 3465 | + assert isinstance(r["observation"], list) |
| 3466 | + r = r.densify(layout=torch.jagged) |
| 3467 | + assert isinstance(r["observation"], list) |
| 3468 | + assert isinstance(r["obs_tokens"], torch.Tensor) |
| 3469 | + assert isinstance(r["action_tokens"], torch.Tensor) |
| 3470 | + |
| 3471 | + @pytest.mark.skipif(not _has_transformers, reason="transformers required") |
| 3472 | + def test_str2str_env_tokenizer_catframes(self): |
| 3473 | + """Tests that we can use Unsqueeze + CatFrames with tokenized strings of variable lengths.""" |
| 3474 | + env = Str2StrEnv() |
| 3475 | + env.set_seed(0) |
| 3476 | + env = env.append_transform( |
| 3477 | + Tokenizer( |
| 3478 | + in_keys=["observation"], |
| 3479 | + out_keys=["obs_tokens"], |
| 3480 | + in_keys_inv=["action"], |
| 3481 | + out_keys_inv=["action_tokens"], |
| 3482 | + # We must use max_length otherwise we can't call cat |
| 3483 | + # Perhaps we could use NJT here? |
| 3484 | + max_length=10, |
| 3485 | + ) |
| 3486 | + ) |
| 3487 | + env = env.append_transform( |
| 3488 | + UnsqueezeTransform( |
| 3489 | + dim=-2, in_keys=["obs_tokens"], out_keys=["obs_tokens_cat"] |
| 3490 | + ), |
| 3491 | + ) |
| 3492 | + env = env.append_transform(CatFrames(N=4, dim=-2, in_keys=["obs_tokens_cat"])) |
| 3493 | + r = env.rollout(3) |
| 3494 | + assert r["obs_tokens_cat"].shape == (3, 4, 10) |
| 3495 | + |
| 3496 | + @pytest.mark.skipif(not _has_transformers, reason="transformers required") |
| 3497 | + def test_str2str_rb_slicesampler(self): |
| 3498 | + """Dedicated test for replay buffer sampling of trajectories with variable token length""" |
| 3499 | + from torchrl.data import LazyStackStorage, ReplayBuffer, SliceSampler |
| 3500 | + from torchrl.envs import TrajCounter |
| 3501 | + |
| 3502 | + env = Str2StrEnv() |
| 3503 | + env.set_seed(0) |
| 3504 | + env = env.append_transform( |
| 3505 | + Tokenizer( |
| 3506 | + in_keys=["observation"], |
| 3507 | + out_keys=["obs_tokens"], |
| 3508 | + in_keys_inv=["action"], |
| 3509 | + out_keys_inv=["action_tokens"], |
| 3510 | + ) |
| 3511 | + ) |
| 3512 | + env = env.append_transform(StepCounter(max_steps=10)) |
| 3513 | + env = env.append_transform(TrajCounter()) |
| 3514 | + rb = ReplayBuffer( |
| 3515 | + storage=LazyStackStorage(100), |
| 3516 | + sampler=SliceSampler(slice_len=10, end_key=("next", "done")), |
| 3517 | + ) |
| 3518 | + r0 = env.rollout(20, break_when_any_done=False) |
| 3519 | + rb.extend(r0) |
| 3520 | + has_0 = False |
| 3521 | + has_1 = False |
| 3522 | + for _ in range(100): |
| 3523 | + v0 = rb.sample(10) |
| 3524 | + assert (v0["step_count"].squeeze() == torch.arange(10)).all() |
| 3525 | + assert (v0["next", "step_count"].squeeze() == torch.arange(1, 11)).all() |
| 3526 | + try: |
| 3527 | + traj = v0["traj_count"].unique().item() |
| 3528 | + except Exception: |
| 3529 | + raise RuntimeError( |
| 3530 | + f"More than one traj found in single slice: {v0['traj_count']}" |
| 3531 | + ) |
| 3532 | + has_0 |= traj == 0 |
| 3533 | + has_1 |= traj == 1 |
| 3534 | + if has_0 and has_1: |
| 3535 | + break |
| 3536 | + else: |
| 3537 | + raise RuntimeError("Failed to sample both trajs") |
| 3538 | + |
3444 | 3539 |
|
3445 | 3540 | # fen strings for board positions generated with:
|
3446 | 3541 | # https://lichess.org/editor
|
@@ -3676,6 +3771,7 @@ def test_reward(
|
3676 | 3771 | assert td["reward"] == expected_reward
|
3677 | 3772 | assert td["turn"] == (not expected_turn)
|
3678 | 3773 |
|
| 3774 | + @pytest.mark.skipif(not _has_transformers, reason="transformers required") |
3679 | 3775 | def test_chess_tokenized(self):
|
3680 | 3776 | env = ChessEnv(include_fen=True, stateful=True, include_san=True)
|
3681 | 3777 | assert isinstance(env.observation_spec["fen"], NonTensor)
|
|
0 commit comments