Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
7c8b11c3
Commit
7c8b11c3
authored
Jun 14, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add collect_steps
parent
e1ff8f92
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
281 additions
and
142 deletions
+281
-142
scripts/cleanba.py
scripts/cleanba.py
+150
-92
ygoai/rl/jax/__init__.py
ygoai/rl/jax/__init__.py
+131
-50
No files found.
scripts/cleanba.py
View file @
7c8b11c3
...
@@ -97,6 +97,8 @@ class Args:
...
@@ -97,6 +97,8 @@ class Args:
"""the number of actor threads to use"""
"""the number of actor threads to use"""
num_steps
:
int
=
128
num_steps
:
int
=
128
"""the number of steps to run in each environment per policy rollout"""
"""the number of steps to run in each environment per policy rollout"""
collect_steps
:
Optional
[
int
]
=
None
"""the number of steps to compute the advantages"""
segment_length
:
Optional
[
int
]
=
None
segment_length
:
Optional
[
int
]
=
None
"""the length of the segment for training"""
"""the length of the segment for training"""
anneal_lr
:
bool
=
False
anneal_lr
:
bool
=
False
...
@@ -226,6 +228,7 @@ class Transition(NamedTuple):
...
@@ -226,6 +228,7 @@ class Transition(NamedTuple):
dones
:
list
dones
:
list
actions
:
list
actions
:
list
logits
:
list
logits
:
list
values
:
list
rewards
:
list
rewards
:
list
mains
:
list
mains
:
list
next_dones
:
list
next_dones
:
list
...
@@ -304,6 +307,31 @@ def reshape_minibatch(
...
@@ -304,6 +307,31 @@ def reshape_minibatch(
return
x
return
x
def
advantage_fn
(
args
,
next_v
,
values
,
rewards
,
next_dones
,
switch_or_mains
,
ratios
=
None
,
return_carry
=
False
):
if
args
.
switch
:
if
args
.
value
==
"vtrace"
or
args
.
sep_value
or
return_carry
:
raise
NotImplementedError
return
gae_sep_switch
(
next_v
,
values
,
rewards
,
next_dones
,
switch_or_mains
,
args
.
gamma
,
args
.
gae_lambda
,
args
.
upgo
)
else
:
# TODO: TD(lambda) for multi-step
if
args
.
value
==
"gae"
:
adv_fn
=
truncated_gae_sep
if
args
.
sep_value
else
truncated_gae
return
adv_fn
(
next_v
,
values
,
rewards
,
next_dones
,
switch_or_mains
,
args
.
gamma
,
args
.
gae_lambda
,
args
.
upgo
,
return_carry
=
return_carry
)
else
:
adv_fn
=
vtrace_sep
if
args
.
sep_value
else
vtrace
if
ratios
is
None
:
ratios
=
jnp
.
ones_like
(
values
)
return
adv_fn
(
next_v
,
ratios
,
values
,
rewards
,
next_dones
,
switch_or_mains
,
args
.
gamma
,
args
.
rho_clip_min
,
args
.
rho_clip_max
,
args
.
c_clip_min
,
args
.
c_clip_max
,
return_carry
=
return_carry
)
def
rollout
(
def
rollout
(
key
:
jax
.
random
.
PRNGKey
,
key
:
jax
.
random
.
PRNGKey
,
args
:
Args
,
args
:
Args
,
...
@@ -370,10 +398,17 @@ def rollout(
...
@@ -370,10 +398,17 @@ def rollout(
@
jax
.
jit
@
jax
.
jit
def
sample_action
(
def
sample_action
(
params
,
next_obs
,
rstate1
,
rstate2
,
main
,
done
,
key
):
params
,
next_obs
,
rstate1
,
rstate2
,
main
,
done
,
key
):
(
rstate1
,
rstate2
),
logits
=
agent
.
apply
(
(
rstate1
,
rstate2
),
logits
,
value
=
agent
.
apply
(
params
,
next_obs
,
(
rstate1
,
rstate2
),
done
,
main
)[:
2
]
params
,
next_obs
,
(
rstate1
,
rstate2
),
done
,
main
)[:
3
]
value
=
jnp
.
squeeze
(
value
,
axis
=-
1
)
action
,
key
=
categorical_sample
(
logits
,
key
)
action
,
key
=
categorical_sample
(
logits
,
key
)
return
next_obs
,
done
,
main
,
rstate1
,
rstate2
,
action
,
logits
,
key
return
next_obs
,
done
,
main
,
rstate1
,
rstate2
,
action
,
logits
,
value
,
key
@
jax
.
jit
def
compute_advantage_carry
(
next_value
,
values
,
rewards
,
next_dones
,
mains
):
return
advantage_fn
(
args
,
next_value
,
values
,
rewards
,
next_dones
,
mains
,
return_carry
=
True
)
deck_names
=
args
.
deck_names
deck_names
=
args
.
deck_names
deck_avg_times
=
{
name
:
0
for
name
in
deck_names
}
deck_avg_times
=
{
name
:
0
for
name
in
deck_names
}
...
@@ -400,11 +435,17 @@ def rollout(
...
@@ -400,11 +435,17 @@ def rollout(
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
])
np
.
random
.
shuffle
(
main_player
)
np
.
random
.
shuffle
(
main_player
)
start_step
=
0
storage
=
[]
storage
=
[]
init_rstates
=
[]
# @jax.jit
# def prepare_data(storage: List[Transition]) -> Transition:
# return jax.tree.map(lambda *xs: jnp.split(jnp.stack(xs), len(learner_devices), axis=1), *storage)
@
jax
.
jit
@
jax
.
jit
def
prepare_data
(
storage
:
List
[
Transition
])
->
Transition
:
def
prepare_data
(
storage
:
List
[
Transition
])
->
Transition
:
return
jax
.
tree
.
map
(
lambda
*
xs
:
jnp
.
s
plit
(
jnp
.
stack
(
xs
),
len
(
learner_devices
),
axis
=
1
),
*
storage
)
return
jax
.
tree
.
map
(
lambda
*
xs
:
jnp
.
s
tack
(
xs
),
*
storage
)
for
update
in
range
(
1
,
args
.
num_updates
+
2
):
for
update
in
range
(
1
,
args
.
num_updates
+
2
):
if
update
==
10
:
if
update
==
10
:
...
@@ -426,16 +467,18 @@ def rollout(
...
@@ -426,16 +467,18 @@ def rollout(
params_queue_get_time
.
append
(
time
.
time
()
-
params_queue_get_time_start
)
params_queue_get_time
.
append
(
time
.
time
()
-
params_queue_get_time_start
)
rollout_time_start
=
time
.
time
()
rollout_time_start
=
time
.
time
()
init_rstate1
,
init_rstate2
=
jax
.
tree
.
map
(
for
k
in
range
(
start_step
,
args
.
collect_steps
):
lambda
x
:
x
.
copy
(),
(
next_rstate1
,
next_rstate2
))
if
k
%
args
.
num_steps
==
0
:
for
k
in
range
(
args
.
num_steps
):
init_rstate1
,
init_rstate2
=
jax
.
tree
.
map
(
lambda
x
:
x
.
copy
(),
(
next_rstate1
,
next_rstate2
))
init_rstates
.
append
((
init_rstate1
,
init_rstate2
))
global_step
+=
args
.
local_num_envs
*
n_actors
*
args
.
world_size
global_step
+=
args
.
local_num_envs
*
n_actors
*
args
.
world_size
main
=
next_to_play
==
main_player
main
=
next_to_play
==
main_player
inference_time_start
=
time
.
time
()
inference_time_start
=
time
.
time
()
cached_next_obs
,
cached_next_done
,
cached_main
,
\
cached_next_obs
,
cached_next_done
,
cached_main
,
\
next_rstate1
,
next_rstate2
,
action
,
logits
,
key
=
sample_action
(
next_rstate1
,
next_rstate2
,
action
,
logits
,
value
,
key
=
sample_action
(
params
,
next_obs
,
next_rstate1
,
next_rstate2
,
main
,
next_done
,
key
)
params
,
next_obs
,
next_rstate1
,
next_rstate2
,
main
,
next_done
,
key
)
cpu_action
=
np
.
array
(
action
)
cpu_action
=
np
.
array
(
action
)
...
@@ -453,6 +496,7 @@ def rollout(
...
@@ -453,6 +496,7 @@ def rollout(
mains
=
cached_main
,
mains
=
cached_main
,
actions
=
action
,
actions
=
action
,
logits
=
logits
,
logits
=
logits
,
values
=
value
,
rewards
=
next_reward
,
rewards
=
next_reward
,
next_dones
=
next_done
,
next_dones
=
next_done
,
)
)
...
@@ -495,8 +539,29 @@ def rollout(
...
@@ -495,8 +539,29 @@ def rollout(
rollout_time
.
append
(
time
.
time
()
-
rollout_time_start
)
rollout_time
.
append
(
time
.
time
()
-
rollout_time_start
)
partitioned_storage
=
prepare_data
(
storage
)
start_step
=
args
.
collect_steps
-
args
.
num_steps
storage
=
[]
next_main
=
main_player
==
next_to_play
if
args
.
collect_steps
==
args
.
num_steps
:
storage_t
=
storage
storage
=
[]
next_data
=
(
next_obs
,
next_main
)
else
:
storage_t
=
storage
[:
args
.
num_steps
]
storage
=
storage
[
args
.
num_steps
:]
values
,
rewards
,
next_dones
,
mains
=
prepare_data
([
(
t
.
values
,
t
.
rewards
,
t
.
next_dones
,
t
.
mains
)
for
t
in
storage
])
next_value
=
sample_action
(
params
,
next_obs
,
next_rstate1
,
next_rstate2
,
next_main
,
next_done
,
key
)[
-
2
]
next_value
=
jnp
.
where
(
next_main
,
next_value
,
-
next_value
)
adv_carry
=
compute_advantage_carry
(
next_value
,
values
,
rewards
,
next_dones
,
mains
)
next_data
=
adv_carry
partitioned_storage
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
split
(
x
,
len
(
learner_devices
),
axis
=
1
),
prepare_data
(
storage_t
))
sharded_storage
=
[]
sharded_storage
=
[]
for
x
in
partitioned_storage
:
for
x
in
partitioned_storage
:
if
isinstance
(
x
,
dict
):
if
isinstance
(
x
,
dict
):
...
@@ -508,10 +573,11 @@ def rollout(
...
@@ -508,10 +573,11 @@ def rollout(
x
=
jax
.
device_put_sharded
(
x
,
devices
=
learner_devices
)
x
=
jax
.
device_put_sharded
(
x
,
devices
=
learner_devices
)
sharded_storage
.
append
(
x
)
sharded_storage
.
append
(
x
)
sharded_storage
=
Transition
(
*
sharded_storage
)
sharded_storage
=
Transition
(
*
sharded_storage
)
next_main
=
main_player
==
next_to_play
init_rstate
=
init_rstates
.
pop
(
0
)
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
(
init_rstate
1
,
init_rstate2
,
next_obs
,
next_main
))
(
init_rstate
,
next_data
))
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
_start
=
time
.
time
()
_start
=
time
.
time
()
...
@@ -594,6 +660,8 @@ def main():
...
@@ -594,6 +660,8 @@ def main():
args
.
local_env_threads
=
args
.
local_env_threads
or
args
.
local_num_envs
args
.
local_env_threads
=
args
.
local_env_threads
or
args
.
local_num_envs
if
args
.
segment_length
is
not
None
:
if
args
.
segment_length
is
not
None
:
assert
args
.
num_steps
%
args
.
segment_length
==
0
,
"num_steps must be divisible by segment_length"
assert
args
.
num_steps
%
args
.
segment_length
==
0
,
"num_steps must be divisible by segment_length"
args
.
collect_steps
=
args
.
collect_steps
or
args
.
num_steps
assert
args
.
collect_steps
>=
args
.
num_steps
,
"collect_steps must be greater than or equal to num_steps"
if
args
.
embedding_file
:
if
args
.
embedding_file
:
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
...
@@ -734,10 +802,10 @@ def main():
...
@@ -734,10 +802,10 @@ def main():
else
:
else
:
eval_variables
=
None
eval_variables
=
None
def
advantage_fn
(
def
compute_advantage
(
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
next_v
alue
):
actions
,
logits
,
rewards
,
next_v
):
num_envs
=
jax
.
tree
.
leaves
(
next_v
alue
)[
0
]
.
shape
[
0
]
num_envs
=
jax
.
tree
.
leaves
(
next_v
)[
0
]
.
shape
[
0
]
num_steps
=
next_dones
.
shape
[
0
]
//
num_envs
num_steps
=
next_dones
.
shape
[
0
]
//
num_envs
def
reshape_time_series
(
x
):
def
reshape_time_series
(
x
):
...
@@ -745,37 +813,20 @@ def main():
...
@@ -745,37 +813,20 @@ def main():
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
ratios
=
reshape_time_series
(
ratios
)
new_values_
,
rewards
,
next_dones
,
switch_or_mains
=
jax
.
tree
.
map
(
new_values_
,
rewards
,
next_dones
,
switch_or_mains
=
jax
.
tree
.
map
(
reshape_time_series
,
(
new_values
,
rewards
,
next_dones
,
switch_or_mains
),
reshape_time_series
,
(
new_values
,
rewards
,
next_dones
,
switch_or_mains
),
)
)
# Advantages and target values
target_values
,
advantages
=
advantage_fn
(
if
args
.
switch
:
args
,
next_v
,
new_values_
,
rewards
,
next_dones
,
switch_or_mains
,
ratios
)
if
args
.
value
==
"vtrace"
or
args
.
sep_value
:
raise
NotImplementedError
target_values
,
advantages
=
gae_sep_switch
(
next_value
,
new_values_
,
rewards
,
next_dones
,
switch_or_mains
,
args
.
gamma
,
args
.
gae_lambda
,
args
.
upgo
)
else
:
# TODO: TD(lambda) for multi-step
ratios_
=
reshape_time_series
(
ratios
)
if
args
.
value
==
"gae"
:
adv_fn
=
truncated_gae_sep
if
args
.
sep_value
else
truncated_gae
target_values
,
advantages
=
adv_fn
(
next_value
,
new_values_
,
rewards
,
next_dones
,
switch_or_mains
,
args
.
gamma
,
args
.
gae_lambda
,
args
.
upgo
)
else
:
adv_fn
=
vtrace_sep
if
args
.
sep_value
else
vtrace
target_values
,
advantages
=
adv_fn
(
next_value
,
ratios_
,
new_values_
,
rewards
,
next_dones
,
switch_or_mains
,
args
.
gamma
,
args
.
rho_clip_min
,
args
.
rho_clip_max
,
args
.
c_clip_min
,
args
.
c_clip_max
)
target_values
,
advantages
=
jax
.
tree
.
map
(
target_values
,
advantages
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)),
(
target_values
,
advantages
))
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)),
(
target_values
,
advantages
))
return
target_values
,
advantages
return
target_values
,
advantages
def
loss_fn
(
def
compute_loss
(
new_logits
,
new_values
,
actions
,
logits
,
target_values
,
advantages
,
new_logits
,
new_values
,
actions
,
logits
,
target_values
,
advantages
,
mask
,
num_steps
=
None
):
mask
,
num_steps
=
None
):
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
...
@@ -820,17 +871,24 @@ def main():
...
@@ -820,17 +871,24 @@ def main():
loss
=
pg_loss
-
args
.
ent_coef
*
ent_loss
+
v_loss
*
args
.
vf_coef
loss
=
pg_loss
-
args
.
ent_coef
*
ent_loss
+
v_loss
*
args
.
vf_coef
return
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
return
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
def
apply_fn
(
variables
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
):
def
apply_fn
(
variables
,
obs
,
init_rstate
,
dones
,
next_dones
,
switch_or_mains
,
train
=
True
):
if
args
.
switch
:
if
args
.
switch
:
dones
=
dones
|
next_dones
dones
=
dones
|
next_dones
((
rstate1
,
rstate2
),
new_logits
,
new_values
,
_
),
state_updates
=
agent
.
apply
(
mutable
=
[
"batch_stats"
]
if
train
else
False
variables
,
obs
,
(
rstate1
,
rstate2
),
dones
,
switch_or_mains
,
rets
=
agent
.
apply
(
train
=
True
,
mutable
=
[
"batch_stats"
])
variables
,
obs
,
init_rstate
,
dones
,
switch_or_mains
,
train
=
train
,
mutable
=
mutable
)
if
train
:
((
rstate1
,
rstate2
),
new_logits
,
new_values
,
_
),
state_updates
=
rets
else
:
(
rstate1
,
rstate2
),
new_logits
,
new_values
,
_
=
rets
state_updates
=
{}
new_values
=
jax
.
tree
.
map
(
lambda
x
:
x
.
squeeze
(
-
1
),
new_values
)
new_values
=
jax
.
tree
.
map
(
lambda
x
:
x
.
squeeze
(
-
1
),
new_values
)
return
((
rstate1
,
rstate2
),
new_logits
,
new_values
),
state_updates
return
((
rstate1
,
rstate2
),
new_logits
,
new_values
),
state_updates
def
compute_next_value
(
def
compute_next_value
(
variables
,
next_rstate
,
next_obs
,
next_main
):
variables
,
rstate1
,
rstate2
,
next_obs
,
next_main
):
rstate1
,
rstate2
=
next_rstate
rstate
=
jax
.
tree
.
map
(
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
next_value
=
agent
.
apply
(
variables
,
next_obs
,
rstate
)[
2
]
next_value
=
agent
.
apply
(
variables
,
next_obs
,
rstate
)[
2
]
...
@@ -840,39 +898,39 @@ def main():
...
@@ -840,39 +898,39 @@ def main():
next_value
=
jnp
.
where
(
next_main
,
sign
*
next_value
,
-
sign
*
next_value
)
next_value
=
jnp
.
where
(
next_main
,
sign
*
next_value
,
-
sign
*
next_value
)
return
next_value
return
next_value
def
compute
_advantage
(
def
get
_advantage
(
variables
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
variables
,
init_rstate
,
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
next_obs
,
next_main
):
switch_or_mains
,
actions
,
logits
,
rewards
,
next_obs
,
next_main
):
segment_length
=
dones
.
shape
[
0
]
num_steps
=
dones
.
shape
[
0
]
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
=
\
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
=
\
jax
.
tree
.
map
(
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)
+
x
.
shape
[
2
:]),
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)
+
x
.
shape
[
2
:]),
(
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
))
(
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
))
(
(
rstate1
,
rstate2
)
,
new_logits
,
new_values
),
state_updates
=
apply_fn
(
(
next_rstate
,
new_logits
,
new_values
),
state_updates
=
apply_fn
(
variables
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
)
variables
,
obs
,
init_rstate
,
dones
,
next_dones
,
switch_or_mains
,
train
=
False
)
next_value
=
compute_next_value
(
next_value
=
compute_next_value
(
variables
,
rstate1
,
rstate2
,
next_obs
,
next_main
)
variables
,
next_rstate
,
next_obs
,
next_main
)
target_values
,
advantages
=
advantage_fn
(
target_values
,
advantages
=
compute_advantage
(
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
next_value
)
actions
,
logits
,
rewards
,
next_value
)
target_values
,
advantages
=
jax
.
tree
.
map
(
target_values
,
advantages
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
segment_length
,
-
1
)
+
x
.
shape
[
2
:]),
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
-
1
)
+
x
.
shape
[
2
:]),
(
target_values
,
advantages
))
(
target_values
,
advantages
))
return
target_values
,
advantages
return
target_values
,
advantages
def
compute
_loss
(
def
get
_loss
(
params
,
batch_stats
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
params
,
batch_stats
,
init_rstate
,
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
target_values
,
advantages
,
mask
):
switch_or_mains
,
actions
,
logits
,
target_values
,
advantages
,
mask
):
variables
=
{
'params'
:
params
,
'batch_stats'
:
batch_stats
}
variables
=
{
'params'
:
params
,
'batch_stats'
:
batch_stats
}
((
rstate1
,
rstate2
),
new_logits
,
new_values
),
state_updates
=
apply_fn
(
((
rstate1
,
rstate2
),
new_logits
,
new_values
),
state_updates
=
apply_fn
(
variables
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
)
variables
,
obs
,
init_rstate
,
dones
,
next_dones
,
switch_or_mains
)
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
=
loss_fn
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
=
compute_loss
(
new_logits
,
new_values
,
actions
,
logits
,
target_values
,
advantages
,
new_logits
,
new_values
,
actions
,
logits
,
target_values
,
advantages
,
mask
,
num_steps
=
None
)
mask
,
num_steps
=
None
)
...
@@ -881,23 +939,27 @@ def main():
...
@@ -881,23 +939,27 @@ def main():
jax
.
lax
.
stop_gradient
,
(
approx_kl
,
rstate1
,
rstate2
))
jax
.
lax
.
stop_gradient
,
(
approx_kl
,
rstate1
,
rstate2
))
return
loss
,
(
state_updates
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
,
rstate1
,
rstate2
)
return
loss
,
(
state_updates
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
,
rstate1
,
rstate2
)
def
compute
_advantage_loss
(
def
get
_advantage_loss
(
params
,
batch_stats
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
params
,
batch_stats
,
init_rstate
,
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
mask
,
next_
obs
,
next_main
):
switch_or_mains
,
actions
,
logits
,
rewards
,
mask
,
next_
data
):
num_envs
=
jax
.
tree
.
leaves
(
next_
main
)[
0
]
.
shape
[
0
]
num_envs
=
jax
.
tree
.
leaves
(
next_
data
)[
0
]
.
shape
[
0
]
variables
=
{
'params'
:
params
,
'batch_stats'
:
batch_stats
}
variables
=
{
'params'
:
params
,
'batch_stats'
:
batch_stats
}
((
rstate1
,
rstate2
),
new_logits
,
new_values
),
state_updates
=
apply_fn
(
(
next_rstate
,
new_logits
,
new_values
),
state_updates
=
apply_fn
(
variables
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
)
variables
,
obs
,
init_rstate
,
dones
,
next_dones
,
switch_or_mains
)
variables
=
{
'params'
:
params
,
'batch_stats'
:
state_updates
[
'batch_stats'
]}
if
args
.
collect_steps
==
args
.
num_steps
:
next_value
=
compute_next_value
(
next_obs
,
next_main
=
next_data
variables
,
rstate1
,
rstate2
,
next_obs
,
next_main
)
variables
=
{
'params'
:
params
,
'batch_stats'
:
state_updates
[
'batch_stats'
]}
next_v
=
compute_next_value
(
variables
,
next_rstate
,
next_obs
,
next_main
)
else
:
next_v
=
next_data
target_values
,
advantages
=
advantage_fn
(
target_values
,
advantages
=
compute_advantage
(
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
next_v
alue
)
actions
,
logits
,
rewards
,
next_v
)
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
=
loss_fn
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
=
compute_loss
(
new_logits
,
new_values
,
actions
,
logits
,
target_values
,
advantages
,
new_logits
,
new_values
,
actions
,
logits
,
target_values
,
advantages
,
mask
,
num_steps
=
dones
.
shape
[
0
]
//
num_envs
)
mask
,
num_steps
=
dones
.
shape
[
0
]
//
num_envs
)
...
@@ -908,18 +970,15 @@ def main():
...
@@ -908,18 +970,15 @@ def main():
def
single_device_update
(
def
single_device_update
(
agent_state
:
TrainState
,
agent_state
:
TrainState
,
sharded_storages
:
List
,
sharded_storages
:
List
,
sharded_init_rstate1
:
List
,
sharded_init_rstate
:
List
,
sharded_init_rstate2
:
List
,
sharded_next_data
:
List
,
sharded_next_obs
:
List
,
sharded_next_main
:
List
,
key
:
jax
.
random
.
PRNGKey
,
key
:
jax
.
random
.
PRNGKey
,
):
):
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
next_
obs
,
init_rstate1
,
init_rstate2
=
[
next_
data
,
init_rstate
=
[
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
for
x
in
[
sharded_next_
obs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
for
x
in
[
sharded_next_
data
,
sharded_init_rstate
]
]
]
next_main
=
jnp
.
concatenate
(
sharded_next_main
)
# reorder storage of individual players
# reorder storage of individual players
# main first, opponent second
# main first, opponent second
...
@@ -934,9 +993,10 @@ def main():
...
@@ -934,9 +993,10 @@ def main():
storage
=
jax
.
tree
.
map
(
lambda
x
:
x
[
indices
,
B
[
None
,
:]],
storage
)
storage
=
jax
.
tree
.
map
(
lambda
x
:
x
[
indices
,
B
[
None
,
:]],
storage
)
if
args
.
segment_length
is
None
:
if
args
.
segment_length
is
None
:
loss_grad_fn
=
jax
.
value_and_grad
(
compute
_advantage_loss
,
has_aux
=
True
)
loss_grad_fn
=
jax
.
value_and_grad
(
get
_advantage_loss
,
has_aux
=
True
)
else
:
else
:
loss_grad_fn
=
jax
.
value_and_grad
(
compute_loss
,
has_aux
=
True
)
# TODO: fix it
loss_grad_fn
=
jax
.
value_and_grad
(
get_loss
,
has_aux
=
True
)
def
update_epoch
(
carry
,
_
):
def
update_epoch
(
carry
,
_
):
agent_state
,
key
=
carry
agent_state
,
key
=
carry
...
@@ -947,9 +1007,9 @@ def main():
...
@@ -947,9 +1007,9 @@ def main():
return
reshape_minibatch
(
return
reshape_minibatch
(
x
,
multi_step
,
args
.
num_minibatches
,
num_steps
,
args
.
segment_length
,
key
=
key
)
x
,
multi_step
,
args
.
num_minibatches
,
num_steps
,
args
.
segment_length
,
key
=
key
)
b_init_rstate
1
,
b_init_rstate2
,
b_next_obs
,
b_next_main
=
\
b_init_rstate
,
b_next_data
=
\
jax
.
tree
.
map
(
partial
(
convert_data
,
multi_step
=
False
),
jax
.
tree
.
map
(
partial
(
convert_data
,
multi_step
=
False
),
(
init_rstate
1
,
init_rstate2
,
next_obs
,
next_main
))
(
init_rstate
,
next_data
))
b_storage
=
jax
.
tree
.
map
(
convert_data
,
storage
)
b_storage
=
jax
.
tree
.
map
(
convert_data
,
storage
)
if
args
.
switch
:
if
args
.
switch
:
switch_or_mains
=
convert_data
(
switch
)
switch_or_mains
=
convert_data
(
switch
)
...
@@ -969,31 +1029,30 @@ def main():
...
@@ -969,31 +1029,30 @@ def main():
else
:
else
:
def
update_minibatch
(
carry
,
minibatch
):
def
update_minibatch
(
carry
,
minibatch
):
def
update_minibatch_t
(
carry
,
minibatch_t
):
def
update_minibatch_t
(
carry
,
minibatch_t
):
agent_state
,
rstate1
,
rstate2
=
carry
agent_state
,
init_rstate
=
carry
minibatch_t
=
rstate1
,
rstate2
,
*
minibatch_t
minibatch_t
=
init_rstate
,
*
minibatch_t
(
loss
,
(
state_updates
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
,
rstate1
,
rstate2
)),
\
(
loss
,
(
state_updates
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
,
next_rstate
)),
\
grads
=
loss_grad_fn
(
agent_state
.
params
,
agent_state
.
batch_stats
,
*
minibatch_t
)
grads
=
loss_grad_fn
(
agent_state
.
params
,
agent_state
.
batch_stats
,
*
minibatch_t
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
agent_state
=
agent_state
.
replace
(
batch_stats
=
state_updates
[
'batch_stats'
])
agent_state
=
agent_state
.
replace
(
batch_stats
=
state_updates
[
'batch_stats'
])
return
(
agent_state
,
rstate1
,
rstate2
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
return
(
agent_state
,
next_rstate
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
rstate1
,
rstate2
,
*
minibatch_t
,
mask
=
minibatch
init_rstate
,
*
minibatch_t
,
mask
=
minibatch
target_values
,
advantages
=
compute
_advantage
(
target_values
,
advantages
=
get
_advantage
(
get_variables
(
carry
),
rstate1
,
rstate2
,
*
minibatch_t
)
get_variables
(
carry
),
init_rstate
,
*
minibatch_t
)
minibatch_t
=
*
minibatch_t
[:
-
2
],
target_values
,
advantages
,
mask
minibatch_t
=
*
minibatch_t
[:
-
2
],
target_values
,
advantages
,
mask
(
carry
,
_
rstate1
,
_rstate2
),
\
(
carry
,
_
next_rstate
),
\
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_minibatch_t
,
(
carry
,
rstate1
,
rstate2
),
minibatch_t
)
update_minibatch_t
,
(
carry
,
init_rstate
),
minibatch_t
)
return
carry
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
return
carry
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_minibatch
,
update_minibatch
,
agent_state
,
agent_state
,
(
(
b_init_rstate1
,
b_init_rstate
,
b_init_rstate2
,
b_storage
.
obs
,
b_storage
.
obs
,
b_storage
.
dones
,
b_storage
.
dones
,
b_storage
.
next_dones
,
b_storage
.
next_dones
,
...
@@ -1002,8 +1061,7 @@ def main():
...
@@ -1002,8 +1061,7 @@ def main():
b_storage
.
logits
,
b_storage
.
logits
,
b_rewards
,
b_rewards
,
b_mask
,
b_mask
,
b_next_obs
,
b_next_data
,
b_next_main
,
),
),
)
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
...
@@ -1023,7 +1081,7 @@ def main():
...
@@ -1023,7 +1081,7 @@ def main():
axis_name
=
"main_devices"
,
axis_name
=
"main_devices"
,
devices
=
global_main_devices
,
devices
=
global_main_devices
,
)
)
multi_device_update
=
jax
.
pmap
(
multi_device_update
=
jax
.
pmap
(
single_device_update
,
single_device_update
,
axis_name
=
"local_devices"
,
axis_name
=
"local_devices"
,
...
...
ygoai/rl/jax/__init__.py
View file @
7c8b11c3
from
functools
import
partial
from
functools
import
partial
from
typing
import
NamedTuple
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
...
@@ -193,6 +194,14 @@ def vtrace_rnad(
...
@@ -193,6 +194,14 @@ def vtrace_rnad(
return
targets
,
q_estimate
return
targets
,
q_estimate
class
VtraceCarry
(
NamedTuple
):
v
:
jnp
.
ndarray
next_value
:
jnp
.
ndarray
last_return
:
jnp
.
ndarray
next_q
:
jnp
.
ndarray
next_main
:
jnp
.
ndarray
def
vtrace_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
def
vtrace_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
v
,
next_value
,
last_return
,
next_q
,
next_main
=
carry
v
,
next_value
,
last_return
,
next_q
,
next_main
=
carry
ratio
,
cur_value
,
next_done
,
reward
,
main
=
inp
ratio
,
cur_value
,
next_done
,
reward
,
main
=
inp
...
@@ -221,22 +230,29 @@ def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
...
@@ -221,22 +230,29 @@ def vtrace_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
next_q
=
reward
+
discount
*
next_value
next_q
=
reward
+
discount
*
next_value
carry
=
v
,
cur_value
,
last_return
,
next_q
,
main
carry
=
VtraceCarry
(
v
,
next_value
,
last_return
,
next_q
,
main
)
return
carry
,
(
v
,
q_t
,
last_return
)
return
carry
,
(
v
,
q_t
,
last_return
)
def
vtrace
(
def
vtrace
(
next_value
,
ratios
,
values
,
rewards
,
next_dones
,
mains
,
next_v
,
ratios
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
rho_min
=
0.001
,
rho_max
=
1.0
,
c_min
=
0.001
,
c_max
=
1.0
,
upgo
=
False
,
gamma
,
rho_min
=
0.001
,
rho_max
=
1.0
,
c_min
=
0.001
,
c_max
=
1.0
,
upgo
=
False
,
return_carry
=
False
):
):
v
=
last_return
=
next_q
=
next_value
if
isinstance
(
next_v
,
(
tuple
,
list
)):
next_main
=
jnp
.
ones_like
(
next_value
,
dtype
=
jnp
.
bool_
)
carry
=
next_v
carry
=
v
,
next_value
,
last_return
,
next_q
,
next_main
else
:
next_value
=
next_v
_
,
(
targets
,
q_estimate
,
return_t
)
=
jax
.
lax
.
scan
(
v
=
last_return
=
next_q
=
next_value
next_main
=
jnp
.
ones_like
(
next_value
,
dtype
=
jnp
.
bool_
)
carry
=
VtraceCarry
(
v
,
next_value
,
last_return
,
next_q
,
next_main
)
carry
,
(
targets
,
q_estimate
,
return_t
)
=
jax
.
lax
.
scan
(
partial
(
vtrace_loop
,
gamma
=
gamma
,
rho_min
=
rho_min
,
rho_max
=
rho_max
,
c_min
=
c_min
,
c_max
=
c_max
),
partial
(
vtrace_loop
,
gamma
=
gamma
,
rho_min
=
rho_min
,
rho_max
=
rho_max
,
c_min
=
c_min
,
c_max
=
c_max
),
carry
,
(
ratios
,
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
carry
,
(
ratios
,
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
)
)
if
return_carry
:
return
carry
advantages
=
q_estimate
-
values
advantages
=
q_estimate
-
values
if
upgo
:
if
upgo
:
advantages
+=
return_t
-
values
advantages
+=
return_t
-
values
...
@@ -244,9 +260,18 @@ def vtrace(
...
@@ -244,9 +260,18 @@ def vtrace(
return
targets
,
advantages
return
targets
,
advantages
class
VtraceSepCarry
(
NamedTuple
):
v
:
jnp
.
ndarray
next_values
:
jnp
.
ndarray
reward
:
jnp
.
ndarray
xi
:
jnp
.
ndarray
last_return
:
jnp
.
ndarray
next_q
:
jnp
.
ndarray
def
vtrace_sep_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
def
vtrace_sep_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
v1
,
v2
,
next_values1
,
next_values2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
(
v1
,
next_values1
,
reward1
,
xi1
,
last_return1
,
next_q1
)
,
\
last_return1
,
last_return2
,
next_q1
,
next_q2
=
carry
(
v2
,
next_values2
,
reward2
,
xi2
,
last_return2
,
next_q2
)
=
carry
ratio
,
cur_values
,
next_done
,
r_t
,
main
=
inp
ratio
,
cur_values
,
next_done
,
r_t
,
main
=
inp
v1
=
jnp
.
where
(
next_done
,
0
,
v1
)
v1
=
jnp
.
where
(
next_done
,
0
,
v1
)
...
@@ -296,28 +321,46 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
...
@@ -296,28 +321,46 @@ def vtrace_sep_loop(carry, inp, gamma, rho_min, rho_max, c_min, c_max):
xi1
=
jnp
.
where
(
main
,
1
,
ratio
*
xi1
)
xi1
=
jnp
.
where
(
main
,
1
,
ratio
*
xi1
)
xi2
=
jnp
.
where
(
main
,
ratio
*
xi2
,
1
)
xi2
=
jnp
.
where
(
main
,
ratio
*
xi2
,
1
)
carry
=
v1
,
v2
,
next_values1
,
next_values2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
carry
1
=
VtraceSepCarry
(
v1
,
next_values1
,
reward1
,
xi1
,
last_return1
,
next_q1
)
last_return1
,
last_return2
,
next_q1
,
next_q2
carry2
=
VtraceSepCarry
(
v2
,
next_values2
,
reward2
,
xi2
,
last_return2
,
next_q2
)
return
carry
,
(
v
,
q_t
,
return_t
)
return
(
carry1
,
carry2
)
,
(
v
,
q_t
,
return_t
)
def
vtrace_sep
(
def
vtrace_sep
(
next_value
,
ratios
,
values
,
rewards
,
next_dones
,
mains
,
next_v
,
ratios
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
rho_min
=
0.001
,
rho_max
=
1.0
,
c_min
=
0.001
,
c_max
=
1.0
,
upgo
=
False
,
gamma
,
rho_min
=
0.001
,
rho_max
=
1.0
,
c_min
=
0.001
,
c_max
=
1.0
,
upgo
=
False
,
return_carry
=
False
):
):
next_value1
=
next_value
if
isinstance
(
next_v
,
(
tuple
,
list
)):
next_value2
=
-
next_value1
carry
=
next_v
v1
=
return1
=
next_q1
=
next_value1
else
:
v2
=
return2
=
next_q2
=
next_value2
next_value
=
next_v
reward1
=
reward2
=
jnp
.
zeros_like
(
next_value
)
next_value1
=
next_value
xi1
=
xi2
=
jnp
.
ones_like
(
next_value
)
carry1
=
VtraceSepCarry
(
carry
=
v1
,
v2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
v
=
next_value1
,
return1
,
return2
,
next_q1
,
next_q2
next_values
=
next_value1
,
reward
=
jnp
.
zeros_like
(
next_value
),
xi
=
jnp
.
ones_like
(
next_value
),
last_return
=
next_value1
,
next_q
=
next_value1
,
)
next_value2
=
-
next_value1
carry2
=
VtraceSepCarry
(
v
=
next_value2
,
next_values
=
next_value2
,
reward
=
jnp
.
zeros_like
(
next_value
),
xi
=
jnp
.
ones_like
(
next_value
),
last_return
=
next_value2
,
next_q
=
next_value2
,
)
carry
=
carry1
,
carry2
_
,
(
targets
,
q_estimate
,
return_t
)
=
jax
.
lax
.
scan
(
carry
,
(
targets
,
q_estimate
,
return_t
)
=
jax
.
lax
.
scan
(
partial
(
vtrace_sep_loop
,
gamma
=
gamma
,
rho_min
=
rho_min
,
rho_max
=
rho_max
,
c_min
=
c_min
,
c_max
=
c_max
),
partial
(
vtrace_sep_loop
,
gamma
=
gamma
,
rho_min
=
rho_min
,
rho_max
=
rho_max
,
c_min
=
c_min
,
c_max
=
c_max
),
carry
,
(
ratios
,
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
carry
,
(
ratios
,
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
)
)
if
return_carry
:
return
carry
advantages
=
q_estimate
-
values
advantages
=
q_estimate
-
values
if
upgo
:
if
upgo
:
advantages
+=
return_t
-
values
advantages
+=
return_t
-
values
...
@@ -325,9 +368,18 @@ def vtrace_sep(
...
@@ -325,9 +368,18 @@ def vtrace_sep(
return
targets
,
advantages
return
targets
,
advantages
class
GAESepCarry
(
NamedTuple
):
lastgaelam
:
jnp
.
ndarray
next_value
:
jnp
.
ndarray
reward
:
jnp
.
ndarray
done_used
:
jnp
.
ndarray
last_return
:
jnp
.
ndarray
next_q
:
jnp
.
ndarray
def
truncated_gae_sep_loop
(
carry
,
inp
,
gamma
,
gae_lambda
):
def
truncated_gae_sep_loop
(
carry
,
inp
,
gamma
,
gae_lambda
):
lastgaelam1
,
lastgaelam2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
\
(
lastgaelam1
,
next_value1
,
reward1
,
done_used1
,
last_return1
,
next_q1
)
,
\
done_used1
,
done_used2
,
last_return1
,
last_return2
,
next_q1
,
next_q2
=
carry
(
lastgaelam2
,
next_value2
,
reward2
,
done_used2
,
last_return2
,
next_q2
)
=
carry
cur_value
,
next_done
,
reward
,
main
=
inp
cur_value
,
next_done
,
reward
,
main
=
inp
main1
=
main
main1
=
main
main2
=
~
main
main2
=
~
main
...
@@ -370,29 +422,40 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
...
@@ -370,29 +422,40 @@ def truncated_gae_sep_loop(carry, inp, gamma, gae_lambda):
lastgaelam1
=
jnp
.
where
(
main1
,
lastgaelam1_
,
lastgaelam1
)
lastgaelam1
=
jnp
.
where
(
main1
,
lastgaelam1_
,
lastgaelam1
)
lastgaelam2
=
jnp
.
where
(
main2
,
lastgaelam2_
,
lastgaelam2
)
lastgaelam2
=
jnp
.
where
(
main2
,
lastgaelam2_
,
lastgaelam2
)
carry
=
lastgaelam1
,
lastgaelam2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
\
carry
1
=
GAESepCarry
(
lastgaelam1
,
next_value1
,
reward1
,
done_used1
,
last_return1
,
next_q1
)
done_used1
,
done_used2
,
last_return1
,
last_return2
,
next_q1
,
next_q2
carry2
=
GAESepCarry
(
lastgaelam2
,
next_value2
,
reward2
,
done_used2
,
last_return2
,
next_q2
)
return
carry
,
(
advantages
,
returns
)
return
(
carry1
,
carry2
)
,
(
advantages
,
returns
)
def
truncated_gae_sep
(
def
truncated_gae_sep
(
next_value
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
gae_lambda
,
upgo
,
next_v
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
gae_lambda
,
upgo
,
return_carry
=
False
):
):
if
isinstance
(
next_v
,
(
tuple
,
list
)):
next_value1
=
next_value
carry
=
next_v
next_value2
=
-
next_value1
else
:
last_return1
=
next_q1
=
next_value1
next_value
=
next_v
last_return2
=
next_q2
=
next_value2
carry1
=
GAESepCarry
(
done_used1
=
jnp
.
ones_like
(
next_dones
[
-
1
])
lastgaelam
=
jnp
.
zeros_like
(
next_value
),
done_used2
=
jnp
.
ones_like
(
next_dones
[
-
1
])
next_value
=
next_value
,
reward1
=
reward2
=
jnp
.
zeros_like
(
next_value
)
reward
=
jnp
.
zeros_like
(
next_value
),
lastgaelam1
=
lastgaelam2
=
jnp
.
zeros_like
(
next_value
)
done_used
=
jnp
.
ones_like
(
next_dones
[
-
1
]),
carry
=
lastgaelam1
,
lastgaelam2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
\
last_return
=
next_value
,
done_used1
,
done_used2
,
last_return1
,
last_return2
,
next_q1
,
next_q2
next_q
=
next_value
,
)
_
,
(
advantages
,
returns
)
=
jax
.
lax
.
scan
(
carry2
=
GAESepCarry
(
lastgaelam
=
jnp
.
zeros_like
(
next_value
),
next_value
=-
next_value
,
reward
=
jnp
.
zeros_like
(
next_value
),
done_used
=
jnp
.
ones_like
(
next_dones
[
-
1
]),
last_return
=-
next_value
,
next_q
=-
next_value
,
)
carry
=
carry1
,
carry2
carry
,
(
advantages
,
returns
)
=
jax
.
lax
.
scan
(
partial
(
truncated_gae_sep_loop
,
gamma
=
gamma
,
gae_lambda
=
gae_lambda
),
partial
(
truncated_gae_sep_loop
,
gamma
=
gamma
,
gae_lambda
=
gae_lambda
),
carry
,
(
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
carry
,
(
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
)
)
if
return_carry
:
return
carry
targets
=
values
+
advantages
targets
=
values
+
advantages
if
upgo
:
if
upgo
:
advantages
+=
returns
-
values
advantages
+=
returns
-
values
...
@@ -400,6 +463,14 @@ def truncated_gae_sep(
...
@@ -400,6 +463,14 @@ def truncated_gae_sep(
return
targets
,
advantages
return
targets
,
advantages
class
GAECarry
(
NamedTuple
):
lastgaelam
:
jnp
.
ndarray
next_value
:
jnp
.
ndarray
last_return
:
jnp
.
ndarray
next_q
:
jnp
.
ndarray
next_main
:
jnp
.
ndarray
def
truncated_gae_loop
(
carry
,
inp
,
gamma
,
gae_lambda
):
def
truncated_gae_loop
(
carry
,
inp
,
gamma
,
gae_lambda
):
lastgaelam
,
next_value
,
last_return
,
next_q
,
next_main
=
carry
lastgaelam
,
next_value
,
last_return
,
next_q
,
next_main
=
carry
cur_value
,
next_done
,
reward
,
main
=
inp
cur_value
,
next_done
,
reward
,
main
=
inp
...
@@ -424,20 +495,30 @@ def truncated_gae_loop(carry, inp, gamma, gae_lambda):
...
@@ -424,20 +495,30 @@ def truncated_gae_loop(carry, inp, gamma, gae_lambda):
next_q
=
reward
+
discount
*
next_value
next_q
=
reward
+
discount
*
next_value
carry
=
lastgaelam
,
cur_value
,
last_return
,
next_q
,
main
carry
=
GAECarry
(
lastgaelam
,
cur_value
,
last_return
,
next_q
,
main
)
return
carry
,
(
lastgaelam
,
last_return
)
return
carry
,
(
lastgaelam
,
last_return
)
def
truncated_gae
(
def
truncated_gae
(
next_value
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
gae_lambda
,
upgo
=
False
):
next_v
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
gae_lambda
,
lastgaelam
=
jnp
.
zeros_like
(
next_value
)
upgo
=
False
,
return_carry
=
False
):
last_return
=
next_q
=
next_value
if
isinstance
(
next_v
,
(
tuple
,
list
)):
next_main
=
jnp
.
ones_like
(
next_value
,
dtype
=
jnp
.
bool_
)
carry
=
next_v
carry
=
lastgaelam
,
next_value
,
last_return
,
next_q
,
next_main
else
:
_
,
(
advantages
,
return_t
)
=
jax
.
lax
.
scan
(
next_value
=
next_v
carry
=
GAECarry
(
lastgaelam
=
jnp
.
zeros_like
(
next_value
),
next_value
=
next_value
,
last_return
=
next_value
,
next_q
=
next_value
,
next_main
=
jnp
.
ones_like
(
next_value
,
dtype
=
jnp
.
bool_
),
)
carry
,
(
advantages
,
return_t
)
=
jax
.
lax
.
scan
(
partial
(
truncated_gae_loop
,
gamma
=
gamma
,
gae_lambda
=
gae_lambda
),
partial
(
truncated_gae_loop
,
gamma
=
gamma
,
gae_lambda
=
gae_lambda
),
carry
,
(
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
carry
,
(
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
)
)
if
return_carry
:
return
carry
targets
=
values
+
advantages
targets
=
values
+
advantages
if
upgo
:
if
upgo
:
advantages
+=
return_t
-
values
advantages
+=
return_t
-
values
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment