Skip to content

Commit d08e488

Browse files
committed
[Feature] IsaacLab wrapper
ghstack-source-id: 1794a63 Pull-Request-resolved: #2937
1 parent 3dbd84c commit d08e488

File tree

8 files changed

+173
-38
lines changed

8 files changed

+173
-38
lines changed

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,7 @@ the following function will return ``1`` when queried:
14171417
HabitatEnv
14181418
IsaacGymEnv
14191419
IsaacGymWrapper
1420+
IsaacLabWrapper
14201421
JumanjiEnv
14211422
JumanjiWrapper
14221423
MeltingpotEnv

test/test_libs.py

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,32 +32,6 @@
3232
import pytest
3333
import torch
3434

35-
if os.getenv("PYTORCH_TEST_FBCODE"):
36-
from pytorch.rl.test._utils_internal import (
37-
_make_multithreaded_env,
38-
CARTPOLE_VERSIONED,
39-
get_available_devices,
40-
get_default_devices,
41-
HALFCHEETAH_VERSIONED,
42-
PENDULUM_VERSIONED,
43-
PONG_VERSIONED,
44-
rand_reset,
45-
retry,
46-
rollout_consistency_assertion,
47-
)
48-
else:
49-
from _utils_internal import (
50-
_make_multithreaded_env,
51-
CARTPOLE_VERSIONED,
52-
get_available_devices,
53-
get_default_devices,
54-
HALFCHEETAH_VERSIONED,
55-
PENDULUM_VERSIONED,
56-
PONG_VERSIONED,
57-
rand_reset,
58-
retry,
59-
rollout_consistency_assertion,
60-
)
6135
from packaging import version
6236
from tensordict import (
6337
assert_allclose_td,
@@ -155,6 +129,33 @@
155129
ValueOperator,
156130
)
157131

132+
if os.getenv("PYTORCH_TEST_FBCODE"):
133+
from pytorch.rl.test._utils_internal import (
134+
_make_multithreaded_env,
135+
CARTPOLE_VERSIONED,
136+
get_available_devices,
137+
get_default_devices,
138+
HALFCHEETAH_VERSIONED,
139+
PENDULUM_VERSIONED,
140+
PONG_VERSIONED,
141+
rand_reset,
142+
retry,
143+
rollout_consistency_assertion,
144+
)
145+
else:
146+
from _utils_internal import (
147+
_make_multithreaded_env,
148+
CARTPOLE_VERSIONED,
149+
get_available_devices,
150+
get_default_devices,
151+
HALFCHEETAH_VERSIONED,
152+
PENDULUM_VERSIONED,
153+
PONG_VERSIONED,
154+
rand_reset,
155+
retry,
156+
rollout_consistency_assertion,
157+
)
158+
158159
_has_d4rl = importlib.util.find_spec("d4rl") is not None
159160

160161
_has_mo = importlib.util.find_spec("mo_gymnasium") is not None
@@ -166,6 +167,9 @@
166167
_has_minari = importlib.util.find_spec("minari") is not None
167168

168169
_has_gymnasium = importlib.util.find_spec("gymnasium") is not None
170+
171+
_has_isaaclab = importlib.util.find_spec("isaaclab") is not None
172+
169173
_has_gym_regular = importlib.util.find_spec("gym") is not None
170174
if _has_gymnasium:
171175
set_gym_backend("gymnasium").set()
@@ -4541,6 +4545,37 @@ def test_render(self, rollout_steps):
45414545
assert not torch.equal(rollout_penultimate_image, image_from_env)
45424546

45434547

4548+
@pytest.mark.skipif(not _has_isaaclab, reason="Isaaclab not found")
4549+
class TestIsaacLab:
4550+
def test_isaaclab(self):
4551+
import gymnasium as gym
4552+
import isaaclab_tasks # noqa: F401
4553+
from isaaclab_tasks.manager_based.classic.ant.ant_env_cfg import AntEnvCfg
4554+
from torchrl.envs.libs.isaac_lab import IsaacLabWrapper
4555+
4556+
env = gym.make("Isaac-Ant-v0", cfg=AntEnvCfg())
4557+
env = IsaacLabWrapper(env)
4558+
assert env.batch_size == (4096,)
4559+
env.check_env_specs(break_when_any_done="both")
4560+
4561+
def test_isaac_collector(self):
4562+
import gymnasium as gym
4563+
import isaaclab_tasks # noqa: F401
4564+
from isaaclab_tasks.manager_based.classic.ant.ant_env_cfg import AntEnvCfg
4565+
from torchrl.envs.libs.isaac_lab import IsaacLabWrapper
4566+
4567+
env = gym.make("Isaac-Ant-v0", cfg=AntEnvCfg())
4568+
env = IsaacLabWrapper(env)
4569+
col = SyncDataCollector(
4570+
env, env.rand_action, frames_per_batch=1000, total_frames=100_000_000
4571+
)
4572+
for _ in col:
4573+
break
4574+
4575+
def test_isaaclab_reset(self):
4576+
...
4577+
4578+
45444579
if __name__ == "__main__":
45454580
args, unknown = argparse.ArgumentParser().parse_known_args()
45464581
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
HabitatEnv,
2121
IsaacGymEnv,
2222
IsaacGymWrapper,
23+
IsaacLabWrapper,
2324
JumanjiEnv,
2425
JumanjiWrapper,
2526
MeltingpotEnv,
@@ -131,6 +132,7 @@
131132
"ActionDiscretizer",
132133
"ActionMask",
133134
"VecNormV2",
135+
"IsaacLabWrapper",
134136
"AutoResetEnv",
135137
"AutoResetTransform",
136138
"AsyncEnvPool",

torchrl/envs/gym_like.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,8 @@ def validated(self, value):
515515
def _reset(
516516
self, tensordict: TensorDictBase | None = None, **kwargs
517517
) -> TensorDictBase:
518+
if tensordict is not None and "_reset" in tensordict and not tensordict["_reset"].all():
519+
raise RuntimeError("Partial resets are not handled at this level.")
518520
obs, info = self._reset_output_transform(self._env.reset(**kwargs))
519521

520522
source = self.read_obs(obs)

torchrl/envs/libs/__init__.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
set_gym_backend,
1717
)
1818
from .habitat import HabitatEnv
19+
from .isaac_lab import IsaacLabWrapper
1920
from .isaacgym import IsaacGymEnv, IsaacGymWrapper
2021
from .jumanji import JumanjiEnv, JumanjiWrapper
2122
from .meltingpot import MeltingpotEnv, MeltingpotWrapper
@@ -32,22 +33,20 @@
3233
"BraxWrapper",
3334
"DMControlEnv",
3435
"DMControlWrapper",
35-
"MultiThreadedEnv",
36-
"MultiThreadedEnvWrapper",
37-
"gym_backend",
3836
"GymEnv",
3937
"GymWrapper",
40-
"MOGymEnv",
41-
"MOGymWrapper",
42-
"register_gym_spec_conversion",
43-
"set_gym_backend",
4438
"HabitatEnv",
4539
"IsaacGymEnv",
4640
"IsaacGymWrapper",
41+
"IsaacLabWrapper",
4742
"JumanjiEnv",
4843
"JumanjiWrapper",
44+
"MOGymEnv",
45+
"MOGymWrapper",
4946
"MeltingpotEnv",
5047
"MeltingpotWrapper",
48+
"MultiThreadedEnv",
49+
"MultiThreadedEnvWrapper",
5150
"OpenMLEnv",
5251
"OpenSpielEnv",
5352
"OpenSpielWrapper",
@@ -60,4 +59,7 @@
6059
"UnityMLAgentsWrapper",
6160
"VmasEnv",
6261
"VmasWrapper",
62+
"gym_backend",
63+
"register_gym_spec_conversion",
64+
"set_gym_backend",
6365
]

torchrl/envs/libs/gym.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353

5454
_has_mo = importlib.util.find_spec("mo_gymnasium") is not None
5555
_has_sb3 = importlib.util.find_spec("stable_baselines3") is not None
56+
_has_isaaclab = importlib.util.find_spec("isaaclab") is not None
5657
_has_minigrid = importlib.util.find_spec("minigrid") is not None
5758

5859

@@ -803,6 +804,11 @@ def __call__(cls, *args, **kwargs):
803804
VecGymEnvTransform,
804805
)
805806

807+
if _has_isaaclab:
808+
from isaaclab.envs import ManagerBasedRLEnv
809+
if isinstance(instance._env.unwrapped, ManagerBasedRLEnv):
810+
return TransformedEnv(instance, VecGymEnvTransform())
811+
806812
if _has_sb3:
807813
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
808814

@@ -1069,12 +1075,15 @@ def _post_init(self):
10691075

10701076
@property
10711077
def _is_batched(self):
1078+
tuple_of_classes = ()
10721079
if _has_sb3:
10731080
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
10741081

1075-
tuple_of_classes = (VecEnv,)
1076-
else:
1077-
tuple_of_classes = ()
1082+
tuple_of_classes = tuple_of_classes + (VecEnv,)
1083+
if _has_isaaclab:
1084+
from isaaclab.envs import ManagerBasedRLEnv
1085+
1086+
tuple_of_classes = tuple_of_classes + (ManagerBasedRLEnv,)
10781087
return isinstance(
10791088
self._env, tuple_of_classes + (gym_backend("vector").VectorEnv,)
10801089
)

torchrl/envs/libs/isaac_lab.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import torch
6+
from torchrl.envs.libs.gym import GymWrapper
7+
8+
9+
class IsaacLabWrapper(GymWrapper):
10+
def __init__(
11+
self,
12+
env: "ManagerBasedRLEnv",
13+
categorical_action_encoding=False,
14+
allow_done_after_reset=True,
15+
convert_actions_to_numpy=False,
16+
device=torch.device("cuda:0"),
17+
**kwargs,
18+
):
19+
"""
20+
Here we are setting some parameters that are what we need for IsaacLab.
21+
"""
22+
super().__init__(
23+
env,
24+
device=device,
25+
categorical_action_encoding=categorical_action_encoding,
26+
allow_done_after_reset=allow_done_after_reset,
27+
convert_actions_to_numpy=convert_actions_to_numpy,
28+
**kwargs,
29+
)
30+
31+
def seed(self, seed: int | None):
32+
self._set_seed(seed)
33+
34+
def _output_transform(self, step_outputs_tuple): # noqa: F811
35+
# IsaacLab will modify the `terminated` and `truncated` tensors
36+
# in-place. We clone them here to make sure data doesn't inadvertently get modified.
37+
# The variable naming follows torchrl's convention here.
38+
observations, reward, terminated, truncated, info = step_outputs_tuple
39+
done = terminated | truncated
40+
reward = reward.unsqueeze(-1) # to get to (num_envs, 1)
41+
return (
42+
observations,
43+
reward,
44+
terminated.clone(),
45+
truncated.clone(),
46+
done.clone(),
47+
info,
48+
)
49+
50+
51+
if __name__ == "__main__":
52+
import argparse
53+
54+
from isaaclab.app import AppLauncher
55+
from torchrl.envs.libs.isaac_lab import IsaacLabWrapper
56+
57+
parser = argparse.ArgumentParser(description="Train an RL agent with skrl.")
58+
AppLauncher.add_app_launcher_args(parser)
59+
args_cli, hydra_args = parser.parse_known_args()
60+
61+
app_launcher = AppLauncher(args_cli)
62+
import gymnasium as gym
63+
import isaaclab_tasks # noqa: F401
64+
from isaaclab_tasks.manager_based.classic.ant.ant_env_cfg import AntEnvCfg
65+
66+
if __name__ == "__main__":
67+
# import isaaclab_tasks
68+
69+
env = gym.make("Isaac-Ant-v0", cfg=AntEnvCfg())
70+
env = IsaacLabWrapper(env)
71+
72+
import tqdm
73+
74+
# env.check_env_specs(break_when_any_done="both")
75+
# env.check_env_specs(break_when_any_done="both")
76+
from torchrl.collectors import SyncDataCollector
77+
from torchrl.record.loggers.wandb import WandbLogger
78+
79+
logger = WandbLogger(exp_name="test_isaac")
80+
col = SyncDataCollector(
81+
env, env.rand_action, frames_per_batch=1000, total_frames=100_000_000
82+
)
83+
for d in tqdm.tqdm(col):
84+
logger.log_scalar("frames", col._frames)

torchrl/envs/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import re
1515
import warnings
1616
from enum import Enum
17-
from typing import Any
17+
from typing import Any, Literal
1818

1919
import torch
2020

@@ -687,7 +687,7 @@ def check_env_specs(
687687
check_dtype=True,
688688
seed: int | None = None,
689689
tensordict: TensorDictBase | None = None,
690-
break_when_any_done: bool | str = None,
690+
break_when_any_done: bool | Literal["both"] = None,
691691
):
692692
"""Tests an environment specs against the results of short rollout.
693693

0 commit comments

Comments
 (0)