Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
34f86ae4
Commit
34f86ae4
authored
Apr 14, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Unify Impala and PPO
parent
9d8d4386
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
976 additions
and
95 deletions
+976
-95
scripts/jax/battle.py
scripts/jax/battle.py
+4
-1
scripts/jax/impala.py
scripts/jax/impala.py
+35
-46
scripts/jax/impala2.py
scripts/jax/impala2.py
+819
-0
scripts/jax/ppo_lstm.py
scripts/jax/ppo_lstm.py
+11
-23
scripts/jax/ppo_lstm2.py
scripts/jax/ppo_lstm2.py
+57
-17
ygoai/rl/jax/eval.py
ygoai/rl/jax/eval.py
+50
-8
No files found.
scripts/jax/battle.py
View file @
34f86ae4
...
...
@@ -165,7 +165,10 @@ if __name__ == "__main__":
else
:
with
open
(
args
.
checkpoint2
,
"rb"
)
as
f
:
params2
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params1
=
jax
.
device_put
(
params1
)
params2
=
jax
.
device_put
(
params2
)
@
jax
.
jit
def
get_probs
(
params
,
rstate
,
obs
,
done
):
agent
=
create_agent
(
args
)
...
...
scripts/jax/impala.py
View file @
34f86ae4
...
...
@@ -17,6 +17,7 @@ import jax.numpy as jnp
import
numpy
as
np
import
optax
import
rlax
import
distrax
import
tyro
from
flax.training.train_state
import
TrainState
from
rich.pretty
import
pprint
...
...
@@ -28,6 +29,7 @@ from ygoai.rl.jax.utils import RecordEpisodeStatistics, masked_mean, masked_norm
from
ygoai.rl.jax.eval
import
evaluate
from
ygoai.rl.jax
import
vtrace
,
upgo_return
,
clipped_surrogate_pg_loss
os
.
environ
[
"XLA_FLAGS"
]
=
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
...
...
@@ -40,7 +42,9 @@ class Args:
log_frequency
:
int
=
10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval
:
int
=
100
"""the frequency of saving the model"""
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments
env_id
:
str
=
"YGOPro-v0"
...
...
@@ -78,8 +82,6 @@ class Args:
"""the discount factor gamma"""
num_minibatches
:
int
=
4
"""the number of mini-batches"""
gradient_accumulation_steps
:
int
=
1
"""the number of gradient accumulation steps before performing an optimization step"""
c_clip_min
:
float
=
0.001
"""the minimum value of the importance sampling clipping"""
c_clip_max
:
float
=
1.007
...
...
@@ -88,8 +90,6 @@ class Args:
"""the minimum value of the importance sampling clipping"""
rho_clip_max
:
float
=
1.007
"""the maximum value of the importance sampling clipping"""
upgo
:
bool
=
False
"""whether to use UPGO for policy update"""
ppo_clip
:
bool
=
True
"""whether to use the PPO clipping to replace V-Trace surrogate clipping"""
clip_coef
:
float
=
0.25
...
...
@@ -127,7 +127,6 @@ class Args:
# runtime arguments to be filled in
local_batch_size
:
int
=
0
local_minibatch_size
:
int
=
0
num_updates
:
int
=
0
world_size
:
int
=
0
local_rank
:
int
=
0
num_envs
:
int
=
0
...
...
@@ -218,34 +217,28 @@ def rollout(
avg_win_rates
=
deque
(
maxlen
=
1000
)
@
jax
.
jit
def
apply_fn
(
params
:
flax
.
core
.
FrozenDict
,
next_obs
,
):
logits
,
value
,
_valid
=
create_agent
(
args
)
.
apply
(
params
,
next_obs
)
return
logits
,
value
def
get_logits
(
params
:
flax
.
core
.
FrozenDict
,
inputs
):
logits
,
value
,
_valid
=
create_agent
(
args
)
.
apply
(
params
,
inputs
)[:
2
]
return
logits
def
get_action
(
params
:
flax
.
core
.
FrozenDict
,
next_obs
,
):
return
apply_fn
(
params
,
next_obs
)[
0
]
.
argmax
(
axis
=
1
)
params
:
flax
.
core
.
FrozenDict
,
inputs
):
return
get_logits
(
params
,
inputs
)
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
sample_action
(
params
:
flax
.
core
.
FrozenDict
,
next_obs
,
key
:
jax
.
random
.
PRNGKey
,
):
next_obs
,
key
:
jax
.
random
.
PRNGKey
):
next_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
(
x
),
next_obs
)
logits
=
apply_fn
(
params
,
next_obs
)[
0
]
logits
=
get_logits
(
params
,
next_obs
)
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key
,
subkey
=
jax
.
random
.
split
(
key
)
u
=
jax
.
random
.
uniform
(
subkey
,
shape
=
logits
.
shape
)
action
=
jnp
.
argmax
(
logits
-
jnp
.
log
(
-
jnp
.
log
(
u
)),
axis
=
1
)
return
next_obs
,
action
,
logits
,
key
# put data in the last index
envs
.
async_reset
()
...
...
@@ -253,13 +246,13 @@ def rollout(
rollout_time
=
deque
(
maxlen
=
10
)
actor_policy_version
=
0
storage
=
[]
ai_player1
=
np
.
concatenate
([
main_player
=
np
.
concatenate
([
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
np
.
random
.
shuffle
(
ai_player1
)
np
.
random
.
shuffle
(
main_player
)
next_to_play
=
None
lear
n
=
np
.
ones
(
args
.
local_num_envs
,
dtype
=
np
.
bool_
)
mai
n
=
np
.
ones
(
args
.
local_num_envs
,
dtype
=
np
.
bool_
)
@
jax
.
jit
def
prepare_data
(
storage
:
List
[
Transition
])
->
Transition
:
...
...
@@ -274,8 +267,7 @@ def rollout(
inference_time
=
0
env_time
=
0
num_steps_with_bootstrap
=
(
args
.
num_steps
+
int
(
len
(
storage
)
==
0
)
)
# num_steps + 1 to get the states for value bootstrapping.
args
.
num_steps
+
int
(
len
(
storage
)
==
0
))
params_queue_get_time_start
=
time
.
time
()
if
args
.
concurrency
:
if
update
!=
2
:
...
...
@@ -295,11 +287,11 @@ def rollout(
_start
=
time
.
time
()
next_obs
,
next_reward
,
next_done
,
info
=
envs
.
recv
()
next_reward
=
np
.
where
(
lear
n
,
next_reward
,
-
next_reward
)
next_reward
=
np
.
where
(
mai
n
,
next_reward
,
-
next_reward
)
env_time
+=
time
.
time
()
-
_start
to_play
=
next_to_play
next_to_play
=
info
[
"to_play"
]
learn
=
next_to_play
==
ai_player1
main
=
next_to_play
==
main_player
inference_time_start
=
time
.
time
()
next_obs
,
action
,
logits
,
key
=
sample_action
(
params
,
next_obs
,
key
)
...
...
@@ -312,17 +304,17 @@ def rollout(
Transition
(
obs
=
next_obs
,
dones
=
next_done
,
mains
=
main
,
rewards
=
next_reward
,
actions
=
action
,
logitss
=
logits
,
rewards
=
next_reward
,
learns
=
learn
,
)
)
for
idx
,
d
in
enumerate
(
next_done
):
if
not
d
:
continue
pl
=
1
if
to_play
[
idx
]
==
ai_player1
[
idx
]
else
-
1
pl
=
1
if
to_play
[
idx
]
==
main_player
[
idx
]
else
-
1
episode_reward
=
info
[
'r'
][
idx
]
*
pl
win
=
1
if
episode_reward
>
0
else
0
avg_ep_returns
.
append
(
episode_reward
)
...
...
@@ -488,7 +480,7 @@ if __name__ == "__main__":
learning_rate
=
linear_schedule
if
args
.
anneal_lr
else
args
.
learning_rate
,
eps
=
1e-5
),
),
every_k_schedule
=
args
.
gradient_accumulation_steps
,
every_k_schedule
=
1
,
)
agent_state
=
TrainState
.
create
(
apply_fn
=
None
,
...
...
@@ -505,13 +497,15 @@ if __name__ == "__main__":
params
:
flax
.
core
.
FrozenDict
,
obs
:
np
.
ndarray
,
):
logits
,
value
,
valid
=
create_agent
(
args
)
.
apply
(
params
,
obs
)
return
logits
,
value
.
squeeze
(
-
1
)
,
valid
logits
,
value
=
create_agent
(
args
)
.
apply
(
params
,
obs
)
return
logits
,
value
.
squeeze
(
-
1
)
def
impala_loss
(
params
,
obs
,
actions
,
logitss
,
rewards
,
dones
,
learns
):
def
impala_loss
(
params
,
obs
,
actions
,
logitss
,
rewards
,
dones
,
learns
):
# (num_steps + 1, local_num_envs // n_mb))
num_steps
=
actions
.
shape
[
0
]
-
1
discounts
=
(
1.0
-
dones
)
*
args
.
gamma
policy_logits
,
newvalue
,
valid
=
jax
.
vmap
(
policy_logits
,
newvalue
=
jax
.
vmap
(
get_logits_and_value
,
in_axes
=
(
None
,
0
))(
params
,
obs
)
newvalue
=
jnp
.
where
(
learns
,
newvalue
,
-
newvalue
)
...
...
@@ -527,19 +521,14 @@ if __name__ == "__main__":
discounts
=
discounts
[
1
:]
mask
=
mask
[:
-
1
]
rhos
=
rlax
.
categorical_importance_sampling_ratios
(
policy_logits
,
logitss
,
actions
)
rhos
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
policy_logits
),
distrax
.
Categorical
(
logitss
),
actions
)
vtrace_fn
=
partial
(
vtrace
,
c_clip_min
=
args
.
c_clip_min
,
c_clip_max
=
args
.
c_clip_max
,
rho_clip_min
=
args
.
rho_clip_min
,
rho_clip_max
=
args
.
rho_clip_max
)
vtrace_returns
=
jax
.
vmap
(
vtrace_fn
,
in_axes
=
1
,
out_axes
=
1
)(
v_tm1
,
v_t
,
rewards
,
discounts
,
rhos
)
jax
.
debug
.
print
(
"R {}"
,
jnp
.
where
(
dones
[
1
:
-
1
,
:
2
],
rewards
[:
-
1
,
:
2
],
0
)
.
T
)
jax
.
debug
.
print
(
"E {}"
,
jnp
.
where
(
dones
[
1
:
-
1
,
:
2
],
vtrace_returns
.
errors
[:
-
1
,
:
2
]
*
100
,
vtrace_returns
.
errors
[:
-
1
,
:
2
])
.
T
)
jax
.
debug
.
print
(
"V {}"
,
v_tm1
[:
-
1
,
:
2
]
.
T
)
T
=
v_tm1
.
shape
[
0
]
if
args
.
upgo
:
advs
=
jax
.
vmap
(
upgo_return
,
in_axes
=
1
,
out_axes
=
1
)(
rewards
,
v_t
,
discounts
)
-
v_tm1
...
...
@@ -548,13 +537,13 @@ if __name__ == "__main__":
if
args
.
ppo_clip
:
pg_loss
=
jax
.
vmap
(
partial
(
clipped_surrogate_pg_loss
,
epsilon
=
args
.
clip_coef
),
in_axes
=
1
)(
rhos
,
advs
,
mask
)
*
T
rhos
,
advs
,
mask
)
*
num_steps
pg_loss
=
jnp
.
sum
(
pg_loss
)
else
:
pg_advs
=
jnp
.
minimum
(
args
.
rho_clip_max
,
rhos
)
*
advs
pg_loss
=
jax
.
vmap
(
rlax
.
policy_gradient_loss
,
in_axes
=
1
)(
policy_logits
,
actions
,
pg_advs
,
mask
)
*
T
policy_logits
,
actions
,
pg_advs
,
mask
)
*
num_steps
pg_loss
=
jnp
.
sum
(
pg_loss
)
baseline_loss
=
0.5
*
jnp
.
sum
(
jnp
.
square
(
vtrace_returns
.
errors
)
*
mask
)
...
...
scripts/jax/impala2.py
0 → 100644
View file @
34f86ae4
This diff is collapsed.
Click to expand it.
scripts/jax/ppo_lstm.py
View file @
34f86ae4
...
...
@@ -23,7 +23,7 @@ from tensorboardX import SummaryWriter
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_mean
,
masked_normalize
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_mean
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
from
ygoai.rl.jax
import
compute_gae_upgo_2p0s
,
compute_gae_2p0s
...
...
@@ -255,11 +255,7 @@ def rollout(
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
# sample action: Gumbel-softmax trick
# see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution
key
,
subkey
=
jax
.
random
.
split
(
key
)
u
=
jax
.
random
.
uniform
(
subkey
,
shape
=
logits
.
shape
)
action
=
jnp
.
argmax
(
logits
-
jnp
.
log
(
-
jnp
.
log
(
u
)),
axis
=
1
)
action
,
key
=
categorical_sample
(
logits
,
key
)
logprob
=
jax
.
nn
.
log_softmax
(
logits
)[
jnp
.
arange
(
action
.
shape
[
0
]),
action
]
logits
=
logits
-
jax
.
scipy
.
special
.
logsumexp
(
logits
,
axis
=-
1
,
keepdims
=
True
)
...
...
@@ -329,7 +325,6 @@ def rollout(
inference_time
+=
time
.
time
()
-
inference_time_start
_start
=
time
.
time
()
to_play
=
next_to_play
next_obs
,
next_reward
,
next_done
,
info
=
envs
.
step
(
cpu_action
)
next_to_play
=
info
[
"to_play"
]
env_time
+=
time
.
time
()
-
_start
...
...
@@ -338,11 +333,11 @@ def rollout(
Transition
(
obs
=
cached_next_obs
,
dones
=
cached_next_done
,
mains
=
main
,
actions
=
action
,
logprobs
=
logprob
,
rewards
=
next_reward
,
mains
=
main
,
probs
=
probs
,
rewards
=
next_reward
,
)
)
...
...
@@ -359,8 +354,7 @@ def rollout(
t
.
dones
[
idx
]
=
True
t
.
rewards
[
idx
]
=
-
next_reward
[
idx
]
break
pl
=
1
if
to_play
[
idx
]
==
main_player
[
idx
]
else
-
1
episode_reward
=
info
[
'r'
][
idx
]
*
pl
episode_reward
=
info
[
'r'
][
idx
]
*
(
1
if
cur_main
else
-
1
)
win
=
1
if
episode_reward
>
0
else
0
avg_ep_returns
.
append
(
episode_reward
)
avg_win_rates
.
append
(
win
)
...
...
@@ -387,16 +381,14 @@ def rollout(
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
next_rstate1
,
next_rstate2
)
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
(
next_obs
,
next_rstate
,
init_rstate1
,
init_rstate2
,
next_done
,
next_main
))
(
init_rstate1
,
init_rstate2
,
(
next_rstate
,
next_obs
)
,
next_done
,
next_main
))
learn_opponent
=
False
payload
=
(
global_step
,
actor_policy_version
,
update
,
sharded_storage
,
*
sharded_data
,
np
.
mean
(
params_queue_get_time
),
device_thread_id
,
learn_opponent
,
)
rollout_queue
.
put
(
payload
)
...
...
@@ -589,7 +581,6 @@ if __name__ == "__main__":
pg_loss
=
jnp
.
maximum
(
pg_loss1
,
pg_loss2
)
pg_loss
=
masked_mean
(
pg_loss
,
valid
)
# Value loss
v_loss
=
0.5
*
((
newvalue
-
target_values
)
**
2
)
v_loss
=
masked_mean
(
v_loss
,
valid
)
...
...
@@ -600,10 +591,9 @@ if __name__ == "__main__":
def
single_device_update
(
agent_state
:
TrainState
,
sharded_storages
:
List
,
sharded_next_obs
:
List
,
sharded_next_rstate
:
List
,
sharded_init_rstate1
:
List
,
sharded_init_rstate2
:
List
,
sharded_next_inputs
:
List
,
sharded_next_done
:
List
,
sharded_next_main
:
List
,
key
:
jax
.
random
.
PRNGKey
,
...
...
@@ -620,9 +610,9 @@ if __name__ == "__main__":
return
x
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
next_
obs
,
next_rstate
,
init_rstate1
,
init_rstate2
=
[
next_
inputs
,
init_rstate1
,
init_rstate2
=
[
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
for
x
in
[
sharded_next_
obs
,
sharded_next_rstate
,
sharded_init_rstate1
,
sharded_init_rstate2
]
for
x
in
[
sharded_next_
inputs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
]
next_done
,
next_main
=
[
jnp
.
concatenate
(
x
)
for
x
in
[
sharded_next_done
,
sharded_next_main
]
...
...
@@ -680,7 +670,7 @@ if __name__ == "__main__":
values
=
values
.
reshape
(
storage
.
rewards
.
shape
)
next_value
=
create_agent
(
args
)
.
apply
(
agent_state
.
params
,
(
next_rstate
,
next_obs
)
)[
2
]
.
squeeze
(
-
1
)
agent_state
.
params
,
next_inputs
)[
2
]
.
squeeze
(
-
1
)
# TODO: check if this is correct
sign
=
jnp
.
where
(
switch_steps
<=
num_steps
,
1.0
,
-
1.0
)
next_value
=
jnp
.
where
(
next_main
,
-
sign
*
next_value
,
sign
*
next_value
)
...
...
@@ -745,7 +735,7 @@ if __name__ == "__main__":
single_device_update
,
axis_name
=
"local_devices"
,
devices
=
global_learner_decices
,
static_broadcasted_argnums
=
(
9
,),
static_broadcasted_argnums
=
(
8
,),
)
params_queues
=
[]
...
...
@@ -786,11 +776,9 @@ if __name__ == "__main__":
for
thread_id
in
range
(
args
.
num_actor_threads
):
(
global_step
,
actor_policy_version
,
update
,
*
sharded_data
,
avg_params_queue_get_time
,
device_thread_id
,
learn_opponent
,
)
=
rollout_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
get
()
sharded_data_list
.
append
(
sharded_data
)
...
...
scripts/jax/ppo_lstm2.py
View file @
34f86ae4
...
...
@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
compute_gae_upgo_2p0s
,
compute_gae_2p0s
...
...
@@ -122,6 +122,8 @@ class Args:
thread_affinity
:
bool
=
False
"""whether to use thread affinity for the environment"""
eval_checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes
:
int
=
32
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
50
...
...
@@ -198,12 +200,16 @@ def rollout(
key
:
jax
.
random
.
PRNGKey
,
args
:
Args
,
rollout_queue
,
params_queue
:
queue
.
Queue
,
stats
_queue
,
params_queue
,
eval
_queue
,
writer
,
learner_devices
,
device_thread_id
,
):
eval_mode
=
'self'
if
args
.
eval_checkpoint
else
'bot'
if
eval_mode
!=
'bot'
:
eval_params
=
params_queue
.
get
()
envs
=
make_env
(
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
...
...
@@ -217,7 +223,7 @@ def rollout(
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
'bot'
)
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
...
...
@@ -244,11 +250,23 @@ def rollout(
rstate
,
logits
=
get_logits
(
params
,
inputs
,
done
)
return
rstate
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
get_action_battle
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
),
done
)
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
),
done
)
logits
=
jnp
.
where
(
main
[:,
None
],
logits1
,
logits2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
next_rstate2
,
rstate2
)
return
rstate1
,
rstate2
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
sample_action
(
params
:
flax
.
core
.
FrozenDict
,
next_obs
,
rstate1
,
rstate2
,
main
,
done
,
key
):
next_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
(
x
),
next_obs
)
done
=
jnp
.
array
(
done
)
main
=
jnp
.
array
(
main
)
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
...
...
@@ -257,7 +275,7 @@ def rollout(
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
action
,
key
=
categorical_sample
(
logits
,
key
)
return
next_obs
,
rstate1
,
rstate2
,
action
,
logits
,
key
return
next_obs
,
done
,
main
,
rstate1
,
rstate2
,
action
,
logits
,
key
# put data in the last index
params_queue_get_time
=
deque
(
maxlen
=
10
)
...
...
@@ -314,7 +332,8 @@ def rollout(
main
=
next_to_play
==
main_player
inference_time_start
=
time
.
time
()
cached_next_obs
,
next_rstate1
,
next_rstate2
,
action
,
logits
,
key
=
sample_action
(
cached_next_obs
,
cached_next_done
,
cached_main
,
\
next_rstate1
,
next_rstate2
,
action
,
logits
,
key
=
sample_action
(
params
,
cached_next_obs
,
next_rstate1
,
next_rstate2
,
main
,
cached_next_done
,
key
)
cpu_action
=
np
.
array
(
action
)
...
...
@@ -329,7 +348,7 @@ def rollout(
Transition
(
obs
=
cached_next_obs
,
dones
=
cached_next_done
,
mains
=
main
,
mains
=
cached_
main
,
actions
=
action
,
logits
=
logits
,
rewards
=
next_reward
,
...
...
@@ -412,19 +431,28 @@ def rollout(
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
# Eval with rule-based policy
_start
=
time
.
time
()
eval_return
=
evaluate
(
eval_envs
,
get_action
,
params
,
eval_rstate
)[
0
]
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
eval_stat
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)[
0
]
metric_name
=
"eval_return"
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_stat
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)[
2
]
metric_name
=
"eval_win_rate"
if
device_thread_id
!=
0
:
stats_queue
.
put
(
eval_return
)
eval_queue
.
put
(
eval_stat
)
else
:
eval_stats
=
[]
eval_stats
.
append
(
eval_
return
)
eval_stats
.
append
(
eval_
stat
)
for
_
in
range
(
1
,
n_actors
):
eval_stats
.
append
(
stats
_queue
.
get
())
eval_stats
.
append
(
eval
_queue
.
get
())
eval_stats
=
np
.
mean
(
eval_stats
)
writer
.
add_scalar
(
"charts/eval_return
"
,
eval_stats
,
global_step
)
writer
.
add_scalar
(
f
"charts/{metric_name}
"
,
eval_stats
,
global_step
)
if
device_thread_id
==
0
:
eval_time
=
time
.
time
()
-
_start
print
(
f
"eval_time={eval_time:.4f},
eval_ep_return
={eval_stats:.4f}"
)
print
(
f
"eval_time={eval_time:.4f},
{metric_name}
={eval_stats:.4f}"
)
other_time
+=
eval_time
...
...
@@ -524,15 +552,24 @@ if __name__ == "__main__":
params
=
params
,
tx
=
tx
,
)
if
args
.
checkpoint
:
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
agent_state
=
agent_state
.
replace
(
params
=
params
)
print
(
f
"loaded checkpoint from {args.checkpoint}"
)
agent_state
=
flax
.
jax_utils
.
replicate
(
agent_state
,
devices
=
learner_devices
)
# print(agent.tabulate(agent_key, sample_obs))
if
args
.
eval_checkpoint
:
with
open
(
args
.
eval_checkpoint
,
"rb"
)
as
f
:
eval_params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
print
(
f
"loaded eval checkpoint from {args.eval_checkpoint}"
)
else
:
eval_params
=
None
@
jax
.
jit
def
get_logits_and_value
(
params
:
flax
.
core
.
FrozenDict
,
inputs
,
...
...
@@ -711,7 +748,7 @@ if __name__ == "__main__":
params_queues
=
[]
rollout_queues
=
[]
stats
_queues
=
queue
.
Queue
()
eval
_queues
=
queue
.
Queue
()
dummy_writer
=
SimpleNamespace
()
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
...
...
@@ -721,7 +758,9 @@ if __name__ == "__main__":
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
rollout_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
params_queues
[
-
1
]
.
put
(
device_params
)
if
eval_params
:
params_queues
[
-
1
]
.
put
(
jax
.
device_put
(
eval_params
,
local_devices
[
d_id
]))
threading
.
Thread
(
target
=
rollout
,
args
=
(
...
...
@@ -729,12 +768,13 @@ if __name__ == "__main__":
args
,
rollout_queues
[
-
1
],
params_queues
[
-
1
],
stats
_queues
,
eval
_queues
,
writer
if
d_idx
==
0
and
thread_id
==
0
else
dummy_writer
,
learner_devices
,
d_idx
*
args
.
num_actor_threads
+
thread_id
,
),
)
.
start
()
params_queues
[
-
1
]
.
put
(
device_params
)
rollout_queue_get_time
=
deque
(
maxlen
=
10
)
data_transfer_time
=
deque
(
maxlen
=
10
)
...
...
ygoai/rl/jax/eval.py
View file @
34f86ae4
import
numpy
as
np
def
evaluate
(
envs
,
act_fn
,
params
,
rnn_state
=
None
):
num_episodes
=
envs
.
num_envs
def
evaluate
(
envs
,
num_episodes
,
predict_fn
,
rnn_state
=
None
):
episode_lengths
=
[]
episode_rewards
=
[]
eval_
win_rates
=
[]
win_rates
=
[]
obs
=
envs
.
reset
()[
0
]
collected
=
np
.
zeros
((
num_episodes
,),
dtype
=
np
.
bool_
)
while
True
:
if
rnn_state
is
None
:
actions
=
act_fn
(
params
,
obs
)
actions
=
predict_fn
(
obs
)
else
:
rnn_state
,
actions
=
act_fn
(
params
,
(
rnn_state
,
obs
))
rnn_state
,
actions
=
predict_fn
(
(
rnn_state
,
obs
))
actions
=
np
.
array
(
actions
)
obs
,
rewards
,
dones
,
info
=
envs
.
step
(
actions
)
...
...
@@ -27,11 +26,54 @@ def evaluate(envs, act_fn, params, rnn_state=None):
episode_lengths
.
append
(
episode_length
)
episode_rewards
.
append
(
episode_reward
)
eval_
win_rates
.
append
(
win
)
win_rates
.
append
(
win
)
if
len
(
episode_lengths
)
>=
num_episodes
:
break
eval_return
=
np
.
mean
(
episode_rewards
[:
num_episodes
])
eval_ep_len
=
np
.
mean
(
episode_lengths
[:
num_episodes
])
eval_win_rate
=
np
.
mean
(
eval_win_rates
[:
num_episodes
])
return
eval_return
,
eval_ep_len
,
eval_win_rate
\ No newline at end of file
eval_win_rate
=
np
.
mean
(
win_rates
[:
num_episodes
])
return
eval_return
,
eval_ep_len
,
eval_win_rate
def
battle
(
envs
,
num_episodes
,
predict_fn
,
init_rnn_state
=
None
):
num_envs
=
envs
.
num_envs
episode_rewards
=
[]
episode_lengths
=
[]
win_rates
=
[]
obs
,
infos
=
envs
.
reset
()
next_to_play
=
infos
[
'to_play'
]
dones
=
np
.
zeros
(
num_envs
,
dtype
=
np
.
bool_
)
main_player
=
np
.
concatenate
([
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
num_envs
-
num_envs
//
2
,
dtype
=
np
.
int64
)
])
rstate1
=
rstate2
=
init_rnn_state
while
True
:
main
=
next_to_play
==
main_player
rstate1
,
rstate2
,
actions
=
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
dones
)
actions
=
np
.
array
(
actions
)
obs
,
rewards
,
dones
,
infos
=
envs
.
step
(
actions
)
next_to_play
=
infos
[
'to_play'
]
for
idx
,
d
in
enumerate
(
dones
):
if
not
d
:
continue
episode_length
=
infos
[
'l'
][
idx
]
episode_reward
=
infos
[
'r'
][
idx
]
*
(
1
if
main
[
idx
]
else
-
1
)
win
=
1
if
episode_reward
>
0
else
0
episode_lengths
.
append
(
episode_length
)
episode_rewards
.
append
(
episode_reward
)
win_rates
.
append
(
win
)
if
len
(
episode_lengths
)
>=
num_episodes
:
break
eval_return
=
np
.
mean
(
episode_rewards
[:
num_episodes
])
eval_ep_len
=
np
.
mean
(
episode_lengths
[:
num_episodes
])
eval_win_rate
=
np
.
mean
(
win_rates
[:
num_episodes
])
return
eval_return
,
eval_ep_len
,
eval_win_rate
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment