Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
43ca871e
Commit
43ca871e
authored
Apr 18, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactor switch
parent
93bc3723
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
502 additions
and
481 deletions
+502
-481
scripts/jax/battle.py
scripts/jax/battle.py
+11
-1
scripts/jax/impala.py
scripts/jax/impala.py
+125
-101
scripts/jax/ppo.py
scripts/jax/ppo.py
+72
-65
ygoai/rl/jax/__init__.py
ygoai/rl/jax/__init__.py
+236
-303
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+5
-0
ygoai/rl/jax/switch.py
ygoai/rl/jax/switch.py
+39
-0
ygoai/rl/utils.py
ygoai/rl/utils.py
+0
-10
ygoai/utils.py
ygoai/utils.py
+14
-1
No files found.
scripts/jax/battle.py
View file @
43ca871e
...
...
@@ -4,6 +4,7 @@ import os
import
random
from
typing
import
Optional
,
Literal
from
dataclasses
import
dataclass
from
tqdm
import
tqdm
import
ygoenv
import
numpy
as
np
...
...
@@ -220,6 +221,9 @@ if __name__ == "__main__":
])
rstate1
=
rstate2
=
init_rnn_state
(
num_envs
,
args
.
rnn_channels
)
if
not
args
.
verbose
:
pbar
=
tqdm
(
total
=
args
.
num_episodes
)
model_time
=
env_time
=
0
while
True
:
if
start_step
==
0
and
len
(
episode_lengths
)
>
int
(
args
.
num_episodes
*
0.1
):
...
...
@@ -255,7 +259,11 @@ if __name__ == "__main__":
episode_rewards
.
append
(
episode_reward
)
win_rates
.
append
(
win
)
win_reasons
.
append
(
1
if
win_reason
==
1
else
0
)
sys
.
stderr
.
write
(
f
"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}
\n
"
)
if
args
.
verbose
:
print
(
f
"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}
\n
"
)
else
:
pbar
.
set_postfix
(
len
=
np
.
mean
(
episode_lengths
),
reward
=
np
.
mean
(
episode_rewards
),
win_rate
=
np
.
mean
(
win_rates
))
pbar
.
update
(
1
)
# Only when num_envs=1, we switch the player here
if
args
.
verbose
:
...
...
@@ -264,6 +272,8 @@ if __name__ == "__main__":
if
len
(
episode_lengths
)
>=
args
.
num_episodes
:
break
if
not
args
.
verbose
:
pbar
.
close
()
print
(
f
"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}"
)
total_time
=
time
.
time
()
-
start
...
...
scripts/jax/impala.py
View file @
43ca871e
...
...
@@ -16,18 +16,17 @@ import jax
import
jax.numpy
as
jnp
import
numpy
as
np
import
optax
import
rlax
import
distrax
import
tyro
from
flax.training.train_state
import
TrainState
from
rich.pretty
import
pprint
from
tensorboardX
import
SummaryWriter
from
ygoai.utils
import
init_ygopro
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
from
ygoai.rl.jax
import
upgo_return
,
vtrace
,
clipped_surrogate_pg
_loss
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
vtrace_2p0s
,
clipped_surrogate_pg_loss
,
policy_gradient_loss
,
mse_loss
,
entropy
_loss
os
.
environ
[
"XLA_FLAGS"
]
=
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
...
...
@@ -63,10 +62,12 @@ class Args:
"""the maximum number of options"""
n_history_actions
:
int
=
32
"""the number of history actions to use"""
greedy_reward
:
bool
=
True
"""whether to use greedy reward (faster kill higher reward)"""
total_timesteps
:
int
=
5000000000
"""total timesteps of the experiments"""
learning_rate
:
float
=
1e-
4
learning_rate
:
float
=
1e-
3
"""the learning rate of the optimizer"""
local_num_envs
:
int
=
128
"""the number of parallel game environments"""
...
...
@@ -74,15 +75,15 @@ class Args:
"""the number of threads to use for environment"""
num_actor_threads
:
int
=
2
"""the number of actor threads to use"""
num_steps
:
int
=
32
num_steps
:
int
=
128
"""the number of steps to run in each environment per policy rollout"""
collect_length
:
Optional
[
int
]
=
None
"""the number of steps to compute the advantages"""
anneal_lr
:
bool
=
False
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
1.0
"""the discount factor gamma"""
num_minibatches
:
int
=
4
upgo
:
bool
=
False
"""Toggle the use of UPGO for advantages"""
num_minibatches
:
int
=
8
"""the number of mini-batches"""
update_epochs
:
int
=
2
"""the K epochs to update the policy"""
...
...
@@ -94,12 +95,12 @@ class Args:
"""the minimum value of the importance sampling clipping"""
rho_clip_max
:
float
=
1.007
"""the maximum value of the importance sampling clipping"""
upgo
:
bool
=
False
"""whether to use UPGO for policy update"""
ppo_clip
:
bool
=
True
"""whether to use the PPO clipping to replace V-Trace surrogate clipping"""
clip_coef
:
float
=
0.25
"""the PPO surrogate clipping coefficient"""
dual_clip_coef
:
Optional
[
float
]
=
None
"""the dual surrogate clipping coefficient"""
ent_coef
:
float
=
0.01
"""coefficient of the entropy"""
vf_coef
:
float
=
0.5
...
...
@@ -122,11 +123,13 @@ class Args:
"""whether to use `jax.distirbuted`"""
concurrency
:
bool
=
True
"""whether to run the actor and learner concurrently"""
bfloat16
:
bool
=
Tru
e
bfloat16
:
bool
=
Fals
e
"""whether to use bfloat16 for the agent"""
thread_affinity
:
bool
=
False
"""whether to use thread affinity for the environment"""
eval_checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes
:
int
=
32
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
50
...
...
@@ -145,6 +148,7 @@ class Args:
actor_devices
:
Optional
[
List
[
str
]]
=
None
learner_devices
:
Optional
[
List
[
str
]]
=
None
num_embeddings
:
Optional
[
int
]
=
None
freeze_id
:
bool
=
False
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
):
...
...
@@ -164,6 +168,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
async_reset
=
False
,
greedy_reward
=
args
.
greedy_reward
if
mode
==
'self'
else
True
,
play_mode
=
mode
,
)
envs
.
num_envs
=
num_envs
...
...
@@ -177,7 +182,6 @@ class Transition(NamedTuple):
logits
:
list
rewards
:
list
mains
:
list
next_dones
:
list
def
create_agent
(
args
,
multi_step
=
False
):
...
...
@@ -189,6 +193,7 @@ def create_agent(args, multi_step=False):
param_dtype
=
jnp
.
float32
,
lstm_channels
=
args
.
rnn_channels
,
multi_step
=
multi_step
,
freeze_id
=
args
.
freeze_id
,
)
...
...
@@ -209,6 +214,10 @@ def rollout(
learner_devices
,
device_thread_id
,
):
eval_mode
=
'self'
if
args
.
eval_checkpoint
else
'bot'
if
eval_mode
!=
'bot'
:
eval_params
=
params_queue
.
get
()
envs
=
make_env
(
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
...
...
@@ -222,7 +231,7 @@ def rollout(
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
'bot'
)
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
...
...
@@ -249,6 +258,17 @@ def rollout(
rstate
,
logits
=
get_logits
(
params
,
inputs
,
done
)
return
rstate
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
get_action_battle
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
),
done
)
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
),
done
)
logits
=
jnp
.
where
(
main
[:,
None
],
logits1
,
logits2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
next_rstate2
,
rstate2
)
return
rstate1
,
rstate2
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
sample_action
(
params
:
flax
.
core
.
FrozenDict
,
...
...
@@ -281,7 +301,6 @@ def rollout(
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
np
.
random
.
shuffle
(
main_player
)
start_step
=
0
storage
=
[]
@
jax
.
jit
...
...
@@ -312,7 +331,7 @@ def rollout(
rollout_time_start
=
time
.
time
()
init_rstate1
,
init_rstate2
=
jax
.
tree
.
map
(
lambda
x
:
x
.
copy
(),
(
next_rstate1
,
next_rstate2
))
for
_
in
range
(
start_step
,
args
.
collect_length
):
for
_
in
range
(
args
.
num_steps
):
global_step
+=
args
.
local_num_envs
*
n_actors
*
args
.
world_size
cached_next_obs
=
next_obs
...
...
@@ -340,7 +359,6 @@ def rollout(
actions
=
action
,
logits
=
logits
,
rewards
=
next_reward
,
next_dones
=
next_done
,
)
)
...
...
@@ -348,15 +366,6 @@ def rollout(
if
not
d
:
continue
cur_main
=
main
[
idx
]
for
j
in
reversed
(
range
(
len
(
storage
)
-
1
)):
t
=
storage
[
j
]
if
t
.
next_dones
[
idx
]:
# For OTK where player may not switch
break
if
t
.
mains
[
idx
]
!=
cur_main
:
t
.
next_dones
[
idx
]
=
True
t
.
rewards
[
idx
]
=
-
next_reward
[
idx
]
break
episode_reward
=
info
[
'r'
][
idx
]
*
(
1
if
cur_main
else
-
1
)
win
=
1
if
episode_reward
>
0
else
0
avg_ep_returns
.
append
(
episode_reward
)
...
...
@@ -364,10 +373,8 @@ def rollout(
rollout_time
.
append
(
time
.
time
()
-
rollout_time_start
)
start_step
=
args
.
collect_length
-
args
.
num_steps
partitioned_storage
=
prepare_data
(
storage
)
storage
=
storage
[
args
.
num_steps
:
]
storage
=
[
]
sharded_storage
=
[]
for
x
in
partitioned_storage
:
if
isinstance
(
x
,
dict
):
...
...
@@ -384,7 +391,7 @@ def rollout(
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
next_rstate1
,
next_rstate2
)
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
(
init_rstate1
,
init_rstate2
,
(
next_rstate
,
next_obs
),
next_main
))
(
init_rstate1
,
init_rstate2
,
(
next_rstate
,
next_obs
),
next_
done
,
next_
main
))
learn_opponent
=
False
payload
=
(
global_step
,
...
...
@@ -403,10 +410,13 @@ def rollout(
SPS_update
=
int
(
args
.
batch_size
/
(
time
.
time
()
-
update_time_start
))
if
device_thread_id
==
0
:
print
(
f
"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}
, rollout_time={rollout_time[-1]:.2f}
"
f
"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
)
time_now
=
datetime
.
now
(
timezone
(
timedelta
(
hours
=
8
)))
.
strftime
(
"
%
H:
%
M:
%
S"
)
print
(
f
"{time_now} SPS: {SPS}, update: {SPS_update}"
)
print
(
f
"{time_now} SPS: {SPS}, update: {SPS_update}, "
f
"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
writer
.
add_scalar
(
"stats/rollout_time"
,
np
.
mean
(
rollout_time
),
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_return"
,
avg_episodic_return
,
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_length"
,
avg_episodic_length
,
global_step
)
...
...
@@ -419,19 +429,28 @@ def rollout(
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
# Eval with rule-based policy
_start
=
time
.
time
()
eval_return
=
evaluate
(
eval_envs
,
get_action
,
params
,
eval_rstate
)[
0
]
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
eval_stat
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)[
0
]
metric_name
=
"eval_return"
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_stat
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)[
2
]
metric_name
=
"eval_win_rate"
if
device_thread_id
!=
0
:
eval_queue
.
put
(
eval_
return
)
eval_queue
.
put
(
eval_
stat
)
else
:
eval_stats
=
[]
eval_stats
.
append
(
eval_
return
)
eval_stats
.
append
(
eval_
stat
)
for
_
in
range
(
1
,
n_actors
):
eval_stats
.
append
(
eval_queue
.
get
())
eval_stats
=
np
.
mean
(
eval_stats
)
writer
.
add_scalar
(
"charts/eval_return
"
,
eval_stats
,
global_step
)
writer
.
add_scalar
(
f
"charts/{metric_name}
"
,
eval_stats
,
global_step
)
if
device_thread_id
==
0
:
eval_time
=
time
.
time
()
-
_start
print
(
f
"eval_time={eval_time:.4f},
eval_ep_return
={eval_stats:.4f}"
)
print
(
f
"eval_time={eval_time:.4f},
{metric_name}
={eval_stats:.4f}"
)
other_time
+=
eval_time
...
...
@@ -461,8 +480,15 @@ if __name__ == "__main__":
args
.
minibatch_size
=
args
.
local_minibatch_size
*
args
.
world_size
args
.
num_updates
=
args
.
total_timesteps
//
(
args
.
local_batch_size
*
args
.
world_size
)
args
.
local_env_threads
=
args
.
local_env_threads
or
args
.
local_num_envs
args
.
collect_length
=
args
.
collect_length
or
args
.
num_steps
assert
args
.
collect_length
>=
args
.
num_steps
,
"collect_length must be greater than or equal to num_steps"
if
args
.
embedding_file
:
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
embedding_shape
=
embeddings
.
shape
args
.
num_embeddings
=
embedding_shape
args
.
freeze_id
=
True
if
args
.
freeze_id
is
None
else
args
.
freeze_id
else
:
embeddings
=
None
embedding_shape
=
None
local_devices
=
jax
.
local_devices
()
global_devices
=
jax
.
devices
()
...
...
@@ -517,6 +543,13 @@ if __name__ == "__main__":
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
agent
=
create_agent
(
args
)
params
=
agent
.
init
(
agent_key
,
(
rstate
,
sample_obs
))
if
embeddings
is
not
None
:
unknown_embed
=
embeddings
.
mean
(
axis
=
0
)
embeddings
=
np
.
concatenate
([
unknown_embed
[
None
,
:],
embeddings
],
axis
=
0
)
params
=
flax
.
core
.
unfreeze
(
params
)
params
[
'params'
][
'Encoder_0'
][
'Embed_0'
][
'embedding'
]
=
jax
.
device_put
(
embeddings
)
params
=
flax
.
core
.
freeze
(
params
)
tx
=
optax
.
MultiSteps
(
optax
.
chain
(
optax
.
clip_by_global_norm
(
args
.
max_grad_norm
),
...
...
@@ -541,6 +574,13 @@ if __name__ == "__main__":
agent_state
=
flax
.
jax_utils
.
replicate
(
agent_state
,
devices
=
learner_devices
)
# print(agent.tabulate(agent_key, sample_obs))
if
args
.
eval_checkpoint
:
with
open
(
args
.
eval_checkpoint
,
"rb"
)
as
f
:
eval_params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
print
(
f
"loaded eval checkpoint from {args.eval_checkpoint}"
)
else
:
eval_params
=
None
@
jax
.
jit
def
get_logits_and_value
(
params
:
flax
.
core
.
FrozenDict
,
inputs
,
...
...
@@ -550,67 +590,54 @@ if __name__ == "__main__":
return
logits
,
value
.
squeeze
(
-
1
)
def
ppo_loss
(
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_done
s
,
switch
,
actions
,
logits
,
rewards
,
mask
,
next_valu
e
):
params
,
rstate1
,
rstate2
,
obs
,
dones
,
main
s
,
actions
,
logits
,
rewards
,
mask
,
next_value
,
next_don
e
):
# (num_steps * local_num_envs // n_mb))
num_envs
=
next_value
.
shape
[
0
]
num_steps
=
dones
.
shape
[
0
]
//
num_envs
mask
=
mask
&
(
~
dones
)
mask
=
mask
*
(
1.0
-
dones
)
n_valids
=
jnp
.
sum
(
mask
)
real_dones
=
dones
|
next_dones
inputs
=
(
rstate1
,
rstate2
,
obs
,
real_dones
,
switch
)
inputs
=
(
rstate1
,
rstate2
,
obs
,
dones
,
mains
)
new_logits
,
new_values
=
get_logits_and_value
(
params
,
inputs
)
new_logits
,
v_tm1
,
logits
,
actions
,
rewards
,
next_dones
,
switch
,
mask
=
jax
.
tree
.
map
(
new_logits
,
new_values
,
logits
,
actions
,
rewards
,
dones
,
mains
,
mask
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)
+
x
.
shape
[
1
:]),
(
new_logits
,
new_values
,
logits
,
actions
,
rewards
,
next_dones
,
switch
,
mask
),
(
new_logits
,
new_values
,
logits
,
actions
,
rewards
,
dones
,
mains
,
mask
),
)
next_dones
=
jnp
.
concatenate
([
dones
[
1
:],
next_done
[
None
,
:]],
axis
=
0
)
v_t
=
jnp
.
concatenate
([
v_tm1
[
1
:],
next_value
[
None
,
:]],
axis
=
0
)
discounts
=
(
1.0
-
next_dones
)
*
args
.
gamma
ratio
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
logratio
=
jnp
.
log
(
ratio
)
approx_kl
=
(((
ratio
-
1
)
-
logratio
)
*
mask
)
.
sum
()
/
n_valids
# TODO: use switch to calculate the correct value
vtrace_fn
=
partial
(
vtrace
,
c_clip_min
=
args
.
c_clip_min
,
c_clip_max
=
args
.
c_clip_max
,
rho_clip_min
=
args
.
rho_clip_min
,
rho_clip_max
=
args
.
rho_clip_max
)
vtrace_returns
=
jax
.
vmap
(
vtrace_fn
,
in_axes
=
1
,
out_axes
=
1
)(
v_tm1
,
v_t
,
rewards
,
discounts
,
ratio
)
if
args
.
upgo
:
advs
=
jax
.
vmap
(
upgo_return
,
in_axes
=
1
,
out_axes
=
1
)(
rewards
,
v_t
,
discounts
)
-
v_tm1
else
:
advs
=
vtrace_returns
.
q_estimate
-
v_tm1
# TODO: TD(lambda) for multi-step
target_values
,
advantages
=
vtrace_2p0s
(
next_value
,
ratios
,
new_values
,
rewards
,
next_dones
,
mains
,
args
.
gamma
,
args
.
rho_clip_min
,
args
.
rho_clip_max
,
args
.
c_clip_min
,
args
.
c_clip_max
)
logratio
=
jnp
.
log
(
ratios
)
approx_kl
=
(((
ratios
-
1
)
-
logratio
)
*
mask
)
.
sum
()
/
n_valids
if
args
.
ppo_clip
:
pg_loss
=
jax
.
vmap
(
partial
(
clipped_surrogate_pg_loss
,
epsilon
=
args
.
clip_coef
),
in_axes
=
1
)(
ratio
,
advs
,
mask
)
*
num_steps
pg_loss
=
jnp
.
sum
(
pg_loss
)
pg_loss
=
clipped_surrogate_pg_loss
(
ratios
,
advantages
,
args
.
clip_coef
,
args
.
dual_clip_coef
)
else
:
pg_advs
=
jnp
.
minimum
(
args
.
rho_clip_max
,
ratio
)
*
advs
pg_loss
=
jax
.
vmap
(
rlax
.
policy_gradient_loss
,
in_axes
=
1
)(
new_logits
,
actions
,
pg_advs
,
mask
)
*
num_steps
pg_loss
=
jnp
.
sum
(
pg_loss
)
pg_advs
=
jnp
.
clip
(
ratios
,
args
.
rho_clip_min
,
args
.
rho_clip_max
)
*
advantages
pg_loss
=
policy_gradient_loss
(
new_logits
,
actions
,
pg_advs
)
pg_loss
=
jnp
.
sum
(
pg_loss
*
mask
)
v_loss
=
0.5
*
(
vtrace_returns
.
errors
**
2
)
v_loss
=
mse_loss
(
new_values
,
target_values
)
v_loss
=
jnp
.
sum
(
v_loss
*
mask
)
ent
ropy_loss
=
distrax
.
Softmax
(
new_logits
)
.
entropy
(
)
ent
ropy_loss
=
jnp
.
sum
(
entropy
_loss
*
mask
)
ent
_loss
=
entropy_loss
(
new_logits
)
ent
_loss
=
jnp
.
sum
(
ent
_loss
*
mask
)
pg_loss
=
pg_loss
/
n_valids
v_loss
=
v_loss
/
n_valids
ent
ropy_loss
=
entropy
_loss
/
n_valids
ent
_loss
=
ent
_loss
/
n_valids
loss
=
pg_loss
-
args
.
ent_coef
*
ent
ropy
_loss
+
v_loss
*
args
.
vf_coef
return
loss
,
(
pg_loss
,
v_loss
,
ent
ropy
_loss
,
jax
.
lax
.
stop_gradient
(
approx_kl
))
loss
=
pg_loss
-
args
.
ent_coef
*
ent_loss
+
v_loss
*
args
.
vf_coef
return
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
jax
.
lax
.
stop_gradient
(
approx_kl
))
def
single_device_update
(
agent_state
:
TrainState
,
...
...
@@ -618,6 +645,7 @@ if __name__ == "__main__":
sharded_init_rstate1
:
List
,
sharded_init_rstate2
:
List
,
sharded_next_inputs
:
List
,
sharded_next_done
:
List
,
sharded_next_main
:
List
,
key
:
jax
.
random
.
PRNGKey
,
learn_opponent
:
bool
=
False
,
...
...
@@ -627,20 +655,13 @@ if __name__ == "__main__":
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
for
x
in
[
sharded_next_inputs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
]
next_main
,
=
[
jnp
.
concatenate
(
x
)
for
x
in
[
sharded_next_main
]
next_main
,
next_done
=
[
jnp
.
concatenate
(
x
)
for
x
in
[
sharded_next_main
,
sharded_next_done
]
]
# reorder storage of individual players
# main first, opponent second
num_steps
,
num_envs
=
storage
.
rewards
.
shape
T
=
jnp
.
arange
(
num_steps
,
dtype
=
jnp
.
int32
)
B
=
jnp
.
arange
(
num_envs
,
dtype
=
jnp
.
int32
)
mains
=
storage
.
mains
.
astype
(
jnp
.
int32
)
indices
=
jnp
.
argsort
(
T
[:,
None
]
-
mains
*
num_steps
,
axis
=
0
)
switch_steps
=
jnp
.
sum
(
mains
,
axis
=
0
)
switch
=
T
[:,
None
]
==
(
switch_steps
[
None
,
:]
-
1
)
storage
=
jax
.
tree
.
map
(
lambda
x
:
x
[
indices
,
B
[
None
,
:]],
storage
)
ppo_loss_grad_fn
=
jax
.
value_and_grad
(
ppo_loss
,
has_aux
=
True
)
...
...
@@ -650,9 +671,7 @@ if __name__ == "__main__":
next_value
=
create_agent
(
args
)
.
apply
(
agent_state
.
params
,
next_inputs
)[
2
]
.
squeeze
(
-
1
)
# TODO: check if this is correct
sign
=
jnp
.
where
(
switch_steps
<=
num_steps
,
1.0
,
-
1.0
)
next_value
=
jnp
.
where
(
next_main
,
-
sign
*
next_value
,
sign
*
next_value
)
next_value
=
jnp
.
where
(
next_main
,
next_value
,
-
next_value
)
def
convert_data
(
x
:
jnp
.
ndarray
,
num_steps
):
if
args
.
update_epochs
>
1
:
...
...
@@ -666,10 +685,11 @@ if __name__ == "__main__":
x
=
jnp
.
reshape
(
x
,
(
N
,
-
1
)
+
x
.
shape
[
1
:])
return
x
shuffled_init_rstate1
,
shuffled_init_rstate2
,
shuffled_next_value
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
1
),
(
init_rstate1
,
init_rstate2
,
next_value
))
shuffled_storage
,
shuffled_switch
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
num_steps
),
(
storage
,
switch
))
shuffled_init_rstate1
,
shuffled_init_rstate2
,
\
shuffled_next_value
,
shuffled_next_done
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
1
),
(
init_rstate1
,
init_rstate2
,
next_value
,
next_done
))
shuffled_storage
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
num_steps
),
storage
)
shuffled_mask
=
jnp
.
ones_like
(
shuffled_storage
.
mains
)
def
update_minibatch
(
agent_state
,
minibatch
):
...
...
@@ -687,13 +707,13 @@ if __name__ == "__main__":
shuffled_init_rstate2
,
shuffled_storage
.
obs
,
shuffled_storage
.
dones
,
shuffled_storage
.
next_dones
,
shuffled_switch
,
shuffled_storage
.
mains
,
shuffled_storage
.
actions
,
shuffled_storage
.
logits
,
shuffled_storage
.
rewards
,
shuffled_mask
,
shuffled_next_value
,
shuffled_next_done
,
),
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
...
...
@@ -712,7 +732,7 @@ if __name__ == "__main__":
single_device_update
,
axis_name
=
"local_devices"
,
devices
=
global_learner_decices
,
static_broadcasted_argnums
=
(
7
,),
static_broadcasted_argnums
=
(
8
,),
)
params_queues
=
[]
...
...
@@ -727,7 +747,9 @@ if __name__ == "__main__":
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
rollout_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
params_queues
[
-
1
]
.
put
(
device_params
)
if
eval_params
:
params_queues
[
-
1
]
.
put
(
jax
.
device_put
(
eval_params
,
local_devices
[
d_id
]))
threading
.
Thread
(
target
=
rollout
,
args
=
(
...
...
@@ -741,6 +763,7 @@ if __name__ == "__main__":
d_idx
*
args
.
num_actor_threads
+
thread_id
,
),
)
.
start
()
params_queues
[
-
1
]
.
put
(
device_params
)
rollout_queue_get_time
=
deque
(
maxlen
=
10
)
data_transfer_time
=
deque
(
maxlen
=
10
)
...
...
@@ -790,8 +813,9 @@ if __name__ == "__main__":
writer
.
add_scalar
(
"stats/rollout_queue_size"
,
rollout_queues
[
-
1
]
.
qsize
(),
global_step
)
writer
.
add_scalar
(
"stats/params_queue_size"
,
params_queues
[
-
1
]
.
qsize
(),
global_step
)
print
(
global_step
,
f
"actor_update={update}, train_time={time.time() - training_time_start:.2f}"
,
f
"{global_step} actor_update={update}, "
f
"train_time={time.time() - training_time_start:.2f}, "
f
"data_time={rollout_queue_get_time[-1]:.2f}"
)
writer
.
add_scalar
(
"charts/learning_rate"
,
agent_state
.
opt_state
[
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
global_step
...
...
scripts/jax/ppo.py
View file @
43ca871e
...
...
@@ -22,11 +22,12 @@ from flax.training.train_state import TrainState
from
rich.pretty
import
pprint
from
tensorboardX
import
SummaryWriter
from
ygoai.utils
import
init_ygopro
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
compute_gae_2p0s
,
upgo_advantage
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
from
ygoai.rl.jax.switch
import
truncated_gae_2p0s
os
.
environ
[
"XLA_FLAGS"
]
=
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
...
...
@@ -77,8 +78,6 @@ class Args:
"""the number of actor threads to use"""
num_steps
:
int
=
128
"""the number of steps to run in each environment per policy rollout"""
collect_length
:
Optional
[
int
]
=
None
"""the number of steps to compute the advantages"""
anneal_lr
:
bool
=
False
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
1.0
...
...
@@ -95,8 +94,10 @@ class Args:
"""Toggles advantages normalization"""
clip_coef
:
float
=
0.25
"""the surrogate clipping coefficient"""
dual_clip_coef
:
Optional
[
float
]
=
None
"""the dual surrogate clipping coefficient, typically 3.0"""
spo_kld_max
:
Optional
[
float
]
=
None
"""the maximum KLD for the SPO policy"""
"""the maximum KLD for the SPO policy
, typically 0.02
"""
ent_coef
:
float
=
0.01
"""coefficient of the entropy"""
vf_coef
:
float
=
0.5
...
...
@@ -144,9 +145,10 @@ class Args:
actor_devices
:
Optional
[
List
[
str
]]
=
None
learner_devices
:
Optional
[
List
[
str
]]
=
None
num_embeddings
:
Optional
[
int
]
=
None
freeze_id
:
bool
=
False
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
):
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
,
eval
=
False
):
if
not
args
.
thread_affinity
:
thread_affinity_offset
=
-
1
if
thread_affinity_offset
>=
0
:
...
...
@@ -163,7 +165,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
async_reset
=
False
,
greedy_reward
=
args
.
greedy_reward
,
greedy_reward
=
args
.
greedy_reward
if
not
eval
else
True
,
play_mode
=
mode
,
)
envs
.
num_envs
=
num_envs
...
...
@@ -189,6 +191,7 @@ def create_agent(args, multi_step=False):
param_dtype
=
jnp
.
float32
,
lstm_channels
=
args
.
rnn_channels
,
multi_step
=
multi_step
,
freeze_id
=
args
.
freeze_id
,
)
...
...
@@ -226,7 +229,7 @@ def rollout(
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
)
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
,
eval
=
True
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
...
...
@@ -296,7 +299,6 @@ def rollout(
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
np
.
random
.
shuffle
(
main_player
)
start_step
=
0
storage
=
[]
@
jax
.
jit
...
...
@@ -327,7 +329,7 @@ def rollout(
rollout_time_start
=
time
.
time
()
init_rstate1
,
init_rstate2
=
jax
.
tree
.
map
(
lambda
x
:
x
.
copy
(),
(
next_rstate1
,
next_rstate2
))
for
_
in
range
(
start_step
,
args
.
collect_length
):
for
_
in
range
(
args
.
num_steps
):
global_step
+=
args
.
local_num_envs
*
n_actors
*
args
.
world_size
cached_next_obs
=
next_obs
...
...
@@ -379,10 +381,8 @@ def rollout(
rollout_time
.
append
(
time
.
time
()
-
rollout_time_start
)
start_step
=
args
.
collect_length
-
args
.
num_steps
partitioned_storage
=
prepare_data
(
storage
)
storage
=
storage
[
args
.
num_steps
:
]
storage
=
[
]
sharded_storage
=
[]
for
x
in
partitioned_storage
:
if
isinstance
(
x
,
dict
):
...
...
@@ -418,10 +418,13 @@ def rollout(
SPS_update
=
int
(
args
.
batch_size
/
(
time
.
time
()
-
update_time_start
))
if
device_thread_id
==
0
:
print
(
f
"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}
, rollout_time={rollout_time[-1]:.2f}
"
f
"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
)
time_now
=
datetime
.
now
(
timezone
(
timedelta
(
hours
=
8
)))
.
strftime
(
"
%
H:
%
M:
%
S"
)
print
(
f
"{time_now} SPS: {SPS}, update: {SPS_update}"
)
print
(
f
"{time_now} SPS: {SPS}, update: {SPS_update}, "
f
"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
writer
.
add_scalar
(
"stats/rollout_time"
,
np
.
mean
(
rollout_time
),
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_return"
,
avg_episodic_return
,
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_length"
,
avg_episodic_length
,
global_step
)
...
...
@@ -436,14 +439,13 @@ def rollout(
_start
=
time
.
time
()
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
eval_stat
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)[
0
]
metric_name
=
"eval_return"
eval_return
,
eval_ep_len
,
eval_win_rate
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_
stat
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
[
2
]
metric_name
=
"eval_win_rate"
eval_
return
,
eval_ep_len
,
eval_win_rate
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
eval_stat
=
np
.
array
([
eval_return
,
eval_win_rate
])
if
device_thread_id
!=
0
:
eval_queue
.
put
(
eval_stat
)
else
:
...
...
@@ -451,12 +453,14 @@ def rollout(
eval_stats
.
append
(
eval_stat
)
for
_
in
range
(
1
,
n_actors
):
eval_stats
.
append
(
eval_queue
.
get
())
eval_stats
=
np
.
mean
(
eval_stats
)
writer
.
add_scalar
(
f
"charts/{metric_name}"
,
eval_stats
,
global_step
)
eval_stats
=
np
.
stack
(
eval_stats
)
eval_return
,
eval_win_rate
=
np
.
mean
(
eval_stats
,
axis
=
0
)
writer
.
add_scalar
(
f
"charts/eval_return"
,
eval_return
,
global_step
)
writer
.
add_scalar
(
f
"charts/eval_win_rate"
,
eval_win_rate
,
global_step
)
if
device_thread_id
==
0
:
eval_time
=
time
.
time
()
-
_start
print
(
f
"eval_time={eval_time:.4f}, {metric_name}={eval_stats:.4f}"
)
other_time
+=
eval_time
print
(
f
"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}"
)
if
__name__
==
"__main__"
:
...
...
@@ -485,8 +489,15 @@ if __name__ == "__main__":
args
.
minibatch_size
=
args
.
local_minibatch_size
*
args
.
world_size
args
.
num_updates
=
args
.
total_timesteps
//
(
args
.
local_batch_size
*
args
.
world_size
)
args
.
local_env_threads
=
args
.
local_env_threads
or
args
.
local_num_envs
args
.
collect_length
=
args
.
collect_length
or
args
.
num_steps
assert
args
.
collect_length
>=
args
.
num_steps
,
"collect_length must be greater than or equal to num_steps"
if
args
.
embedding_file
:
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
embedding_shape
=
embeddings
.
shape
args
.
num_embeddings
=
embedding_shape
args
.
freeze_id
=
True
if
args
.
freeze_id
is
None
else
args
.
freeze_id
else
:
embeddings
=
None
embedding_shape
=
None
local_devices
=
jax
.
local_devices
()
global_devices
=
jax
.
devices
()
...
...
@@ -541,6 +552,13 @@ if __name__ == "__main__":
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
agent
=
create_agent
(
args
)
params
=
agent
.
init
(
agent_key
,
(
rstate
,
sample_obs
))
if
embeddings
is
not
None
:
unknown_embed
=
embeddings
.
mean
(
axis
=
0
)
embeddings
=
np
.
concatenate
([
unknown_embed
[
None
,
:],
embeddings
],
axis
=
0
)
params
=
flax
.
core
.
unfreeze
(
params
)
params
[
'params'
][
'Encoder_0'
][
'Embed_0'
][
'embedding'
]
=
jax
.
device_put
(
embeddings
)
params
=
flax
.
core
.
freeze
(
params
)
tx
=
optax
.
MultiSteps
(
optax
.
chain
(
optax
.
clip_by_global_norm
(
args
.
max_grad_norm
),
...
...
@@ -587,66 +605,53 @@ if __name__ == "__main__":
num_envs
=
next_value
.
shape
[
0
]
num_steps
=
dones
.
shape
[
0
]
//
num_envs
mask
=
mask
&
(
~
dones
)
mask
=
mask
*
(
1.0
-
dones
)
n_valids
=
jnp
.
sum
(
mask
)
real_dones
=
dones
|
next_dones
inputs
=
(
rstate1
,
rstate2
,
obs
,
real_dones
,
switch
)
new_logits
,
new_values
=
get_logits_and_value
(
params
,
inputs
)
values
,
rewards
,
next_dones
,
switch
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)),
(
jax
.
lax
.
stop_gradient
(
new_values
)
,
rewards
,
next_dones
,
switch
),
new_values_
,
rewards
,
next_dones
,
switch
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)
+
x
.
shape
[
1
:]
),
(
new_values
,
rewards
,
next_dones
,
switch
),
)
advantages
,
target_values
=
compute_gae_2p0s
(
next_value
,
values
,
rewards
,
next_dones
,
switch
,
args
.
gamma
,
args
.
gae_lambda
)
if
args
.
upgo
:
advantages
=
advantages
+
upgo_advantage
(
next_value
,
values
,
rewards
,
next_dones
,
switch
,
args
.
gamma
)
advantages
,
target_values
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)),
(
advantages
,
target_values
))
ratio
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
logratio
=
jnp
.
log
(
ratio
)
approx_kl
=
(((
ratio
-
1
)
-
logratio
)
*
mask
)
.
sum
()
/
n_valids
target_values
,
advantages
=
truncated_gae_2p0s
(
next_value
,
new_values_
,
rewards
,
next_dones
,
switch
,
args
.
gamma
,
args
.
gae_lambda
,
args
.
upgo
)
target_values
,
advantages
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)),
(
target_values
,
advantages
))
logratio
=
jnp
.
log
(
ratios
)
approx_kl
=
(((
ratios
-
1
)
-
logratio
)
*
mask
)
.
sum
()
/
n_valids
if
args
.
norm_adv
:
advantages
=
masked_normalize
(
advantages
,
mask
,
eps
=
1e-8
)
# Policy loss
if
args
.
spo_kld_max
is
not
None
:
probs
=
jax
.
nn
.
softmax
(
logits
)
new_probs
=
jax
.
nn
.
softmax
(
new_logits
)
eps
=
1e-8
kld
=
jnp
.
sum
(
probs
*
jnp
.
log
((
probs
+
eps
)
/
(
new_probs
+
eps
)),
axis
=-
1
)
kld_clip
=
jnp
.
clip
(
kld
,
0
,
args
.
spo_kld_max
)
d_ratio
=
kld_clip
/
(
kld
+
eps
)
d_ratio
=
jnp
.
where
(
kld
<
1e-6
,
1.0
,
d_ratio
)
sign_a
=
jnp
.
sign
(
advantages
)
result
=
(
d_ratio
+
sign_a
-
1
)
*
sign_a
pg_loss
=
-
advantages
*
ratio
*
result
pg_loss
=
simple_policy_loss
(
ratios
,
logits
,
new_logits
,
advantages
,
args
.
spo_kld_max
)
else
:
pg_loss1
=
-
advantages
*
ratio
pg_loss2
=
-
advantages
*
jnp
.
clip
(
ratio
,
1
-
args
.
clip_coef
,
1
+
args
.
clip_coef
)
pg_loss
=
jnp
.
maximum
(
pg_loss1
,
pg_loss2
)
pg_loss
=
clipped_surrogate_pg_loss
(
ratios
,
advantages
,
args
.
clip_coef
,
args
.
dual_clip_coef
)
pg_loss
=
jnp
.
sum
(
pg_loss
*
mask
)
v_loss
=
0.5
*
((
new_values
-
target_values
)
**
2
)
v_loss
=
mse_loss
(
new_values
,
target_values
)
v_loss
=
jnp
.
sum
(
v_loss
*
mask
)
ent
ropy_loss
=
distrax
.
Softmax
(
new_logits
)
.
entropy
(
)
ent
ropy_loss
=
jnp
.
sum
(
entropy
_loss
*
mask
)
ent
_loss
=
entropy_loss
(
new_logits
)
ent
_loss
=
jnp
.
sum
(
ent
_loss
*
mask
)
pg_loss
=
pg_loss
/
n_valids
v_loss
=
v_loss
/
n_valids
ent
ropy_loss
=
entropy
_loss
/
n_valids
ent
_loss
=
ent
_loss
/
n_valids
loss
=
pg_loss
-
args
.
ent_coef
*
ent
ropy
_loss
+
v_loss
*
args
.
vf_coef
return
loss
,
(
pg_loss
,
v_loss
,
ent
ropy
_loss
,
jax
.
lax
.
stop_gradient
(
approx_kl
))
loss
=
pg_loss
-
args
.
ent_coef
*
ent_loss
+
v_loss
*
args
.
vf_coef
return
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
jax
.
lax
.
stop_gradient
(
approx_kl
))
def
single_device_update
(
agent_state
:
TrainState
,
...
...
@@ -702,7 +707,8 @@ if __name__ == "__main__":
x
=
jnp
.
reshape
(
x
,
(
N
,
-
1
)
+
x
.
shape
[
1
:])
return
x
shuffled_init_rstate1
,
shuffled_init_rstate2
,
shuffled_next_value
=
jax
.
tree
.
map
(
shuffled_init_rstate1
,
shuffled_init_rstate2
,
\
shuffled_next_value
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
1
),
(
init_rstate1
,
init_rstate2
,
next_value
))
shuffled_storage
,
shuffled_switch
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
num_steps
),
(
storage
,
switch
))
...
...
@@ -829,8 +835,9 @@ if __name__ == "__main__":
writer
.
add_scalar
(
"stats/rollout_queue_size"
,
rollout_queues
[
-
1
]
.
qsize
(),
global_step
)
writer
.
add_scalar
(
"stats/params_queue_size"
,
params_queues
[
-
1
]
.
qsize
(),
global_step
)
print
(
global_step
,
f
"actor_update={update}, train_time={time.time() - training_time_start:.2f}"
,
f
"{global_step} actor_update={update}, "
f
"train_time={time.time() - training_time_start:.2f}, "
f
"data_time={rollout_queue_get_time[-1]:.2f}"
)
writer
.
add_scalar
(
"charts/learning_rate"
,
agent_state
.
opt_state
[
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
global_step
...
...
ygoai/rl/jax/__init__.py
View file @
43ca871e
...
...
@@ -2,340 +2,273 @@ from functools import partial
import
jax
import
jax.numpy
as
jnp
from
typing
import
NamedTuple
class
VTraceOutput
(
NamedTuple
):
q_estimate
:
jnp
.
ndarray
errors
:
jnp
.
ndarray
def
vtrace
(
v_tm1
,
v_t
,
r_t
,
discount_t
,
rho_tm1
,
lambda_
=
1.0
,
c_clip_min
:
float
=
0.001
,
c_clip_max
:
float
=
1.007
,
rho_clip_min
:
float
=
0.001
,
rho_clip_max
:
float
=
1.007
,
stop_target_gradients
:
bool
=
True
,
):
"""
Args:
v_tm1: values at time t-1.
v_t: values at time t.
r_t: reward at time t.
discount_t: discount at time t.
rho_tm1: importance sampling ratios at time t-1.
lambda_: mixing parameter; a scalar or a vector for timesteps t.
clip_rho_threshold: clip threshold for importance weights.
stop_target_gradients: whether or not to apply stop gradient to targets.
"""
# Clip importance sampling ratios.
lambda_
=
jnp
.
ones_like
(
discount_t
)
*
lambda_
c_tm1
=
jnp
.
clip
(
rho_tm1
,
c_clip_min
,
c_clip_max
)
*
lambda_
clipped_rhos_tm1
=
jnp
.
clip
(
rho_tm1
,
rho_clip_min
,
rho_clip_max
)
# Compute the temporal difference errors.
td_errors
=
clipped_rhos_tm1
*
(
r_t
+
discount_t
*
v_t
-
v_tm1
)
# Work backwards computing the td-errors.
def
_body
(
acc
,
xs
):
td_error
,
discount
,
c
=
xs
acc
=
td_error
+
discount
*
c
*
acc
return
acc
,
acc
_
,
errors
=
jax
.
lax
.
scan
(
_body
,
0.0
,
(
td_errors
,
discount_t
,
c_tm1
),
reverse
=
True
)
# Return errors, maybe disabling gradient flow through bootstrap targets.
errors
=
jax
.
lax
.
select
(
stop_target_gradients
,
jax
.
lax
.
stop_gradient
(
errors
+
v_tm1
)
-
v_tm1
,
errors
)
targets_tm1
=
errors
+
v_tm1
q_bootstrap
=
jnp
.
concatenate
([
lambda_
[:
-
1
]
*
targets_tm1
[
1
:]
+
(
1
-
lambda_
[:
-
1
])
*
v_tm1
[
1
:],
v_t
[
-
1
:],
],
axis
=
0
)
q_estimate
=
r_t
+
discount_t
*
q_bootstrap
return
VTraceOutput
(
q_estimate
=
q_estimate
,
errors
=
errors
)
def
clipped_surrogate_pg_loss
(
prob_ratios_t
,
adv_t
,
mask
,
epsilon
,
use_stop_gradient
=
True
):
adv_t
=
jax
.
lax
.
select
(
use_stop_gradient
,
jax
.
lax
.
stop_gradient
(
adv_t
),
adv_t
)
clipped_ratios_t
=
jnp
.
clip
(
prob_ratios_t
,
1.
-
epsilon
,
1.
+
epsilon
)
clipped_objective
=
jnp
.
fmin
(
prob_ratios_t
*
adv_t
,
clipped_ratios_t
*
adv_t
)
return
-
jnp
.
mean
(
clipped_objective
*
mask
)
@
partial
(
jax
.
jit
,
static_argnums
=
(
5
,
6
))
def
compute_gae_2p0s
(
next_value
,
values
,
rewards
,
next_dones
,
switch
,
gamma
,
gae_lambda
,
):
def
body_fn
(
carry
,
inp
):
boot_value
,
boot_done
,
next_value
,
lastgaelam
=
carry
next_done
,
cur_value
,
reward
,
switch
=
inp
next_done
=
jnp
.
where
(
switch
,
boot_done
,
next_done
)
next_value
=
jnp
.
where
(
switch
,
-
boot_value
,
next_value
)
lastgaelam
=
jnp
.
where
(
switch
,
0
,
lastgaelam
)
import
chex
import
distrax
gamma_
=
gamma
*
(
1.0
-
next_done
)
delta
=
reward
+
gamma_
*
next_value
-
cur_value
lastgaelam
=
delta
+
gae_lambda
*
gamma_
*
lastgaelam
return
(
boot_value
,
boot_done
,
cur_value
,
lastgaelam
),
lastgaelam
next_done
=
next_dones
[
-
1
]
lastgaelam
=
jnp
.
zeros_like
(
next_value
)
carry
=
next_value
,
next_done
,
next_value
,
lastgaelam
# class VTraceOutput(NamedTuple):
# q_estimate: jnp.ndarray
# errors: jnp.ndarray
_
,
advantages
=
jax
.
lax
.
scan
(
body_fn
,
carry
,
(
next_dones
,
values
,
rewards
,
switch
),
reverse
=
True
)
target_values
=
advantages
+
values
return
advantages
,
target_values
@
partial
(
jax
.
jit
,
static_argnums
=
(
5
,))
def
upgo_advantage
(
next_value
,
values
,
rewards
,
next_dones
,
switch
,
gamma
):
def
body_fn
(
carry
,
inp
):
boot_value
,
boot_done
,
next_value
,
next_q
,
last_return
=
carry
next_done
,
cur_value
,
reward
,
switch
=
inp
next_done
=
jnp
.
where
(
switch
,
boot_done
,
next_done
)
next_value
=
jnp
.
where
(
switch
,
-
boot_value
,
next_value
)
next_q
=
jnp
.
where
(
switch
,
-
boot_value
*
gamma
,
next_q
)
last_return
=
jnp
.
where
(
switch
,
-
boot_value
,
last_return
)
gamma_
=
gamma
*
(
1.0
-
next_done
)
last_return
=
reward
+
gamma_
*
jnp
.
where
(
next_q
>=
next_value
,
last_return
,
next_value
)
next_q
=
reward
+
gamma_
*
next_value
carry
=
boot_value
,
boot_done
,
cur_value
,
next_q
,
last_return
return
carry
,
last_return
next_done
=
next_dones
[
-
1
]
carry
=
next_value
,
next_done
,
next_value
,
next_value
,
next_value
_
,
returns
=
jax
.
lax
.
scan
(
body_fn
,
carry
,
(
next_dones
,
values
,
rewards
,
switch
),
reverse
=
True
)
return
returns
-
values
# def compute_gae_once(carry, inp, gamma, gae_lambda):
# v1, v2, next_values1, next_values2, reward1, reward2, xi1, xi2 = carry
# rho, cur_values, log_ratio, next_done, r_t, corr_r_t, main = inp
# v = jnp.where(main, v1, v2)
# next_values = jnp.where(main, next_values1, next_values2)
# reward = jnp.where(main, reward1, reward2)
# xi = jnp.where(main, xi1, xi2)
# p_t = c_t = jnp.minimum(1.0, rho * xi)
# sig_v = p_t * (r_t + reward * rho + next_values - cur_values)
# reg_r = jnp.log(p / p_reg)
# q = r_t + rho * (reward + v)
# q = -eta * + cur_values
# v = cur_values + sig_v + c_t * (v - next_values)
# v1 = jnp.where(main, v, v1)
# v2 = jnp.where(main, v2, v)
# next_values1 = jnp.where(main, cur_values, next_values1)
# next_values2 = jnp.where(main, next_values2, cur_values)
# reward1 = jnp.where(main, 0, r_t + rho * reward1)
# reward2 = jnp.where(main, r_t + rho * reward2, 0)
# xi1 = jnp.where(main, 1, rho * xi1)
# xi2 = jnp.where(main, rho * xi2, 1)
# learn1 = learn
# learn2 = ~learn
# factor = jnp.where(learn1, jnp.ones_like(reward), -jnp.ones_like(reward))
# reward1 = jnp.where(next_done, reward * factor, jnp.where(learn1 & done_used1, 0, reward1))
# reward2 = jnp.where(next_done, reward * -factor, jnp.where(learn2 & done_used2, 0, reward2))
# real_done1 = next_done | ~done_used1
# nextvalues1 = jnp.where(real_done1, 0, nextvalues1)
# lastgaelam1 = jnp.where(real_done1, 0, lastgaelam1)
# real_done2 = next_done | ~done_used2
# nextvalues2 = jnp.where(real_done2, 0, nextvalues2)
# lastgaelam2 = jnp.where(real_done2, 0, lastgaelam2)
# done_used1 = jnp.where(
# next_done, learn1, jnp.where(learn1 & ~done_used1, True, done_used1))
# done_used2 = jnp.where(
# next_done, learn2, jnp.where(learn2 & ~done_used2, True, done_used2))
# delta1 = reward1 + gamma * nextvalues1 - curvalues
# delta2 = reward2 + gamma * nextvalues2 - curvalues
# lastgaelam1_ = delta1 + gamma * gae_lambda * lastgaelam1
# lastgaelam2_ = delta2 + gamma * gae_lambda * lastgaelam2
# advantages = jnp.where(learn1, lastgaelam1_, lastgaelam2_)
# nextvalues1 = jnp.where(learn1, curvalues, nextvalues1)
# nextvalues2 = jnp.where(learn2, curvalues, nextvalues2)
# lastgaelam1 = jnp.where(learn1, lastgaelam1_, lastgaelam1)
# lastgaelam2 = jnp.where(learn2, lastgaelam2_, lastgaelam2)
# carry = nextvalues1, nextvalues2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
# return carry, advantages
# @partial(jax.jit, static_argnums=(6, 7))
# def vtrace_rnad(
# next_value, next_done, values, rewards, dones, learns,
# gamma, gae_lambda,
# ):
# next_value1 = next_value
# next_value2 = -next_value1
# done_used1 = jnp.ones_like(next_done)
# done_used2 = jnp.ones_like(next_done)
# reward1 = jnp.zeros_like(next_value)
# reward2 = jnp.zeros_like(next_value)
# lastgaelam1 = jnp.zeros_like(next_value)
# lastgaelam2 = jnp.zeros_like(next_value)
# carry = next_value1, next_value2, done_used1, done_used2, reward1, reward2, lastgaelam1, lastgaelam2
# dones = jnp.concatenate([dones, next_done[None, :]], axis=0)
# _, advantages = jax.lax.scan(
# partial(compute_gae_once, gamma=gamma, gae_lambda=gae_lambda),
# carry, (dones[1:], values, rewards, learns), reverse=True
# )
# target_values = advantages + values
# return advantages, target_values
def
compute_gae_once
(
carry
,
inp
,
gamma
,
gae_lambda
):
nextvalues1
,
nextvalues2
,
done_used1
,
done_used2
,
reward1
,
reward2
,
lastgaelam1
,
lastgaelam2
=
carry
next_done
,
curvalues
,
reward
,
learn
=
inp
learn1
=
learn
learn2
=
~
learn
factor
=
jnp
.
where
(
learn1
,
jnp
.
ones_like
(
reward
),
-
jnp
.
ones_like
(
reward
))
reward1
=
jnp
.
where
(
next_done
,
reward
*
factor
,
jnp
.
where
(
learn1
&
done_used1
,
0
,
reward1
))
reward2
=
jnp
.
where
(
next_done
,
reward
*
-
factor
,
jnp
.
where
(
learn2
&
done_used2
,
0
,
reward2
))
real_done1
=
next_done
|
~
done_used1
nextvalues1
=
jnp
.
where
(
real_done1
,
0
,
nextvalues1
)
lastgaelam1
=
jnp
.
where
(
real_done1
,
0
,
lastgaelam1
)
real_done2
=
next_done
|
~
done_used2
nextvalues2
=
jnp
.
where
(
real_done2
,
0
,
nextvalues2
)
lastgaelam2
=
jnp
.
where
(
real_done2
,
0
,
lastgaelam2
)
done_used1
=
jnp
.
where
(
next_done
,
learn1
,
jnp
.
where
(
learn1
&
~
done_used1
,
True
,
done_used1
))
done_used2
=
jnp
.
where
(
next_done
,
learn2
,
jnp
.
where
(
learn2
&
~
done_used2
,
True
,
done_used2
))
delta1
=
reward1
+
gamma
*
nextvalues1
-
curvalues
delta2
=
reward2
+
gamma
*
nextvalues2
-
curvalues
lastgaelam1_
=
delta1
+
gamma
*
gae_lambda
*
lastgaelam1
lastgaelam2_
=
delta2
+
gamma
*
gae_lambda
*
lastgaelam2
advantages
=
jnp
.
where
(
learn1
,
lastgaelam1_
,
lastgaelam2_
)
nextvalues1
=
jnp
.
where
(
learn1
,
curvalues
,
nextvalues1
)
nextvalues2
=
jnp
.
where
(
learn2
,
curvalues
,
nextvalues2
)
lastgaelam1
=
jnp
.
where
(
learn1
,
lastgaelam1_
,
lastgaelam1
)
lastgaelam2
=
jnp
.
where
(
learn2
,
lastgaelam2_
,
lastgaelam2
)
carry
=
nextvalues1
,
nextvalues2
,
done_used1
,
done_used2
,
reward1
,
reward2
,
lastgaelam1
,
lastgaelam2
return
carry
,
advantages
@
partial
(
jax
.
jit
,
static_argnums
=
(
7
,
8
))
def
compute_gae
(
next_value
,
next_done
,
next_learn
,
values
,
rewards
,
dones
,
learns
,
gamma
,
gae_lambda
,
# def vtrace(
# v_tm1,
# v_t,
# r_t,
# discount_t,
# rho_tm1,
# lambda_=1.0,
# c_clip_min: float = 0.001,
# c_clip_max: float = 1.007,
# rho_clip_min: float = 0.001,
# rho_clip_max: float = 1.007,
# stop_target_gradients: bool = True,
# ):
# """
# Args:
# v_tm1: values at time t-1.
# v_t: values at time t.
# r_t: reward at time t.
# discount_t: discount at time t.
# rho_tm1: importance sampling ratios at time t-1.
# lambda_: mixing parameter; a scalar or a vector for timesteps t.
# clip_rho_threshold: clip threshold for importance weights.
# stop_target_gradients: whether or not to apply stop gradient to targets.
# """
# # Clip importance sampling ratios.
# lambda_ = jnp.ones_like(discount_t) * lambda_
# c_tm1 = jnp.clip(rho_tm1, c_clip_min, c_clip_max) * lambda_
# clipped_rhos_tm1 = jnp.clip(rho_tm1, rho_clip_min, rho_clip_max)
# # Compute the temporal difference errors.
# td_errors = clipped_rhos_tm1 * (r_t + discount_t * v_t - v_tm1)
# # Work backwards computing the td-errors.
# def _body(acc, xs):
# td_error, discount, c = xs
# acc = td_error + discount * c * acc
# return acc, acc
# _, errors = jax.lax.scan(
# _body, 0.0, (td_errors, discount_t, c_tm1), reverse=True)
# # Return errors, maybe disabling gradient flow through bootstrap targets.
# errors = jax.lax.select(
# stop_target_gradients,
# jax.lax.stop_gradient(errors + v_tm1) - v_tm1,
# errors)
# targets_tm1 = errors + v_tm1
# q_bootstrap = jnp.concatenate([
# lambda_[:-1] * targets_tm1[1:] + (1 - lambda_[:-1]) * v_tm1[1:],
# v_t[-1:],
# ], axis=0)
# q_estimate = r_t + discount_t * q_bootstrap
# return VTraceOutput(q_estimate=q_estimate, errors=errors)
def
entropy_loss
(
logits
):
return
distrax
.
Softmax
(
logits
=
logits
)
.
entropy
()
def
mse_loss
(
y_true
,
y_pred
):
return
0.5
*
((
y_true
-
y_pred
)
**
2
)
def
policy_gradient_loss
(
logits
,
actions
,
advantages
):
chex
.
assert_type
([
logits
,
actions
,
advantages
],
[
float
,
int
,
float
])
advs
=
jax
.
lax
.
stop_gradient
(
advantages
)
log_probs
=
distrax
.
Softmax
(
logits
=
logits
)
.
log_prob
(
actions
)
pg_loss
=
-
log_probs
*
advs
return
pg_loss
def
clipped_surrogate_pg_loss
(
ratios
,
advantages
,
clip_coef
,
dual_clip_coef
=
None
):
# dual clip from JueWu (Mastering Complex Control in MOBA Games with Deep Reinforcement Learning)
advs
=
jax
.
lax
.
stop_gradient
(
advantages
)
clipped_ratios
=
jnp
.
clip
(
ratios
,
1
-
clip_coef
,
1
+
clip_coef
)
clipped_obj
=
jnp
.
fmin
(
ratios
*
advs
,
clipped_ratios
*
advs
)
if
dual_clip_coef
is
not
None
:
clipped_obj
=
jnp
.
where
(
advs
>=
0
,
clipped_obj
,
jnp
.
fmax
(
clipped_obj
,
dual_clip_coef
*
advs
)
)
pg_loss
=
-
clipped_obj
return
pg_loss
def
vtrace_loop
(
carry
,
inp
,
gamma
,
rho_min
,
rho_max
,
c_min
,
c_max
):
v1
,
v2
,
next_values1
,
next_values2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
last_return1
,
last_return2
,
next_q1
,
next_q2
=
carry
ratio
,
cur_values
,
next_done
,
r_t
,
main
=
inp
v1
=
jnp
.
where
(
next_done
,
0
,
v1
)
v2
=
jnp
.
where
(
next_done
,
0
,
v2
)
next_values1
=
jnp
.
where
(
next_done
,
0
,
next_values1
)
next_values2
=
jnp
.
where
(
next_done
,
0
,
next_values2
)
reward1
=
jnp
.
where
(
next_done
,
0
,
reward1
)
reward2
=
jnp
.
where
(
next_done
,
0
,
reward2
)
xi1
=
jnp
.
where
(
next_done
,
1
,
xi1
)
xi2
=
jnp
.
where
(
next_done
,
1
,
xi2
)
discount
=
gamma
*
(
1.0
-
next_done
)
v
=
jnp
.
where
(
main
,
v1
,
v2
)
next_values
=
jnp
.
where
(
main
,
next_values1
,
next_values2
)
reward
=
jnp
.
where
(
main
,
reward1
,
reward2
)
xi
=
jnp
.
where
(
main
,
xi1
,
xi2
)
q_t
=
r_t
+
ratio
*
reward
+
discount
*
v
rho_t
=
jnp
.
clip
(
ratio
*
xi
,
rho_min
,
rho_max
)
c_t
=
jnp
.
clip
(
ratio
*
xi
,
c_min
,
c_max
)
sig_v
=
rho_t
*
(
r_t
+
ratio
*
reward
+
discount
*
next_values
-
cur_values
)
v
=
cur_values
+
sig_v
+
c_t
*
discount
*
(
v
-
next_values
)
# UPGO advantage (not corrected by importance sampling, unlike V-trace)
return_t
=
jnp
.
where
(
main
,
last_return1
,
last_return2
)
next_q
=
jnp
.
where
(
main
,
next_q1
,
next_q2
)
factor
=
jnp
.
where
(
main
,
jnp
.
ones_like
(
r_t
),
-
jnp
.
ones_like
(
r_t
))
return_t
=
r_t
+
discount
*
jnp
.
where
(
next_q
>=
next_values
,
return_t
,
next_values
)
last_return1
=
jnp
.
where
(
next_done
,
r_t
*
factor
,
jnp
.
where
(
main
,
return_t
,
last_return1
))
last_return2
=
jnp
.
where
(
next_done
,
r_t
*
-
factor
,
jnp
.
where
(
main
,
last_return2
,
return_t
))
next_q
=
r_t
+
discount
*
next_values
next_q1
=
jnp
.
where
(
next_done
,
r_t
*
factor
,
jnp
.
where
(
main
,
next_q
,
next_q1
))
next_q2
=
jnp
.
where
(
next_done
,
r_t
*
-
factor
,
jnp
.
where
(
main
,
next_q2
,
next_q
))
v1
=
jnp
.
where
(
main
,
v
,
v1
)
v2
=
jnp
.
where
(
main
,
v2
,
v
)
next_values1
=
jnp
.
where
(
main
,
cur_values
,
next_values1
)
next_values2
=
jnp
.
where
(
main
,
next_values2
,
cur_values
)
reward1
=
jnp
.
where
(
main
,
0
,
-
r_t
+
ratio
*
reward1
)
reward2
=
jnp
.
where
(
main
,
-
r_t
+
ratio
*
reward2
,
0
)
xi1
=
jnp
.
where
(
main
,
1
,
ratio
*
xi1
)
xi2
=
jnp
.
where
(
main
,
ratio
*
xi2
,
1
)
carry
=
v1
,
v2
,
next_values1
,
next_values2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
last_return1
,
last_return2
,
next_q1
,
next_q2
return
carry
,
(
v
,
q_t
,
return_t
)
def
vtrace_2p0s
(
next_value
,
ratios
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
rho_min
=
0.001
,
rho_max
=
1.0
,
c_min
=
0.001
,
c_max
=
1.0
,
upgo
=
False
,
):
next_value1
=
jnp
.
where
(
next_learn
,
next_value
,
-
next_value
)
next_value1
=
next_value
next_value2
=
-
next_value1
done_used1
=
jnp
.
ones_like
(
next_done
)
done_used2
=
jnp
.
ones_like
(
next_done
)
reward1
=
jnp
.
zeros_like
(
next_value
)
reward2
=
jnp
.
zeros_like
(
next_value
)
lastgaelam1
=
jnp
.
zeros_like
(
next_value
)
lastgaelam2
=
jnp
.
zeros_like
(
next_value
)
carry
=
next_value1
,
next_value2
,
done_used1
,
done_used2
,
reward1
,
reward2
,
lastgaelam1
,
lastgaelam2
dones
=
jnp
.
concatenate
([
dones
,
next_done
[
None
,
:]],
axis
=
0
)
_
,
advantages
=
jax
.
lax
.
scan
(
partial
(
compute_gae_once
,
gamma
=
gamma
,
gae_lambda
=
gae_lambda
),
carry
,
(
dones
[
1
:],
values
,
rewards
,
learns
),
reverse
=
True
v1
=
return1
=
next_q1
=
next_value1
v2
=
return2
=
next_q2
=
next_value2
reward1
=
reward2
=
jnp
.
zeros_like
(
next_value
)
xi1
=
xi2
=
jnp
.
ones_like
(
next_value
)
carry
=
v1
,
v2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
xi1
,
xi2
,
\
return1
,
return2
,
next_q1
,
next_q2
_
,
(
targets
,
q_estimate
,
return_t
)
=
jax
.
lax
.
scan
(
partial
(
vtrace_loop
,
gamma
=
gamma
,
rho_min
=
rho_min
,
rho_max
=
rho_max
,
c_min
=
c_min
,
c_max
=
c_max
),
carry
,
(
ratios
,
values
,
next_dones
,
rewards
,
mains
),
reverse
=
True
)
target_values
=
advantages
+
values
return
advantages
,
target_values
def
compute_gae_once_upgo
(
carry
,
inp
,
gamma
,
gae_lambda
):
next_value1
,
next_value2
,
next_q1
,
next_q2
,
last_return1
,
last_return2
,
\
done_used1
,
done_used2
,
reward1
,
reward2
,
lastgaelam1
,
lastgaelam2
=
carry
next_done
,
curvalues
,
reward
,
learn
=
inp
learn1
=
learn
learn2
=
~
learn
factor
=
jnp
.
where
(
learn1
,
jnp
.
ones_like
(
reward
),
-
jnp
.
ones_like
(
reward
))
reward1
=
jnp
.
where
(
next_done
,
reward
*
factor
,
jnp
.
where
(
learn1
&
done_used1
,
0
,
reward1
))
reward2
=
jnp
.
where
(
next_done
,
reward
*
-
factor
,
jnp
.
where
(
learn2
&
done_used2
,
0
,
reward2
))
advantages
=
q_estimate
-
values
if
upgo
:
advantages
+=
return_t
-
values
targets
=
jax
.
lax
.
stop_gradient
(
targets
)
return
targets
,
advantages
def
truncated_gae_upgo_loop
(
carry
,
inp
,
gamma
,
gae_lambda
):
lastgaelam1
,
lastgaelam2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
\
done_used1
,
done_used2
,
last_return1
,
last_return2
,
next_q1
,
next_q2
=
carry
cur_value
,
next_done
,
reward
,
main
=
inp
main1
=
main
main2
=
~
main
factor
=
jnp
.
where
(
main1
,
jnp
.
ones_like
(
reward
),
-
jnp
.
ones_like
(
reward
))
reward1
=
jnp
.
where
(
next_done
,
reward
*
factor
,
jnp
.
where
(
main1
&
done_used1
,
0
,
reward1
))
reward2
=
jnp
.
where
(
next_done
,
reward
*
-
factor
,
jnp
.
where
(
main2
&
done_used2
,
0
,
reward2
))
real_done1
=
next_done
|
~
done_used1
next_value1
=
jnp
.
where
(
real_done1
,
0
,
next_value1
)
last_return1
=
jnp
.
where
(
real_done1
,
0
,
last_return1
)
lastgaelam1
=
jnp
.
where
(
real_done1
,
0
,
lastgaelam1
)
real_done2
=
next_done
|
~
done_used2
next_value2
=
jnp
.
where
(
real_done2
,
0
,
next_value2
)
last_return2
=
jnp
.
where
(
real_done2
,
0
,
last_return2
)
lastgaelam2
=
jnp
.
where
(
real_done2
,
0
,
lastgaelam2
)
done_used1
=
jnp
.
where
(
next_done
,
learn1
,
jnp
.
where
(
lear
n1
&
~
done_used1
,
True
,
done_used1
))
next_done
,
main1
,
jnp
.
where
(
mai
n1
&
~
done_used1
,
True
,
done_used1
))
done_used2
=
jnp
.
where
(
next_done
,
learn2
,
jnp
.
where
(
lear
n2
&
~
done_used2
,
True
,
done_used2
))
next_done
,
main2
,
jnp
.
where
(
mai
n2
&
~
done_used2
,
True
,
done_used2
))
# UPGO advantage
last_return1
=
jnp
.
where
(
real_done1
,
0
,
last_return1
)
last_return2
=
jnp
.
where
(
real_done2
,
0
,
last_return2
)
last_return1_
=
reward1
+
gamma
*
jnp
.
where
(
next_q1
>=
next_value1
,
last_return1
,
next_value1
)
last_return2_
=
reward2
+
gamma
*
jnp
.
where
(
next_q2
>=
next_value2
,
last_return2
,
next_value2
)
next_q1_
=
reward1
+
gamma
*
next_value1
next_q2_
=
reward2
+
gamma
*
next_value2
delta1
=
next_q1_
-
curvalues
delta2
=
next_q2_
-
curvalues
next_q1
=
jnp
.
where
(
main1
,
next_q1_
,
next_q1
)
next_q2
=
jnp
.
where
(
main2
,
next_q2_
,
next_q1
)
last_return1
=
jnp
.
where
(
main1
,
last_return1_
,
last_return1
)
last_return2
=
jnp
.
where
(
main2
,
last_return2_
,
last_return2
)
returns
=
jnp
.
where
(
main1
,
last_return1_
,
last_return2_
)
delta1
=
next_q1_
-
cur_value
delta2
=
next_q2_
-
cur_value
lastgaelam1_
=
delta1
+
gamma
*
gae_lambda
*
lastgaelam1
lastgaelam2_
=
delta2
+
gamma
*
gae_lambda
*
lastgaelam2
returns
=
jnp
.
where
(
learn1
,
last_return1_
,
last_return2_
)
advantages
=
jnp
.
where
(
learn1
,
lastgaelam1_
,
lastgaelam2_
)
next_value1
=
jnp
.
where
(
learn1
,
curvalues
,
next_value1
)
next_value2
=
jnp
.
where
(
learn2
,
curvalues
,
next_value2
)
lastgaelam1
=
jnp
.
where
(
learn1
,
lastgaelam1_
,
lastgaelam1
)
lastgaelam2
=
jnp
.
where
(
learn2
,
lastgaelam2_
,
lastgaelam2
)
next_q1
=
jnp
.
where
(
learn1
,
next_q1_
,
next_q1
)
next_q2
=
jnp
.
where
(
learn2
,
next_q2_
,
next_q1
)
last_return1
=
jnp
.
where
(
learn1
,
last_return1_
,
last_return1
)
last_return2
=
jnp
.
where
(
learn2
,
last_return2_
,
last_return2
)
carry
=
next_value1
,
next_value2
,
next_q1
,
next_q2
,
last_return1
,
last_return2
,
\
done_used1
,
done_used2
,
reward1
,
reward2
,
lastgaelam1
,
lastgaelam2
advantages
=
jnp
.
where
(
main1
,
lastgaelam1_
,
lastgaelam2_
)
next_value1
=
jnp
.
where
(
main1
,
cur_value
,
next_value1
)
next_value2
=
jnp
.
where
(
main2
,
cur_value
,
next_value2
)
lastgaelam1
=
jnp
.
where
(
main1
,
lastgaelam1_
,
lastgaelam1
)
lastgaelam2
=
jnp
.
where
(
main2
,
lastgaelam2_
,
lastgaelam2
)
carry
=
lastgaelam1
,
lastgaelam2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
\
done_used1
,
done_used2
,
last_return1
,
last_return2
,
next_q1
,
next_q2
return
carry
,
(
advantages
,
returns
)
@
partial
(
jax
.
jit
,
static_argnums
=
(
7
,
8
))
def
compute_gae_upgo
(
next_value
,
next_done
,
next_learn
,
values
,
rewards
,
dones
,
learns
,
gamma
,
gae_lambda
,
def
truncated_gae_2p0s
(
next_value
,
values
,
rewards
,
next_dones
,
mains
,
gamma
,
gae_lambda
,
upgo
,
):
next_value1
=
jnp
.
where
(
next_learn
,
next_value
,
-
next_value
)
next_value1
=
next_value
next_value2
=
-
next_value1
last_return1
=
next_q1
=
next_value1
last_return2
=
next_q2
=
next_value2
done_used1
=
jnp
.
ones_like
(
next_done
)
done_used2
=
jnp
.
ones_like
(
next_done
)
reward1
=
jnp
.
zeros_like
(
next_value
)
reward2
=
jnp
.
zeros_like
(
next_value
)
lastgaelam1
=
jnp
.
zeros_like
(
next_value
)
lastgaelam2
=
jnp
.
zeros_like
(
next_value
)
carry
=
next_value1
,
next_value2
,
next_q1
,
next_q2
,
last_return1
,
last_return2
,
\
done_used1
,
done_used2
,
reward1
,
reward2
,
lastgaelam1
,
lastgaelam2
dones
=
jnp
.
concatenate
([
dones
,
next_done
[
None
,
:]],
axis
=
0
)
done_used1
=
jnp
.
ones_like
(
next_dones
[
-
1
])
done_used2
=
jnp
.
ones_like
(
next_dones
[
-
1
])
reward1
=
reward2
=
jnp
.
zeros_like
(
next_value
)
lastgaelam1
=
lastgaelam2
=
jnp
.
zeros_like
(
next_value
)
carry
=
lastgaelam1
,
lastgaelam2
,
next_value1
,
next_value2
,
reward1
,
reward2
,
\
done_used1
,
done_used2
,
last_return1
,
last_return2
,
next_q1
,
next_q2
_
,
(
advantages
,
returns
)
=
jax
.
lax
.
scan
(
partial
(
compute_gae_once_upgo
,
gamma
=
gamma
,
gae_lambda
=
gae_lambda
),
carry
,
(
dones
[
1
:],
values
,
rewards
,
lear
ns
),
reverse
=
True
partial
(
truncated_gae_upgo_loop
,
gamma
=
gamma
,
gae_lambda
=
gae_lambda
),
carry
,
(
values
,
next_dones
,
rewards
,
mai
ns
),
reverse
=
True
)
return
returns
-
values
,
advantages
+
values
if
upgo
:
advantages
+=
returns
-
values
targets
=
values
+
advantages
targets
=
jax
.
lax
.
stop_gradient
(
targets
)
return
targets
,
advantages
def
simple_policy_loss
(
ratios
,
logits
,
new_logits
,
advantages
,
kld_max
,
eps
=
1e-12
):
advs
=
jax
.
lax
.
stop_gradient
(
advantages
)
probs
=
jax
.
nn
.
softmax
(
logits
)
new_probs
=
jax
.
nn
.
softmax
(
new_logits
)
kld
=
jnp
.
sum
(
probs
*
jnp
.
log
((
probs
+
eps
)
/
(
new_probs
+
eps
)),
axis
=-
1
)
kld_clip
=
jnp
.
clip
(
kld
,
0
,
kld_max
)
d_ratio
=
kld_clip
/
(
kld
+
eps
)
# e == 1 and t == 1
d_ratio
=
jnp
.
where
(
kld
<
1e-6
,
1.0
,
d_ratio
)
sign_a
=
jnp
.
sign
(
advs
)
result
=
(
d_ratio
+
sign_a
-
1
)
*
sign_a
pg_loss
=
-
advs
*
ratios
*
result
return
pg_loss
\ No newline at end of file
ygoai/rl/jax/agent2.py
View file @
43ca871e
...
...
@@ -150,6 +150,7 @@ class Encoder(nn.Module):
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
freeze_id
:
bool
=
False
@
nn
.
compact
def
__call__
(
self
,
x
):
...
...
@@ -168,6 +169,8 @@ class Encoder(nn.Module):
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
param_dtype
=
self
.
param_dtype
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
if
self
.
freeze_id
:
id_embed
=
lambda
x
:
jax
.
lax
.
stop_gradient
(
id_embed
(
x
))
action_encoder
=
ActionEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
...
...
@@ -337,6 +340,7 @@ class PPOLSTMAgent(nn.Module):
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
multi_step
:
bool
=
False
switch
:
bool
=
True
freeze_id
:
bool
=
False
@
nn
.
compact
def
__call__
(
self
,
inputs
):
...
...
@@ -355,6 +359,7 @@ class PPOLSTMAgent(nn.Module):
embedding_shape
=
self
.
embedding_shape
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
freeze_id
=
self
.
freeze_id
,
)
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
...
...
ygoai/rl/jax/switch.py
0 → 100644
View file @
43ca871e
import
jax
import
jax.numpy
as
jnp
def
truncated_gae_2p0s
(
next_value
,
values
,
rewards
,
next_dones
,
switch
,
gamma
,
gae_lambda
,
upgo
):
def
body_fn
(
carry
,
inp
):
boot_value
,
boot_done
,
next_value
,
lastgaelam
,
next_q
,
last_return
=
carry
next_done
,
cur_value
,
reward
,
switch
=
inp
next_done
=
jnp
.
where
(
switch
,
boot_done
,
next_done
)
next_value
=
jnp
.
where
(
switch
,
-
boot_value
,
next_value
)
lastgaelam
=
jnp
.
where
(
switch
,
0
,
lastgaelam
)
next_q
=
jnp
.
where
(
switch
,
-
boot_value
*
gamma
,
next_q
)
last_return
=
jnp
.
where
(
switch
,
-
boot_value
,
last_return
)
discount
=
gamma
*
(
1.0
-
next_done
)
last_return
=
reward
+
discount
*
jnp
.
where
(
next_q
>=
next_value
,
last_return
,
next_value
)
next_q
=
reward
+
discount
*
next_value
delta
=
next_q
-
cur_value
lastgaelam
=
delta
+
gae_lambda
*
discount
*
lastgaelam
carry
=
boot_value
,
boot_done
,
cur_value
,
lastgaelam
,
next_q
,
last_return
return
carry
,
(
lastgaelam
,
last_return
)
next_done
=
next_dones
[
-
1
]
lastgaelam
=
jnp
.
zeros_like
(
next_value
)
next_q
=
last_return
=
next_value
carry
=
next_value
,
next_done
,
next_value
,
lastgaelam
,
next_q
,
last_return
_
,
(
advantages
,
returns
)
=
jax
.
lax
.
scan
(
body_fn
,
carry
,
(
next_dones
,
values
,
rewards
,
switch
),
reverse
=
True
)
if
upgo
:
advantages
+=
returns
-
values
targets
=
values
+
advantages
targets
=
jax
.
lax
.
stop_gradient
(
targets
)
return
targets
,
advantages
ygoai/rl/utils.py
View file @
43ca871e
...
...
@@ -58,13 +58,3 @@ def masked_normalize(x, valid, eps=1e-8):
def
to_tensor
(
x
,
device
,
dtype
=
None
):
return
optree
.
tree_map
(
lambda
x
:
torch
.
from_numpy
(
x
)
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
True
),
x
)
def
load_embeddings
(
embedding_file
,
code_list_file
):
with
open
(
embedding_file
,
"rb"
)
as
f
:
embeddings
=
pickle
.
load
(
f
)
with
open
(
code_list_file
,
"r"
)
as
f
:
code_list
=
f
.
readlines
()
code_list
=
[
int
(
code
.
strip
())
for
code
in
code_list
]
assert
len
(
embeddings
)
==
len
(
code_list
),
f
"len(embeddings)={len(embeddings)}, len(code_list)={len(code_list)}"
embeddings
=
np
.
array
([
embeddings
[
code
]
for
code
in
code_list
],
dtype
=
np
.
float32
)
return
embeddings
ygoai/utils.py
View file @
43ca871e
import
pickle
import
numpy
as
np
from
pathlib
import
Path
...
...
@@ -43,4 +45,15 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
elif
'EDOPro'
in
env_id
:
from
ygoenv.edopro
import
init_module
init_module
(
str
(
db_path
),
code_list_file
,
decks
)
return
deck_name
\ No newline at end of file
return
deck_name
def
load_embeddings
(
embedding_file
,
code_list_file
):
with
open
(
embedding_file
,
"rb"
)
as
f
:
embeddings
=
pickle
.
load
(
f
)
with
open
(
code_list_file
,
"r"
)
as
f
:
code_list
=
f
.
readlines
()
code_list
=
[
int
(
code
.
strip
())
for
code
in
code_list
]
assert
len
(
embeddings
)
==
len
(
code_list
),
f
"len(embeddings)={len(embeddings)}, len(code_list)={len(code_list)}"
embeddings
=
np
.
array
([
embeddings
[
code
]
for
code
in
code_list
],
dtype
=
np
.
float32
)
return
embeddings
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment