Skip to content

Commit f0f6e73

Browse files
committed
[Feature] IsaacLab wrapper
ghstack-source-id: f99fab8 Pull-Request-resolved: #2937
1 parent 3dbd84c commit f0f6e73

File tree

11 files changed

+287
-44
lines changed

11 files changed

+287
-44
lines changed

.github/unittest/linux_libs/scripts_gym/setup_env.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ set -e
1010
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
1111
# Avoid error: "fatal: unsafe repository"
1212
apt-get update && apt-get install -y git wget gcc g++
13-
1413
apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libsdl2-dev libsdl2-2.0-0
1514
apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 xvfb libegl-dev libx11-dev freeglut3-dev
1615

.github/workflows/test-linux-libs.yml

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,93 @@ jobs:
230230
./.github/unittest/linux_libs/scripts_gym/batch_scripts.sh
231231
./.github/unittest/linux_libs/scripts_gym/post_process.sh
232232
233+
unittests-isaaclab:
234+
strategy:
235+
matrix:
236+
python_version: ["3.10"]
237+
cuda_arch_version: ["12.8"]
238+
if: ${{ github.event_name == 'push' }} # || contains(github.event.pull_request.labels.*.name, 'Environments') }}
239+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
240+
with:
241+
repository: pytorch/rl
242+
runner: "linux.g5.4xlarge.nvidia.gpu"
243+
gpu-arch-type: cuda
244+
gpu-arch-version: "12.8"
245+
docker-image: "nvidia/cuda:12.4.0-devel-ubuntu22.04"
246+
timeout: 120
247+
script: |
248+
if [[ "${{ github.ref }}" =~ release/* ]]; then
249+
export RELEASE=1
250+
export TORCH_VERSION=stable
251+
else
252+
export RELEASE=0
253+
export TORCH_VERSION=nightly
254+
fi
255+
256+
set -euo pipefail
257+
export PYTHON_VERSION="3.10"
258+
export CU_VERSION="12.8"
259+
export TAR_OPTIONS="--no-same-owner"
260+
export UPLOAD_CHANNEL="nightly"
261+
export TF_CPP_MIN_LOG_LEVEL=0
262+
export BATCHED_PIPE_TIMEOUT=60
263+
export TD_GET_DEFAULTS_TO_NONE=1
264+
265+
nvidia-smi
266+
267+
# Setup
268+
apt-get update && apt-get install -y git wget gcc g++
269+
apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libsdl2-dev libsdl2-2.0-0
270+
apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 xvfb libegl-dev libx11-dev freeglut3-dev
271+
272+
git config --global --add safe.directory '*'
273+
root_dir="$(git rev-parse --show-toplevel)"
274+
conda_dir="${root_dir}/conda"
275+
env_dir="${root_dir}/env"
276+
lib_dir="${env_dir}/lib"
277+
278+
cd "${root_dir}"
279+
280+
# install conda
281+
printf "* Installing conda\n"
282+
wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh"
283+
bash ./miniconda.sh -b -f -p "${conda_dir}"
284+
eval "$(${conda_dir}/bin/conda shell.bash hook)"
285+
286+
287+
conda create -n env_isaaclab python=3.10 -y
288+
conda activate env_isaaclab
289+
pip install --upgrade pip
290+
pip install 'isaacsim[all,extscache]==4.5.0' --extra-index-url https://pypi.nvidia.com
291+
git clone [email protected]:isaac-sim/IsaacLab.git
292+
conda install "conda-forge::cmake>3.22" -y
293+
cd IsaacLab
294+
./isaaclab.sh --install
295+
cd ../
296+
297+
# install tensordict
298+
if [[ "$RELEASE" == 0 ]]; then
299+
conda install "anaconda::cmake>=3.22" -y
300+
pip3 install "pybind11[global]"
301+
pip3 install git+https://github.com/pytorch/tensordict.git
302+
else
303+
pip3 install tensordict
304+
fi
305+
306+
# smoke test
307+
python -c "import tensordict"
308+
309+
printf "* Installing torchrl\n"
310+
python setup.py develop
311+
python -c "import torchrl"
312+
313+
# Install pytest
314+
pip install pytest pytest-cov pytest-mock pytest-instafail pytest-rerunfailures pytest-error-for-skips pytest-asyncio
315+
316+
# Run tests
317+
pytest test/test_libs.py -k isaac
318+
319+
233320
unittests-jumanji:
234321
strategy:
235322
matrix:

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,7 @@ the following function will return ``1`` when queried:
14171417
HabitatEnv
14181418
IsaacGymEnv
14191419
IsaacGymWrapper
1420+
IsaacLabWrapper
14201421
JumanjiEnv
14211422
JumanjiWrapper
14221423
MeltingpotEnv

test/test_libs.py

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,32 +32,6 @@
3232
import pytest
3333
import torch
3434

35-
if os.getenv("PYTORCH_TEST_FBCODE"):
36-
from pytorch.rl.test._utils_internal import (
37-
_make_multithreaded_env,
38-
CARTPOLE_VERSIONED,
39-
get_available_devices,
40-
get_default_devices,
41-
HALFCHEETAH_VERSIONED,
42-
PENDULUM_VERSIONED,
43-
PONG_VERSIONED,
44-
rand_reset,
45-
retry,
46-
rollout_consistency_assertion,
47-
)
48-
else:
49-
from _utils_internal import (
50-
_make_multithreaded_env,
51-
CARTPOLE_VERSIONED,
52-
get_available_devices,
53-
get_default_devices,
54-
HALFCHEETAH_VERSIONED,
55-
PENDULUM_VERSIONED,
56-
PONG_VERSIONED,
57-
rand_reset,
58-
retry,
59-
rollout_consistency_assertion,
60-
)
6135
from packaging import version
6236
from tensordict import (
6337
assert_allclose_td,
@@ -155,6 +129,33 @@
155129
ValueOperator,
156130
)
157131

132+
if os.getenv("PYTORCH_TEST_FBCODE"):
133+
from pytorch.rl.test._utils_internal import (
134+
_make_multithreaded_env,
135+
CARTPOLE_VERSIONED,
136+
get_available_devices,
137+
get_default_devices,
138+
HALFCHEETAH_VERSIONED,
139+
PENDULUM_VERSIONED,
140+
PONG_VERSIONED,
141+
rand_reset,
142+
retry,
143+
rollout_consistency_assertion,
144+
)
145+
else:
146+
from _utils_internal import (
147+
_make_multithreaded_env,
148+
CARTPOLE_VERSIONED,
149+
get_available_devices,
150+
get_default_devices,
151+
HALFCHEETAH_VERSIONED,
152+
PENDULUM_VERSIONED,
153+
PONG_VERSIONED,
154+
rand_reset,
155+
retry,
156+
rollout_consistency_assertion,
157+
)
158+
158159
_has_d4rl = importlib.util.find_spec("d4rl") is not None
159160

160161
_has_mo = importlib.util.find_spec("mo_gymnasium") is not None
@@ -166,6 +167,9 @@
166167
_has_minari = importlib.util.find_spec("minari") is not None
167168

168169
_has_gymnasium = importlib.util.find_spec("gymnasium") is not None
170+
171+
_has_isaaclab = importlib.util.find_spec("isaaclab") is not None
172+
169173
_has_gym_regular = importlib.util.find_spec("gym") is not None
170174
if _has_gymnasium:
171175
set_gym_backend("gymnasium").set()
@@ -4541,6 +4545,38 @@ def test_render(self, rollout_steps):
45414545
assert not torch.equal(rollout_penultimate_image, image_from_env)
45424546

45434547

4548+
@pytest.mark.skipif(not _has_isaaclab, reason="Isaaclab not found")
4549+
class TestIsaacLab:
4550+
@pytest.fixture(scope="class")
4551+
def env(self):
4552+
import gymnasium as gym
4553+
import isaaclab_tasks # noqa: F401
4554+
from isaaclab_tasks.manager_based.classic.ant.ant_env_cfg import AntEnvCfg
4555+
from torchrl.envs.libs.isaac_lab import IsaacLabWrapper
4556+
4557+
env = gym.make("Isaac-Ant-v0", cfg=AntEnvCfg())
4558+
env = IsaacLabWrapper(env)
4559+
yield env
4560+
4561+
def test_isaaclab(self, env):
4562+
assert env.batch_size == (4096,)
4563+
assert env._is_batched
4564+
env.check_env_specs(break_when_any_done="both")
4565+
4566+
def test_isaac_collector(self, env):
4567+
col = SyncDataCollector(
4568+
env, env.rand_action, frames_per_batch=1000, total_frames=100_000_000
4569+
)
4570+
for _ in col:
4571+
break
4572+
4573+
def test_isaaclab_reset(self):
4574+
# Make a rollout that will stop as soon as a trajectory reaches a done state
4575+
r = env.rollout(1_000_000)
4576+
# Check that done obs are None
4577+
assert (r["next", "policy"][r["next", "done"].squeeze(-1)] == np.nan).all()
4578+
4579+
45444580
if __name__ == "__main__":
45454581
args, unknown = argparse.ArgumentParser().parse_known_args()
45464582
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
HabitatEnv,
2121
IsaacGymEnv,
2222
IsaacGymWrapper,
23+
IsaacLabWrapper,
2324
JumanjiEnv,
2425
JumanjiWrapper,
2526
MeltingpotEnv,
@@ -131,6 +132,7 @@
131132
"ActionDiscretizer",
132133
"ActionMask",
133134
"VecNormV2",
135+
"IsaacLabWrapper",
134136
"AutoResetEnv",
135137
"AutoResetTransform",
136138
"AsyncEnvPool",

torchrl/envs/gym_like.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,12 @@ def validated(self, value):
515515
def _reset(
516516
self, tensordict: TensorDictBase | None = None, **kwargs
517517
) -> TensorDictBase:
518+
if (
519+
tensordict is not None
520+
and "_reset" in tensordict
521+
and not tensordict["_reset"].all()
522+
):
523+
raise RuntimeError("Partial resets are not handled at this level.")
518524
obs, info = self._reset_output_transform(self._env.reset(**kwargs))
519525

520526
source = self.read_obs(obs)

torchrl/envs/libs/__init__.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
set_gym_backend,
1717
)
1818
from .habitat import HabitatEnv
19+
from .isaac_lab import IsaacLabWrapper
1920
from .isaacgym import IsaacGymEnv, IsaacGymWrapper
2021
from .jumanji import JumanjiEnv, JumanjiWrapper
2122
from .meltingpot import MeltingpotEnv, MeltingpotWrapper
@@ -32,22 +33,20 @@
3233
"BraxWrapper",
3334
"DMControlEnv",
3435
"DMControlWrapper",
35-
"MultiThreadedEnv",
36-
"MultiThreadedEnvWrapper",
37-
"gym_backend",
3836
"GymEnv",
3937
"GymWrapper",
40-
"MOGymEnv",
41-
"MOGymWrapper",
42-
"register_gym_spec_conversion",
43-
"set_gym_backend",
4438
"HabitatEnv",
4539
"IsaacGymEnv",
4640
"IsaacGymWrapper",
41+
"IsaacLabWrapper",
4742
"JumanjiEnv",
4843
"JumanjiWrapper",
44+
"MOGymEnv",
45+
"MOGymWrapper",
4946
"MeltingpotEnv",
5047
"MeltingpotWrapper",
48+
"MultiThreadedEnv",
49+
"MultiThreadedEnvWrapper",
5150
"OpenMLEnv",
5251
"OpenSpielEnv",
5352
"OpenSpielWrapper",
@@ -60,4 +59,7 @@
6059
"UnityMLAgentsWrapper",
6160
"VmasEnv",
6261
"VmasWrapper",
62+
"gym_backend",
63+
"register_gym_spec_conversion",
64+
"set_gym_backend",
6365
]

torchrl/envs/libs/gym.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353

5454
_has_mo = importlib.util.find_spec("mo_gymnasium") is not None
5555
_has_sb3 = importlib.util.find_spec("stable_baselines3") is not None
56+
_has_isaaclab = importlib.util.find_spec("isaaclab") is not None
5657
_has_minigrid = importlib.util.find_spec("minigrid") is not None
5758

5859

@@ -793,6 +794,7 @@ class PixelObservationWrapper:
793794

794795
class _GymAsyncMeta(_EnvPostInit):
795796
def __call__(cls, *args, **kwargs):
797+
missing_obs_value = kwargs.pop("missing_obs_value", None)
796798
instance: GymWrapper = super().__call__(*args, **kwargs)
797799

798800
# before gym 0.22, there was no final_observation
@@ -803,6 +805,15 @@ def __call__(cls, *args, **kwargs):
803805
VecGymEnvTransform,
804806
)
805807

808+
if _has_isaaclab:
809+
from isaaclab.envs import ManagerBasedRLEnv
810+
811+
kwargs = {}
812+
if missing_obs_value is not None:
813+
kwargs["missing_obs_value"] = missing_obs_value
814+
if isinstance(instance._env.unwrapped, ManagerBasedRLEnv):
815+
return TransformedEnv(instance, VecGymEnvTransform(**kwargs))
816+
806817
if _has_sb3:
807818
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
808819

@@ -845,7 +856,10 @@ def __call__(cls, *args, **kwargs):
845856
instance.observation_spec, backend=backend
846857
)
847858
)
848-
return TransformedEnv(instance, VecGymEnvTransform())
859+
kwargs = {}
860+
if missing_obs_value is not None:
861+
kwargs["missing_obs_value"] = missing_obs_value
862+
return TransformedEnv(instance, VecGymEnvTransform(**kwargs))
849863
return instance
850864

851865

@@ -892,6 +906,10 @@ class GymWrapper(GymLikeEnv, metaclass=_GymAsyncMeta):
892906
env step function. Set this to ``False`` if the environment is evaluated
893907
on GPU, such as IsaacLab.
894908
Defaults to ``True``.
909+
missing_obs_value (Any, optional): default value to use as placeholder for missing observations, when
910+
the environment is auto-resetting and missing observations cannot be found in the info dictionary
911+
(e.g., with IsaacLab). This argument is passed to :class:`~torchrl.envs.VecGymEnvTransform` by
912+
the metaclass.
895913
896914
Attributes:
897915
available_envs (List[str]): a list of environments to build.
@@ -1069,14 +1087,17 @@ def _post_init(self):
10691087

10701088
@property
10711089
def _is_batched(self):
1090+
tuple_of_classes = ()
10721091
if _has_sb3:
10731092
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
10741093

1075-
tuple_of_classes = (VecEnv,)
1076-
else:
1077-
tuple_of_classes = ()
1094+
tuple_of_classes = tuple_of_classes + (VecEnv,)
1095+
if _has_isaaclab:
1096+
from isaaclab.envs import ManagerBasedRLEnv
1097+
1098+
tuple_of_classes = tuple_of_classes + (ManagerBasedRLEnv,)
10781099
return isinstance(
1079-
self._env, tuple_of_classes + (gym_backend("vector").VectorEnv,)
1100+
self._env.unwrapped, tuple_of_classes + (gym_backend("vector").VectorEnv,)
10801101
)
10811102

10821103
@implement_for("gym")
@@ -1562,7 +1583,10 @@ def _replace_reset(self, reset, kwargs): # noqa
15621583
def _replace_reset(self, reset, kwargs): # noqa
15631584
import gymnasium as gym
15641585

1565-
if self._env.autoreset_mode == gym.vector.AutoresetMode.DISABLED:
1586+
if (
1587+
getattr(self._env, "autoreset_mode", None)
1588+
== gym.vector.AutoresetMode.DISABLED
1589+
):
15661590
options = {"reset_mask": reset.view(self.batch_size).numpy()}
15671591
kwargs.setdefault("options", {}).update(options)
15681592
return kwargs

0 commit comments

Comments
 (0)