2

This is a code snippet using Keras library for creating models:

    for state, action, reward, next_state, done in minibatch:
        target = reward
        if not done:
            target = (reward + self.gamma *
                      np.amax(self.model.predict(next_state)[0]))
        target_f = self.model.predict(state)
        #print (target_f)
        target_f[0][action] = target
        self.model.fit(state, target_f, epochs=1, verbose=0)

I am trying to vectorize it. The only way I think to do is : 1. Create a numpy table with each row = (state, action, reward, next_state, done, target). So, there will be "mini-batch" number of rows. 2. Update target column based on other columns as (using masked arrays):

target[done==True] ==reward
target[done==False] == reward + self.gamma 
*np.amax(self.model.predict(next_state)[0])
  1. Now update self.model.fit(state, target_f, epochs=1, verbose=0)

NB: state is 8-D, so state vector has 8 elements.

Despite hours of efforts, I am unable to code this properly. Is it possible to actually vectorize this piece of code?

1 Answer 1

3

You are very close! Assuming that minibatch is an np.array:

First find all the indices where done is true. Assuming done is index number 4.

minibatch_done=minibatch[np.where(minibatch[:,4]==True)]
minibatch_not_done=minibatch[np.where(minibatch[:,4]==False)]

Now we use this to update the minibatch matrix conditionally. Assuming index 2 is reward and index 3 is next_state

target = np.empty((minibatch.shape[0]))
n_done = minibatch_done.shape[0]
# First half (index 0...n_done)
target[:n_done] = minibatch_done[:,2]+self.gamma*np.amax(self.model.predict(minibatch_done[:,3]))
target[n_done:] = minibatch_not_done[:,2]

And there you have it :)

Edit: Fixed index error in target problems

Sign up to request clarification or add additional context in comments.

7 Comments

Thanks.. But when updating model using model.fit(), the states and corresponding targets should be in a row. Here, does target preserve the order of state?
It does preserve the order of state. Just do target.reshape((-1,1)) to make it a column vector
Thanks Andrew. Would you please also take a look at stackoverflow.com/questions/51068981/…
Hi Andrew, I found a problem. The target was only having the values of minibatch_not_done. I think the target is getting overwritten, might be a masked array on target will help.
target[np.where(minibatch[:,3]==True)]=minibatch_done[:,1]+1*(minibatch_done[:,2]) target[np.where(minibatch[:,3]==False)]=minibatch_not_done[:,1]
|

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.