16.03.2021 Views

Advanced Deep Learning with Keras

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

Deep Reinforcement Learning

# compute Q_max

# use of target Q Network solves the non-stationarity problem

def get_target_q_value(self, next_state):

# max Q value among next state's actions

if self.ddqn:

# DDQN

# current Q Network selects the action

# a'_max = argmax_a' Q(s', a')

action = np.argmax(self.q_model.predict(next_state)[0])

# target Q Network evaluates the action

# Q_max = Q_target(s', a'_max)

q_value = self.target_q_model.predict(next_state)[0]

[action]

else:

# DQN chooses the max Q value among next actions

# selection and evaluation of action is on the

# target Q Network

# Q_max = max_a' Q_target(s', a')

q_value = np.amax(self.target_q_model.predict(next_state)

[0])

# Q_max = reward + gamma * Q_max

q_value *= self.gamma

q_value += reward

return q_value

# experience replay addresses the correlation issue between

samples

def replay(self, batch_size):

# sars = state, action, reward, state' (next_state)

sars_batch = random.sample(self.memory, batch_size)

state_batch, q_values_batch = [], []

# fixme: for speedup, this could be done on the tensor level

# but easier to understand using a loop

for state, action, reward, next_state, done in sars_batch:

# policy prediction for a given state

q_values = self.q_model.predict(state)

# get Q_max

q_value = self.get_target_q_value(next_state)

[ 300 ]

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!