Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
0e9969c5
Commit
0e9969c5
authored
May 06, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Unify PPO and Impala to Cleanba
parent
907b51bc
Changes
3
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
1010 additions
and
54 deletions
+1010
-54
scripts/cleanba.py
scripts/cleanba.py
+954
-0
scripts/impala.py
scripts/impala.py
+31
-28
scripts/ppo.py
scripts/ppo.py
+25
-26
No files found.
scripts/cleanba.py
0 → 100644
View file @
0e9969c5
This diff is collapsed.
Click to expand it.
scripts/impala.py
View file @
0e9969c5
...
@@ -73,12 +73,12 @@ class Args:
...
@@ -73,12 +73,12 @@ class Args:
"""the maximum number of options"""
"""the maximum number of options"""
n_history_actions
:
int
=
32
n_history_actions
:
int
=
32
"""the number of history actions to use"""
"""the number of history actions to use"""
greedy_reward
:
bool
=
Tru
e
greedy_reward
:
bool
=
Fals
e
"""whether to use greedy reward (faster kill higher reward)"""
"""whether to use greedy reward (faster kill higher reward)"""
total_timesteps
:
int
=
5000000000
total_timesteps
:
int
=
5000000000
0
"""total timesteps of the experiments"""
"""total timesteps of the experiments"""
learning_rate
:
float
=
1e-3
learning_rate
:
float
=
3e-4
"""the learning rate of the optimizer"""
"""the learning rate of the optimizer"""
local_num_envs
:
int
=
128
local_num_envs
:
int
=
128
"""the number of parallel game environments"""
"""the number of parallel game environments"""
...
@@ -92,12 +92,12 @@ class Args:
...
@@ -92,12 +92,12 @@ class Args:
"""Toggle learning rate annealing for policy and value networks"""
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
1.0
gamma
:
float
=
1.0
"""the discount factor gamma"""
"""the discount factor gamma"""
upgo
:
bool
=
False
num_minibatches
:
int
=
64
"""Toggle the use of UPGO for advantages"""
num_minibatches
:
int
=
8
"""the number of mini-batches"""
"""the number of mini-batches"""
update_epochs
:
int
=
2
update_epochs
:
int
=
2
"""the K epochs to update the policy"""
"""the K epochs to update the policy"""
upgo
:
bool
=
True
"""Toggle the use of UPGO for advantages"""
c_clip_min
:
float
=
0.001
c_clip_min
:
float
=
0.001
"""the minimum value of the importance sampling clipping"""
"""the minimum value of the importance sampling clipping"""
c_clip_max
:
float
=
1.007
c_clip_max
:
float
=
1.007
...
@@ -141,9 +141,9 @@ class Args:
...
@@ -141,9 +141,9 @@ class Args:
eval_checkpoint
:
Optional
[
str
]
=
None
eval_checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to evaluate"""
"""the path to the model checkpoint to evaluate"""
local_eval_episodes
:
int
=
32
local_eval_episodes
:
int
=
128
"""the number of episodes to evaluate the model"""
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
5
0
eval_interval
:
int
=
10
0
"""the number of iterations to evaluate the model"""
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
# runtime arguments to be filled in
...
@@ -193,6 +193,7 @@ class Transition(NamedTuple):
...
@@ -193,6 +193,7 @@ class Transition(NamedTuple):
logits
:
list
logits
:
list
rewards
:
list
rewards
:
list
mains
:
list
mains
:
list
next_dones
:
list
def
create_agent
(
args
,
multi_step
=
False
):
def
create_agent
(
args
,
multi_step
=
False
):
...
@@ -203,6 +204,7 @@ def create_agent(args, multi_step=False):
...
@@ -203,6 +204,7 @@ def create_agent(args, multi_step=False):
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
lstm_channels
=
args
.
rnn_channels
,
lstm_channels
=
args
.
rnn_channels
,
switch
=
False
,
multi_step
=
multi_step
,
multi_step
=
multi_step
,
freeze_id
=
args
.
freeze_id
,
freeze_id
=
args
.
freeze_id
,
)
)
...
@@ -373,6 +375,7 @@ def rollout(
...
@@ -373,6 +375,7 @@ def rollout(
actions
=
action
,
actions
=
action
,
logits
=
logits
,
logits
=
logits
,
rewards
=
next_reward
,
rewards
=
next_reward
,
next_dones
=
next_done
,
)
)
)
)
...
@@ -405,7 +408,7 @@ def rollout(
...
@@ -405,7 +408,7 @@ def rollout(
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
next_rstate1
,
next_rstate2
)
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
next_rstate1
,
next_rstate2
)
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
(
init_rstate1
,
init_rstate2
,
(
next_rstate
,
next_obs
),
next_
done
,
next_
main
))
(
init_rstate1
,
init_rstate2
,
(
next_rstate
,
next_obs
),
next_main
))
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
_start
=
time
.
time
()
_start
=
time
.
time
()
...
@@ -616,33 +619,36 @@ if __name__ == "__main__":
...
@@ -616,33 +619,36 @@ if __name__ == "__main__":
return
logits
,
value
.
squeeze
(
-
1
)
return
logits
,
value
.
squeeze
(
-
1
)
def
loss_fn
(
def
loss_fn
(
params
,
rstate1
,
rstate2
,
obs
,
dones
,
main
s
,
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_done
s
,
actions
,
logits
,
rewards
,
mask
,
next_value
,
next_don
e
):
mains
,
actions
,
logits
,
rewards
,
mask
,
next_valu
e
):
# (num_steps * local_num_envs // n_mb))
# (num_steps * local_num_envs // n_mb))
num_envs
=
next_value
.
shape
[
0
]
num_envs
=
next_value
.
shape
[
0
]
num_steps
=
dones
.
shape
[
0
]
//
num_envs
num_steps
=
dones
.
shape
[
0
]
//
num_envs
def
reshape_time_series
(
x
):
return
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)
+
x
.
shape
[
1
:])
mask
=
mask
*
(
1.0
-
dones
)
mask
=
mask
*
(
1.0
-
dones
)
n_valids
=
jnp
.
sum
(
mask
)
n_valids
=
jnp
.
sum
(
mask
)
inputs
=
(
rstate1
,
rstate2
,
obs
,
dones
,
mains
)
inputs
=
(
rstate1
,
rstate2
,
obs
,
dones
,
mains
)
new_logits
,
new_values
=
get_logits_and_value
(
params
,
inputs
)
new_logits
,
new_values
=
get_logits_and_value
(
params
,
inputs
)
new_logits
,
new_values
,
logits
,
actions
,
rewards
,
dones
,
mains
,
mask
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)
+
x
.
shape
[
1
:]),
(
new_logits
,
new_values
,
logits
,
actions
,
rewards
,
dones
,
mains
,
mask
),
)
next_dones
=
jnp
.
concatenate
([
dones
[
1
:],
next_done
[
None
,
:]],
axis
=
0
)
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
logratio
=
jnp
.
log
(
ratios
)
approx_kl
=
(((
ratios
-
1
)
-
logratio
)
*
mask
)
.
sum
()
/
n_valids
ratios_
,
new_values_
,
rewards
,
next_dones
,
mains
=
jax
.
tree
.
map
(
reshape_time_series
,
(
ratios
,
new_values
,
rewards
,
next_dones
,
mains
),
)
# TODO: TD(lambda) for multi-step
# TODO: TD(lambda) for multi-step
target_values
,
advantages
=
vtrace_2p0s
(
target_values
,
advantages
=
vtrace_2p0s
(
next_value
,
ratios
,
new_values
,
rewards
,
next_dones
,
mains
,
args
.
gamma
,
next_value
,
ratios
_
,
new_values_
,
rewards
,
next_dones
,
mains
,
args
.
gamma
,
args
.
rho_clip_min
,
args
.
rho_clip_max
,
args
.
c_clip_min
,
args
.
c_clip_max
)
args
.
rho_clip_min
,
args
.
rho_clip_max
,
args
.
c_clip_min
,
args
.
c_clip_max
)
logratio
=
jnp
.
log
(
ratios
)
target_values
,
advantages
=
jax
.
tree
.
map
(
approx_kl
=
(((
ratios
-
1
)
-
logratio
)
*
mask
)
.
sum
()
/
n_valids
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)),
(
target_values
,
advantages
))
if
args
.
ppo_clip
:
if
args
.
ppo_clip
:
pg_loss
=
clipped_surrogate_pg_loss
(
pg_loss
=
clipped_surrogate_pg_loss
(
...
@@ -671,7 +677,6 @@ if __name__ == "__main__":
...
@@ -671,7 +677,6 @@ if __name__ == "__main__":
sharded_init_rstate1
:
List
,
sharded_init_rstate1
:
List
,
sharded_init_rstate2
:
List
,
sharded_init_rstate2
:
List
,
sharded_next_inputs
:
List
,
sharded_next_inputs
:
List
,
sharded_next_done
:
List
,
sharded_next_main
:
List
,
sharded_next_main
:
List
,
key
:
jax
.
random
.
PRNGKey
,
key
:
jax
.
random
.
PRNGKey
,
learn_opponent
:
bool
=
False
,
learn_opponent
:
bool
=
False
,
...
@@ -682,9 +687,7 @@ if __name__ == "__main__":
...
@@ -682,9 +687,7 @@ if __name__ == "__main__":
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
for
x
in
[
sharded_next_inputs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
for
x
in
[
sharded_next_inputs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
]
]
next_main
,
next_done
=
[
next_main
=
jnp
.
concatenate
(
sharded_next_main
)
jnp
.
concatenate
(
x
)
for
x
in
[
sharded_next_main
,
sharded_next_done
]
]
# reorder storage of individual players
# reorder storage of individual players
# main first, opponent second
# main first, opponent second
...
@@ -713,8 +716,8 @@ if __name__ == "__main__":
...
@@ -713,8 +716,8 @@ if __name__ == "__main__":
return
x
return
x
shuffled_init_rstate1
,
shuffled_init_rstate2
,
\
shuffled_init_rstate1
,
shuffled_init_rstate2
,
\
shuffled_next_value
,
shuffled_next_done
=
jax
.
tree
.
map
(
shuffled_next_value
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
1
),
(
init_rstate1
,
init_rstate2
,
next_value
,
next_done
))
partial
(
convert_data
,
num_steps
=
1
),
(
init_rstate1
,
init_rstate2
,
next_value
))
shuffled_storage
=
jax
.
tree
.
map
(
shuffled_storage
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
num_steps
),
storage
)
partial
(
convert_data
,
num_steps
=
num_steps
),
storage
)
shuffled_mask
=
jnp
.
ones_like
(
shuffled_storage
.
mains
)
shuffled_mask
=
jnp
.
ones_like
(
shuffled_storage
.
mains
)
...
@@ -734,13 +737,13 @@ if __name__ == "__main__":
...
@@ -734,13 +737,13 @@ if __name__ == "__main__":
shuffled_init_rstate2
,
shuffled_init_rstate2
,
shuffled_storage
.
obs
,
shuffled_storage
.
obs
,
shuffled_storage
.
dones
,
shuffled_storage
.
dones
,
shuffled_storage
.
next_dones
,
shuffled_storage
.
mains
,
shuffled_storage
.
mains
,
shuffled_storage
.
actions
,
shuffled_storage
.
actions
,
shuffled_storage
.
logits
,
shuffled_storage
.
logits
,
shuffled_storage
.
rewards
,
shuffled_storage
.
rewards
,
shuffled_mask
,
shuffled_mask
,
shuffled_next_value
,
shuffled_next_value
,
shuffled_next_done
,
),
),
)
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
...
@@ -765,7 +768,7 @@ if __name__ == "__main__":
...
@@ -765,7 +768,7 @@ if __name__ == "__main__":
single_device_update
,
single_device_update
,
axis_name
=
"local_devices"
,
axis_name
=
"local_devices"
,
devices
=
global_learner_decices
,
devices
=
global_learner_decices
,
static_broadcasted_argnums
=
(
8
,),
static_broadcasted_argnums
=
(
7
,),
)
)
params_queues
=
[]
params_queues
=
[]
...
...
scripts/ppo.py
View file @
0e9969c5
...
@@ -74,12 +74,12 @@ class Args:
...
@@ -74,12 +74,12 @@ class Args:
"""the maximum number of options"""
"""the maximum number of options"""
n_history_actions
:
int
=
32
n_history_actions
:
int
=
32
"""the number of history actions to use"""
"""the number of history actions to use"""
greedy_reward
:
bool
=
Tru
e
greedy_reward
:
bool
=
Fals
e
"""whether to use greedy reward (faster kill higher reward)"""
"""whether to use greedy reward (faster kill higher reward)"""
total_timesteps
:
int
=
5000000000
total_timesteps
:
int
=
5000000000
0
"""total timesteps of the experiments"""
"""total timesteps of the experiments"""
learning_rate
:
float
=
1e-3
learning_rate
:
float
=
3e-4
"""the learning rate of the optimizer"""
"""the learning rate of the optimizer"""
local_num_envs
:
int
=
128
local_num_envs
:
int
=
128
"""the number of parallel game environments"""
"""the number of parallel game environments"""
...
@@ -93,16 +93,16 @@ class Args:
...
@@ -93,16 +93,16 @@ class Args:
"""Toggle learning rate annealing for policy and value networks"""
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
1.0
gamma
:
float
=
1.0
"""the discount factor gamma"""
"""the discount factor gamma"""
gae_lambda
:
float
=
0.95
num_minibatches
:
int
=
64
"""the lambda for the general advantage estimation"""
upgo
:
bool
=
False
"""Toggle the use of UPGO for advantages"""
num_minibatches
:
int
=
8
"""the number of mini-batches"""
"""the number of mini-batches"""
update_epochs
:
int
=
2
update_epochs
:
int
=
2
"""the K epochs to update the policy"""
"""the K epochs to update the policy"""
norm_adv
:
bool
=
False
norm_adv
:
bool
=
False
"""Toggles advantages normalization"""
"""Toggles advantages normalization"""
upgo
:
bool
=
True
"""Toggle the use of UPGO for advantages"""
gae_lambda
:
float
=
0.95
"""the lambda for the general advantage estimation"""
clip_coef
:
float
=
0.25
clip_coef
:
float
=
0.25
"""the surrogate clipping coefficient"""
"""the surrogate clipping coefficient"""
dual_clip_coef
:
Optional
[
float
]
=
3.0
dual_clip_coef
:
Optional
[
float
]
=
3.0
...
@@ -113,7 +113,7 @@ class Args:
...
@@ -113,7 +113,7 @@ class Args:
"""the logits threshold for NeuRD and ACH, typically 2.0-6.0"""
"""the logits threshold for NeuRD and ACH, typically 2.0-6.0"""
ent_coef
:
float
=
0.01
ent_coef
:
float
=
0.01
"""coefficient of the entropy"""
"""coefficient of the entropy"""
vf_coef
:
float
=
0.5
vf_coef
:
float
=
1.0
"""coefficient of the value function"""
"""coefficient of the value function"""
max_grad_norm
:
float
=
1.0
max_grad_norm
:
float
=
1.0
"""the maximum norm for the gradient clipping"""
"""the maximum norm for the gradient clipping"""
...
@@ -140,9 +140,9 @@ class Args:
...
@@ -140,9 +140,9 @@ class Args:
eval_checkpoint
:
Optional
[
str
]
=
None
eval_checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to evaluate"""
"""the path to the model checkpoint to evaluate"""
local_eval_episodes
:
int
=
32
local_eval_episodes
:
int
=
128
"""the number of episodes to evaluate the model"""
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
5
0
eval_interval
:
int
=
10
0
"""the number of iterations to evaluate the model"""
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
# runtime arguments to be filled in
...
@@ -203,6 +203,7 @@ def create_agent(args, multi_step=False):
...
@@ -203,6 +203,7 @@ def create_agent(args, multi_step=False):
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
lstm_channels
=
args
.
rnn_channels
,
lstm_channels
=
args
.
rnn_channels
,
switch
=
True
,
multi_step
=
multi_step
,
multi_step
=
multi_step
,
freeze_id
=
args
.
freeze_id
,
freeze_id
=
args
.
freeze_id
,
)
)
...
@@ -632,28 +633,30 @@ if __name__ == "__main__":
...
@@ -632,28 +633,30 @@ if __name__ == "__main__":
num_envs
=
next_value
.
shape
[
0
]
num_envs
=
next_value
.
shape
[
0
]
num_steps
=
dones
.
shape
[
0
]
//
num_envs
num_steps
=
dones
.
shape
[
0
]
//
num_envs
def
reshape_time_series
(
x
):
return
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)
+
x
.
shape
[
1
:])
mask
=
mask
*
(
1.0
-
dones
)
mask
=
mask
*
(
1.0
-
dones
)
n_valids
=
jnp
.
sum
(
mask
)
n_valids
=
jnp
.
sum
(
mask
)
real_
dones
=
dones
|
next_dones
dones
=
dones
|
next_dones
inputs
=
(
rstate1
,
rstate2
,
obs
,
real_
dones
,
switch
)
inputs
=
(
rstate1
,
rstate2
,
obs
,
dones
,
switch
)
new_logits
,
new_values
=
get_logits_and_value
(
params
,
inputs
)
new_logits
,
new_values
=
get_logits_and_value
(
params
,
inputs
)
new_values_
,
rewards
,
next_dones
,
switch
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)
+
x
.
shape
[
1
:]),
(
new_values
,
rewards
,
next_dones
,
switch
),
)
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
logratio
=
jnp
.
log
(
ratios
)
approx_kl
=
(((
ratios
-
1
)
-
logratio
)
*
mask
)
.
sum
()
/
n_valids
new_values_
,
rewards
,
next_dones
,
switch
=
jax
.
tree
.
map
(
reshape_time_series
,
(
new_values
,
rewards
,
next_dones
,
switch
),
)
target_values
,
advantages
=
truncated_gae_2p0s
(
target_values
,
advantages
=
truncated_gae_2p0s
(
next_value
,
new_values_
,
rewards
,
next_dones
,
switch
,
next_value
,
new_values_
,
rewards
,
next_dones
,
switch
,
args
.
gamma
,
args
.
gae_lambda
,
args
.
upgo
)
args
.
gamma
,
args
.
gae_lambda
,
args
.
upgo
)
target_values
,
advantages
=
jax
.
tree
.
map
(
target_values
,
advantages
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)),
(
target_values
,
advantages
))
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)),
(
target_values
,
advantages
))
logratio
=
jnp
.
log
(
ratios
)
approx_kl
=
(((
ratios
-
1
)
-
logratio
)
*
mask
)
.
sum
()
/
n_valids
if
args
.
norm_adv
:
if
args
.
norm_adv
:
advantages
=
masked_normalize
(
advantages
,
mask
,
eps
=
1e-8
)
advantages
=
masked_normalize
(
advantages
,
mask
,
eps
=
1e-8
)
...
@@ -699,9 +702,7 @@ if __name__ == "__main__":
...
@@ -699,9 +702,7 @@ if __name__ == "__main__":
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
for
x
in
[
sharded_next_inputs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
for
x
in
[
sharded_next_inputs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
]
]
next_main
,
=
[
next_main
=
jnp
.
concatenate
(
sharded_next_main
)
jnp
.
concatenate
(
x
)
for
x
in
[
sharded_next_main
]
]
# reorder storage of individual players
# reorder storage of individual players
# main first, opponent second
# main first, opponent second
...
@@ -722,9 +723,7 @@ if __name__ == "__main__":
...
@@ -722,9 +723,7 @@ if __name__ == "__main__":
next_value
=
create_agent
(
args
)
.
apply
(
next_value
=
create_agent
(
args
)
.
apply
(
agent_state
.
params
,
next_inputs
)[
2
]
.
squeeze
(
-
1
)
agent_state
.
params
,
next_inputs
)[
2
]
.
squeeze
(
-
1
)
# TODO: check if this is correct
next_value
=
jnp
.
where
(
next_main
,
-
next_value
,
next_value
)
sign
=
jnp
.
where
(
switch_steps
<=
num_steps
,
1.0
,
-
1.0
)
next_value
=
jnp
.
where
(
next_main
,
-
sign
*
next_value
,
sign
*
next_value
)
def
convert_data
(
x
:
jnp
.
ndarray
,
num_steps
):
def
convert_data
(
x
:
jnp
.
ndarray
,
num_steps
):
if
args
.
update_epochs
>
1
:
if
args
.
update_epochs
>
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment