Skip to content

Commit

Permalink
MVP docs (#267)
Browse files Browse the repository at this point in the history
* MVP docs overview

* Apply suggestions from code review

Co-authored-by: Adam Gleave <[email protected]>

* Address comments

Co-authored-by: Adam Gleave <[email protected]>
  • Loading branch information
shwang and AdamGleave authored Feb 12, 2021
1 parent d0a04ce commit 9ddd6fa
Show file tree
Hide file tree
Showing 18 changed files with 249 additions and 33 deletions.
17 changes: 17 additions & 0 deletions docs/algorithms/airl.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
=================================================
Adversarial Inverse Reinforcement Learning (AIRL)
=================================================

Implements `Learning Robust Rewards with Adversarial Inverse Reinforcement Learning <https://arxiv.org/abs/1710.11248>`_.


API
===
.. autoclass:: imitation.algorithms.adversarial.AIRL
:members:
:inherited-members:
:noindex:

.. autoclass:: imitation.algorithms.adversarial.AdversarialTrainer
:members:
:noindex:
11 changes: 11 additions & 0 deletions docs/algorithms/bc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
=======================
Behavioral Cloning (BC)
=======================

Supervised learning on observation-action pairs.

API
===
.. autoclass:: imitation.algorithms.bc.BC
:members:
:noindex:
17 changes: 17 additions & 0 deletions docs/algorithms/dagger.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
=======================
DAgger
=======================

Implements `A Reduction of Imitation Learning and Structured Prediction to No-Regret Online Learning <https://arxiv.org/abs/1011.0686>`_.

API
===
.. autoclass:: imitation.algorithms.dagger.InteractiveTrajectoryCollector
:members:
:inherited-members:
:noindex:

.. autoclass:: imitation.algorithms.dagger.DAggerTrainer
:members:
:inherited-members:
:noindex:
17 changes: 17 additions & 0 deletions docs/algorithms/gail.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
================================================
Generative Adversarial Imitation Learning (GAIL)
================================================

Implements `Generative Adversarial Imitation Learning <https://arxiv.org/abs/1606.03476>`_.

API
===
.. autoclass:: imitation.algorithms.adversarial.GAIL
:members:
:inherited-members:
:noindex:

.. autoclass:: imitation.algorithms.adversarial.AdversarialTrainer
:members:
:inherited-members:
:noindex:
104 changes: 104 additions & 0 deletions docs/guide/gettingstarted.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
===============
Getting Started
===============


CLI Quickstart
==============

We provide several CLI scripts as front-ends to the algorithms implemented in ``imitation``.
These use `Sacred <https://github.com/idsia/sacred>`_ for configuration and replicability.

For information on how to configure Sacred CLI options, see the `Sacred docs <https://sacred.readthedocs.io/en/stable/>`_.

.. code-block:: bash
# Train PPO agent on cartpole and collect expert demonstrations. Tensorboard logs saved
# in `quickstart/rl/`
python -m imitation.scripts.expert_demos with fast cartpole log_dir=quickstart/rl/
# Train GAIL from demonstrations. Tensorboard logs saved in output/ (default log directory).
python -m imitation.scripts.train_adversarial with fast gail cartpole \
rollout_path=quickstart/rl/rollouts/final.pkl
# Train AIRL from demonstrations. Tensorboard logs saved in output/ (default log directory).
python -m imitation.scripts.train_adversarial with fast airl cartpole \
rollout_path=quickstart/rl/rollouts/final.pkl
.. note::
Remove the ``fast`` option from the commands above to allow training run to completion.

.. tip::
``python -m imitation.scripts.expert_demos print_config`` will list Sacred script options.
These configuration options are also documented in each script's docstrings.


Python Interface Quickstart
===========================

Here's an `example script`_ that loads CartPole-v1 demonstrations and trains BC, GAIL, and
AIRL models on that data.

.. _example script: https://github.com/HumanCompatibleAI/imitation/blob/master/examples/quickstart.py

.. code-block:: python
"""Loads CartPole-v1 demonstrations and trains BC, GAIL, and AIRL models on that data.
"""
import pathlib
import pickle
import tempfile
import stable_baselines3 as sb3
from imitation.algorithms import adversarial, bc
from imitation.data import rollout
from imitation.util import logger, util
# Load pickled test demonstrations.
with open("tests/data/expert_models/cartpole_0/rollouts/final.pkl", "rb") as f:
# This is a list of `imitation.data.types.Trajectory`, where
# every instance contains observations and actions for a single expert
# demonstration.
trajectories = pickle.load(f)
# Convert List[types.Trajectory] to an instance of `imitation.data.types.Transitions`.
# This is a more general dataclass containing unordered
# (observation, actions, next_observation) transitions.
transitions = rollout.flatten_trajectories(trajectories)
venv = util.make_vec_env("CartPole-v1", n_envs=2)
tempdir = tempfile.TemporaryDirectory(prefix="quickstart")
tempdir_path = pathlib.Path(tempdir.name)
print(f"All Tensorboards and logging are being written inside {tempdir_path}/.")
# Train BC on expert data.
# BC also accepts as `expert_data` any PyTorch-style DataLoader that iterates over
# dictionaries containing observations and actions.
logger.configure(tempdir_path / "BC/")
bc_trainer = bc.BC(venv.observation_space, venv.action_space, expert_data=transitions)
bc_trainer.train(n_epochs=1)
# Train GAIL on expert data.
# GAIL, and AIRL also accept as `expert_data` any Pytorch-style DataLoader that
# iterates over dictionaries containing observations, actions, and next_observations.
logger.configure(tempdir_path / "GAIL/")
gail_trainer = adversarial.GAIL(
venv,
expert_data=transitions,
expert_batch_size=32,
gen_algo=sb3.PPO("MlpPolicy", venv, verbose=1, n_steps=1024),
)
gail_trainer.train(total_timesteps=2048)
# Train AIRL on expert data.
logger.configure(tempdir_path / "AIRL/")
airl_trainer = adversarial.AIRL(
venv,
expert_data=transitions,
expert_batch_size=32,
gen_algo=sb3.PPO("MlpPolicy", venv, verbose=1, n_steps=1024),
)
airl_trainer.train(total_timesteps=2048)
26 changes: 26 additions & 0 deletions docs/guide/install.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
============
Installation
============

**Installing PyPI release**

.. code-block:: bash
pip install imitation
**Install latest commit**

.. code-block:: bash
git clone http://github.com/HumanCompatibleAI/imitation
cd imitation
pip install -e .
**Optional Mujoco Dependency**

Follow instructions to install `mujoco\_py v1.5 here`_.

.. _mujoco_py v1.5 here:
https://github.com/openai/mujoco-py/tree/498b451a03fb61e5bdfcb6956d8d7c881b1098b5#install-mujoco
47 changes: 42 additions & 5 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,55 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to imitation's documentation!
=====================================
=================================================================
Imitation: Clean implementations of Imitation Learning algorithms
=================================================================

``imitation`` is available on GitHub at http://github.com/HumanCompatibleAI/imitation


Main Features
~~~~~~~~~~~~~

- Built on and compatible with Stable Baselines 3 (SB3).
- Modular Pytorch implementations of Behavioral Cloning, DAgger, GAIL, and AIRL that can
train arbitrary SB3 policies.
- GAIL and AIRL have customizable reward and discriminator networks.
- Scripts to train policies using SB3 and save rollouts from these policies as synthetic "expert" demonstrations.
- Data structures and scripts for loading and storing expert demonstrations.


.. toctree::
:maxdepth: 2
:caption: Contents:
:caption: User Guide
:hidden:

guide/install
guide/gettingstarted


.. toctree::
:maxdepth: 2
:caption: Algorithms
:hidden:

algorithms/bc
algorithms/gail
algorithms/airl
algorithms/dagger


.. toctree::
:maxdepth: 3
:caption: API
:hidden:

modules/imitation



Indices and tables
Index
==================

* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
9 changes: 0 additions & 9 deletions docs/modules.rst

This file was deleted.

File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,3 @@ Module contents
---------------

.. automodule:: imitation.envs.examples

Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,3 @@ Module contents
---------------

.. automodule:: imitation.envs

Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ imitation.policies.base module
------------------------------

.. automodule:: imitation.policies.base

imitation.policies.serialize module
-----------------------------------

Expand All @@ -19,4 +19,3 @@ Module contents
---------------

.. automodule:: imitation.policies

Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ Submodules
imitation.rewards.discrim\_nets module
--------------------------------------


.. automodule:: imitation.rewards.discrim_nets

imitation.rewards.reward\_nets module
-------------------------------------

.. automodule:: imitation.rewards.reward_nets

imitation.rewards.serialize
---------------------------

Expand All @@ -24,4 +25,3 @@ Module contents
---------------

.. automodule:: imitation.rewards

2 changes: 1 addition & 1 deletion docs/imitation.rst → docs/modules/imitation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ Subpackages
Module contents
---------------

.. automodule:: imitation
.. automodule:: imitation
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,27 @@ imitation.scripts.config.analyze module
----------------------------------------

.. automodule:: imitation.scripts.config.analyze

imitation.scripts.config.common module
--------------------------------------

.. automodule:: imitation.scripts.config.common

imitation.scripts.config.eval\_policy module
--------------------------------------------

.. automodule:: imitation.scripts.config.eval_policy

imitation.scripts.config.expert\_demos module
---------------------------------------------

.. automodule:: imitation.scripts.config.expert_demos

imitation.scripts.config.parallel module
----------------------------------------

.. automodule:: imitation.scripts.config.parallel

imitation.scripts.config.train\_adversarial module
--------------------------------------------------

Expand All @@ -38,4 +38,4 @@ imitation.scripts.config.train\_adversarial module
Module contents
---------------

.. automodule:: imitation.scripts.config
.. automodule:: imitation.scripts.config
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@ imitation.scripts.analyze module
--------------------------------

.. automodule:: imitation.scripts.analyze

imitation.scripts.eval\_policy module
-------------------------------------

.. automodule:: imitation.scripts.eval_policy

imitation.scripts.expert\_demos module
--------------------------------------

.. automodule:: imitation.scripts.expert_demos

imitation.scripts.parallel module
-------------------------------------------

Expand All @@ -35,10 +35,9 @@ imitation.scripts.train\_adversarial module
-------------------------------------------

.. automodule:: imitation.scripts.train_adversarial


Module contents
---------------

.. automodule:: imitation.scripts

File renamed without changes.

0 comments on commit 9ddd6fa

Please sign in to comment.