-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add LSTM layer #2
Comments
Hi Chih, I have implemented LSTM, and will probably release that code once i have some time. I did not do extensive tests so i can't really speak of its performance. However it did learn. |
Hello Alfredo, I also begin to implement LSTM (not finished yet, just as an exercise), maybe later will try another architecture (echo state machine?) and see if it is comparable to LSTM or not. In any case it is good to Thanks a lot and wait for your good news. Best, |
Just a quick tip, it may be a little tricky to implement an LSTM here. This is because you may have experiences from different episodes (from the same environment) in the same minibatch. What this means is that you have to reset the state of the LSTM within a minibatch, something that is not supported by Tensorflow's dynamic_rnn. I will attach the code for the LSTM networ, hope it will help you, class LSTMNetwork(Network):
def __init__(self, conf):
super(LSTMNetwork, self).__init__(conf)
self.lstm_size = conf['lstm_size']
self.lstm_state = (tf.zeros((self.emulator_counts, self.lstm_size), dtype=tf.float32),
(tf.zeros((self.emulator_counts, self.lstm_size), dtype=tf.float32)))
with tf.device(self.device):
with tf.name_scope(self.name):
# 0.0 if the episode was over on the previous timestep, else 1.0
self.prev_episode_over_mask_ph = tf.placeholder(tf.float32, [None, self.emulator_counts], name='episode_over_mask')
self.lstm_state_c = tf.Variable(self.lstm_state[0], trainable=False, dtype=tf.float32)
self.lstm_state_h = tf.Variable(self.lstm_state[1], trainable=False, dtype=tf.float32)
self.reset_state = tf.group(tf.assign(self.lstm_state_c, self.lstm_state[0]),
tf.assign(self.lstm_state_h, self.lstm_state[1]))
stored_lstm_state_c = tf.Variable(self.lstm_state[0], trainable=False, dtype=tf.float32)
stored_lstm_state_h = tf.Variable(self.lstm_state[1], trainable=False, dtype=tf.float32)
stored_lstm_state = tf.contrib.rnn.LSTMStateTuple(self.lstm_state_c, self.lstm_state_h)
input_dim = self.output.get_shape().as_list()[1]
reshaped_input = tf.reshape(self.output, [-1, self.emulator_counts, input_dim])
max_time = tf.cast(tf.gather(tf.cast(tf.shape(reshaped_input), dtype=tf.float32), 0), dtype=tf.int32)
self.max_time = max_time
inputs_ta = tf.TensorArray(dtype=tf.float32, size=1, dynamic_size=True, name='tensor_array_inputs', clear_after_read=False)
inputs_ta = inputs_ta.unstack(reshaped_input, 'unstack_inputs')
episode_over_ta = tf.TensorArray(dtype=tf.float32, size=self.emulator_counts, dynamic_size=True, name='tensor_array_episode_over')
episode_over_ta = episode_over_ta.unstack(self.prev_episode_over_mask_ph, 'unstack_episode_over_mask')
cell = tf.contrib.rnn.BasicLSTMCell(self.lstm_size)
def loop_fn(time, cell_output, cell_state, loop_state):
nonlocal max_time
emit_output = cell_output # == None for time == 0
if cell_output is None: # time == 0
use_cell_state = stored_lstm_state
else:
use_cell_state = cell_state
episode_over = tf.expand_dims(episode_over_ta.read(time), axis=1)
# If episode ended, next cell state should be zero.
next_cell_state = tf.contrib.rnn.LSTMStateTuple(tf.multiply(use_cell_state[0], episode_over),
tf.multiply(use_cell_state[1], episode_over))
elements_finished = tf.greater_equal(time, max_time, name='elements_finished')
# If time = max_time then the loop will finish and next_input will not be used,
# this minimum effectively replaces a conditional statement by reusing the last input.
next_input = inputs_ta.read(tf.minimum(time, max_time-1))
next_loop_state = None
return (elements_finished, next_input, next_cell_state,
emit_output, next_loop_state)
outputs_ta, final_state, _ = tf.nn.raw_rnn(cell, loop_fn)
update_lstm_state = tf.group(tf.assign(self.lstm_state_c, final_state[0]),
tf.assign(self.lstm_state_h, final_state[1]))
with tf.control_dependencies([update_lstm_state]):
self.output_stack = outputs_ta.stack()
self.output = tf.reshape(self.output_stack, [-1, self.lstm_size])
self.store_lstm_state = tf.group(tf.assign(stored_lstm_state_c, self.lstm_state_c),
tf.assign(stored_lstm_state_h, self.lstm_state_h))
self.rollback_lstm_state = tf.group(tf.assign(self.lstm_state_c, stored_lstm_state_c),
tf.assign(self.lstm_state_h, stored_lstm_state_h))
# Normal flow
# 1. store_lstm_state
# 2. output + update_lstm_state
# 3. repeat #2 max_local_steps timesteps
# 4. output + update_lstm_state
# 5. rollback_lstm_state
# 6. output + update weights + update_lstm_state |
Hello Alfredo, Thanks a lot, I really learn something from your reply. I will continue to work on it and maybe can provide new information later. Thanks a lot. Best, |
@Alfredvc I don't understand the need to reset the LSTM with every minibatch. I would assume that we want the LSTM cell to learn across minibatch otherwise the LSTM is being retrained on every minibatch, am I missing something? |
What happens is that each minibatch contains the experience from the different environments for a set number of timesteps. Given that these environments are episodic, once an episode is over you must reset the state of the lstm. It would make no sense to backpropagate across episodes. If episodes only terminated at the end of a minibatch you could just run the minibatch, and then reset the state of the lstm corresponding to the environments who's episodes have terminated. However episodes may terminate at any point in the minibatch, meaning that you must be able to reset the state of the lstm even within a minibatch. So the idea is not to reset the state after each minibatch, but to reset it after the episode ends. I hope this makes sense. |
Yes that is how you would implement experience replay for the lstm architecture. And it is similar to what the code does. However it is currently only only for on-policy data, so the "replay memory" is just the experience you have gathered since the last update. I think I may have misunderstood what @zencoding meant by resetting the state after each minibatch. I think you may be referring to the rollback in the comment? In that case, before you do one step of optimization, as @pisiiki says, you must roll back the state of the lstm to the state it had before it first encountered the first transition in the minibatch. |
@Alfredvc thanks for the clarification, this helps. I am wondering why is it not done in most implementation that uses LSTM, for example, https://github.com/zencoding/DeepRL-Agents/blob/master/A3C-Doom.ipynb which is an on-policy A3C implementation, is it due to batching performed in paac? I was trying to get MountainCar working on paac but I gave up after numerous changes to hyperparameters, it was just not learning anything. Reading online and other github implementations for MountainCar, it seems that it is not built to be solved without Experience Replay (or some kind of higher exploration than on-policy). I looked at DDPG and ACER, I liked ACER and it is closer to paac. I am going to attempt to reimplement ACER (from https://github.com/Kaixhin/ACER) in paac after I get ACER working for MountainCar. This code will be very helpful as LSTM helps a lot when it comes to temporal learning. |
Sorry for the super late response, I was on vacation for a while! The issue with the LSTM is usualy resolved in one of two ways, having a dynamic batch size or padding. I have done experiments with experience replay using something similar to ACER, and also different technique that i have not yet presented. I will probably present both of this at some point in the future. |
Hello,
May I ask a naive question, did you try to implement LSTM on this architecture? Or you already did it and find it is not efficient (maybe time consuming?) as people think.
In any case thanks for not such harware-demanding idea/architecture.
Best,
Chih-Chieh
The text was updated successfully, but these errors were encountered: