Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
45885f81
Commit
45885f81
authored
Apr 28, 2024
by
Biluo Shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add apply_if_finite in opt
parent
7f80e9a8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
28 deletions
+44
-28
scripts/impala.py
scripts/impala.py
+42
-26
scripts/ppo.py
scripts/ppo.py
+2
-2
No files found.
scripts/impala.py
View file @
45885f81
...
...
@@ -23,6 +23,7 @@ from rich.pretty import pprint
from
tensorboardX
import
SummaryWriter
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.ckpt
import
ModelCheckpoint
,
sync_to_gcs
,
zip_files
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
...
...
@@ -45,6 +46,13 @@ class Args:
checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to load"""
tb_dir
:
str
=
"runs"
"""the directory to save the tensorboard logs"""
ckpt_dir
:
str
=
"checkpoints"
"""the directory to save the model checkpoints"""
gcs_bucket
:
Optional
[
str
]
=
None
"""the GCS bucket to save the model checkpoints"""
# Algorithm specific arguments
env_id
:
str
=
"YGOPro-v0"
"""the id of the environment"""
...
...
@@ -151,7 +159,7 @@ class Args:
freeze_id
:
bool
=
False
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
):
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
,
eval
=
False
):
if
not
args
.
thread_affinity
:
thread_affinity_offset
=
-
1
if
thread_affinity_offset
>=
0
:
...
...
@@ -168,7 +176,7 @@ def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_off
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
async_reset
=
False
,
greedy_reward
=
args
.
greedy_reward
if
mode
==
'self'
else
True
,
greedy_reward
=
args
.
greedy_reward
if
not
eval
else
True
,
play_mode
=
mode
,
)
envs
.
num_envs
=
num_envs
...
...
@@ -231,7 +239,7 @@ def rollout(
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
)
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
,
eval
=
True
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
...
...
@@ -431,14 +439,13 @@ def rollout(
_start
=
time
.
time
()
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
eval_stat
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)[
0
]
metric_name
=
"eval_return"
eval_return
,
eval_ep_len
,
eval_win_rate
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_
stat
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
[
2
]
metric_name
=
"eval_win_rate"
eval_
return
,
eval_ep_len
,
eval_win_rate
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
eval_stat
=
np
.
array
([
eval_return
,
eval_win_rate
])
if
device_thread_id
!=
0
:
eval_queue
.
put
(
eval_stat
)
else
:
...
...
@@ -446,12 +453,14 @@ def rollout(
eval_stats
.
append
(
eval_stat
)
for
_
in
range
(
1
,
n_actors
):
eval_stats
.
append
(
eval_queue
.
get
())
eval_stats
=
np
.
mean
(
eval_stats
)
writer
.
add_scalar
(
f
"charts/{metric_name}"
,
eval_stats
,
global_step
)
eval_stats
=
np
.
stack
(
eval_stats
)
eval_return
,
eval_win_rate
=
np
.
mean
(
eval_stats
,
axis
=
0
)
writer
.
add_scalar
(
f
"charts/eval_return"
,
eval_return
,
global_step
)
writer
.
add_scalar
(
f
"charts/eval_win_rate"
,
eval_win_rate
,
global_step
)
if
device_thread_id
==
0
:
eval_time
=
time
.
time
()
-
_start
print
(
f
"eval_time={eval_time:.4f}, {metric_name}={eval_stats:.4f}"
)
other_time
+=
eval_time
print
(
f
"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}"
)
if
__name__
==
"__main__"
:
...
...
@@ -508,12 +517,21 @@ if __name__ == "__main__":
timestamp
=
int
(
time
.
time
())
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer
=
SummaryWriter
(
f
"runs/{run_name}"
)
tb_log_dir
=
f
"{args.tb_dir}/{run_name}"
writer
=
SummaryWriter
(
tb_log_dir
)
writer
.
add_text
(
"hyperparameters"
,
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
)
def
save_fn
(
obj
,
path
):
with
open
(
path
,
"wb"
)
as
f
:
f
.
write
(
flax
.
serialization
.
to_bytes
(
obj
))
ckpt_maneger
=
ModelCheckpoint
(
args
.
ckpt_dir
,
save_fn
,
n_saved
=
3
)
# seeding
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
...
...
@@ -559,12 +577,12 @@ if __name__ == "__main__":
),
every_k_schedule
=
1
,
)
tx
=
optax
.
apply_if_finite
(
tx
,
max_consecutive_errors
=
3
)
agent_state
=
TrainState
.
create
(
apply_fn
=
None
,
params
=
params
,
tx
=
tx
,
)
if
args
.
checkpoint
:
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
...
...
@@ -589,7 +607,7 @@ if __name__ == "__main__":
args
,
multi_step
=
True
)
.
apply
(
params
,
inputs
)
return
logits
,
value
.
squeeze
(
-
1
)
def
ppo_loss
(
def
loss_fn
(
params
,
rstate1
,
rstate2
,
obs
,
dones
,
mains
,
actions
,
logits
,
rewards
,
mask
,
next_value
,
next_done
):
# (num_steps * local_num_envs // n_mb))
...
...
@@ -663,7 +681,7 @@ if __name__ == "__main__":
# main first, opponent second
num_steps
,
num_envs
=
storage
.
rewards
.
shape
ppo_loss_grad_fn
=
jax
.
value_and_grad
(
ppo_loss
,
has_aux
=
True
)
loss_grad_fn
=
jax
.
value_and_grad
(
loss_fn
,
has_aux
=
True
)
def
update_epoch
(
carry
,
_
):
agent_state
,
key
=
carry
...
...
@@ -693,7 +711,7 @@ if __name__ == "__main__":
shuffled_mask
=
jnp
.
ones_like
(
shuffled_storage
.
mains
)
def
update_minibatch
(
agent_state
,
minibatch
):
(
loss
,
(
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)),
grads
=
ppo_
loss_grad_fn
(
(
loss
,
(
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)),
grads
=
loss_grad_fn
(
agent_state
.
params
,
*
minibatch
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
...
...
@@ -818,7 +836,7 @@ if __name__ == "__main__":
f
"data_time={rollout_queue_get_time[-1]:.2f}"
)
writer
.
add_scalar
(
"charts/learning_rate"
,
agent_state
.
opt_state
[
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
global_step
"charts/learning_rate"
,
agent_state
.
opt_state
[
3
][
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/policy_loss"
,
pg_loss
[
-
1
]
.
item
(),
global_step
)
...
...
@@ -827,15 +845,13 @@ if __name__ == "__main__":
writer
.
add_scalar
(
"losses/loss"
,
loss
,
global_step
)
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
:
ckpt_dir
=
f
"checkpoints"
os
.
makedirs
(
ckpt_dir
,
exist_ok
=
True
)
M_steps
=
args
.
batch_size
*
learner_policy_version
//
(
2
**
20
)
model_path
=
os
.
path
.
join
(
ckpt_dir
,
f
"{timestamp}_{M_steps}M.flax_model"
)
with
open
(
model_path
,
"wb"
)
as
f
:
f
.
write
(
flax
.
serialization
.
to_bytes
(
unreplicated_params
)
)
print
(
f
"model saved to {model_path}"
)
ckpt_name
=
f
"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger
.
save
(
unreplicated_params
,
ckpt_name
)
if
args
.
gcs_bucket
is
not
None
:
zip_file_path
=
"latest.zip"
zip_files
(
zip_file_path
,
[
ckpt_maneger
.
get_latest
(),
tb_log_dir
]
)
sync_to_gcs
(
args
.
gcs_bucket
,
zip_file_path
)
if
learner_policy_version
>=
args
.
num_updates
:
break
...
...
scripts/ppo.py
View file @
45885f81
...
...
@@ -587,12 +587,12 @@ if __name__ == "__main__":
),
every_k_schedule
=
1
,
)
tx
=
optax
.
apply_if_finite
(
tx
,
max_consecutive_errors
=
3
)
agent_state
=
TrainState
.
create
(
apply_fn
=
None
,
params
=
params
,
tx
=
tx
,
)
if
args
.
checkpoint
:
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
...
...
@@ -862,7 +862,7 @@ if __name__ == "__main__":
f
"data_time={rollout_queue_get_time[-1]:.2f}"
)
writer
.
add_scalar
(
"charts/learning_rate"
,
agent_state
.
opt_state
[
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
global_step
"charts/learning_rate"
,
agent_state
.
opt_state
[
3
][
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/policy_loss"
,
pg_loss
[
-
1
]
.
item
(),
global_step
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment