0

To preface, I am a complete Julia newbie... I am trying to implement PPO for the first time and I've been having issues updating the actor (and by extension critic) network parameters using the gradient from Flux.jl.

Below is just a small snippet of my code. I think it should be sufficient (?), but if not please let me know and I am happy to provide.

batch_obs, batch_acts, _, batch_rtgo, _ = rollout(ppo_network)
V, curr_log_probs = evaluate(ppo_network, batch_obs, batch_acts)
V_scaled = V*maximum(abs.(batch_rtgo)) / maximum(abs.(V))

ratios = exp.(curr_log_probs - transpose(hcat(old_batch_log_probs...)))
ratios = ratios .+ 1e-8

A_k = batch_rtgo - deepcopy(V_scaled)
A_k = (A_k .- mean(A_k)) ./ (std(A_k) .+ 1e-10)

# surrogate objectives
surr1 = ratios .* A_k
surr2 = clamp.(ratios, 1-clip, 1+clip) .* A_k

actor_loss = -mean(min.(surr1, surr2))
actor_opt = Adam(lr)
actor_gs = gradient(() -> actor_loss, params(ppo.actor.model, ppo.actor.mean, ppo.actor.logstd))

# update parameters
update!(actor_opt, params([ppo.actor.model, ppo.actor.mean, ppo.actor.logstd]), actor_gs)

I checked the gradient and find that the Dict has keys while all values are nothing.

How can I properly define the gradient so that I can perform this update?

I did some research and found that the loss function requires the parameters to be passed into it, otherwise the gradient does not know which parameters to take the gradient w.r.t and thus ends up returning nothing... Is this correct?

I tried to create a custom loss function like this which did seem to get me nonzero gradients, but the problem is that I need to run the rollout and evaluate which can be costly. It also runs into issues with large named tuples as the batch size of the rollout becomes large.

# define actor parameters
actor_ps = params(ppo.actor.model, ppo.actor.mean, ppo.actor.logstd)

# define gradient function
actor_gs = gradient(actor_ps) do 
   batch_obs, batch_acts, _, batch_rtgo, _ = rollout(ppo_network)
   V, curr_log_probs = evaluate(ppo_network, batch_obs, batch_acts)
   V_scaled = V*maximum(abs.(batch_rtgo)) / maximum(abs.(V))
   ratios = exp.(curr_log_probs - transpose(hcat(old_batch_log_probs...)))
   ratios = ratios .+ 1e-8

   A_k = batch_rtgo - deepcopy(V_scaled)
   A_k = (A_k .- mean(A_k)) ./ (std(A_k) .+ 1e-10)
   surr1 = ratios .* A_k
   surr2 = clamp.(ratios, 1-clip, 1+clip) .* A_k

   actor_loss = -mean(min.(surr1, surr2))
   return actor_loss
end

1 Answer 1

1

I checked the gradient and find that the Dict has keys while all values are nothing.

Here's the problem:

actor_loss = -mean(min.(surr1, surr2))

actor_gs = gradient(() -> actor_loss, params(...

The variable actor_loss is just a floating point number, Flux has no idea where it came from. This is because Flux's automatic differentiation engine, Zygote.jl, does not work by making special tracked types which cary extra meaning.

Instead, only calculations which happen within the call to gradient are observed. Zygote only sees operations which occur while evaluating the function () -> actor_loss. But here that function does nothing, it just returns a global variable.

found that the loss function requires the parameters to be passed into it,

Roughly. A tricky part is that there are two ways to use Zygote, and the code above is the old (deprecated) way, called "implicit", with Dict-like Params and Grads objects. In the correct use of this, the function made by () -> ... (or by do) accepts zero arguments, but does need to perform all calculations using the parameters.

The other (recommended) way to use Zygote is called "explicit", and demands that you pass the parameters as arguments to a function. Something like this:

# calculations not involving parameters here
actor_gs = gradient(ppo.actor) do local_actor
  # calculations which use `ppo.actor.model` should use `local_actor.model` here
  actor_loss = -mean(min.(surr1, surr2)) 
end

I don't understand the other steps of your code well enough to comment. Possibly this does not solve your problem with rollout.

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.