Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Agents: Short-Horizon Actor Critic #262

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Conversation

peabody124
Copy link
Contributor

@peabody124 peabody124 commented Nov 19, 2022

This is an implementation of https://arxiv.org/pdf/2204.07137.pdf

Not sure if there is interest merging this into the main branch. This might be an algorithm
worth supporting as it leverages the differentiable simulator to outperform PPO according
to the paper.

Note that many of the environments don't actually have rewards that are differentiable
w.r.t. the actions, in which case this algorithm performs poorly. For example, the fast
environment used for testing APG and SHAC isn't. I added a fast_differentiable env and
also made APG use this by default, after which the performance is much better.

Still could do with tuning for environments and replicating the performance benefits seen
in the original manuscript.

Addressed #247

Stills needs to have target network
Need to separate the policy learning so we can differentiate
through the experience back to the policy network.
The original fast is not, due to action>0 as the control
signal. Thus APG and other differentiable solvers like
SHAC perform poorly.
This gives a performance score comparable to other
algos.
@peabody124
Copy link
Contributor Author

BTW I would love any feedback on this. It seems much slower than I had hoped. I'm not sure if there are some functions that should be fitted, such as the policy gradient function https://github.com/peabody124/brax/blob/shac/brax/training/agents/shac/train.py#L185

Starting to see progress training the ant environment.
@cdfreeman-google
Copy link
Collaborator

Oh wow this is great! Thanks for the clean implementation!

Could you provide us some benchmarking of this method versus the other ones we support? Something like training curves vs. Ant for this / APG without this / PPO / SAC with reward and wallclock time on a public TPU colab runtime would be fantastic. Basically, we want to make sure it's competitive before committing to supporting this.

@EelcoHoogendoorn
Copy link

EelcoHoogendoorn commented Jan 18, 2023

Hey, nice indeed to see this. Im also curious if the performance benefits claimed in the paper manage to replicate in this implementation., on a scenario not too terribly synthetic. Should we view your comment in the original post up top as saying that it does not appear to provide a benefit over PPO in any of the currently implemented environments? ITs been a while since I read the paper, but that would kinda surprise me; or are you saying that isnt surprising, given the characteristics of the environments as formulated now?

@peabody124
Copy link
Contributor Author

Great to see the interest! I'm pretty busy right now but hopefully in Feb can get some benchmarking numbers. I'm not blown away by the performance I saw. I was using it with a humanoid environment, but have the impression differentiating through the physics engine causes quite an increase in wall time. This may not be the SHAC aspect, because I see similar things with a PPO modification I made to try and differentiate and optimize some anthropomorphic properties of the models, and enabling this easily halfs the sps.

However, I should tune it up on ant and cheetah and and produce some benchmarks to really evaluate it and see if it is worth merging or something to motivation some optimization.

@EelcoHoogendoorn
Copy link

Re-reading the paper, where they claim noticable gains in wall clock time is really only for the very high degree of freedom walker env they try; which im not sure has an analogue in brax. For the ant there does appear to be a modest objective benefit; but of a magnitude that I can easily see vaporize under an alternative implementation, as these things often do. Still; when you are say you are not blown away by the wall time, what are we talking about? Similar OOM as PPO, or something much worse?

If its true the backprop is not really worth the trouble, its interesting to reflect on that. In theory, backprop should be of a similar computation cost as a forward pass. The results of difftaichi also show this in practice. Is this then more the fault of JAX than anything? Frankly after working with JAX for quite a bit in the past year, im still kinda uncomfortable with the black box nature of some things; the fact that differentiating through a scan 'just works' is pretty cool; but the fact that I have no control over checkpointing, or good visibility as to what is optimized and what isnt, is a little eerie.

Even if the wall clock time is not an obvious win given the constraints of brax/jax, I still think there is value in merging/maintaining this; since it does remain a very valuable benchmark in probing that exact question, as to wether there actually is a point to this whole differentiable physics in JAX thing. If the answer is no, thats also good to know. And aside from wall time, I can see both theoretical reasons, and empirical support in their paper, that a hybrid of a learned critic plus BPTT should converge better / more robustly in some type of environments; which can be a fantastic thing even if you have to wait 10x longer for the result.

@EelcoHoogendoorn
Copy link

EelcoHoogendoorn commented Feb 10, 2023

@peabody124: hope you dont mind the ping, but im still curious; can you give a one-sentence impression of the kind of performance you did observe, compared to ppo for instance?

@cdfreeman-google: Last time I did a deep dive into brax a year ago I dont think its contact implementation was particularly differentiable. Where would you say it stands today, with respect to the known issues with differentiable contact handling? can we opt in to a moderate penalty based form of contact that should preserve differentiability? Or are there other solutions; or solutions on the horizon? Is there any currently supported (nontrivial) env where you feel SHAC 'should' work fine; or would it perhaps be good to implement something like an inverted cartpole in brax, as a trivially differentiable but not entirely trivial benchmark for this kind of algorithm?

EDIT: ah I see there is an inverted cartpole already; seems like that should be differentiable, and a good benchmark for this algorithm? I was thinking of adding an even simpler 2-dof pure pendulum, without a cart; so that the observation space is purely 2d, and one can actually nicely visualize 2d maps of the learning progress of the value function.

@EelcoHoogendoorn
Copy link

EelcoHoogendoorn commented Feb 12, 2023

I refamiliarized myself with the latest brax code; looking at this line here I would expect brax to indeed fare poorly in envs that involve contact; since there is no configurable compliance term here this line implies that all contact handling is hard contact handling, you will run into issues with your gradients without time-of-impact logic.

It would be a simple fix to add a compliance term into the linked line though; if you can tune your setting such that a compliant contact resolves over several timesteps, allowing for some interpenetration, you will recover differentiable gradients. But with contact handling like this I wouldnt expect SHAC to work very much, if at all, on something like a walker environment.

That being said; contactless systems like a cartpole should still provide a good benchmark I think; and my expectation would be that SHAC would outperform APG or PPO; perhaps in terms of wall time, or total reward.

On a related note; ive been working on really reviewing this PR in detail, and it sure is not easy on me! Its pretty clean code, but all these nested function definitions are hard to isolate and test in isolation; either formally, or just mentally. While I do love JAX functional paradigm, I suppose I am still not as fluent in reasoning about it as I would more traditional imperative code. For the record I have not found anything wrong with it yet; though my code has diverged a lot from the code as it is here in my attempts to pull it apart into more unit-testable components.

@peabody124
Copy link
Contributor Author

Thanks @EelcoHoogendoorn for looking into it.

I totally relate to you on the challenges understanding the different nesting layers of rollouts etc. Quite confusing to track. For another project, I made a modified PPO where various anthropomorphic had to be propagated through the rollouts which also really pushed me to delve into it.

I'm still trying to carve out some time to benchmark this more fully. I know on my humanoid testing, the performance benefits were underwhelming. However, I think there is a lot to be set for a benchmark like you described. The fast differentiable unit test and humanoid where my prior two main tests.

@EelcoHoogendoorn
Copy link

Ive been moving on from the unit testing to some cartpole-benchmarking myself; but to give an initial impression, I am not having much luck yet.

class HardInvertedPendulum(InvertedPendulum):
  def __init__(self, h=0.1):
    super(HardInvertedPendulum, self).__init__()
    self.h = h
  def _noise(self, rng):
    return jp.random_uniform(rng, (self.sys.num_joint_dof,), -1, 1) * self.h

Ive added a tunable hardness parameter since the default cartpole is kind trivial and solved in one epoch by PPO. Its probably just my hyperparams but even with h=0.3 PPO isn't making much progress on this problem, so that does make it an interesting benchmark.

I suppose reverse-engineering the actual intended meaning of the high level parameters would be the first order of busines... just wasted 3 hours coming to realize the importance of the num_evals parameter (I guess in my time the kids used to call that n_epochs; or maybe I still dont understand its true meaning?). Good times. I realize youve just been following the brax conventions; whom in turn are just following what appears to be the broader RL community spirit of 'real men dont need docstrings'. But yeah.

@EelcoHoogendoorn
Copy link

One thing im struggling with is with the intended meaning of num_minibatches. I imagined its only intended to impact the value function training; but it also appears in the length of the policy rollout? It feels wrong that you can completely screw over the sensibility of your policy rollout length by changing a parameter which shouldnt really have anything to do with policy training; but its perfectly possible I misunderstand the actual intent.

@EelcoHoogendoorn
Copy link

EelcoHoogendoorn commented Feb 13, 2023

Indeed im also getting massive slowdowns in both compile and runtime compared to the same simple env using PPO. (and zero convergence as well so far; though that could very well be a matter of hyperparameters).

Im kinda wondering if the compiler isnt choking a bit here on the sheer number of nested scans. There is the action_repeat scan of the env_wrapper, the num_training_steps_per_epoch unroll scan; (which I kinda doubt adds much value when a single train step already takes a solid amount of time?). And then also the extra scan in rollout_loss_fn to chain multiple scanned unroll_lengths together. The purpose of this nested unrolling is not obvious to me, by the way. Why not just train on / process a single unroll_length trajectory of data at a time? Its not clear to me that the original paper is arguing in favor of processing multiple such chained rollouts at the same time, but I could be missing something.

So I count 3 nested scans just to do a rollout; plus whats specific to the training loop. I dont have any concrete evidence that this might be an issue; but if its half as confusing to the jax compiler as it is to me, it might spell trouble; and I am not aware of any other uses of JAX where this amount of nested scanning is going on. If seen the JAX compiler stack freak out about more mundane constructions; its not that mature yet. But I could be completely off-base here.

Ill do a set of experiments versus APG; ive got decent experiences with that myself, but itd be a nice comparison on this env to tease apart any specific trouble with using grads, versus other SHAC specific factors.

@EelcoHoogendoorn
Copy link

Heh; ive moved on to setting up the simplest possible SHAC first; and ive gotten some quite nice results on an inverted pendulum. It did teach me a lot of lessons about JAX development though. Ive become quite adamant about putting shape asserts on all input and output shapes. Its really easy to get disastrous but hard to notice under JITing broadcasting errors with things like your action/reward returning either a shape (1,) or shape (), and so on.

@UltronAI
Copy link

UltronAI commented Jul 26, 2023

Hi @peabody124 and @cdfreeman-google,

I am wondering if there are any recent benchmarking results comparing how SHAC (or APG) performs on the current Brax environments versus PPO or SAC. I tried training the built-in APG agent but was unable to achieve comparable results to PPO - the agent didn't learn as well.

I attempted the parameters suggested in this issue but ran into OOM errors even on an A100 (80GB) with 1024 parallel environments. The issue comments also indicate APG is much worse than PPO, although with different evaluation settings.

I also tried the implemented SHAC code from the PR but was unable to get the expected level of performance, especially in contact-rich environments.

Have there been any promising benchmarking results yet showing these differentiability-leveraging algorithms like SHAC or APG can match the performance of PPO or SAC on the Brax environments? I'm curious if they can achieve parity.

Also, NaNs and exploding gradients are common problems even with a short truncation length (e.g. 10). While I can detect NaN gradients and set them to zero and clip the gradient norm, I'm not sure if this is the right approach. Frankly, I'm not sure if some reward shaping or reward smoothing is needed before applying algorithms like APG and SHAC.

Please let me know if you have any insights on tuning these algorithms or can point me to recent benchmark comparisons. Thanks!

@EelcoHoogendoorn
Copy link

I didnt try and further for a literal reproduction of the paper. However, I have iterated quite a bit within the conceptual space of learned critics and short horizon differentiable rollouts. My general impression is that many of the specific of the SHAC paper are not essential and can readily be tuned to specific applications. I did not get around to any formal or comprehensive benchmarking but qualitatively I was quite happy getting reasonable performance and decent robustness.

@UltronAI
Copy link

Thank you for your response! @EelcoHoogendoorn

I was wondering if you could provide further details about your personal experiences with SHAC and APG. Have you achieved promising results using Brax or any other environments? Lately, I've been facing challenges while attempting to apply SHAC to more extensive benchmarks. If you have any advice or insights, I would greatly appreciate it. Thank you!

@EelcoHoogendoorn
Copy link

EelcoHoogendoorn commented Jul 28, 2023

Its not part of a public repo unfortunately; but its a context with a simulator qualitatively similar to Brax; the most salient difference being that the compliance of contacts is configurable using XPBD logic; and the problems I have been applying it to have either been contact-free; or had 'well resolved' contacts, meaning the contacts were sufficiently compliant that theyd be resolved over multiple simulation timesteps. Thats a significant difference with brax; or the hard-contact-context ive seen in many papers trying to use differentiable simulation, so I cannot really comment on that context. But for what it worth, in such a context, the advantages over pure APG / no learned critic, or things like PPO, seems quite solid.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants