-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
QF_Loss backprops policy network #5
Comments
Is this the error you are talking about? Because I have been trying to debug this too, can add full outputs if helpful.
|
I don't think the Qf_loss backprops the policy_loss, because they use different optimizer for the policy and the Q networks, respectively. In any case, have you tried to move the:
just after the computation of the policy loss ? |
I'm currently testing this (small change) PR. It blocks the gradient flow to the q functions in the policy update which prevents the error. |
I am afraid that change will break the learning of the policy itself, because the policy_loss = (alpha*log_pi - q_new_actions.detach()).mean() will also block the gradient flow to the policy, since q_new_actions is computed as below: if self.num_qs == 1:
q_new_actions = self.qf1(obs, new_obs_actions)
else:
q_new_actions = torch.min(
self.qf1(obs, new_obs_actions),
self.qf2(obs, new_obs_actions),
) and the new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
obs, reparameterize=True, return_log_prob=True,
) Therefore, the policy weights will only be updated to minimize the (On a personal anecdote, I did the exact same thing when implementing SAC a while ago. It is critical not to detach the Q values when updating the policy. I think that is also the main reason the optimizers are separated for the policy and the Q networks: so that the action value of the policy can be backproped through the Q functions, without altering the weights of the latter.) |
@olliejday I think the error is caused due to pytorch version. If you try like torch 1.4 that could fix it. Something more might break it. Could you please confirm if this is the issue or not? |
@olliejday @dosssman The Q-function detach will not work, since then the policy is not trained using the Q-function which is incorrect. |
I'm looking at what @dosssman says but in reverse (ie. moving the q functions rather than the policy update) So moving
To just before the policy update but after the Q updates, ie. directly above here:
This stops the error and seems to match the order in the paper I'm trying it now otherwise will test which torch versions work. Thanks |
Hello, Any updates on this issue? Unmodified code w/ torch=1.4 : 3638.71 |
Hi, I ended up just reverting torch versions to 1.4 |
Adding |
@dssrgu Thanks for your contribution!, I have adapted your commits. But here are some questions. |
@sweetice Hi, I did test the modified version against the original version (which was ran on torch == 1.4), and the two versions had similar performances on d4rl datasets. I do not have the actual values right now though. Note: You may have to additionally correct the |
@dssrgu Did you get similar results to the values reported in D4RL paper? I both tried the paper hyperparameters(policy_lr = 3e-5, lagrange_thresh=10.0) and the recommended one in this github (policy_lr = 1e-4, lagrange_thresh=-1.0) in Pytorch 1.4 and 1.7+, but I can not obtain similar values in some environemnts, for example, there is a big difference in 'halfcheetah-medium-expert-v0', and huge difference in Adroit task, like 'pen-human', 'hammer-human' and 'door-human'. Do you know how to set the hyperparemeters to make CQL work in most cases? Thanks! |
@Zhendong-Wang I found policy_lr=1e-4, min_q_weight=10.0, lagrange_thresh=-1.0 to work fairly well on most of the gym environments, though I used '*-v2' datasets. Exceptionally, for 'halfcheetah-random-v2', policy_lr=1e-4, min_q_weight=1.0, lagrange_thresh=10.0 works well. If the problem is only the medium-expert datasets, it seems the algorithm needs to run 3000 epochs to converge. For Adroit task, I also could not reproduce the results... |
@dssrgu Could you give me some advice? I use the hyparameter you recommended, and the results in 'medium' envs are keep in line with the CQL paper results. I believe there must be something wrong in my settings, which are: Do you have any suggestions for me? |
@cangcn Actually, with the github code and the hyperparameters recommended in Readme file, I can not reproduce the reported results in D4RL paper, even in Gym tasks. I tried both 'v2' and 'v0'. The performance on 'v2' is generally better than 'v0'. It still can not match most of the results reported, though it is mentioned in D4RL they used 'v0' for fair comparison . |
Hi. @olliejday: I think this issue shouldn't be resolved by just resorting to switching back to torch versions below 1.5 (i.e. <=1.4), because then the reproducibility relies on the bug in the torch code (see this thread). According to the linked discussion, in torch < 1.5, even when the code runs and trains network parameters, the computed gradients can be incorrect, which is fixed in torch >=1.5. Hopefully, the PR that @dssrgu posted can solve this issue, but for some tasks, it seems the results cannot be reproduced.. I hope the original author @aviralkumar2907 can provide some feedback on this matter :) In the meantime, I think I'll use @dssrgu's modifications to make the code runnable. Thanks! |
In the CQL trainer, the policy_loss is formulated before the QF_Loss is, but the QF_Loss backprops the policy network before policy_loss does, which causes a Torch error. Would the intended use be to optimize policy network on the policy_loss before formulating the QF_Loss (and still optimize the policy using the QF_Loss) or to not reparametrize the policy output when formulating the QF_Loss (eg line 201)?
The text was updated successfully, but these errors were encountered: