Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
34f86ae4
Commit
34f86ae4
authored
Apr 14, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Unify Impala and PPO
parent
9d8d4386
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
976 additions
and
95 deletions
+976
-95
scripts/jax/battle.py
scripts/jax/battle.py
+4
-1
scripts/jax/impala.py
scripts/jax/impala.py
+35
-46
scripts/jax/impala2.py
scripts/jax/impala2.py
+819
-0
scripts/jax/ppo_lstm.py
scripts/jax/ppo_lstm.py
+11
-23
scripts/jax/ppo_lstm2.py
scripts/jax/ppo_lstm2.py
+57
-17
ygoai/rl/jax/eval.py
ygoai/rl/jax/eval.py
+50
-8
No files found.
scripts/jax/battle.py
View file @
34f86ae4
...
@@ -165,7 +165,10 @@ if __name__ == "__main__":
...
@@ -165,7 +165,10 @@ if __name__ == "__main__":
else
:
else
:
with
open
(
args
.
checkpoint2
,
"rb"
)
as
f
:
with
open
(
args
.
checkpoint2
,
"rb"
)
as
f
:
params2
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params2
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params1
=
jax
.
device_put
(
params1
)
params2
=
jax
.
device_put
(
params2
)
@
jax
.
jit
@
jax
.
jit
def
get_probs
(
params
,
rstate
,
obs
,
done
):
def
get_probs
(
params
,
rstate
,
obs
,
done
):
agent
=
create_agent
(
args
)
agent
=
create_agent
(
args
)
...
...
scripts/jax/impala.py
View file @
34f86ae4
...
@@ -17,6 +17,7 @@ import jax.numpy as jnp
...
@@ -17,6 +17,7 @@ import jax.numpy as jnp
import
numpy
as
np
import
numpy
as
np
import
optax
import
optax
import
rlax
import
rlax
import
distrax
import
tyro
import
tyro
from
flax.training.train_state
import
TrainState
from
flax.training.train_state
import
TrainState
from
rich.pretty
import
pprint
from
rich.pretty
import
pprint
...
@@ -28,6 +29,7 @@ from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_norm
...
@@ -28,6 +29,7 @@ from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_norm
from
ygoai.rl.jax.eval
import
evaluate
from
ygoai.rl.jax.eval
import
evaluate
from
ygoai.rl.jax
import
vtrace
,
upgo_return
,
clipped_surrogate_pg_loss
from
ygoai.rl.jax
import
vtrace
,
upgo_return
,
clipped_surrogate_pg_loss
os
.
environ
[
"XLA_FLAGS"
]
=
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
os
.
environ
[
"XLA_FLAGS"
]
=
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
...
@@ -40,7 +42,9 @@ class Args:
...
@@ -40,7 +42,9 @@ class Args:
log_frequency
:
int
=
10
log_frequency
:
int
=
10
"""the logging frequency of the model performance (in terms of `updates`)"""
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval
:
int
=
100
save_interval
:
int
=
100
"""the frequency of saving the model"""
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments
# Algorithm specific arguments
env_id
:
str
=
"YGOPro-v0"
env_id
:
str
=
"YGOPro-v0"
...
@@ -78,8 +82,6 @@ class Args:
...
@@ -78,8 +82,6 @@ class Args:
"""the discount factor gamma"""
"""the discount factor gamma"""
num_minibatches
:
int
=
4
num_minibatches
:
int
=
4
"""the number of mini-batches"""
"""the number of mini-batches"""
gradient_accumulation_steps
:
int
=
1
"""the number of gradient accumulation steps before performing an optimization step"""
c_clip_min
:
float
=
0.001
c_clip_min
:
float
=
0.001
"""the minimum value of the importance sampling clipping"""
"""the minimum value of the importance sampling clipping"""
c_clip_max
:
float
=
1.007
c_clip_max
:
float
=
1.007
...
@@ -88,8 +90,6 @@ class Args:
...
@@ -88,8 +90,6 @@ class Args:
"""the minimum value of the importance sampling clipping"""
"""the minimum value of the importance sampling clipping"""
rho_clip_max
:
float
=
1.007
rho_clip_max
:
float
=
1.007
"""the maximum value of the importance sampling clipping"""
"""the maximum value of the importance sampling clipping"""
upgo
:
bool
=
False
"""whether to use UPGO for policy update"""
ppo_clip
:
bool
=
True
ppo_clip
:
bool
=
True
"""whether to use the PPO clipping to replace V-Trace surrogate clipping"""
"""whether to use the PPO clipping to replace V-Trace surrogate clipping"""
clip_coef
:
float
=
0.25
clip_coef
:
float
=
0.25
...
@@ -127,7 +127,6 @@ class Args:
...
@@ -127,7 +127,6 @@ class Args:
# runtime arguments to be filled in
# runtime arguments to be filled in
local_batch_size
:
int
=
0
local_batch_size
:
int
=
0
local_minibatch_size
:
int
=
0
local_minibatch_size
:
int
=
0
num_updates
:
int
=
0
world_size
:
int
=
0
world_size
:
int
=
0
local_rank
:
int
=
0
local_rank
:
int
=
0
num_envs
:
int
=
0
num_envs
:
int
=
0
...
@@ -218,34 +217,28 @@ def rollout(
...
@@ -218,34 +217,28 @@ def rollout(
avg_win_rates
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
@
jax
.
jit
@
jax
.
jit
def
apply_fn
(
def
get_logits
(
params
:
flax
.
core
.
FrozenDict
,
params
:
flax
.
core
.
FrozenDict
,
inputs
):
next_obs
,
logits
,
value
,
_valid
=
create_agent
(
args
)
.
apply
(
params
,
inputs
)[:
2
]
):
return
logits
logits
,
value
,
_valid
=
create_agent
(
args
)
.
apply
(
params
,
next_obs
)
return
logits
,
value
def
get_action
(
def
get_action
(
params
:
flax
.
core
.
FrozenDict
,
params
:
flax
.
core
.
FrozenDict
,
inputs
):
next_obs
,
return
get_logits
(
params
,
inputs
)
.
argmax
(
axis
=
1
)
):
return
apply_fn
(
params
,
next_obs
)[
0
]
.
argmax
(
axis
=
1
)
@
jax
.
jit
@
jax
.
jit
def
sample_action
(
def
sample_action
(
params
:
flax
.
core
.
FrozenDict
,
params
:
flax
.
core
.
FrozenDict
,
next_obs
,
next_obs
,
key
:
jax
.
random
.
PRNGKey
):
key
:
jax
.
random
.
PRNGKey
,
):
next_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
(
x
),
next_obs
)
next_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
(
x
),
next_obs
)
logits
=
apply_fn
(
params
,
next_obs
)[
0
]
logits
=
get_logits
(
params
,
next_obs
)
# sample action: Gumbel-softmax trick
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key
,
subkey
=
jax
.
random
.
split
(
key
)
key
,
subkey
=
jax
.
random
.
split
(
key
)
u
=
jax
.
random
.
uniform
(
subkey
,
shape
=
logits
.
shape
)
u
=
jax
.
random
.
uniform
(
subkey
,
shape
=
logits
.
shape
)
action
=
jnp
.
argmax
(
logits
-
jnp
.
log
(
-
jnp
.
log
(
u
)),
axis
=
1
)
action
=
jnp
.
argmax
(
logits
-
jnp
.
log
(
-
jnp
.
log
(
u
)),
axis
=
1
)
return
next_obs
,
action
,
logits
,
key
return
next_obs
,
action
,
logits
,
key
# put data in the last index
# put data in the last index
envs
.
async_reset
()
envs
.
async_reset
()
...
@@ -253,13 +246,13 @@ def rollout(
...
@@ -253,13 +246,13 @@ def rollout(
rollout_time
=
deque
(
maxlen
=
10
)
rollout_time
=
deque
(
maxlen
=
10
)
actor_policy_version
=
0
actor_policy_version
=
0
storage
=
[]
storage
=
[]
ai_player1
=
np
.
concatenate
([
main_player
=
np
.
concatenate
([
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
])
np
.
random
.
shuffle
(
ai_player1
)
np
.
random
.
shuffle
(
main_player
)
next_to_play
=
None
next_to_play
=
None
lear
n
=
np
.
ones
(
args
.
local_num_envs
,
dtype
=
np
.
bool_
)
mai
n
=
np
.
ones
(
args
.
local_num_envs
,
dtype
=
np
.
bool_
)
@
jax
.
jit
@
jax
.
jit
def
prepare_data
(
storage
:
List
[
Transition
])
->
Transition
:
def
prepare_data
(
storage
:
List
[
Transition
])
->
Transition
:
...
@@ -274,8 +267,7 @@ def rollout(
...
@@ -274,8 +267,7 @@ def rollout(
inference_time
=
0
inference_time
=
0
env_time
=
0
env_time
=
0
num_steps_with_bootstrap
=
(
num_steps_with_bootstrap
=
(
args
.
num_steps
+
int
(
len
(
storage
)
==
0
)
args
.
num_steps
+
int
(
len
(
storage
)
==
0
))
)
# num_steps + 1 to get the states for value bootstrapping.
params_queue_get_time_start
=
time
.
time
()
params_queue_get_time_start
=
time
.
time
()
if
args
.
concurrency
:
if
args
.
concurrency
:
if
update
!=
2
:
if
update
!=
2
:
...
@@ -295,11 +287,11 @@ def rollout(
...
@@ -295,11 +287,11 @@ def rollout(
_start
=
time
.
time
()
_start
=
time
.
time
()
next_obs
,
next_reward
,
next_done
,
info
=
envs
.
recv
()
next_obs
,
next_reward
,
next_done
,
info
=
envs
.
recv
()
next_reward
=
np
.
where
(
lear
n
,
next_reward
,
-
next_reward
)
next_reward
=
np
.
where
(
mai
n
,
next_reward
,
-
next_reward
)
env_time
+=
time
.
time
()
-
_start
env_time
+=
time
.
time
()
-
_start
to_play
=
next_to_play
to_play
=
next_to_play
next_to_play
=
info
[
"to_play"
]
next_to_play
=
info
[
"to_play"
]
learn
=
next_to_play
==
ai_player1
main
=
next_to_play
==
main_player
inference_time_start
=
time
.
time
()
inference_time_start
=
time
.
time
()
next_obs
,
action
,
logits
,
key
=
sample_action
(
params
,
next_obs
,
key
)
next_obs
,
action
,
logits
,
key
=
sample_action
(
params
,
next_obs
,
key
)
...
@@ -312,17 +304,17 @@ def rollout(
...
@@ -312,17 +304,17 @@ def rollout(
Transition
(
Transition
(
obs
=
next_obs
,
obs
=
next_obs
,
dones
=
next_done
,
dones
=
next_done
,
mains
=
main
,
rewards
=
next_reward
,
actions
=
action
,
actions
=
action
,
logitss
=
logits
,
logitss
=
logits
,
rewards
=
next_reward
,
learns
=
learn
,
)
)
)
)
for
idx
,
d
in
enumerate
(
next_done
):
for
idx
,
d
in
enumerate
(
next_done
):
if
not
d
:
if
not
d
:
continue
continue
pl
=
1
if
to_play
[
idx
]
==
ai_player1
[
idx
]
else
-
1
pl
=
1
if
to_play
[
idx
]
==
main_player
[
idx
]
else
-
1
episode_reward
=
info
[
'r'
][
idx
]
*
pl
episode_reward
=
info
[
'r'
][
idx
]
*
pl
win
=
1
if
episode_reward
>
0
else
0
win
=
1
if
episode_reward
>
0
else
0
avg_ep_returns
.
append
(
episode_reward
)
avg_ep_returns
.
append
(
episode_reward
)
...
@@ -488,7 +480,7 @@ if __name__ == "__main__":
...
@@ -488,7 +480,7 @@ if __name__ == "__main__":
learning_rate
=
linear_schedule
if
args
.
anneal_lr
else
args
.
learning_rate
,
eps
=
1e-5
learning_rate
=
linear_schedule
if
args
.
anneal_lr
else
args
.
learning_rate
,
eps
=
1e-5
),
),
),
),
every_k_schedule
=
args
.
gradient_accumulation_steps
,
every_k_schedule
=
1
,
)
)
agent_state
=
TrainState
.
create
(
agent_state
=
TrainState
.
create
(
apply_fn
=
None
,
apply_fn
=
None
,
...
@@ -505,13 +497,15 @@ if __name__ == "__main__":
...
@@ -505,13 +497,15 @@ if __name__ == "__main__":
params
:
flax
.
core
.
FrozenDict
,
params
:
flax
.
core
.
FrozenDict
,
obs
:
np
.
ndarray
,
obs
:
np
.
ndarray
,
):
):
logits
,
value
,
valid
=
create_agent
(
args
)
.
apply
(
params
,
obs
)
logits
,
value
=
create_agent
(
args
)
.
apply
(
params
,
obs
)
return
logits
,
value
.
squeeze
(
-
1
)
,
valid
return
logits
,
value
.
squeeze
(
-
1
)
def
impala_loss
(
params
,
obs
,
actions
,
logitss
,
rewards
,
dones
,
learns
):
def
impala_loss
(
params
,
obs
,
actions
,
logitss
,
rewards
,
dones
,
learns
):
# (num_steps + 1, local_num_envs // n_mb))
# (num_steps + 1, local_num_envs // n_mb))
num_steps
=
actions
.
shape
[
0
]
-
1
discounts
=
(
1.0
-
dones
)
*
args
.
gamma
discounts
=
(
1.0
-
dones
)
*
args
.
gamma
policy_logits
,
newvalue
,
valid
=
jax
.
vmap
(
policy_logits
,
newvalue
=
jax
.
vmap
(
get_logits_and_value
,
in_axes
=
(
None
,
0
))(
params
,
obs
)
get_logits_and_value
,
in_axes
=
(
None
,
0
))(
params
,
obs
)
newvalue
=
jnp
.
where
(
learns
,
newvalue
,
-
newvalue
)
newvalue
=
jnp
.
where
(
learns
,
newvalue
,
-
newvalue
)
...
@@ -527,19 +521,14 @@ if __name__ == "__main__":
...
@@ -527,19 +521,14 @@ if __name__ == "__main__":
discounts
=
discounts
[
1
:]
discounts
=
discounts
[
1
:]
mask
=
mask
[:
-
1
]
mask
=
mask
[:
-
1
]
rhos
=
rlax
.
categorical_importance_sampling_ratios
(
rhos
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
policy_logits
,
logitss
,
actions
)
policy_logits
),
distrax
.
Categorical
(
logitss
),
actions
)
vtrace_fn
=
partial
(
vtrace_fn
=
partial
(
vtrace
,
c_clip_min
=
args
.
c_clip_min
,
c_clip_max
=
args
.
c_clip_max
,
rho_clip_min
=
args
.
rho_clip_min
,
rho_clip_max
=
args
.
rho_clip_max
)
vtrace
,
c_clip_min
=
args
.
c_clip_min
,
c_clip_max
=
args
.
c_clip_max
,
rho_clip_min
=
args
.
rho_clip_min
,
rho_clip_max
=
args
.
rho_clip_max
)
vtrace_returns
=
jax
.
vmap
(
vtrace_returns
=
jax
.
vmap
(
vtrace_fn
,
in_axes
=
1
,
out_axes
=
1
)(
vtrace_fn
,
in_axes
=
1
,
out_axes
=
1
)(
v_tm1
,
v_t
,
rewards
,
discounts
,
rhos
)
v_tm1
,
v_t
,
rewards
,
discounts
,
rhos
)
jax
.
debug
.
print
(
"R {}"
,
jnp
.
where
(
dones
[
1
:
-
1
,
:
2
],
rewards
[:
-
1
,
:
2
],
0
)
.
T
)
jax
.
debug
.
print
(
"E {}"
,
jnp
.
where
(
dones
[
1
:
-
1
,
:
2
],
vtrace_returns
.
errors
[:
-
1
,
:
2
]
*
100
,
vtrace_returns
.
errors
[:
-
1
,
:
2
])
.
T
)
jax
.
debug
.
print
(
"V {}"
,
v_tm1
[:
-
1
,
:
2
]
.
T
)
T
=
v_tm1
.
shape
[
0
]
if
args
.
upgo
:
if
args
.
upgo
:
advs
=
jax
.
vmap
(
upgo_return
,
in_axes
=
1
,
out_axes
=
1
)(
advs
=
jax
.
vmap
(
upgo_return
,
in_axes
=
1
,
out_axes
=
1
)(
rewards
,
v_t
,
discounts
)
-
v_tm1
rewards
,
v_t
,
discounts
)
-
v_tm1
...
@@ -548,13 +537,13 @@ if __name__ == "__main__":
...
@@ -548,13 +537,13 @@ if __name__ == "__main__":
if
args
.
ppo_clip
:
if
args
.
ppo_clip
:
pg_loss
=
jax
.
vmap
(
pg_loss
=
jax
.
vmap
(
partial
(
clipped_surrogate_pg_loss
,
epsilon
=
args
.
clip_coef
),
in_axes
=
1
)(
partial
(
clipped_surrogate_pg_loss
,
epsilon
=
args
.
clip_coef
),
in_axes
=
1
)(
rhos
,
advs
,
mask
)
*
T
rhos
,
advs
,
mask
)
*
num_steps
pg_loss
=
jnp
.
sum
(
pg_loss
)
pg_loss
=
jnp
.
sum
(
pg_loss
)
else
:
else
:
pg_advs
=
jnp
.
minimum
(
args
.
rho_clip_max
,
rhos
)
*
advs
pg_advs
=
jnp
.
minimum
(
args
.
rho_clip_max
,
rhos
)
*
advs
pg_loss
=
jax
.
vmap
(
pg_loss
=
jax
.
vmap
(
rlax
.
policy_gradient_loss
,
in_axes
=
1
)(
rlax
.
policy_gradient_loss
,
in_axes
=
1
)(
policy_logits
,
actions
,
pg_advs
,
mask
)
*
T
policy_logits
,
actions
,
pg_advs
,
mask
)
*
num_steps
pg_loss
=
jnp
.
sum
(
pg_loss
)
pg_loss
=
jnp
.
sum
(
pg_loss
)
baseline_loss
=
0.5
*
jnp
.
sum
(
jnp
.
square
(
vtrace_returns
.
errors
)
*
mask
)
baseline_loss
=
0.5
*
jnp
.
sum
(
jnp
.
square
(
vtrace_returns
.
errors
)
*
mask
)
...
...
scripts/jax/impala2.py
0 → 100644
View file @
34f86ae4
import
os
import
queue
import
random
import
threading
import
time
from
datetime
import
datetime
,
timedelta
,
timezone
from
collections
import
deque
from
dataclasses
import
dataclass
,
field
from
types
import
SimpleNamespace
from
typing
import
List
,
NamedTuple
,
Optional
from
functools
import
partial
import
ygoenv
import
flax
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
import
optax
import
rlax
import
distrax
import
tyro
from
flax.training.train_state
import
TrainState
from
rich.pretty
import
pprint
from
tensorboardX
import
SummaryWriter
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
from
ygoai.rl.jax
import
upgo_return
,
vtrace
,
clipped_surrogate_pg_loss
os
.
environ
[
"XLA_FLAGS"
]
=
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@
dataclass
class
Args
:
exp_name
:
str
=
os
.
path
.
basename
(
__file__
)
.
rstrip
(
".py"
)
"""the name of this experiment"""
seed
:
int
=
1
"""seed of the experiment"""
log_frequency
:
int
=
10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval
:
int
=
400
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments
env_id
:
str
=
"YGOPro-v0"
"""the id of the environment"""
deck
:
str
=
"../assets/deck"
"""the deck file to use"""
deck1
:
Optional
[
str
]
=
None
"""the deck file for the first player"""
deck2
:
Optional
[
str
]
=
None
"""the deck file for the second player"""
code_list_file
:
str
=
"code_list.txt"
"""the code list file for card embeddings"""
embedding_file
:
Optional
[
str
]
=
None
"""the embedding file for card embeddings"""
max_options
:
int
=
24
"""the maximum number of options"""
n_history_actions
:
int
=
32
"""the number of history actions to use"""
total_timesteps
:
int
=
5000000000
"""total timesteps of the experiments"""
learning_rate
:
float
=
1e-4
"""the learning rate of the optimizer"""
local_num_envs
:
int
=
128
"""the number of parallel game environments"""
local_env_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for environment"""
num_actor_threads
:
int
=
2
"""the number of actor threads to use"""
num_steps
:
int
=
32
"""the number of steps to run in each environment per policy rollout"""
collect_length
:
Optional
[
int
]
=
None
"""the number of steps to compute the advantages"""
anneal_lr
:
bool
=
False
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
1.0
"""the discount factor gamma"""
num_minibatches
:
int
=
4
"""the number of mini-batches"""
update_epochs
:
int
=
2
"""the K epochs to update the policy"""
c_clip_min
:
float
=
0.001
"""the minimum value of the importance sampling clipping"""
c_clip_max
:
float
=
1.007
"""the maximum value of the importance sampling clipping"""
rho_clip_min
:
float
=
0.001
"""the minimum value of the importance sampling clipping"""
rho_clip_max
:
float
=
1.007
"""the maximum value of the importance sampling clipping"""
upgo
:
bool
=
False
"""whether to use UPGO for policy update"""
ppo_clip
:
bool
=
True
"""whether to use the PPO clipping to replace V-Trace surrogate clipping"""
clip_coef
:
float
=
0.25
"""the PPO surrogate clipping coefficient"""
ent_coef
:
float
=
0.01
"""coefficient of the entropy"""
vf_coef
:
float
=
0.5
"""coefficient of the value function"""
max_grad_norm
:
float
=
1.0
"""the maximum norm for the gradient clipping"""
num_layers
:
int
=
2
"""the number of layers for the agent"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
rnn_channels
:
int
=
512
"""the number of channels for the RNN in the agent"""
actor_device_ids
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
0
,
1
])
"""the device ids that actor workers will use"""
learner_device_ids
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
2
,
3
])
"""the device ids that learner workers will use"""
distributed
:
bool
=
False
"""whether to use `jax.distirbuted`"""
concurrency
:
bool
=
True
"""whether to run the actor and learner concurrently"""
bfloat16
:
bool
=
True
"""whether to use bfloat16 for the agent"""
thread_affinity
:
bool
=
False
"""whether to use thread affinity for the environment"""
local_eval_episodes
:
int
=
32
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
50
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size
:
int
=
0
local_minibatch_size
:
int
=
0
world_size
:
int
=
0
local_rank
:
int
=
0
num_envs
:
int
=
0
batch_size
:
int
=
0
minibatch_size
:
int
=
0
num_updates
:
int
=
0
global_learner_decices
:
Optional
[
List
[
str
]]
=
None
actor_devices
:
Optional
[
List
[
str
]]
=
None
learner_devices
:
Optional
[
List
[
str
]]
=
None
num_embeddings
:
Optional
[
int
]
=
None
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
):
if
not
args
.
thread_affinity
:
thread_affinity_offset
=
-
1
if
thread_affinity_offset
>=
0
:
print
(
"Binding to thread offset"
,
thread_affinity_offset
)
envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
num_envs
,
num_threads
=
num_threads
,
thread_affinity_offset
=
thread_affinity_offset
,
seed
=
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
async_reset
=
False
,
play_mode
=
mode
,
)
envs
.
num_envs
=
num_envs
return
envs
class
Transition
(
NamedTuple
):
obs
:
list
dones
:
list
actions
:
list
logits
:
list
rewards
:
list
mains
:
list
next_dones
:
list
def
create_agent
(
args
,
multi_step
=
False
):
return
PPOLSTMAgent
(
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
embedding_shape
=
args
.
num_embeddings
,
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
lstm_channels
=
args
.
rnn_channels
,
multi_step
=
multi_step
,
)
def
init_rnn_state
(
num_envs
,
rnn_channels
):
return
(
np
.
zeros
((
num_envs
,
rnn_channels
)),
np
.
zeros
((
num_envs
,
rnn_channels
)),
)
def
rollout
(
key
:
jax
.
random
.
PRNGKey
,
args
:
Args
,
rollout_queue
,
params_queue
:
queue
.
Queue
,
stats_queue
,
writer
,
learner_devices
,
device_thread_id
,
):
envs
=
make_env
(
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
local_num_envs
,
args
.
local_env_threads
,
thread_affinity_offset
=
device_thread_id
*
args
.
local_env_threads
,
)
envs
=
RecordEpisodeStatistics
(
envs
)
eval_envs
=
make_env
(
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
'bot'
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
n_actors
=
args
.
num_actor_threads
*
len_actor_device_ids
global_step
=
0
start_time
=
time
.
time
()
warmup_step
=
0
other_time
=
0
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
@
jax
.
jit
def
get_logits
(
params
:
flax
.
core
.
FrozenDict
,
inputs
,
done
):
rstate
,
logits
=
create_agent
(
args
)
.
apply
(
params
,
inputs
)[:
2
]
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
return
rstate
,
logits
@
jax
.
jit
def
get_action
(
params
:
flax
.
core
.
FrozenDict
,
inputs
):
batch_size
=
jax
.
tree
.
leaves
(
inputs
)[
0
]
.
shape
[
0
]
done
=
jnp
.
zeros
(
batch_size
,
dtype
=
jnp
.
bool_
)
rstate
,
logits
=
get_logits
(
params
,
inputs
,
done
)
return
rstate
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
sample_action
(
params
:
flax
.
core
.
FrozenDict
,
next_obs
,
rstate1
,
rstate2
,
main
,
done
,
key
):
next_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
(
x
),
next_obs
)
main
=
jnp
.
array
(
main
)
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
logits
=
get_logits
(
params
,
(
rstate
,
next_obs
),
done
)
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
action
,
key
=
categorical_sample
(
logits
,
key
)
return
next_obs
,
rstate1
,
rstate2
,
action
,
logits
,
key
# put data in the last index
params_queue_get_time
=
deque
(
maxlen
=
10
)
rollout_time
=
deque
(
maxlen
=
10
)
actor_policy_version
=
0
next_obs
,
info
=
envs
.
reset
()
next_to_play
=
info
[
"to_play"
]
next_done
=
np
.
zeros
(
args
.
local_num_envs
,
dtype
=
np
.
bool_
)
next_rstate1
=
next_rstate2
=
init_rnn_state
(
args
.
local_num_envs
,
args
.
rnn_channels
)
eval_rstate
=
init_rnn_state
(
args
.
local_eval_episodes
,
args
.
rnn_channels
)
main_player
=
np
.
concatenate
([
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
np
.
random
.
shuffle
(
main_player
)
start_step
=
0
storage
=
[]
@
jax
.
jit
def
prepare_data
(
storage
:
List
[
Transition
])
->
Transition
:
return
jax
.
tree
.
map
(
lambda
*
xs
:
jnp
.
split
(
jnp
.
stack
(
xs
),
len
(
learner_devices
),
axis
=
1
),
*
storage
)
for
update
in
range
(
1
,
args
.
num_updates
+
2
):
if
update
==
10
:
start_time
=
time
.
time
()
warmup_step
=
global_step
update_time_start
=
time
.
time
()
inference_time
=
0
env_time
=
0
params_queue_get_time_start
=
time
.
time
()
if
args
.
concurrency
:
if
update
!=
2
:
params
=
params_queue
.
get
()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version
+=
1
else
:
params
=
params_queue
.
get
()
actor_policy_version
+=
1
params_queue_get_time
.
append
(
time
.
time
()
-
params_queue_get_time_start
)
rollout_time_start
=
time
.
time
()
init_rstate1
,
init_rstate2
=
jax
.
tree
.
map
(
lambda
x
:
x
.
copy
(),
(
next_rstate1
,
next_rstate2
))
for
_
in
range
(
start_step
,
args
.
collect_length
):
global_step
+=
args
.
local_num_envs
*
n_actors
*
args
.
world_size
cached_next_obs
=
next_obs
cached_next_done
=
next_done
main
=
next_to_play
==
main_player
inference_time_start
=
time
.
time
()
cached_next_obs
,
next_rstate1
,
next_rstate2
,
action
,
logits
,
key
=
sample_action
(
params
,
cached_next_obs
,
next_rstate1
,
next_rstate2
,
main
,
cached_next_done
,
key
)
cpu_action
=
np
.
array
(
action
)
inference_time
+=
time
.
time
()
-
inference_time_start
_start
=
time
.
time
()
next_obs
,
next_reward
,
next_done
,
info
=
envs
.
step
(
cpu_action
)
next_to_play
=
info
[
"to_play"
]
env_time
+=
time
.
time
()
-
_start
storage
.
append
(
Transition
(
obs
=
cached_next_obs
,
dones
=
cached_next_done
,
mains
=
main
,
actions
=
action
,
logits
=
logits
,
rewards
=
next_reward
,
next_dones
=
next_done
,
)
)
for
idx
,
d
in
enumerate
(
next_done
):
if
not
d
:
continue
cur_main
=
main
[
idx
]
for
j
in
reversed
(
range
(
len
(
storage
)
-
1
)):
t
=
storage
[
j
]
if
t
.
next_dones
[
idx
]:
# For OTK where player may not switch
break
if
t
.
mains
[
idx
]
!=
cur_main
:
t
.
next_dones
[
idx
]
=
True
t
.
rewards
[
idx
]
=
-
next_reward
[
idx
]
break
episode_reward
=
info
[
'r'
][
idx
]
*
(
1
if
cur_main
else
-
1
)
win
=
1
if
episode_reward
>
0
else
0
avg_ep_returns
.
append
(
episode_reward
)
avg_win_rates
.
append
(
win
)
rollout_time
.
append
(
time
.
time
()
-
rollout_time_start
)
start_step
=
args
.
collect_length
-
args
.
num_steps
partitioned_storage
=
prepare_data
(
storage
)
storage
=
storage
[
args
.
num_steps
:]
sharded_storage
=
[]
for
x
in
partitioned_storage
:
if
isinstance
(
x
,
dict
):
x
=
{
k
:
jax
.
device_put_sharded
(
v
,
devices
=
learner_devices
)
for
k
,
v
in
x
.
items
()
}
else
:
x
=
jax
.
device_put_sharded
(
x
,
devices
=
learner_devices
)
sharded_storage
.
append
(
x
)
sharded_storage
=
Transition
(
*
sharded_storage
)
next_main
=
main_player
==
next_to_play
next_rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
next_rstate1
,
next_rstate2
)
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
(
init_rstate1
,
init_rstate2
,
(
next_rstate
,
next_obs
),
next_main
))
learn_opponent
=
False
payload
=
(
global_step
,
update
,
sharded_storage
,
*
sharded_data
,
np
.
mean
(
params_queue_get_time
),
learn_opponent
,
)
rollout_queue
.
put
(
payload
)
if
update
%
args
.
log_frequency
==
0
:
avg_episodic_return
=
np
.
mean
(
avg_ep_returns
)
avg_episodic_length
=
np
.
mean
(
envs
.
returned_episode_lengths
)
SPS
=
int
((
global_step
-
warmup_step
)
/
(
time
.
time
()
-
start_time
-
other_time
))
SPS_update
=
int
(
args
.
batch_size
/
(
time
.
time
()
-
update_time_start
))
if
device_thread_id
==
0
:
print
(
f
"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}, rollout_time={rollout_time[-1]:.2f}"
)
time_now
=
datetime
.
now
(
timezone
(
timedelta
(
hours
=
8
)))
.
strftime
(
"
%
H:
%
M:
%
S"
)
print
(
f
"{time_now} SPS: {SPS}, update: {SPS_update}"
)
writer
.
add_scalar
(
"stats/rollout_time"
,
np
.
mean
(
rollout_time
),
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_return"
,
avg_episodic_return
,
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_length"
,
avg_episodic_length
,
global_step
)
writer
.
add_scalar
(
"stats/params_queue_get_time"
,
np
.
mean
(
params_queue_get_time
),
global_step
)
writer
.
add_scalar
(
"stats/inference_time"
,
inference_time
,
global_step
)
writer
.
add_scalar
(
"stats/env_time"
,
env_time
,
global_step
)
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
writer
.
add_scalar
(
"charts/SPS_update"
,
SPS_update
,
global_step
)
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
# Eval with rule-based policy
_start
=
time
.
time
()
eval_return
=
evaluate
(
eval_envs
,
get_action
,
params
,
eval_rstate
)[
0
]
if
device_thread_id
!=
0
:
stats_queue
.
put
(
eval_return
)
else
:
eval_stats
=
[]
eval_stats
.
append
(
eval_return
)
for
_
in
range
(
1
,
n_actors
):
eval_stats
.
append
(
stats_queue
.
get
())
eval_stats
=
np
.
mean
(
eval_stats
)
writer
.
add_scalar
(
"charts/eval_return"
,
eval_stats
,
global_step
)
if
device_thread_id
==
0
:
eval_time
=
time
.
time
()
-
_start
print
(
f
"eval_time={eval_time:.4f}, eval_ep_return={eval_stats:.4f}"
)
other_time
+=
eval_time
if
__name__
==
"__main__"
:
args
=
tyro
.
cli
(
Args
)
args
.
local_batch_size
=
int
(
args
.
local_num_envs
*
args
.
num_steps
*
args
.
num_actor_threads
*
len
(
args
.
actor_device_ids
))
args
.
local_minibatch_size
=
int
(
args
.
local_batch_size
//
args
.
num_minibatches
)
assert
(
args
.
local_num_envs
%
len
(
args
.
learner_device_ids
)
==
0
),
"local_num_envs must be divisible by len(learner_device_ids)"
assert
(
int
(
args
.
local_num_envs
/
len
(
args
.
learner_device_ids
))
*
args
.
num_actor_threads
%
args
.
num_minibatches
==
0
),
"int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if
args
.
distributed
:
jax
.
distributed
.
initialize
(
local_device_ids
=
range
(
len
(
args
.
learner_device_ids
)
+
len
(
args
.
actor_device_ids
)),
)
print
(
list
(
range
(
len
(
args
.
learner_device_ids
)
+
len
(
args
.
actor_device_ids
))))
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
args
.
world_size
=
jax
.
process_count
()
args
.
local_rank
=
jax
.
process_index
()
args
.
num_envs
=
args
.
local_num_envs
*
args
.
world_size
*
args
.
num_actor_threads
*
len
(
args
.
actor_device_ids
)
args
.
batch_size
=
args
.
local_batch_size
*
args
.
world_size
args
.
minibatch_size
=
args
.
local_minibatch_size
*
args
.
world_size
args
.
num_updates
=
args
.
total_timesteps
//
(
args
.
local_batch_size
*
args
.
world_size
)
args
.
local_env_threads
=
args
.
local_env_threads
or
args
.
local_num_envs
args
.
collect_length
=
args
.
collect_length
or
args
.
num_steps
assert
args
.
collect_length
>=
args
.
num_steps
,
"collect_length must be greater than or equal to num_steps"
local_devices
=
jax
.
local_devices
()
global_devices
=
jax
.
devices
()
learner_devices
=
[
local_devices
[
d_id
]
for
d_id
in
args
.
learner_device_ids
]
actor_devices
=
[
local_devices
[
d_id
]
for
d_id
in
args
.
actor_device_ids
]
global_learner_decices
=
[
global_devices
[
d_id
+
process_index
*
len
(
local_devices
)]
for
process_index
in
range
(
args
.
world_size
)
for
d_id
in
args
.
learner_device_ids
]
print
(
"global_learner_decices"
,
global_learner_decices
)
args
.
global_learner_decices
=
[
str
(
item
)
for
item
in
global_learner_decices
]
args
.
actor_devices
=
[
str
(
item
)
for
item
in
actor_devices
]
args
.
learner_devices
=
[
str
(
item
)
for
item
in
learner_devices
]
pprint
(
args
)
timestamp
=
int
(
time
.
time
())
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer
=
SummaryWriter
(
f
"runs/{run_name}"
)
writer
.
add_text
(
"hyperparameters"
,
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
)
# seeding
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
learner_keys
=
jax
.
device_put_replicated
(
key
,
learner_devices
)
deck
=
init_ygopro
(
args
.
env_id
,
"english"
,
args
.
deck
,
args
.
code_list_file
)
args
.
deck1
=
args
.
deck1
or
deck
args
.
deck2
=
args
.
deck2
or
deck
# env setup
envs
=
make_env
(
args
,
args
.
seed
,
8
,
1
)
obs_space
=
envs
.
observation_space
action_shape
=
envs
.
action_space
.
shape
print
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
envs
.
close
()
del
envs
def
linear_schedule
(
count
):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac
=
1.0
-
(
count
//
(
args
.
num_minibatches
*
args
.
update_epochs
))
/
args
.
num_updates
return
args
.
learning_rate
*
frac
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
agent
=
create_agent
(
args
)
params
=
agent
.
init
(
agent_key
,
(
rstate
,
sample_obs
))
tx
=
optax
.
MultiSteps
(
optax
.
chain
(
optax
.
clip_by_global_norm
(
args
.
max_grad_norm
),
optax
.
inject_hyperparams
(
optax
.
adam
)(
learning_rate
=
linear_schedule
if
args
.
anneal_lr
else
args
.
learning_rate
,
eps
=
1e-5
),
),
every_k_schedule
=
1
,
)
agent_state
=
TrainState
.
create
(
apply_fn
=
None
,
params
=
params
,
tx
=
tx
,
)
if
args
.
checkpoint
:
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
agent_state
=
agent_state
.
replace
(
params
=
params
)
print
(
f
"loaded checkpoint from {args.checkpoint}"
)
agent_state
=
flax
.
jax_utils
.
replicate
(
agent_state
,
devices
=
learner_devices
)
# print(agent.tabulate(agent_key, sample_obs))
@
jax
.
jit
def
get_logits_and_value
(
params
:
flax
.
core
.
FrozenDict
,
inputs
,
):
rstate
,
logits
,
value
,
valid
=
create_agent
(
args
,
multi_step
=
True
)
.
apply
(
params
,
inputs
)
return
logits
,
value
.
squeeze
(
-
1
)
def
ppo_loss
(
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
switch
,
actions
,
logits
,
rewards
,
mask
,
next_value
):
# (num_steps * local_num_envs // n_mb))
num_envs
=
next_value
.
shape
[
0
]
num_steps
=
dones
.
shape
[
0
]
//
num_envs
mask
=
mask
&
(
~
dones
)
n_valids
=
jnp
.
sum
(
mask
)
real_dones
=
dones
|
next_dones
inputs
=
(
rstate1
,
rstate2
,
obs
,
real_dones
,
switch
)
new_logits
,
new_values
=
get_logits_and_value
(
params
,
inputs
)
new_logits
,
v_tm1
,
logits
,
actions
,
rewards
,
next_dones
,
switch
,
mask
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)
+
x
.
shape
[
1
:]),
(
new_logits
,
new_values
,
logits
,
actions
,
rewards
,
next_dones
,
switch
,
mask
),
)
v_t
=
jnp
.
concatenate
([
v_tm1
[
1
:],
next_value
[
None
,
:]],
axis
=
0
)
discounts
=
(
1.0
-
next_dones
)
*
args
.
gamma
ratio
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
logratio
=
jnp
.
log
(
ratio
)
approx_kl
=
(((
ratio
-
1
)
-
logratio
)
*
mask
)
.
sum
()
/
n_valids
# TODO: use switch to calculate the correct value
vtrace_fn
=
partial
(
vtrace
,
c_clip_min
=
args
.
c_clip_min
,
c_clip_max
=
args
.
c_clip_max
,
rho_clip_min
=
args
.
rho_clip_min
,
rho_clip_max
=
args
.
rho_clip_max
)
vtrace_returns
=
jax
.
vmap
(
vtrace_fn
,
in_axes
=
1
,
out_axes
=
1
)(
v_tm1
,
v_t
,
rewards
,
discounts
,
ratio
)
if
args
.
upgo
:
advs
=
jax
.
vmap
(
upgo_return
,
in_axes
=
1
,
out_axes
=
1
)(
rewards
,
v_t
,
discounts
)
-
v_tm1
else
:
advs
=
vtrace_returns
.
q_estimate
-
v_tm1
if
args
.
ppo_clip
:
pg_loss
=
jax
.
vmap
(
partial
(
clipped_surrogate_pg_loss
,
epsilon
=
args
.
clip_coef
),
in_axes
=
1
)(
ratio
,
advs
,
mask
)
*
num_steps
pg_loss
=
jnp
.
sum
(
pg_loss
)
else
:
pg_advs
=
jnp
.
minimum
(
args
.
rho_clip_max
,
ratio
)
*
advs
pg_loss
=
jax
.
vmap
(
rlax
.
policy_gradient_loss
,
in_axes
=
1
)(
new_logits
,
actions
,
pg_advs
,
mask
)
*
num_steps
pg_loss
=
jnp
.
sum
(
pg_loss
)
v_loss
=
0.5
*
(
vtrace_returns
.
errors
**
2
)
v_loss
=
jnp
.
sum
(
v_loss
*
mask
)
entropy_loss
=
distrax
.
Softmax
(
new_logits
)
.
entropy
()
entropy_loss
=
jnp
.
sum
(
entropy_loss
*
mask
)
pg_loss
=
pg_loss
/
n_valids
v_loss
=
v_loss
/
n_valids
entropy_loss
=
entropy_loss
/
n_valids
loss
=
pg_loss
-
args
.
ent_coef
*
entropy_loss
+
v_loss
*
args
.
vf_coef
return
loss
,
(
pg_loss
,
v_loss
,
entropy_loss
,
jax
.
lax
.
stop_gradient
(
approx_kl
))
def
single_device_update
(
agent_state
:
TrainState
,
sharded_storages
:
List
,
sharded_init_rstate1
:
List
,
sharded_init_rstate2
:
List
,
sharded_next_inputs
:
List
,
sharded_next_main
:
List
,
key
:
jax
.
random
.
PRNGKey
,
learn_opponent
:
bool
=
False
,
):
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
next_inputs
,
init_rstate1
,
init_rstate2
=
[
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
for
x
in
[
sharded_next_inputs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
]
next_main
,
=
[
jnp
.
concatenate
(
x
)
for
x
in
[
sharded_next_main
]
]
# reorder storage of individual players
# main first, opponent second
num_steps
,
num_envs
=
storage
.
rewards
.
shape
T
=
jnp
.
arange
(
num_steps
,
dtype
=
jnp
.
int32
)
B
=
jnp
.
arange
(
num_envs
,
dtype
=
jnp
.
int32
)
mains
=
storage
.
mains
.
astype
(
jnp
.
int32
)
indices
=
jnp
.
argsort
(
T
[:,
None
]
-
mains
*
num_steps
,
axis
=
0
)
switch_steps
=
jnp
.
sum
(
mains
,
axis
=
0
)
switch
=
T
[:,
None
]
==
(
switch_steps
[
None
,
:]
-
1
)
storage
=
jax
.
tree
.
map
(
lambda
x
:
x
[
indices
,
B
[
None
,
:]],
storage
)
ppo_loss_grad_fn
=
jax
.
value_and_grad
(
ppo_loss
,
has_aux
=
True
)
def
update_epoch
(
carry
,
_
):
agent_state
,
key
=
carry
key
,
subkey
=
jax
.
random
.
split
(
key
)
next_value
=
create_agent
(
args
)
.
apply
(
agent_state
.
params
,
next_inputs
)[
2
]
.
squeeze
(
-
1
)
# TODO: check if this is correct
sign
=
jnp
.
where
(
switch_steps
<=
num_steps
,
1.0
,
-
1.0
)
next_value
=
jnp
.
where
(
next_main
,
-
sign
*
next_value
,
sign
*
next_value
)
def
convert_data
(
x
:
jnp
.
ndarray
,
num_steps
):
if
args
.
update_epochs
>
1
:
x
=
jax
.
random
.
permutation
(
subkey
,
x
,
axis
=
1
if
num_steps
>
1
else
0
)
N
=
args
.
num_minibatches
if
num_steps
>
1
:
x
=
jnp
.
reshape
(
x
,
(
num_steps
,
N
,
-
1
)
+
x
.
shape
[
2
:])
x
=
x
.
transpose
(
1
,
0
,
*
range
(
2
,
x
.
ndim
))
x
=
x
.
reshape
(
N
,
-
1
,
*
x
.
shape
[
3
:])
else
:
x
=
jnp
.
reshape
(
x
,
(
N
,
-
1
)
+
x
.
shape
[
1
:])
return
x
shuffled_init_rstate1
,
shuffled_init_rstate2
,
shuffled_next_value
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
1
),
(
init_rstate1
,
init_rstate2
,
next_value
))
shuffled_storage
,
shuffled_switch
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
num_steps
),
(
storage
,
switch
))
shuffled_mask
=
jnp
.
ones_like
(
shuffled_storage
.
mains
)
def
update_minibatch
(
agent_state
,
minibatch
):
(
loss
,
(
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)),
grads
=
ppo_loss_grad_fn
(
agent_state
.
params
,
*
minibatch
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
return
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_minibatch
,
agent_state
,
(
shuffled_init_rstate1
,
shuffled_init_rstate2
,
shuffled_storage
.
obs
,
shuffled_storage
.
dones
,
shuffled_storage
.
next_dones
,
shuffled_switch
,
shuffled_storage
.
actions
,
shuffled_storage
.
logits
,
shuffled_storage
.
rewards
,
shuffled_mask
,
shuffled_next_value
,
),
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_epoch
,
(
agent_state
,
key
),
(),
length
=
args
.
update_epochs
)
loss
=
jax
.
lax
.
pmean
(
loss
,
axis_name
=
"local_devices"
)
.
mean
()
pg_loss
=
jax
.
lax
.
pmean
(
pg_loss
,
axis_name
=
"local_devices"
)
.
mean
()
v_loss
=
jax
.
lax
.
pmean
(
v_loss
,
axis_name
=
"local_devices"
)
.
mean
()
entropy_loss
=
jax
.
lax
.
pmean
(
entropy_loss
,
axis_name
=
"local_devices"
)
.
mean
()
approx_kl
=
jax
.
lax
.
pmean
(
approx_kl
,
axis_name
=
"local_devices"
)
.
mean
()
return
agent_state
,
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
,
key
multi_device_update
=
jax
.
pmap
(
single_device_update
,
axis_name
=
"local_devices"
,
devices
=
global_learner_decices
,
static_broadcasted_argnums
=
(
7
,),
)
params_queues
=
[]
rollout_queues
=
[]
stats_queues
=
queue
.
Queue
()
dummy_writer
=
SimpleNamespace
()
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
device_params
=
jax
.
device_put
(
unreplicated_params
,
local_devices
[
d_id
])
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
rollout_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
params_queues
[
-
1
]
.
put
(
device_params
)
threading
.
Thread
(
target
=
rollout
,
args
=
(
jax
.
device_put
(
key
,
local_devices
[
d_id
]),
args
,
rollout_queues
[
-
1
],
params_queues
[
-
1
],
stats_queues
,
writer
if
d_idx
==
0
and
thread_id
==
0
else
dummy_writer
,
learner_devices
,
d_idx
*
args
.
num_actor_threads
+
thread_id
,
),
)
.
start
()
rollout_queue_get_time
=
deque
(
maxlen
=
10
)
data_transfer_time
=
deque
(
maxlen
=
10
)
learner_policy_version
=
0
while
True
:
learner_policy_version
+=
1
rollout_queue_get_time_start
=
time
.
time
()
sharded_data_list
=
[]
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
for
thread_id
in
range
(
args
.
num_actor_threads
):
(
global_step
,
update
,
*
sharded_data
,
avg_params_queue_get_time
,
learn_opponent
,
)
=
rollout_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
get
()
sharded_data_list
.
append
(
sharded_data
)
rollout_queue_get_time
.
append
(
time
.
time
()
-
rollout_queue_get_time_start
)
training_time_start
=
time
.
time
()
(
agent_state
,
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
,
learner_keys
)
=
multi_device_update
(
agent_state
,
*
list
(
zip
(
*
sharded_data_list
)),
learner_keys
,
learn_opponent
,
)
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
device_params
=
jax
.
device_put
(
unreplicated_params
,
local_devices
[
d_id
])
device_params
[
"params"
][
"Encoder_0"
][
'Embed_0'
][
"embedding"
]
.
block_until_ready
()
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
put
(
device_params
)
loss
=
loss
[
-
1
]
.
item
()
if
np
.
isnan
(
loss
)
or
np
.
isinf
(
loss
):
raise
ValueError
(
f
"loss is {loss}"
)
# record rewards for plotting purposes
if
learner_policy_version
%
args
.
log_frequency
==
0
:
writer
.
add_scalar
(
"stats/rollout_queue_get_time"
,
np
.
mean
(
rollout_queue_get_time
),
global_step
)
writer
.
add_scalar
(
"stats/rollout_params_queue_get_time_diff"
,
np
.
mean
(
rollout_queue_get_time
)
-
avg_params_queue_get_time
,
global_step
,
)
writer
.
add_scalar
(
"stats/training_time"
,
time
.
time
()
-
training_time_start
,
global_step
)
writer
.
add_scalar
(
"stats/rollout_queue_size"
,
rollout_queues
[
-
1
]
.
qsize
(),
global_step
)
writer
.
add_scalar
(
"stats/params_queue_size"
,
params_queues
[
-
1
]
.
qsize
(),
global_step
)
print
(
global_step
,
f
"actor_update={update}, train_time={time.time() - training_time_start:.2f}"
,
)
writer
.
add_scalar
(
"charts/learning_rate"
,
agent_state
.
opt_state
[
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/policy_loss"
,
pg_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/entropy"
,
entropy_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/approx_kl"
,
approx_kl
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/loss"
,
loss
,
global_step
)
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
:
ckpt_dir
=
f
"checkpoints"
os
.
makedirs
(
ckpt_dir
,
exist_ok
=
True
)
M_steps
=
args
.
batch_size
*
learner_policy_version
//
(
2
**
20
)
model_path
=
os
.
path
.
join
(
ckpt_dir
,
f
"{timestamp}_{M_steps}M.flax_model"
)
with
open
(
model_path
,
"wb"
)
as
f
:
f
.
write
(
flax
.
serialization
.
to_bytes
(
unreplicated_params
)
)
print
(
f
"model saved to {model_path}"
)
if
learner_policy_version
>=
args
.
num_updates
:
break
if
args
.
distributed
:
jax
.
distributed
.
shutdown
()
writer
.
close
()
\ No newline at end of file
scripts/jax/ppo_lstm.py
View file @
34f86ae4
...
@@ -23,7 +23,7 @@ from tensorboardX import SummaryWriter
...
@@ -23,7 +23,7 @@ from tensorboardX import SummaryWriter
from
ygoai.utils
import
init_ygopro
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_mean
,
masked_normalize
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_mean
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
from
ygoai.rl.jax.eval
import
evaluate
from
ygoai.rl.jax
import
compute_gae_upgo_2p0s
,
compute_gae_2p0s
from
ygoai.rl.jax
import
compute_gae_upgo_2p0s
,
compute_gae_2p0s
...
@@ -255,11 +255,7 @@ def rollout(
...
@@ -255,11 +255,7 @@ def rollout(
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
# sample action: Gumbel-softmax trick
action
,
key
=
categorical_sample
(
logits
,
key
)
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key
,
subkey
=
jax
.
random
.
split
(
key
)
u
=
jax
.
random
.
uniform
(
subkey
,
shape
=
logits
.
shape
)
action
=
jnp
.
argmax
(
logits
-
jnp
.
log
(
-
jnp
.
log
(
u
)),
axis
=
1
)
logprob
=
jax
.
nn
.
log_softmax
(
logits
)[
jnp
.
arange
(
action
.
shape
[
0
]),
action
]
logprob
=
jax
.
nn
.
log_softmax
(
logits
)[
jnp
.
arange
(
action
.
shape
[
0
]),
action
]
logits
=
logits
-
jax
.
scipy
.
special
.
logsumexp
(
logits
,
axis
=-
1
,
keepdims
=
True
)
logits
=
logits
-
jax
.
scipy
.
special
.
logsumexp
(
logits
,
axis
=-
1
,
keepdims
=
True
)
...
@@ -329,7 +325,6 @@ def rollout(
...
@@ -329,7 +325,6 @@ def rollout(
inference_time
+=
time
.
time
()
-
inference_time_start
inference_time
+=
time
.
time
()
-
inference_time_start
_start
=
time
.
time
()
_start
=
time
.
time
()
to_play
=
next_to_play
next_obs
,
next_reward
,
next_done
,
info
=
envs
.
step
(
cpu_action
)
next_obs
,
next_reward
,
next_done
,
info
=
envs
.
step
(
cpu_action
)
next_to_play
=
info
[
"to_play"
]
next_to_play
=
info
[
"to_play"
]
env_time
+=
time
.
time
()
-
_start
env_time
+=
time
.
time
()
-
_start
...
@@ -338,11 +333,11 @@ def rollout(
...
@@ -338,11 +333,11 @@ def rollout(
Transition
(
Transition
(
obs
=
cached_next_obs
,
obs
=
cached_next_obs
,
dones
=
cached_next_done
,
dones
=
cached_next_done
,
mains
=
main
,
actions
=
action
,
actions
=
action
,
logprobs
=
logprob
,
logprobs
=
logprob
,
rewards
=
next_reward
,
mains
=
main
,
probs
=
probs
,
probs
=
probs
,
rewards
=
next_reward
,
)
)
)
)
...
@@ -359,8 +354,7 @@ def rollout(
...
@@ -359,8 +354,7 @@ def rollout(
t
.
dones
[
idx
]
=
True
t
.
dones
[
idx
]
=
True
t
.
rewards
[
idx
]
=
-
next_reward
[
idx
]
t
.
rewards
[
idx
]
=
-
next_reward
[
idx
]
break
break
pl
=
1
if
to_play
[
idx
]
==
main_player
[
idx
]
else
-
1
episode_reward
=
info
[
'r'
][
idx
]
*
(
1
if
cur_main
else
-
1
)
episode_reward
=
info
[
'r'
][
idx
]
*
pl
win
=
1
if
episode_reward
>
0
else
0
win
=
1
if
episode_reward
>
0
else
0
avg_ep_returns
.
append
(
episode_reward
)
avg_ep_returns
.
append
(
episode_reward
)
avg_win_rates
.
append
(
win
)
avg_win_rates
.
append
(
win
)
...
@@ -387,16 +381,14 @@ def rollout(
...
@@ -387,16 +381,14 @@ def rollout(
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
next_rstate1
,
next_rstate2
)
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
next_rstate1
,
next_rstate2
)
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
(
next_obs
,
next_rstate
,
init_rstate1
,
init_rstate2
,
next_done
,
next_main
))
(
init_rstate1
,
init_rstate2
,
(
next_rstate
,
next_obs
)
,
next_done
,
next_main
))
learn_opponent
=
False
learn_opponent
=
False
payload
=
(
payload
=
(
global_step
,
global_step
,
actor_policy_version
,
update
,
update
,
sharded_storage
,
sharded_storage
,
*
sharded_data
,
*
sharded_data
,
np
.
mean
(
params_queue_get_time
),
np
.
mean
(
params_queue_get_time
),
device_thread_id
,
learn_opponent
,
learn_opponent
,
)
)
rollout_queue
.
put
(
payload
)
rollout_queue
.
put
(
payload
)
...
@@ -589,7 +581,6 @@ if __name__ == "__main__":
...
@@ -589,7 +581,6 @@ if __name__ == "__main__":
pg_loss
=
jnp
.
maximum
(
pg_loss1
,
pg_loss2
)
pg_loss
=
jnp
.
maximum
(
pg_loss1
,
pg_loss2
)
pg_loss
=
masked_mean
(
pg_loss
,
valid
)
pg_loss
=
masked_mean
(
pg_loss
,
valid
)
# Value loss
v_loss
=
0.5
*
((
newvalue
-
target_values
)
**
2
)
v_loss
=
0.5
*
((
newvalue
-
target_values
)
**
2
)
v_loss
=
masked_mean
(
v_loss
,
valid
)
v_loss
=
masked_mean
(
v_loss
,
valid
)
...
@@ -600,10 +591,9 @@ if __name__ == "__main__":
...
@@ -600,10 +591,9 @@ if __name__ == "__main__":
def
single_device_update
(
def
single_device_update
(
agent_state
:
TrainState
,
agent_state
:
TrainState
,
sharded_storages
:
List
,
sharded_storages
:
List
,
sharded_next_obs
:
List
,
sharded_next_rstate
:
List
,
sharded_init_rstate1
:
List
,
sharded_init_rstate1
:
List
,
sharded_init_rstate2
:
List
,
sharded_init_rstate2
:
List
,
sharded_next_inputs
:
List
,
sharded_next_done
:
List
,
sharded_next_done
:
List
,
sharded_next_main
:
List
,
sharded_next_main
:
List
,
key
:
jax
.
random
.
PRNGKey
,
key
:
jax
.
random
.
PRNGKey
,
...
@@ -620,9 +610,9 @@ if __name__ == "__main__":
...
@@ -620,9 +610,9 @@ if __name__ == "__main__":
return
x
return
x
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
next_
obs
,
next_rstate
,
init_rstate1
,
init_rstate2
=
[
next_
inputs
,
init_rstate1
,
init_rstate2
=
[
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
for
x
in
[
sharded_next_
obs
,
sharded_next_rstate
,
sharded_init_rstate1
,
sharded_init_rstate2
]
for
x
in
[
sharded_next_
inputs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
]
]
next_done
,
next_main
=
[
next_done
,
next_main
=
[
jnp
.
concatenate
(
x
)
for
x
in
[
sharded_next_done
,
sharded_next_main
]
jnp
.
concatenate
(
x
)
for
x
in
[
sharded_next_done
,
sharded_next_main
]
...
@@ -680,7 +670,7 @@ if __name__ == "__main__":
...
@@ -680,7 +670,7 @@ if __name__ == "__main__":
values
=
values
.
reshape
(
storage
.
rewards
.
shape
)
values
=
values
.
reshape
(
storage
.
rewards
.
shape
)
next_value
=
create_agent
(
args
)
.
apply
(
next_value
=
create_agent
(
args
)
.
apply
(
agent_state
.
params
,
(
next_rstate
,
next_obs
)
)[
2
]
.
squeeze
(
-
1
)
agent_state
.
params
,
next_inputs
)[
2
]
.
squeeze
(
-
1
)
# TODO: check if this is correct
# TODO: check if this is correct
sign
=
jnp
.
where
(
switch_steps
<=
num_steps
,
1.0
,
-
1.0
)
sign
=
jnp
.
where
(
switch_steps
<=
num_steps
,
1.0
,
-
1.0
)
next_value
=
jnp
.
where
(
next_main
,
-
sign
*
next_value
,
sign
*
next_value
)
next_value
=
jnp
.
where
(
next_main
,
-
sign
*
next_value
,
sign
*
next_value
)
...
@@ -745,7 +735,7 @@ if __name__ == "__main__":
...
@@ -745,7 +735,7 @@ if __name__ == "__main__":
single_device_update
,
single_device_update
,
axis_name
=
"local_devices"
,
axis_name
=
"local_devices"
,
devices
=
global_learner_decices
,
devices
=
global_learner_decices
,
static_broadcasted_argnums
=
(
9
,),
static_broadcasted_argnums
=
(
8
,),
)
)
params_queues
=
[]
params_queues
=
[]
...
@@ -786,11 +776,9 @@ if __name__ == "__main__":
...
@@ -786,11 +776,9 @@ if __name__ == "__main__":
for
thread_id
in
range
(
args
.
num_actor_threads
):
for
thread_id
in
range
(
args
.
num_actor_threads
):
(
(
global_step
,
global_step
,
actor_policy_version
,
update
,
update
,
*
sharded_data
,
*
sharded_data
,
avg_params_queue_get_time
,
avg_params_queue_get_time
,
device_thread_id
,
learn_opponent
,
learn_opponent
,
)
=
rollout_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
get
()
)
=
rollout_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
get
()
sharded_data_list
.
append
(
sharded_data
)
sharded_data_list
.
append
(
sharded_data
)
...
...
scripts/jax/ppo_lstm2.py
View file @
34f86ae4
...
@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
...
@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
from
ygoai.utils
import
init_ygopro
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
compute_gae_upgo_2p0s
,
compute_gae_2p0s
from
ygoai.rl.jax
import
compute_gae_upgo_2p0s
,
compute_gae_2p0s
...
@@ -122,6 +122,8 @@ class Args:
...
@@ -122,6 +122,8 @@ class Args:
thread_affinity
:
bool
=
False
thread_affinity
:
bool
=
False
"""whether to use thread affinity for the environment"""
"""whether to use thread affinity for the environment"""
eval_checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes
:
int
=
32
local_eval_episodes
:
int
=
32
"""the number of episodes to evaluate the model"""
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
50
eval_interval
:
int
=
50
...
@@ -198,12 +200,16 @@ def rollout(
...
@@ -198,12 +200,16 @@ def rollout(
key
:
jax
.
random
.
PRNGKey
,
key
:
jax
.
random
.
PRNGKey
,
args
:
Args
,
args
:
Args
,
rollout_queue
,
rollout_queue
,
params_queue
:
queue
.
Queue
,
params_queue
,
stats
_queue
,
eval
_queue
,
writer
,
writer
,
learner_devices
,
learner_devices
,
device_thread_id
,
device_thread_id
,
):
):
eval_mode
=
'self'
if
args
.
eval_checkpoint
else
'bot'
if
eval_mode
!=
'bot'
:
eval_params
=
params_queue
.
get
()
envs
=
make_env
(
envs
=
make_env
(
args
,
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
...
@@ -217,7 +223,7 @@ def rollout(
...
@@ -217,7 +223,7 @@ def rollout(
args
,
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
'bot'
)
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
...
@@ -244,11 +250,23 @@ def rollout(
...
@@ -244,11 +250,23 @@ def rollout(
rstate
,
logits
=
get_logits
(
params
,
inputs
,
done
)
rstate
,
logits
=
get_logits
(
params
,
inputs
,
done
)
return
rstate
,
logits
.
argmax
(
axis
=
1
)
return
rstate
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
get_action_battle
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
),
done
)
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
),
done
)
logits
=
jnp
.
where
(
main
[:,
None
],
logits1
,
logits2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
next_rstate2
,
rstate2
)
return
rstate1
,
rstate2
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
@
jax
.
jit
def
sample_action
(
def
sample_action
(
params
:
flax
.
core
.
FrozenDict
,
params
:
flax
.
core
.
FrozenDict
,
next_obs
,
rstate1
,
rstate2
,
main
,
done
,
key
):
next_obs
,
rstate1
,
rstate2
,
main
,
done
,
key
):
next_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
(
x
),
next_obs
)
next_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
(
x
),
next_obs
)
done
=
jnp
.
array
(
done
)
main
=
jnp
.
array
(
main
)
main
=
jnp
.
array
(
main
)
rstate
=
jax
.
tree
.
map
(
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
...
@@ -257,7 +275,7 @@ def rollout(
...
@@ -257,7 +275,7 @@ def rollout(
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
action
,
key
=
categorical_sample
(
logits
,
key
)
action
,
key
=
categorical_sample
(
logits
,
key
)
return
next_obs
,
rstate1
,
rstate2
,
action
,
logits
,
key
return
next_obs
,
done
,
main
,
rstate1
,
rstate2
,
action
,
logits
,
key
# put data in the last index
# put data in the last index
params_queue_get_time
=
deque
(
maxlen
=
10
)
params_queue_get_time
=
deque
(
maxlen
=
10
)
...
@@ -314,7 +332,8 @@ def rollout(
...
@@ -314,7 +332,8 @@ def rollout(
main
=
next_to_play
==
main_player
main
=
next_to_play
==
main_player
inference_time_start
=
time
.
time
()
inference_time_start
=
time
.
time
()
cached_next_obs
,
next_rstate1
,
next_rstate2
,
action
,
logits
,
key
=
sample_action
(
cached_next_obs
,
cached_next_done
,
cached_main
,
\
next_rstate1
,
next_rstate2
,
action
,
logits
,
key
=
sample_action
(
params
,
cached_next_obs
,
next_rstate1
,
next_rstate2
,
main
,
cached_next_done
,
key
)
params
,
cached_next_obs
,
next_rstate1
,
next_rstate2
,
main
,
cached_next_done
,
key
)
cpu_action
=
np
.
array
(
action
)
cpu_action
=
np
.
array
(
action
)
...
@@ -329,7 +348,7 @@ def rollout(
...
@@ -329,7 +348,7 @@ def rollout(
Transition
(
Transition
(
obs
=
cached_next_obs
,
obs
=
cached_next_obs
,
dones
=
cached_next_done
,
dones
=
cached_next_done
,
mains
=
main
,
mains
=
cached_
main
,
actions
=
action
,
actions
=
action
,
logits
=
logits
,
logits
=
logits
,
rewards
=
next_reward
,
rewards
=
next_reward
,
...
@@ -412,19 +431,28 @@ def rollout(
...
@@ -412,19 +431,28 @@ def rollout(
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
# Eval with rule-based policy
# Eval with rule-based policy
_start
=
time
.
time
()
_start
=
time
.
time
()
eval_return
=
evaluate
(
eval_envs
,
get_action
,
params
,
eval_rstate
)[
0
]
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
eval_stat
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)[
0
]
metric_name
=
"eval_return"
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_stat
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)[
2
]
metric_name
=
"eval_win_rate"
if
device_thread_id
!=
0
:
if
device_thread_id
!=
0
:
stats_queue
.
put
(
eval_return
)
eval_queue
.
put
(
eval_stat
)
else
:
else
:
eval_stats
=
[]
eval_stats
=
[]
eval_stats
.
append
(
eval_
return
)
eval_stats
.
append
(
eval_
stat
)
for
_
in
range
(
1
,
n_actors
):
for
_
in
range
(
1
,
n_actors
):
eval_stats
.
append
(
stats
_queue
.
get
())
eval_stats
.
append
(
eval
_queue
.
get
())
eval_stats
=
np
.
mean
(
eval_stats
)
eval_stats
=
np
.
mean
(
eval_stats
)
writer
.
add_scalar
(
"charts/eval_return
"
,
eval_stats
,
global_step
)
writer
.
add_scalar
(
f
"charts/{metric_name}
"
,
eval_stats
,
global_step
)
if
device_thread_id
==
0
:
if
device_thread_id
==
0
:
eval_time
=
time
.
time
()
-
_start
eval_time
=
time
.
time
()
-
_start
print
(
f
"eval_time={eval_time:.4f},
eval_ep_return
={eval_stats:.4f}"
)
print
(
f
"eval_time={eval_time:.4f},
{metric_name}
={eval_stats:.4f}"
)
other_time
+=
eval_time
other_time
+=
eval_time
...
@@ -524,15 +552,24 @@ if __name__ == "__main__":
...
@@ -524,15 +552,24 @@ if __name__ == "__main__":
params
=
params
,
params
=
params
,
tx
=
tx
,
tx
=
tx
,
)
)
if
args
.
checkpoint
:
if
args
.
checkpoint
:
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
agent_state
=
agent_state
.
replace
(
params
=
params
)
agent_state
=
agent_state
.
replace
(
params
=
params
)
print
(
f
"loaded checkpoint from {args.checkpoint}"
)
print
(
f
"loaded checkpoint from {args.checkpoint}"
)
agent_state
=
flax
.
jax_utils
.
replicate
(
agent_state
,
devices
=
learner_devices
)
agent_state
=
flax
.
jax_utils
.
replicate
(
agent_state
,
devices
=
learner_devices
)
# print(agent.tabulate(agent_key, sample_obs))
# print(agent.tabulate(agent_key, sample_obs))
if
args
.
eval_checkpoint
:
with
open
(
args
.
eval_checkpoint
,
"rb"
)
as
f
:
eval_params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
print
(
f
"loaded eval checkpoint from {args.eval_checkpoint}"
)
else
:
eval_params
=
None
@
jax
.
jit
@
jax
.
jit
def
get_logits_and_value
(
def
get_logits_and_value
(
params
:
flax
.
core
.
FrozenDict
,
inputs
,
params
:
flax
.
core
.
FrozenDict
,
inputs
,
...
@@ -711,7 +748,7 @@ if __name__ == "__main__":
...
@@ -711,7 +748,7 @@ if __name__ == "__main__":
params_queues
=
[]
params_queues
=
[]
rollout_queues
=
[]
rollout_queues
=
[]
stats
_queues
=
queue
.
Queue
()
eval
_queues
=
queue
.
Queue
()
dummy_writer
=
SimpleNamespace
()
dummy_writer
=
SimpleNamespace
()
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
...
@@ -721,7 +758,9 @@ if __name__ == "__main__":
...
@@ -721,7 +758,9 @@ if __name__ == "__main__":
for
thread_id
in
range
(
args
.
num_actor_threads
):
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
params_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
rollout_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
rollout_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
params_queues
[
-
1
]
.
put
(
device_params
)
if
eval_params
:
params_queues
[
-
1
]
.
put
(
jax
.
device_put
(
eval_params
,
local_devices
[
d_id
]))
threading
.
Thread
(
threading
.
Thread
(
target
=
rollout
,
target
=
rollout
,
args
=
(
args
=
(
...
@@ -729,12 +768,13 @@ if __name__ == "__main__":
...
@@ -729,12 +768,13 @@ if __name__ == "__main__":
args
,
args
,
rollout_queues
[
-
1
],
rollout_queues
[
-
1
],
params_queues
[
-
1
],
params_queues
[
-
1
],
stats
_queues
,
eval
_queues
,
writer
if
d_idx
==
0
and
thread_id
==
0
else
dummy_writer
,
writer
if
d_idx
==
0
and
thread_id
==
0
else
dummy_writer
,
learner_devices
,
learner_devices
,
d_idx
*
args
.
num_actor_threads
+
thread_id
,
d_idx
*
args
.
num_actor_threads
+
thread_id
,
),
),
)
.
start
()
)
.
start
()
params_queues
[
-
1
]
.
put
(
device_params
)
rollout_queue_get_time
=
deque
(
maxlen
=
10
)
rollout_queue_get_time
=
deque
(
maxlen
=
10
)
data_transfer_time
=
deque
(
maxlen
=
10
)
data_transfer_time
=
deque
(
maxlen
=
10
)
...
...
ygoai/rl/jax/eval.py
View file @
34f86ae4
import
numpy
as
np
import
numpy
as
np
def
evaluate
(
envs
,
act_fn
,
params
,
rnn_state
=
None
):
def
evaluate
(
envs
,
num_episodes
,
predict_fn
,
rnn_state
=
None
):
num_episodes
=
envs
.
num_envs
episode_lengths
=
[]
episode_lengths
=
[]
episode_rewards
=
[]
episode_rewards
=
[]
eval_
win_rates
=
[]
win_rates
=
[]
obs
=
envs
.
reset
()[
0
]
obs
=
envs
.
reset
()[
0
]
collected
=
np
.
zeros
((
num_episodes
,),
dtype
=
np
.
bool_
)
collected
=
np
.
zeros
((
num_episodes
,),
dtype
=
np
.
bool_
)
while
True
:
while
True
:
if
rnn_state
is
None
:
if
rnn_state
is
None
:
actions
=
act_fn
(
params
,
obs
)
actions
=
predict_fn
(
obs
)
else
:
else
:
rnn_state
,
actions
=
act_fn
(
params
,
(
rnn_state
,
obs
))
rnn_state
,
actions
=
predict_fn
(
(
rnn_state
,
obs
))
actions
=
np
.
array
(
actions
)
actions
=
np
.
array
(
actions
)
obs
,
rewards
,
dones
,
info
=
envs
.
step
(
actions
)
obs
,
rewards
,
dones
,
info
=
envs
.
step
(
actions
)
...
@@ -27,11 +26,54 @@ def evaluate(envs, act_fn, params, rnn_state=None):
...
@@ -27,11 +26,54 @@ def evaluate(envs, act_fn, params, rnn_state=None):
episode_lengths
.
append
(
episode_length
)
episode_lengths
.
append
(
episode_length
)
episode_rewards
.
append
(
episode_reward
)
episode_rewards
.
append
(
episode_reward
)
eval_
win_rates
.
append
(
win
)
win_rates
.
append
(
win
)
if
len
(
episode_lengths
)
>=
num_episodes
:
if
len
(
episode_lengths
)
>=
num_episodes
:
break
break
eval_return
=
np
.
mean
(
episode_rewards
[:
num_episodes
])
eval_return
=
np
.
mean
(
episode_rewards
[:
num_episodes
])
eval_ep_len
=
np
.
mean
(
episode_lengths
[:
num_episodes
])
eval_ep_len
=
np
.
mean
(
episode_lengths
[:
num_episodes
])
eval_win_rate
=
np
.
mean
(
eval_win_rates
[:
num_episodes
])
eval_win_rate
=
np
.
mean
(
win_rates
[:
num_episodes
])
return
eval_return
,
eval_ep_len
,
eval_win_rate
return
eval_return
,
eval_ep_len
,
eval_win_rate
\ No newline at end of file
def
battle
(
envs
,
num_episodes
,
predict_fn
,
init_rnn_state
=
None
):
num_envs
=
envs
.
num_envs
episode_rewards
=
[]
episode_lengths
=
[]
win_rates
=
[]
obs
,
infos
=
envs
.
reset
()
next_to_play
=
infos
[
'to_play'
]
dones
=
np
.
zeros
(
num_envs
,
dtype
=
np
.
bool_
)
main_player
=
np
.
concatenate
([
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
num_envs
-
num_envs
//
2
,
dtype
=
np
.
int64
)
])
rstate1
=
rstate2
=
init_rnn_state
while
True
:
main
=
next_to_play
==
main_player
rstate1
,
rstate2
,
actions
=
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
dones
)
actions
=
np
.
array
(
actions
)
obs
,
rewards
,
dones
,
infos
=
envs
.
step
(
actions
)
next_to_play
=
infos
[
'to_play'
]
for
idx
,
d
in
enumerate
(
dones
):
if
not
d
:
continue
episode_length
=
infos
[
'l'
][
idx
]
episode_reward
=
infos
[
'r'
][
idx
]
*
(
1
if
main
[
idx
]
else
-
1
)
win
=
1
if
episode_reward
>
0
else
0
episode_lengths
.
append
(
episode_length
)
episode_rewards
.
append
(
episode_reward
)
win_rates
.
append
(
win
)
if
len
(
episode_lengths
)
>=
num_episodes
:
break
eval_return
=
np
.
mean
(
episode_rewards
[:
num_episodes
])
eval_ep_len
=
np
.
mean
(
episode_lengths
[:
num_episodes
])
eval_win_rate
=
np
.
mean
(
win_rates
[:
num_episodes
])
return
eval_return
,
eval_ep_len
,
eval_win_rate
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment