-
Notifications
You must be signed in to change notification settings - Fork 256
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* MVP docs overview * Apply suggestions from code review Co-authored-by: Adam Gleave <[email protected]> * Address comments Co-authored-by: Adam Gleave <[email protected]>
- Loading branch information
1 parent
d0a04ce
commit 9ddd6fa
Showing
18 changed files
with
249 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
================================================= | ||
Adversarial Inverse Reinforcement Learning (AIRL) | ||
================================================= | ||
|
||
Implements `Learning Robust Rewards with Adversarial Inverse Reinforcement Learning <https://arxiv.org/abs/1710.11248>`_. | ||
|
||
|
||
API | ||
=== | ||
.. autoclass:: imitation.algorithms.adversarial.AIRL | ||
:members: | ||
:inherited-members: | ||
:noindex: | ||
|
||
.. autoclass:: imitation.algorithms.adversarial.AdversarialTrainer | ||
:members: | ||
:noindex: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
======================= | ||
Behavioral Cloning (BC) | ||
======================= | ||
|
||
Supervised learning on observation-action pairs. | ||
|
||
API | ||
=== | ||
.. autoclass:: imitation.algorithms.bc.BC | ||
:members: | ||
:noindex: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
======================= | ||
DAgger | ||
======================= | ||
|
||
Implements `A Reduction of Imitation Learning and Structured Prediction to No-Regret Online Learning <https://arxiv.org/abs/1011.0686>`_. | ||
|
||
API | ||
=== | ||
.. autoclass:: imitation.algorithms.dagger.InteractiveTrajectoryCollector | ||
:members: | ||
:inherited-members: | ||
:noindex: | ||
|
||
.. autoclass:: imitation.algorithms.dagger.DAggerTrainer | ||
:members: | ||
:inherited-members: | ||
:noindex: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
================================================ | ||
Generative Adversarial Imitation Learning (GAIL) | ||
================================================ | ||
|
||
Implements `Generative Adversarial Imitation Learning <https://arxiv.org/abs/1606.03476>`_. | ||
|
||
API | ||
=== | ||
.. autoclass:: imitation.algorithms.adversarial.GAIL | ||
:members: | ||
:inherited-members: | ||
:noindex: | ||
|
||
.. autoclass:: imitation.algorithms.adversarial.AdversarialTrainer | ||
:members: | ||
:inherited-members: | ||
:noindex: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
=============== | ||
Getting Started | ||
=============== | ||
|
||
|
||
CLI Quickstart | ||
============== | ||
|
||
We provide several CLI scripts as front-ends to the algorithms implemented in ``imitation``. | ||
These use `Sacred <https://github.com/idsia/sacred>`_ for configuration and replicability. | ||
|
||
For information on how to configure Sacred CLI options, see the `Sacred docs <https://sacred.readthedocs.io/en/stable/>`_. | ||
|
||
.. code-block:: bash | ||
# Train PPO agent on cartpole and collect expert demonstrations. Tensorboard logs saved | ||
# in `quickstart/rl/` | ||
python -m imitation.scripts.expert_demos with fast cartpole log_dir=quickstart/rl/ | ||
# Train GAIL from demonstrations. Tensorboard logs saved in output/ (default log directory). | ||
python -m imitation.scripts.train_adversarial with fast gail cartpole \ | ||
rollout_path=quickstart/rl/rollouts/final.pkl | ||
# Train AIRL from demonstrations. Tensorboard logs saved in output/ (default log directory). | ||
python -m imitation.scripts.train_adversarial with fast airl cartpole \ | ||
rollout_path=quickstart/rl/rollouts/final.pkl | ||
.. note:: | ||
Remove the ``fast`` option from the commands above to allow training run to completion. | ||
|
||
.. tip:: | ||
``python -m imitation.scripts.expert_demos print_config`` will list Sacred script options. | ||
These configuration options are also documented in each script's docstrings. | ||
|
||
|
||
Python Interface Quickstart | ||
=========================== | ||
|
||
Here's an `example script`_ that loads CartPole-v1 demonstrations and trains BC, GAIL, and | ||
AIRL models on that data. | ||
|
||
.. _example script: https://github.com/HumanCompatibleAI/imitation/blob/master/examples/quickstart.py | ||
|
||
.. code-block:: python | ||
"""Loads CartPole-v1 demonstrations and trains BC, GAIL, and AIRL models on that data. | ||
""" | ||
import pathlib | ||
import pickle | ||
import tempfile | ||
import stable_baselines3 as sb3 | ||
from imitation.algorithms import adversarial, bc | ||
from imitation.data import rollout | ||
from imitation.util import logger, util | ||
# Load pickled test demonstrations. | ||
with open("tests/data/expert_models/cartpole_0/rollouts/final.pkl", "rb") as f: | ||
# This is a list of `imitation.data.types.Trajectory`, where | ||
# every instance contains observations and actions for a single expert | ||
# demonstration. | ||
trajectories = pickle.load(f) | ||
# Convert List[types.Trajectory] to an instance of `imitation.data.types.Transitions`. | ||
# This is a more general dataclass containing unordered | ||
# (observation, actions, next_observation) transitions. | ||
transitions = rollout.flatten_trajectories(trajectories) | ||
venv = util.make_vec_env("CartPole-v1", n_envs=2) | ||
tempdir = tempfile.TemporaryDirectory(prefix="quickstart") | ||
tempdir_path = pathlib.Path(tempdir.name) | ||
print(f"All Tensorboards and logging are being written inside {tempdir_path}/.") | ||
# Train BC on expert data. | ||
# BC also accepts as `expert_data` any PyTorch-style DataLoader that iterates over | ||
# dictionaries containing observations and actions. | ||
logger.configure(tempdir_path / "BC/") | ||
bc_trainer = bc.BC(venv.observation_space, venv.action_space, expert_data=transitions) | ||
bc_trainer.train(n_epochs=1) | ||
# Train GAIL on expert data. | ||
# GAIL, and AIRL also accept as `expert_data` any Pytorch-style DataLoader that | ||
# iterates over dictionaries containing observations, actions, and next_observations. | ||
logger.configure(tempdir_path / "GAIL/") | ||
gail_trainer = adversarial.GAIL( | ||
venv, | ||
expert_data=transitions, | ||
expert_batch_size=32, | ||
gen_algo=sb3.PPO("MlpPolicy", venv, verbose=1, n_steps=1024), | ||
) | ||
gail_trainer.train(total_timesteps=2048) | ||
# Train AIRL on expert data. | ||
logger.configure(tempdir_path / "AIRL/") | ||
airl_trainer = adversarial.AIRL( | ||
venv, | ||
expert_data=transitions, | ||
expert_batch_size=32, | ||
gen_algo=sb3.PPO("MlpPolicy", venv, verbose=1, n_steps=1024), | ||
) | ||
airl_trainer.train(total_timesteps=2048) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
============ | ||
Installation | ||
============ | ||
|
||
**Installing PyPI release** | ||
|
||
.. code-block:: bash | ||
pip install imitation | ||
**Install latest commit** | ||
|
||
.. code-block:: bash | ||
git clone http://github.com/HumanCompatibleAI/imitation | ||
cd imitation | ||
pip install -e . | ||
**Optional Mujoco Dependency** | ||
|
||
Follow instructions to install `mujoco\_py v1.5 here`_. | ||
|
||
.. _mujoco_py v1.5 here: | ||
https://github.com/openai/mujoco-py/tree/498b451a03fb61e5bdfcb6956d8d7c881b1098b5#install-mujoco |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,4 +14,3 @@ Module contents | |
--------------- | ||
|
||
.. automodule:: imitation.envs.examples | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,4 +21,3 @@ Module contents | |
--------------- | ||
|
||
.. automodule:: imitation.envs | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,4 +18,4 @@ Subpackages | |
Module contents | ||
--------------- | ||
|
||
.. automodule:: imitation | ||
.. automodule:: imitation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.