16.03.2021 Views

Advanced Deep Learning with Keras

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

Listing 9.6.1 shows us the DQN implementation for CartPole-v0. The DQNAgent

class represents the agent using DQN. Two Q-Networks are created:

1. Q-Network or Q in Algorithm 9.6.1

2. Target Q-Network or Q target

in Algorithm 9.6.1

Chapter 9

Both networks are MLP with three hidden layers of 256 units each. The Q-Network

is trained during experience replay, replay(). At a regular interval of C = 10

training steps, the Q-Network parameters are copied to the Target Q-Network

by update_weights(). This implements line 13, Q target

= Q, in algorithm 9.6.1.

After every episode, the ratio of exploration-exploitation is decreased by update_

epsilon() to take advantage of the learned policy.

To implement line 10 in Algorithm 9.6.1 during experience replay, replay(), for each

experience unit, (s j

, a j

, r j+1

, s j+1

), the Q value for the action a j

is set to Q max

. All other

actions have their Q values unchanged.

This is implemented by the following lines:

# policy prediction for a given state

q_values = self.q_model.predict(state)

# get Q_max

q_value = self.get_target_q_value(next_state)

# correction on the Q value for the action used

q_values[0][action] = reward if done else q_value

Only the action a j

has a non-zero loss equal to ( Q ( )) 2

max

Q s

j

, a

j;

θ

− as shown by line

11 of Algorithm 9.6.1. Note that the experience replay is called by the perceptionaction-learning

loop in Listing 9.6.2 after the end of each episode assuming that there

is sufficient data in the buffer (that is, buffer size, is greater or equal to batch size).

During the experience replay, one batch of experience units is randomly sampled

and used to train the Q-Network.

Similar to the Q-Table, act() implements the ε -greedy policy, Equation 9.6.1.

Experiences are stored by remember() in the replay buffer. The computation

of Q is done by the get_target_q_value() function. On the average of 10 runs,

CartPole-v0 is solved by DQN within 822 episodes. We need to take note that

the results may vary every time the training runs.

Listing 9.6.1, dqn-cartpole-9.6.1.py shows us the DQN implementation within

Keras:

from keras.layers import Dense, Input

[ 297 ]

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!