Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
6c7ac38a
Commit
6c7ac38a
authored
May 01, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Count global eval_stats in impala
parent
6c08e7d2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
34 deletions
+42
-34
scripts/impala.py
scripts/impala.py
+42
-31
scripts/ppo.py
scripts/ppo.py
+0
-3
No files found.
scripts/impala.py
View file @
6c7ac38a
...
...
@@ -220,7 +220,6 @@ def rollout(
args
:
Args
,
rollout_queue
,
params_queue
,
eval_queue
,
writer
,
learner_devices
,
device_thread_id
,
...
...
@@ -407,6 +406,23 @@ def rollout(
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
(
init_rstate1
,
init_rstate2
,
(
next_rstate
,
next_obs
),
next_done
,
next_main
))
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
_start
=
time
.
time
()
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
eval_time
=
time
.
time
()
-
_start
other_time
+=
eval_time
eval_stats
=
np
.
array
([
eval_time
,
eval_return
,
eval_win_rate
],
dtype
=
np
.
float32
)
else
:
eval_stats
=
None
learn_opponent
=
False
payload
=
(
global_step
,
...
...
@@ -415,6 +431,7 @@ def rollout(
*
sharded_data
,
np
.
mean
(
params_queue_get_time
),
learn_opponent
,
eval_stats
,
)
rollout_queue
.
put
(
payload
)
...
...
@@ -441,34 +458,6 @@ def rollout(
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
writer
.
add_scalar
(
"charts/SPS_update"
,
SPS_update
,
global_step
)
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
# Eval with rule-based policy
_start
=
time
.
time
()
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
eval_stat
=
np
.
array
([
eval_return
,
eval_win_rate
])
if
device_thread_id
!=
0
:
eval_queue
.
put
(
eval_stat
)
else
:
eval_stats
=
[]
eval_stats
.
append
(
eval_stat
)
for
_
in
range
(
1
,
n_actors
):
eval_stats
.
append
(
eval_queue
.
get
())
eval_stats
=
np
.
stack
(
eval_stats
)
eval_return
,
eval_win_rate
=
np
.
mean
(
eval_stats
,
axis
=
0
)
writer
.
add_scalar
(
f
"charts/eval_return"
,
eval_return
,
global_step
)
writer
.
add_scalar
(
f
"charts/eval_win_rate"
,
eval_win_rate
,
global_step
)
if
device_thread_id
==
0
:
eval_time
=
time
.
time
()
-
_start
other_time
+=
eval_time
print
(
f
"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}"
)
if
__name__
==
"__main__"
:
args
=
tyro
.
cli
(
Args
)
...
...
@@ -515,6 +504,10 @@ if __name__ == "__main__":
for
process_index
in
range
(
args
.
world_size
)
for
d_id
in
args
.
learner_device_ids
]
global_main_devices
=
[
global_devices
[
process_index
*
len
(
local_devices
)]
for
process_index
in
range
(
args
.
world_size
)
]
print
(
"global_learner_decices"
,
global_learner_decices
)
args
.
global_learner_decices
=
[
str
(
item
)
for
item
in
global_learner_decices
]
args
.
actor_devices
=
[
str
(
item
)
for
item
in
actor_devices
]
...
...
@@ -762,6 +755,12 @@ if __name__ == "__main__":
approx_kl
=
jax
.
lax
.
pmean
(
approx_kl
,
axis_name
=
"local_devices"
)
.
mean
()
return
agent_state
,
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
,
key
all_reduce_value
=
jax
.
pmap
(
lambda
x
:
jax
.
lax
.
pmean
(
x
,
axis_name
=
"main_devices"
),
axis_name
=
"main_devices"
,
devices
=
global_main_devices
,
)
multi_device_update
=
jax
.
pmap
(
single_device_update
,
axis_name
=
"local_devices"
,
...
...
@@ -771,7 +770,6 @@ if __name__ == "__main__":
params_queues
=
[]
rollout_queues
=
[]
eval_queue
=
queue
.
Queue
()
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
...
...
@@ -790,7 +788,6 @@ if __name__ == "__main__":
args
,
rollout_queues
[
-
1
],
params_queues
[
-
1
],
eval_queue
,
writer
if
d_idx
==
0
and
thread_id
==
0
else
dummy_writer
,
learner_devices
,
actor_thread_id
,
...
...
@@ -805,6 +802,7 @@ if __name__ == "__main__":
learner_policy_version
+=
1
rollout_queue_get_time_start
=
time
.
time
()
sharded_data_list
=
[]
eval_stat_list
=
[]
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
for
thread_id
in
range
(
args
.
num_actor_threads
):
(
...
...
@@ -813,8 +811,21 @@ if __name__ == "__main__":
*
sharded_data
,
avg_params_queue_get_time
,
learn_opponent
,
eval_stats
,
)
=
rollout_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
get
()
sharded_data_list
.
append
(
sharded_data
)
if
eval_stats
is
not
None
:
eval_stat_list
.
append
(
eval_stats
)
if
update
%
args
.
eval_interval
==
0
:
eval_stats
=
np
.
mean
(
eval_stat_list
,
axis
=
0
)
eval_stats
=
jax
.
device_put
(
eval_stats
,
local_devices
[
0
])
eval_stats
=
np
.
array
(
all_reduce_value
(
eval_stats
[
None
])[
0
])
eval_time
,
eval_return
,
eval_win_rate
=
eval_stats
writer
.
add_scalar
(
f
"charts/eval_return"
,
eval_return
,
global_step
)
writer
.
add_scalar
(
f
"charts/eval_win_rate"
,
eval_win_rate
,
global_step
)
print
(
f
"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}"
)
rollout_queue_get_time
.
append
(
time
.
time
()
-
rollout_queue_get_time_start
)
training_time_start
=
time
.
time
()
(
agent_state
,
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
,
learner_keys
)
=
multi_device_update
(
...
...
scripts/ppo.py
View file @
6c7ac38a
...
...
@@ -430,7 +430,6 @@ def rollout(
eval_time
=
time
.
time
()
-
_start
other_time
+=
eval_time
eval_stats
=
np
.
array
([
eval_time
,
eval_return
,
eval_win_rate
],
dtype
=
np
.
float32
)
print
(
eval_stats
)
else
:
eval_stats
=
None
...
...
@@ -846,7 +845,6 @@ if __name__ == "__main__":
if
update
%
args
.
eval_interval
==
0
:
eval_stats
=
np
.
mean
(
eval_stat_list
,
axis
=
0
)
print
(
eval_stats
)
eval_stats
=
jax
.
device_put
(
eval_stats
,
local_devices
[
0
])
eval_stats
=
np
.
array
(
all_reduce_value
(
eval_stats
[
None
])[
0
])
eval_time
,
eval_return
,
eval_win_rate
=
eval_stats
...
...
@@ -854,7 +852,6 @@ if __name__ == "__main__":
writer
.
add_scalar
(
f
"charts/eval_win_rate"
,
eval_win_rate
,
global_step
)
print
(
f
"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}"
)
rollout_queue_get_time
.
append
(
time
.
time
()
-
rollout_queue_get_time_start
)
training_time_start
=
time
.
time
()
(
agent_state
,
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
,
learner_keys
)
=
multi_device_update
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment