We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I found the indexing in build_function not right. You can run the code below to testify the wrong indexing in VS[:, A]
This indexing should be written like line 51, 52 in https://github.com/ShibiHe/DQN_OpenAI_keras/blob/master/agents.py or https://github.com/spragunr/deep_q_rl/blob/master/deep_q_rl/q_network.py
def build_functions(self): S = Input(shape=self.state_size) NS = Input(shape=self.state_size) A = Input(shape=(1,), dtype='int32') R = Input(shape=(1,), dtype='float32') T = Input(shape=(1,), dtype='int32') self.build_model() self.value_fn = K.function([S], self.model(S)) VS = self.model(S) VNS = disconnected_grad(self.model(NS)) future_value = (1-T) * VNS.max(axis=1, keepdims=True) discounted_future_value = self.discount * future_value target = R + discounted_future_value cost0 = VS[:, A] - target cost = ((VS[:, A] - target)**2).mean() opt = RMSprop(0.0001) params = self.model.trainable_weights updates = opt.get_updates(params, [], cost) self.train_fn = K.function([S, NS, A, R, T], [cost, cost0, target, A], updates=updates) # import numpy as np # t = self.train_fn([np.random.rand(10, *self.state_size), np.random.rand(10, *self.state_size), np.ones((10, 1)), np.ones((10, 1)), np.zeros((10, 1))]) # print('cost=', t[0]) # print('cost0=', t[1]) # print('target=', t[2]) # print('A=', t[3]) # raw_input()_
The text was updated successfully, but these errors were encountered:
Hi @ShibiHe, thanks for your comment. You're right. This is a bug. I should be using np.arange(n) instead of :.
np.arange(n)
:
Sorry, something went wrong.
No branches or pull requests
I found the indexing in build_function not right.
You can run the code below to testify the wrong indexing in VS[:, A]
This indexing should be written like line 51, 52 in https://github.com/ShibiHe/DQN_OpenAI_keras/blob/master/agents.py or https://github.com/spragunr/deep_q_rl/blob/master/deep_q_rl/q_network.py
The text was updated successfully, but these errors were encountered: