Skip to content

Commit 9c7a8f5

Browse files
committed
[Feature] IsaacLab wrapper
ghstack-source-id: 66fe16a Pull-Request-resolved: #2937
1 parent a31dca3 commit 9c7a8f5

File tree

13 files changed

+374
-58
lines changed

13 files changed

+374
-58
lines changed

.github/unittest/linux_libs/scripts_gym/setup_env.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ set -e
1010
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
1111
# Avoid error: "fatal: unsafe repository"
1212
apt-get update && apt-get install -y git wget gcc g++
13-
1413
apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libsdl2-dev libsdl2-2.0-0
1514
apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 xvfb libegl-dev libx11-dev freeglut3-dev
1615

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#!/usr/bin/env bash
2+
3+
set -e
4+
set -v
5+
6+
#if [[ "${{ github.ref }}" =~ release/* ]]; then
7+
# export RELEASE=1
8+
# export TORCH_VERSION=stable
9+
#else
10+
export RELEASE=0
11+
export TORCH_VERSION=nightly
12+
#fi
13+
14+
set -euo pipefail
15+
export PYTHON_VERSION="3.10"
16+
export CU_VERSION="12.8"
17+
export TAR_OPTIONS="--no-same-owner"
18+
export UPLOAD_CHANNEL="nightly"
19+
export TF_CPP_MIN_LOG_LEVEL=0
20+
export BATCHED_PIPE_TIMEOUT=60
21+
export TD_GET_DEFAULTS_TO_NONE=1
22+
23+
nvidia-smi
24+
25+
# Setup
26+
apt-get update && apt-get install -y git wget gcc g++
27+
apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libsdl2-dev libsdl2-2.0-0
28+
apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 xvfb libegl-dev libx11-dev freeglut3-dev
29+
30+
git config --global --add safe.directory '*'
31+
root_dir="$(git rev-parse --show-toplevel)"
32+
conda_dir="${root_dir}/conda"
33+
env_dir="${root_dir}/env"
34+
lib_dir="${env_dir}/lib"
35+
36+
cd "${root_dir}"
37+
38+
# install conda
39+
printf "* Installing conda\n"
40+
wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh"
41+
bash ./miniconda.sh -b -f -p "${conda_dir}"
42+
eval "$(${conda_dir}/bin/conda shell.bash hook)"
43+
44+
45+
conda create -n env_isaaclab python=3.10 -y
46+
conda activate env_isaaclab
47+
48+
# Pin pytorch to 2.5.1 for IsaacLab
49+
conda install pytorch==2.5.1 torchvision==0.20.1 pytorch-cuda=12.4 -c pytorch -c nvidia -y
50+
51+
pip install --upgrade pip
52+
pip install 'isaacsim[all,extscache]==4.5.0' --extra-index-url https://pypi.nvidia.com
53+
conda install conda-forge::"cmake>3.22" -y
54+
55+
git clone https://github.com/isaac-sim/IsaacLab.git
56+
cd IsaacLab
57+
conda run -p ${conda_dir} ./isaaclab.sh --install sb3
58+
cd ../
59+
60+
# install tensordict
61+
if [[ "$RELEASE" == 0 ]]; then
62+
conda install "anaconda::cmake>=3.22" -y
63+
conda run -p ${conda_dir} python3 -m pip install "pybind11[global]"
64+
conda run -p ${conda_dir} python3 -m pip install git+https://github.com/pytorch/tensordict.git
65+
else
66+
conda run -p ${conda_dir} python3 -m pip install tensordict
67+
fi
68+
69+
# smoke test
70+
conda run -p ${conda_dir} python -c "import tensordict"
71+
72+
printf "* Installing torchrl\n"
73+
conda run -p ${conda_dir} python setup.py develop
74+
conda run -p ${conda_dir} python -c "import torchrl"
75+
76+
# Install pytest
77+
conda run -p ${conda_dir} python -m pip install pytest pytest-cov pytest-mock pytest-instafail pytest-rerunfailures pytest-error-for-skips pytest-asyncio
78+
79+
# Run tests
80+
conda run -p ${conda_dir} python -m pytest test/test_libs.py -k isaac

.github/workflows/test-linux-libs.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,24 @@ jobs:
230230
./.github/unittest/linux_libs/scripts_gym/batch_scripts.sh
231231
./.github/unittest/linux_libs/scripts_gym/post_process.sh
232232
233+
unittests-isaaclab:
234+
strategy:
235+
matrix:
236+
python_version: ["3.10"]
237+
cuda_arch_version: ["12.8"]
238+
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments/Isaac') }}
239+
uses: vmoens/test-infra/.github/workflows/isaac_linux_job_v2.yml@main
240+
with:
241+
repository: pytorch/rl
242+
runner: "linux.g5.4xlarge.nvidia.gpu"
243+
docker-image: "nvcr.io/nvidia/isaac-lab:2.1.0"
244+
gpu-arch-type: cuda
245+
gpu-arch-version: "12.8"
246+
timeout: 120
247+
test-infra-repository: vmoens/test-infra
248+
script: |
249+
./.github/unittest/linux_libs/scripts_isaaclab/isaac.sh
250+
233251
unittests-jumanji:
234252
strategy:
235253
matrix:

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,7 @@ the following function will return ``1`` when queried:
14171417
HabitatEnv
14181418
IsaacGymEnv
14191419
IsaacGymWrapper
1420+
IsaacLabWrapper
14201421
JumanjiEnv
14211422
JumanjiWrapper
14221423
MeltingpotEnv

test/test_libs.py

Lines changed: 89 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,32 +32,6 @@
3232
import pytest
3333
import torch
3434

35-
if os.getenv("PYTORCH_TEST_FBCODE"):
36-
from pytorch.rl.test._utils_internal import (
37-
_make_multithreaded_env,
38-
CARTPOLE_VERSIONED,
39-
get_available_devices,
40-
get_default_devices,
41-
HALFCHEETAH_VERSIONED,
42-
PENDULUM_VERSIONED,
43-
PONG_VERSIONED,
44-
rand_reset,
45-
retry,
46-
rollout_consistency_assertion,
47-
)
48-
else:
49-
from _utils_internal import (
50-
_make_multithreaded_env,
51-
CARTPOLE_VERSIONED,
52-
get_available_devices,
53-
get_default_devices,
54-
HALFCHEETAH_VERSIONED,
55-
PENDULUM_VERSIONED,
56-
PONG_VERSIONED,
57-
rand_reset,
58-
retry,
59-
rollout_consistency_assertion,
60-
)
6135
from packaging import version
6236
from tensordict import (
6337
assert_allclose_td,
@@ -155,6 +129,33 @@
155129
ValueOperator,
156130
)
157131

132+
if os.getenv("PYTORCH_TEST_FBCODE"):
133+
from pytorch.rl.test._utils_internal import (
134+
_make_multithreaded_env,
135+
CARTPOLE_VERSIONED,
136+
get_available_devices,
137+
get_default_devices,
138+
HALFCHEETAH_VERSIONED,
139+
PENDULUM_VERSIONED,
140+
PONG_VERSIONED,
141+
rand_reset,
142+
retry,
143+
rollout_consistency_assertion,
144+
)
145+
else:
146+
from _utils_internal import (
147+
_make_multithreaded_env,
148+
CARTPOLE_VERSIONED,
149+
get_available_devices,
150+
get_default_devices,
151+
HALFCHEETAH_VERSIONED,
152+
PENDULUM_VERSIONED,
153+
PONG_VERSIONED,
154+
rand_reset,
155+
retry,
156+
rollout_consistency_assertion,
157+
)
158+
158159
_has_d4rl = importlib.util.find_spec("d4rl") is not None
159160

160161
_has_mo = importlib.util.find_spec("mo_gymnasium") is not None
@@ -166,6 +167,9 @@
166167
_has_minari = importlib.util.find_spec("minari") is not None
167168

168169
_has_gymnasium = importlib.util.find_spec("gymnasium") is not None
170+
171+
_has_isaaclab = importlib.util.find_spec("scripts_isaaclab") is not None
172+
169173
_has_gym_regular = importlib.util.find_spec("gym") is not None
170174
if _has_gymnasium:
171175
set_gym_backend("gymnasium").set()
@@ -4541,6 +4545,65 @@ def test_render(self, rollout_steps):
45414545
assert not torch.equal(rollout_penultimate_image, image_from_env)
45424546

45434547

4548+
@pytest.mark.skipif(not _has_isaaclab, reason="Isaaclab not found")
4549+
class TestIsaacLab:
4550+
@pytest.fixture(scope="class")
4551+
def env(self):
4552+
torch.manual_seed(0)
4553+
import argparse
4554+
4555+
# This code block ensures that the Isaac app is started in headless mode
4556+
from isaaclab.app import AppLauncher
4557+
4558+
parser = argparse.ArgumentParser(description="Train an RL agent with TorchRL.")
4559+
AppLauncher.add_app_launcher_args(parser)
4560+
args_cli, hydra_args = parser.parse_known_args(["--headless"])
4561+
AppLauncher(args_cli)
4562+
4563+
# Imports and env
4564+
import gymnasium as gym
4565+
import isaaclab_tasks # noqa: F401
4566+
from isaaclab_tasks.manager_based.classic.ant.ant_env_cfg import AntEnvCfg
4567+
from torchrl.envs.libs.isaac_lab import IsaacLabWrapper
4568+
4569+
torchrl_logger.info("Making IsaacLab env...")
4570+
env = gym.make("Isaac-Ant-v0", cfg=AntEnvCfg())
4571+
torchrl_logger.info("Wrapping IsaacLab env...")
4572+
try:
4573+
env = IsaacLabWrapper(env)
4574+
yield env
4575+
finally:
4576+
torchrl_logger.info("Closing IsaacLab env...")
4577+
env.close()
4578+
torchrl_logger.info("Closed")
4579+
4580+
def test_isaaclab(self, env):
4581+
assert env.batch_size == (4096,)
4582+
assert env._is_batched
4583+
torchrl_logger.info("Checking env specs...")
4584+
env.check_env_specs(break_when_any_done="both")
4585+
torchrl_logger.info("Check succeeded!")
4586+
4587+
def test_isaac_collector(self, env):
4588+
col = SyncDataCollector(
4589+
env, env.rand_action, frames_per_batch=1000, total_frames=100_000_000
4590+
)
4591+
try:
4592+
for data in col:
4593+
assert data.shape == (4096, 1)
4594+
break
4595+
finally:
4596+
# We must do that, otherwise `__del__` calls `shutdown` and the next test will fail
4597+
col.shutdown(close_env=False)
4598+
4599+
def test_isaaclab_reset(self, env):
4600+
# Make a rollout that will stop as soon as a trajectory reaches a done state
4601+
r = env.rollout(1_000_000)
4602+
4603+
# Check that done obs are None
4604+
assert not r["next", "policy"][r["next", "done"].squeeze(-1)].isfinite().any()
4605+
4606+
45444607
if __name__ == "__main__":
45454608
args, unknown = argparse.ArgumentParser().parse_known_args()
45464609
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/collectors/collectors.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -272,15 +272,19 @@ def pause(self):
272272
f"Collector pause() is not implemented for {type(self).__name__}."
273273
)
274274

275-
def async_shutdown(self, timeout: float | None = None) -> None:
275+
def async_shutdown(
276+
self, timeout: float | None = None, close_env: bool = True
277+
) -> None:
276278
"""Shuts down the collector when started asynchronously with the `start` method.
277279
278280
Arg:
279281
timeout (float, optional): The maximum time to wait for the collector to shutdown.
282+
close_env (bool, optional): If True, the collector will close the contained environment.
283+
Defaults to `True`.
280284
281285
.. seealso:: :meth:`~.start`
282286
"""
283-
return self.shutdown(timeout=timeout)
287+
return self.shutdown(timeout=timeout, close_env=close_env)
284288

285289
def update_policy_weights_(
286290
self,
@@ -336,7 +340,7 @@ def next(self):
336340
return None
337341

338342
@abc.abstractmethod
339-
def shutdown(self, timeout: float | None = None) -> None:
343+
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
340344
raise NotImplementedError
341345

342346
@abc.abstractmethod
@@ -1311,12 +1315,14 @@ def _run_iterator(self):
13111315
if self._stop:
13121316
return
13131317

1314-
def async_shutdown(self, timeout: float | None = None) -> None:
1318+
def async_shutdown(
1319+
self, timeout: float | None = None, close_env: bool = True
1320+
) -> None:
13151321
"""Finishes processes started by ray.init() during async execution."""
13161322
self._stop = True
13171323
if hasattr(self, "_thread") and self._thread.is_alive():
13181324
self._thread.join(timeout=timeout)
1319-
self.shutdown()
1325+
self.shutdown(close_env=close_env)
13201326

13211327
def _postproc(self, tensordict_out):
13221328
if self.split_trajs:
@@ -1576,14 +1582,20 @@ def reset(self, index=None, **kwargs) -> None:
15761582
)
15771583
self._shuttle["collector"] = collector_metadata
15781584

1579-
def shutdown(self, timeout: float | None = None) -> None:
1580-
"""Shuts down all workers and/or closes the local environment."""
1585+
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
1586+
"""Shuts down all workers and/or closes the local environment.
1587+
1588+
Args:
1589+
timeout (float, optional): The timeout for closing pipes between workers.
1590+
No effect for this class.
1591+
close_env (bool, optional): Whether to close the environment. Defaults to `True`.
1592+
"""
15811593
if not self.closed:
15821594
self.closed = True
15831595
del self._shuttle
15841596
if self._use_buffers:
15851597
del self._final_rollout
1586-
if not self.env.is_closed:
1598+
if close_env and not self.env.is_closed:
15871599
self.env.close()
15881600
del self.env
15891601
return
@@ -2385,8 +2397,17 @@ def __del__(self):
23852397
# __del__ will not affect the program.
23862398
pass
23872399

2388-
def shutdown(self, timeout: float | None = None) -> None:
2389-
"""Shuts down all processes. This operation is irreversible."""
2400+
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
2401+
"""Shuts down all processes. This operation is irreversible.
2402+
2403+
Args:
2404+
timeout (float, optional): The timeout for closing pipes between workers.
2405+
close_env (bool, optional): Whether to close the environment. Defaults to `True`.
2406+
"""
2407+
if not close_env:
2408+
raise RuntimeError(
2409+
f"Cannot shutdown {type(self).__name__} collector without environment being closed."
2410+
)
23902411
self._shutdown_main(timeout)
23912412

23922413
def _shutdown_main(self, timeout: float | None = None) -> None:
@@ -2659,7 +2680,11 @@ def next(self):
26592680
return super().next()
26602681

26612682
# for RPC
2662-
def shutdown(self, timeout: float | None = None) -> None:
2683+
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
2684+
if not close_env:
2685+
raise RuntimeError(
2686+
f"Cannot shutdown {type(self).__name__} collector without environment being closed."
2687+
)
26632688
if hasattr(self, "out_buffer"):
26642689
del self.out_buffer
26652690
if hasattr(self, "buffers"):
@@ -3032,9 +3057,13 @@ def next(self):
30323057
return super().next()
30333058

30343059
# for RPC
3035-
def shutdown(self, timeout: float | None = None) -> None:
3060+
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
30363061
if hasattr(self, "out_tensordicts"):
30373062
del self.out_tensordicts
3063+
if not close_env:
3064+
raise RuntimeError(
3065+
f"Cannot shutdown {type(self).__name__} collector without environment being closed."
3066+
)
30383067
return super().shutdown(timeout=timeout)
30393068

30403069
# for RPC
@@ -3376,8 +3405,8 @@ def next(self):
33763405
return super().next()
33773406

33783407
# for RPC
3379-
def shutdown(self, timeout: float | None = None) -> None:
3380-
return super().shutdown(timeout=timeout)
3408+
def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None:
3409+
return super().shutdown(timeout=timeout, close_env=close_env)
33813410

33823411
# for RPC
33833412
def set_seed(self, seed: int, static_seed: bool = False) -> int:

0 commit comments

Comments
 (0)