Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
78a3bc47
Commit
78a3bc47
authored
May 01, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Count global eval_stats
parent
81f63996
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
28 deletions
+46
-28
scripts/ppo.py
scripts/ppo.py
+46
-28
No files found.
scripts/ppo.py
View file @
78a3bc47
...
@@ -417,6 +417,25 @@ def rollout(
...
@@ -417,6 +417,25 @@ def rollout(
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
(
init_rstate1
,
init_rstate2
,
(
next_rstate
,
next_obs
),
next_main
))
(
init_rstate1
,
init_rstate2
,
(
next_rstate
,
next_obs
),
next_main
))
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
_start
=
time
.
time
()
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
if
device_thread_id
==
0
:
eval_time
=
time
.
time
()
-
_start
other_time
+=
eval_time
eval_stats
=
np
.
array
([
eval_time
,
eval_return
,
eval_win_rate
],
dtype
=
np
.
float32
)
print
(
eval_stats
)
else
:
eval_stats
=
None
learn_opponent
=
False
learn_opponent
=
False
payload
=
(
payload
=
(
global_step
,
global_step
,
...
@@ -425,6 +444,7 @@ def rollout(
...
@@ -425,6 +444,7 @@ def rollout(
*
sharded_data
,
*
sharded_data
,
np
.
mean
(
params_queue_get_time
),
np
.
mean
(
params_queue_get_time
),
learn_opponent
,
learn_opponent
,
eval_stats
,
)
)
rollout_queue
.
put
(
payload
)
rollout_queue
.
put
(
payload
)
...
@@ -451,34 +471,6 @@ def rollout(
...
@@ -451,34 +471,6 @@ def rollout(
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
writer
.
add_scalar
(
"charts/SPS_update"
,
SPS_update
,
global_step
)
writer
.
add_scalar
(
"charts/SPS_update"
,
SPS_update
,
global_step
)
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
# Eval with rule-based policy
_start
=
time
.
time
()
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
eval_stat
=
np
.
array
([
eval_return
,
eval_win_rate
])
if
device_thread_id
!=
0
:
eval_queue
.
put
(
eval_stat
)
else
:
eval_stats
=
[]
eval_stats
.
append
(
eval_stat
)
for
_
in
range
(
1
,
n_actors
):
eval_stats
.
append
(
eval_queue
.
get
())
eval_stats
=
np
.
stack
(
eval_stats
)
eval_return
,
eval_win_rate
=
np
.
mean
(
eval_stats
,
axis
=
0
)
writer
.
add_scalar
(
f
"charts/eval_return"
,
eval_return
,
global_step
)
writer
.
add_scalar
(
f
"charts/eval_win_rate"
,
eval_win_rate
,
global_step
)
if
device_thread_id
==
0
:
eval_time
=
time
.
time
()
-
_start
other_time
+=
eval_time
print
(
f
"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
tyro
.
cli
(
Args
)
args
=
tyro
.
cli
(
Args
)
...
@@ -525,6 +517,10 @@ if __name__ == "__main__":
...
@@ -525,6 +517,10 @@ if __name__ == "__main__":
for
process_index
in
range
(
args
.
world_size
)
for
process_index
in
range
(
args
.
world_size
)
for
d_id
in
args
.
learner_device_ids
for
d_id
in
args
.
learner_device_ids
]
]
global_main_devices
=
[
global_devices
[
process_index
*
len
(
local_devices
)]
for
process_index
in
range
(
args
.
world_size
)
]
print
(
"global_learner_decices"
,
global_learner_decices
)
print
(
"global_learner_decices"
,
global_learner_decices
)
args
.
global_learner_decices
=
[
str
(
item
)
for
item
in
global_learner_decices
]
args
.
global_learner_decices
=
[
str
(
item
)
for
item
in
global_learner_decices
]
args
.
actor_devices
=
[
str
(
item
)
for
item
in
actor_devices
]
args
.
actor_devices
=
[
str
(
item
)
for
item
in
actor_devices
]
...
@@ -788,6 +784,12 @@ if __name__ == "__main__":
...
@@ -788,6 +784,12 @@ if __name__ == "__main__":
approx_kl
=
jax
.
lax
.
pmean
(
approx_kl
,
axis_name
=
"local_devices"
)
.
mean
()
approx_kl
=
jax
.
lax
.
pmean
(
approx_kl
,
axis_name
=
"local_devices"
)
.
mean
()
return
agent_state
,
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
,
key
return
agent_state
,
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
,
key
all_reduce_value
=
jax
.
pmap
(
lambda
x
:
jax
.
lax
.
pmean
(
x
,
axis_name
=
"main_devices"
),
axis_name
=
"main_devices"
,
devices
=
global_main_devices
,
)
multi_device_update
=
jax
.
pmap
(
multi_device_update
=
jax
.
pmap
(
single_device_update
,
single_device_update
,
axis_name
=
"local_devices"
,
axis_name
=
"local_devices"
,
...
@@ -831,6 +833,7 @@ if __name__ == "__main__":
...
@@ -831,6 +833,7 @@ if __name__ == "__main__":
learner_policy_version
+=
1
learner_policy_version
+=
1
rollout_queue_get_time_start
=
time
.
time
()
rollout_queue_get_time_start
=
time
.
time
()
sharded_data_list
=
[]
sharded_data_list
=
[]
eval_stat_list
=
[]
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
for
thread_id
in
range
(
args
.
num_actor_threads
):
for
thread_id
in
range
(
args
.
num_actor_threads
):
(
(
...
@@ -839,8 +842,23 @@ if __name__ == "__main__":
...
@@ -839,8 +842,23 @@ if __name__ == "__main__":
*
sharded_data
,
*
sharded_data
,
avg_params_queue_get_time
,
avg_params_queue_get_time
,
learn_opponent
,
learn_opponent
,
eval_stats
,
)
=
rollout_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
get
()
)
=
rollout_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
get
()
sharded_data_list
.
append
(
sharded_data
)
sharded_data_list
.
append
(
sharded_data
)
if
eval_stats
is
not
None
:
eval_stat_list
.
append
(
eval_stats
)
if
update
%
args
.
eval_interval
==
0
:
eval_stats
=
np
.
mean
(
eval_stat_list
,
axis
=
0
)
print
(
eval_stats
)
eval_stats
=
jax
.
device_put
(
eval_stats
,
local_devices
[
0
])
eval_stats
=
np
.
array
(
all_reduce_value
(
eval_stats
[
None
])[
0
])
eval_time
,
eval_return
,
eval_win_rate
=
eval_stats
writer
.
add_scalar
(
f
"charts/eval_return"
,
eval_return
,
global_step
)
writer
.
add_scalar
(
f
"charts/eval_win_rate"
,
eval_win_rate
,
global_step
)
print
(
f
"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}"
)
rollout_queue_get_time
.
append
(
time
.
time
()
-
rollout_queue_get_time_start
)
rollout_queue_get_time
.
append
(
time
.
time
()
-
rollout_queue_get_time_start
)
training_time_start
=
time
.
time
()
training_time_start
=
time
.
time
()
(
agent_state
,
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
,
learner_keys
)
=
multi_device_update
(
(
agent_state
,
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
,
learner_keys
)
=
multi_device_update
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment