Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
330ee6af
Commit
330ee6af
authored
May 30, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add truncated LSTM
parent
cd59a6e9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
396 additions
and
1073 deletions
+396
-1073
scripts/battle.py
scripts/battle.py
+158
-78
scripts/cleanba.py
scripts/cleanba.py
+198
-49
scripts/eval.py
scripts/eval.py
+40
-30
scripts/ppo.py
scripts/ppo.py
+0
-916
No files found.
scripts/battle.py
View file @
330ee6af
...
@@ -26,8 +26,10 @@ class Args:
...
@@ -26,8 +26,10 @@ class Args:
seed
:
int
=
1
seed
:
int
=
1
"""the random seed"""
"""the random seed"""
env_id
:
str
=
"YGOPro-v1"
env_id1
:
str
=
"YGOPro-v1"
"""the id of the environment"""
"""the id of the environment1"""
env_id2
:
Optional
[
str
]
=
None
"""the id of the environment2, defaults to `env_id1`"""
deck
:
str
=
"../assets/deck"
deck
:
str
=
"../assets/deck"
"""the deck file to use"""
"""the deck file to use"""
deck1
:
Optional
[
str
]
=
None
deck1
:
Optional
[
str
]
=
None
...
@@ -40,10 +42,16 @@ class Args:
...
@@ -40,10 +42,16 @@ class Args:
"""the language to use"""
"""the language to use"""
max_options
:
int
=
24
max_options
:
int
=
24
"""the maximum number of options"""
"""the maximum number of options"""
n_history_actions
:
int
=
32
n_history_actions1
:
int
=
32
"""the number of history actions to use"""
"""the number of history actions to use for the environment1"""
n_history_actions2
:
Optional
[
int
]
=
None
"""the number of history actions to use for the environment2, defaults to `n_history_actions1`"""
num_embeddings
:
Optional
[
int
]
=
None
num_embeddings
:
Optional
[
int
]
=
None
"""the number of embeddings of the agent"""
"""the number of embeddings of the agent"""
accurate
:
bool
=
True
"""whether to do accurate evaluation. If not, there will be more short games"""
reverse
:
bool
=
False
"""whether to reverse the order of the agents"""
verbose
:
bool
=
False
verbose
:
bool
=
False
"""whether to print debug information"""
"""whether to print debug information"""
...
@@ -101,56 +109,94 @@ if __name__ == "__main__":
...
@@ -101,56 +109,94 @@ if __name__ == "__main__":
args
.
env_threads
=
min
(
args
.
env_threads
or
args
.
num_envs
,
args
.
num_envs
)
args
.
env_threads
=
min
(
args
.
env_threads
or
args
.
num_envs
,
args
.
num_envs
)
deck
=
init_ygopro
(
args
.
env_id
,
args
.
lang
,
args
.
deck
,
args
.
code_list_file
)
if
args
.
env_id2
is
None
:
args
.
env_id2
=
args
.
env_id1
if
args
.
n_history_actions2
is
None
:
args
.
n_history_actions2
=
args
.
n_history_actions1
args
.
deck1
=
args
.
deck1
or
deck
cross_env
=
args
.
env_id1
!=
args
.
env_id2
args
.
deck2
=
args
.
deck2
or
deck
env_id1
=
args
.
env_id1
env_id2
=
args
.
env_id2
seed
=
args
.
seed
deck1
=
init_ygopro
(
env_id1
,
args
.
lang
,
args
.
deck
,
args
.
code_list_file
)
if
not
cross_env
:
deck2
=
deck1
else
:
deck2
=
init_ygopro
(
env_id2
,
args
.
lang
,
args
.
deck
,
args
.
code_list_file
)
args
.
deck1
=
args
.
deck1
or
deck1
args
.
deck2
=
args
.
deck2
or
deck2
seed
=
args
.
seed
+
100000
random
.
seed
(
seed
)
seed
=
random
.
randint
(
0
,
1e8
)
random
.
seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
if
args
.
xla_device
is
not
None
:
if
args
.
xla_device
is
not
None
:
os
.
environ
.
setdefault
(
"JAX_PLATFORMS"
,
args
.
xla_device
)
os
.
environ
.
setdefault
(
"JAX_PLATFORMS"
,
args
.
xla_device
)
if
args
.
accurate
:
if
args
.
num_envs
!=
args
.
num_episodes
:
args
.
num_envs
=
args
.
num_episodes
print
(
"Set num_envs to num_episodes for accurate evaluation"
)
num_envs
=
args
.
num_envs
num_envs
=
args
.
num_envs
envs
=
ygoenv
.
make
(
env_option
=
dict
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
env_type
=
"gymnasium"
,
num_envs
=
num_envs
,
num_envs
=
num_envs
,
num_threads
=
args
.
env_threads
,
num_threads
=
args
.
env_threads
,
seed
=
seed
,
seed
=
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
player
=-
1
,
player
=-
1
,
max_options
=
args
.
max_options
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
'self'
,
play_mode
=
'self'
,
async_reset
=
False
,
async_reset
=
False
,
verbose
=
args
.
verbose
,
verbose
=
args
.
verbose
,
record
=
args
.
record
,
record
=
args
.
record
,
)
)
obs_space
=
envs
.
observation_space
envs1
=
ygoenv
.
make
(
envs
.
num_envs
=
num_envs
task_id
=
env_id1
,
envs
=
RecordEpisodeStatistics
(
envs
)
n_history_actions
=
args
.
n_history_actions1
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
**
env_option
,
)
if
cross_env
:
envs2
=
ygoenv
.
make
(
task_id
=
env_id2
,
n_history_actions
=
args
.
n_history_actions2
,
deck1
=
deck2
,
deck2
=
deck2
,
**
env_option
,
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
=
jax
.
random
.
PRNGKey
(
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
obs_space1
=
envs1
.
observation_space
envs1
.
num_envs
=
num_envs
envs1
=
RecordEpisodeStatistics
(
envs1
)
sample_obs1
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space1
.
sample
())
agent1
=
create_agent1
(
args
)
agent1
=
create_agent1
(
args
)
rstate
=
agent1
.
init_rnn_state
(
1
)
rstate1
=
agent1
.
init_rnn_state
(
1
)
params1
=
jax
.
jit
(
agent1
.
init
)(
agent_key
,
sample_obs
,
rstate
)
params1
=
jax
.
jit
(
agent1
.
init
)(
key
,
sample_obs1
,
rstate1
)
with
open
(
args
.
checkpoint1
,
"rb"
)
as
f
:
with
open
(
args
.
checkpoint1
,
"rb"
)
as
f
:
params1
=
flax
.
serialization
.
from_bytes
(
params1
,
f
.
read
())
params1
=
flax
.
serialization
.
from_bytes
(
params1
,
f
.
read
())
if
cross_env
:
obs_space2
=
envs2
.
observation_space
envs2
.
num_envs
=
num_envs
envs2
=
RecordEpisodeStatistics
(
envs2
)
sample_obs2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space2
.
sample
())
else
:
sample_obs2
=
sample_obs1
if
args
.
checkpoint1
==
args
.
checkpoint2
:
if
args
.
checkpoint1
==
args
.
checkpoint2
:
params2
=
params1
params2
=
params1
else
:
else
:
agent2
=
create_agent2
(
args
)
agent2
=
create_agent2
(
args
)
rstate
=
agent2
.
init_rnn_state
(
1
)
rstate
2
=
agent2
.
init_rnn_state
(
1
)
params2
=
jax
.
jit
(
agent2
.
init
)(
agent_key
,
sample_obs
,
rstate
)
params2
=
jax
.
jit
(
agent2
.
init
)(
key
,
sample_obs2
,
rstate2
)
with
open
(
args
.
checkpoint2
,
"rb"
)
as
f
:
with
open
(
args
.
checkpoint2
,
"rb"
)
as
f
:
params2
=
flax
.
serialization
.
from_bytes
(
params2
,
f
.
read
())
params2
=
flax
.
serialization
.
from_bytes
(
params2
,
f
.
read
())
...
@@ -169,11 +215,11 @@ if __name__ == "__main__":
...
@@ -169,11 +215,11 @@ if __name__ == "__main__":
next_rstate
=
jnp
.
where
(
done
[:,
None
],
0
,
next_rstate
)
next_rstate
=
jnp
.
where
(
done
[:,
None
],
0
,
next_rstate
)
return
next_rstate
,
probs
return
next_rstate
,
probs
if
args
.
num_envs
!=
1
:
if
num_envs
!=
1
:
@
jax
.
jit
@
jax
.
jit
def
get_probs2
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
def
get_probs2
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
1
,
obs2
,
main
,
done
):
next_rstate1
,
probs1
=
get_probs
(
params1
,
rstate1
,
obs
,
None
,
1
)
next_rstate1
,
probs1
=
get_probs
(
params1
,
rstate1
,
obs
1
,
None
,
1
)
next_rstate2
,
probs2
=
get_probs
(
params2
,
rstate2
,
obs
,
None
,
2
)
next_rstate2
,
probs2
=
get_probs
(
params2
,
rstate2
,
obs
2
,
None
,
2
)
probs
=
jnp
.
where
(
main
[:,
None
],
probs1
,
probs2
)
probs
=
jnp
.
where
(
main
[:,
None
],
probs1
,
probs2
)
rstate1
=
jax
.
tree
.
map
(
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
...
@@ -183,19 +229,26 @@ if __name__ == "__main__":
...
@@ -183,19 +229,26 @@ if __name__ == "__main__":
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
return
rstate1
,
rstate2
,
probs
return
rstate1
,
rstate2
,
probs
def
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
done
):
def
predict_fn
(
rstate1
,
rstate2
,
obs
1
,
obs2
,
main
,
done
):
rstate1
,
rstate2
,
probs
=
get_probs2
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
)
rstate1
,
rstate2
,
probs
=
get_probs2
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
1
,
obs2
,
main
,
done
)
return
rstate1
,
rstate2
,
np
.
array
(
probs
)
return
rstate1
,
rstate2
,
np
.
array
(
probs
)
else
:
else
:
def
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
done
):
def
predict_fn
(
rstate1
,
rstate2
,
obs
1
,
obs2
,
main
,
done
):
if
main
[
0
]:
if
main
[
0
]:
rstate1
,
probs
=
get_probs
(
params1
,
rstate1
,
obs
,
done
,
1
)
rstate1
,
probs
=
get_probs
(
params1
,
rstate1
,
obs
1
,
done
,
1
)
else
:
else
:
rstate2
,
probs
=
get_probs
(
params2
,
rstate2
,
obs
,
done
,
2
)
rstate2
,
probs
=
get_probs
(
params2
,
rstate2
,
obs
2
,
done
,
2
)
return
rstate1
,
rstate2
,
np
.
array
(
probs
)
return
rstate1
,
rstate2
,
np
.
array
(
probs
)
obs
,
infos
=
envs
.
reset
()
obs1
,
infos1
=
envs1
.
reset
()
next_to_play
=
infos
[
'to_play'
]
next_to_play1
=
infos1
[
'to_play'
]
if
cross_env
:
obs2
,
infos2
=
envs2
.
reset
()
next_to_play2
=
infos2
[
'to_play'
]
np
.
testing
.
assert_array_equal
(
next_to_play1
,
next_to_play2
)
else
:
obs2
=
obs1
dones
=
np
.
zeros
(
num_envs
,
dtype
=
np
.
bool_
)
dones
=
np
.
zeros
(
num_envs
,
dtype
=
np
.
bool_
)
episode_rewards
=
[]
episode_rewards
=
[]
...
@@ -209,12 +262,17 @@ if __name__ == "__main__":
...
@@ -209,12 +262,17 @@ if __name__ == "__main__":
start
=
time
.
time
()
start
=
time
.
time
()
start_step
=
step
start_step
=
step
main_player
=
np
.
concatenate
([
first_player
=
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
)
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
),
second_player
=
np
.
ones
(
num_envs
-
num_envs
//
2
,
dtype
=
np
.
int64
)
np
.
ones
(
num_envs
-
num_envs
//
2
,
dtype
=
np
.
int64
)
if
args
.
reverse
:
])
main_player
=
np
.
concatenate
([
second_player
,
first_player
])
else
:
main_player
=
np
.
concatenate
([
first_player
,
second_player
])
# main_player = np.zeros(num_envs, dtype=np.int64)
# main_player = np.ones(num_envs, dtype=np.int64)
rstate1
=
agent1
.
init_rnn_state
(
num_envs
)
rstate1
=
agent1
.
init_rnn_state
(
num_envs
)
rstate2
=
agent2
.
init_rnn_state
(
num_envs
)
rstate2
=
agent2
.
init_rnn_state
(
num_envs
)
collected
=
np
.
zeros
((
args
.
num_episodes
,),
dtype
=
np
.
bool_
)
if
not
args
.
verbose
:
if
not
args
.
verbose
:
pbar
=
tqdm
(
total
=
args
.
num_episodes
)
pbar
=
tqdm
(
total
=
args
.
num_episodes
)
...
@@ -227,35 +285,45 @@ if __name__ == "__main__":
...
@@ -227,35 +285,45 @@ if __name__ == "__main__":
model_time
=
env_time
=
0
model_time
=
env_time
=
0
_start
=
time
.
time
()
_start
=
time
.
time
()
main
=
next_to_play
==
main_player
main
=
next_to_play
1
==
main_player
rstate1
,
rstate2
,
probs
=
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
dones
)
rstate1
,
rstate2
,
probs
=
predict_fn
(
rstate1
,
rstate2
,
obs
1
,
obs2
,
main
,
dones
)
actions
=
probs
.
argmax
(
axis
=
1
)
actions
=
probs
.
argmax
(
axis
=
1
)
model_time
+=
time
.
time
()
-
_start
model_time
+=
time
.
time
()
-
_start
to_play
=
next_to_play
to_play
1
=
next_to_play1
_start
=
time
.
time
()
_start
=
time
.
time
()
obs
,
rewards
,
dones
,
infos
=
envs
.
step
(
actions
)
obs1
,
rewards1
,
dones1
,
infos1
=
envs1
.
step
(
actions
)
next_to_play
=
infos
[
'to_play'
]
next_to_play1
=
infos1
[
'to_play'
]
if
cross_env
:
obs2
,
rewards2
,
dones2
,
infos2
=
envs2
.
step
(
actions
)
next_to_play2
=
infos2
[
'to_play'
]
np
.
testing
.
assert_array_equal
(
next_to_play1
,
next_to_play2
)
np
.
testing
.
assert_array_equal
(
dones1
,
dones2
)
else
:
obs2
=
obs1
env_time
+=
time
.
time
()
-
_start
env_time
+=
time
.
time
()
-
_start
step
+=
1
step
+=
1
for
idx
,
d
in
enumerate
(
dones
):
for
idx
,
d
in
enumerate
(
dones1
):
if
d
:
if
not
d
or
(
args
.
accurate
and
collected
[
idx
]):
win_reason
=
infos
[
'win_reason'
][
idx
]
continue
pl
=
1
if
to_play
[
idx
]
==
main_player
[
idx
]
else
-
1
collected
[
idx
]
=
True
episode_length
=
infos
[
'l'
][
idx
]
win_reason
=
infos1
[
'win_reason'
][
idx
]
episode_reward
=
infos
[
'r'
][
idx
]
pl
=
1
if
main
[
idx
]
else
-
1
episode_length
=
infos1
[
'l'
][
idx
]
episode_reward
=
infos1
[
'r'
][
idx
]
main_reward
=
episode_reward
*
pl
main_reward
=
episode_reward
*
pl
win
=
int
(
main_reward
>
0
)
win
=
int
(
main_reward
>
0
)
win_player
=
0
if
(
to_play
[
idx
]
==
0
and
episode_reward
>
0
)
or
(
to_play
[
idx
]
==
1
and
episode_reward
<
0
)
else
1
win_player
=
0
if
(
to_play1
[
idx
]
==
0
and
episode_reward
>
0
)
or
(
to_play1
[
idx
]
==
1
and
episode_reward
<
0
)
else
1
win_players
.
append
(
win_player
)
win_players
.
append
(
win_player
)
win_agent
=
1
if
main_reward
>
0
else
2
win_agent
=
1
if
main_reward
>
0
else
2
win_agents
.
append
(
win_agent
)
win_agents
.
append
(
win_agent
)
# print(f"{len(episode_lengths)}: {episode_length}, {main_reward}")
episode_lengths
.
append
(
episode_length
)
episode_lengths
.
append
(
episode_length
)
episode_rewards
.
append
(
main_reward
)
episode_rewards
.
append
(
main_reward
)
win_rates
.
append
(
win
)
win_rates
.
append
(
win
)
...
@@ -269,6 +337,8 @@ if __name__ == "__main__":
...
@@ -269,6 +337,8 @@ if __name__ == "__main__":
# Only when num_envs=1, we switch the player here
# Only when num_envs=1, we switch the player here
if
args
.
verbose
:
if
args
.
verbose
:
main_player
=
1
-
main_player
main_player
=
1
-
main_player
else
:
main_player
[
idx
]
=
1
-
main_player
[
idx
]
if
len
(
episode_lengths
)
>=
args
.
num_episodes
:
if
len
(
episode_lengths
)
>=
args
.
num_episodes
:
break
break
...
@@ -277,14 +347,17 @@ if __name__ == "__main__":
...
@@ -277,14 +347,17 @@ if __name__ == "__main__":
pbar
.
close
()
pbar
.
close
()
print
(
f
"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}"
)
print
(
f
"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}"
)
episode_lengths
=
np
.
array
(
episode_lengths
)
win_players
=
np
.
array
(
win_players
)
win_players
=
np
.
array
(
win_players
)
win_agents
=
np
.
array
(
win_agents
)
win_agents
=
np
.
array
(
win_agents
)
N
=
len
(
win_players
)
N
=
len
(
win_players
)
N1
=
np
.
sum
((
win_players
==
0
)
&
(
win_agents
==
1
))
mask1
=
(
win_players
==
0
)
&
(
win_agents
==
1
)
N2
=
np
.
sum
((
win_players
==
0
)
&
(
win_agents
==
2
))
mask2
=
(
win_players
==
0
)
&
(
win_agents
==
2
)
N3
=
np
.
sum
((
win_players
==
1
)
&
(
win_agents
==
1
))
mask3
=
(
win_players
==
1
)
&
(
win_agents
==
1
)
N4
=
np
.
sum
((
win_players
==
1
)
&
(
win_agents
==
2
))
mask4
=
(
win_players
==
1
)
&
(
win_agents
==
2
)
N1
,
N2
,
N3
,
N4
=
[
np
.
sum
(
m
)
for
m
in
[
mask1
,
mask2
,
mask3
,
mask4
]]
print
(
f
"Payoff matrix:"
)
print
(
f
"Payoff matrix:"
)
w1
=
N1
/
N
w1
=
N1
/
N
...
@@ -304,6 +377,13 @@ if __name__ == "__main__":
...
@@ -304,6 +377,13 @@ if __name__ == "__main__":
print
(
f
"0 {w1:.4f} {w2:.4f}"
)
print
(
f
"0 {w1:.4f} {w2:.4f}"
)
print
(
f
"1 {w3:.4f} {w4:.4f}"
)
print
(
f
"1 {w3:.4f} {w4:.4f}"
)
print
(
f
"Length matrix, length of games of agentX as playerY"
)
l1
=
np
.
mean
(
episode_lengths
[
mask1
|
mask4
])
l2
=
np
.
mean
(
episode_lengths
[
mask2
|
mask3
])
print
(
f
" agent1 agent2"
)
print
(
f
"0 {l1:3.2f} {l2:3.2f}"
)
print
(
f
"1 {l2:3.2f} {l1:3.2f}"
)
total_time
=
time
.
time
()
-
start
total_time
=
time
.
time
()
-
start
total_steps
=
(
step
-
start_step
)
*
num_envs
total_steps
=
(
step
-
start_step
)
*
num_envs
print
(
f
"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}"
)
print
(
f
"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}"
)
...
...
scripts/cleanba.py
View file @
330ee6af
...
@@ -95,6 +95,8 @@ class Args:
...
@@ -95,6 +95,8 @@ class Args:
"""the number of actor threads to use"""
"""the number of actor threads to use"""
num_steps
:
int
=
128
num_steps
:
int
=
128
"""the number of steps to run in each environment per policy rollout"""
"""the number of steps to run in each environment per policy rollout"""
segment_length
:
Optional
[
int
]
=
None
"""the length of the segment for training"""
anneal_lr
:
bool
=
False
anneal_lr
:
bool
=
False
"""Toggle learning rate annealing for policy and value networks"""
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
1.0
gamma
:
float
=
1.0
...
@@ -247,6 +249,53 @@ def init_rnn_state(num_envs, rnn_channels):
...
@@ -247,6 +249,53 @@ def init_rnn_state(num_envs, rnn_channels):
)
)
def
reshape_minibatch
(
x
,
multi_step
,
num_minibatches
,
num_steps
,
segment_length
=
None
,
key
=
None
):
# if segment_length is None,
# n_mb = num_minibatches
# if multi_step, from (num_steps, num_envs, ...)) to
# (n_mb, num_steps * (num_envs // n_mb), ...)
# else, from (num_envs, ...) to
# (n_mb, num_envs // n_mb, ...)
# else,
# n_mb_t = num_steps // segment_length
# n_mb_e = num_minibatches // num_minibatches1
# if multi_step, from (num_steps, num_envs, ...)) to
# (n_mb_e, n_mb_t, segment_length * (num_envs // n_mb_e), ...)
# else, from (num_envs, ...) to
# (n_mb_e, num_envs // n_mb_e, ...)
if
key
is
not
None
:
x
=
jax
.
random
.
permutation
(
key
,
x
,
axis
=
1
if
multi_step
else
0
)
N
=
num_minibatches
if
segment_length
is
None
:
if
multi_step
:
x
=
jnp
.
reshape
(
x
,
(
num_steps
,
N
,
-
1
)
+
x
.
shape
[
2
:])
x
=
x
.
transpose
(
1
,
0
,
*
range
(
2
,
x
.
ndim
))
x
=
x
.
reshape
(
N
,
-
1
,
*
x
.
shape
[
3
:])
else
:
x
=
jnp
.
reshape
(
x
,
(
N
,
-
1
)
+
x
.
shape
[
1
:])
else
:
M
=
segment_length
Nt
=
num_steps
//
M
Ne
=
N
//
Nt
if
multi_step
:
x
=
jnp
.
reshape
(
x
,
(
Nt
,
M
,
Ne
,
-
1
)
+
x
.
shape
[
2
:])
x
=
x
.
transpose
(
2
,
0
,
1
,
*
range
(
3
,
x
.
ndim
))
x
=
jnp
.
reshape
(
x
,
(
Ne
,
Nt
,
-
1
)
+
x
.
shape
[
4
:])
else
:
x
=
jnp
.
reshape
(
x
,
(
Ne
,
-
1
)
+
x
.
shape
[
1
:])
return
x
def
reshape_batch
(
x
,
num_minibatches
,
num_steps
,
segment_length
=
None
):
N
=
num_minibatches
x
=
jnp
.
reshape
(
x
,
(
N
,
num_steps
,
-
1
)
+
x
.
shape
[
2
:])
x
=
x
.
transpose
(
1
,
0
,
*
range
(
2
,
x
.
ndim
))
x
=
jnp
.
reshape
(
x
,
(
num_steps
,
-
1
)
+
x
.
shape
[
3
:])
return
x
def
rollout
(
def
rollout
(
key
:
jax
.
random
.
PRNGKey
,
key
:
jax
.
random
.
PRNGKey
,
args
:
Args
,
args
:
Args
,
...
@@ -539,6 +588,8 @@ def main():
...
@@ -539,6 +588,8 @@ def main():
args
.
minibatch_size
=
args
.
local_minibatch_size
*
args
.
world_size
args
.
minibatch_size
=
args
.
local_minibatch_size
*
args
.
world_size
args
.
num_updates
=
args
.
total_timesteps
//
(
args
.
local_batch_size
*
args
.
world_size
)
args
.
num_updates
=
args
.
total_timesteps
//
(
args
.
local_batch_size
*
args
.
world_size
)
args
.
local_env_threads
=
args
.
local_env_threads
or
args
.
local_num_envs
args
.
local_env_threads
=
args
.
local_env_threads
or
args
.
local_num_envs
if
args
.
segment_length
is
not
None
:
assert
args
.
num_steps
%
args
.
segment_length
==
0
,
"num_steps must be divisible by segment_length"
if
args
.
embedding_file
:
if
args
.
embedding_file
:
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
...
@@ -675,29 +726,17 @@ def main():
...
@@ -675,29 +726,17 @@ def main():
else
:
else
:
eval_params
=
None
eval_params
=
None
def
loss_fn
(
def
advantage_fn
(
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
switch_or_mains
,
actions
,
logits
,
rewards
,
mask
,
next_value
):
actions
,
logits
,
rewards
,
next_value
):
# (num_steps * local_num_envs // n_mb))
num_envs
=
next_value
.
shape
[
0
]
num_envs
=
next_value
.
shape
[
0
]
num_steps
=
dones
.
shape
[
0
]
//
num_envs
num_steps
=
next_
dones
.
shape
[
0
]
//
num_envs
def
reshape_time_series
(
x
):
def
reshape_time_series
(
x
):
return
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)
+
x
.
shape
[
1
:])
return
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)
+
x
.
shape
[
1
:])
mask
=
mask
*
(
1.0
-
dones
)
if
args
.
switch
:
dones
=
dones
|
next_dones
new_logits
,
new_values
=
create_agent
(
args
)
.
apply
(
params
,
obs
,
(
rstate1
,
rstate2
),
dones
,
switch_or_mains
)[
1
:
3
]
new_values
=
new_values
.
squeeze
(
-
1
)
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
logratio
=
jnp
.
log
(
ratios
)
approx_kl
=
(
ratios
-
1
)
-
logratio
new_values_
,
rewards
,
next_dones
,
switch_or_mains
=
jax
.
tree
.
map
(
new_values_
,
rewards
,
next_dones
,
switch_or_mains
=
jax
.
tree
.
map
(
reshape_time_series
,
(
new_values
,
rewards
,
next_dones
,
switch_or_mains
),
reshape_time_series
,
(
new_values
,
rewards
,
next_dones
,
switch_or_mains
),
...
@@ -717,6 +756,15 @@ def main():
...
@@ -717,6 +756,15 @@ def main():
target_values
,
advantages
=
jax
.
tree
.
map
(
target_values
,
advantages
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)),
(
target_values
,
advantages
))
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)),
(
target_values
,
advantages
))
return
target_values
,
advantages
def
loss_fn
(
new_logits
,
new_values
,
actions
,
logits
,
target_values
,
advantages
,
mask
,
num_steps
=
None
):
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
logratio
=
jnp
.
log
(
ratios
)
approx_kl
=
(
ratios
-
1
)
-
logratio
if
args
.
norm_adv
:
if
args
.
norm_adv
:
advantages
=
masked_normalize
(
advantages
,
mask
,
eps
=
1e-8
)
advantages
=
masked_normalize
(
advantages
,
mask
,
eps
=
1e-8
)
...
@@ -743,7 +791,7 @@ def main():
...
@@ -743,7 +791,7 @@ def main():
if
args
.
burn_in_steps
:
if
args
.
burn_in_steps
:
mask
=
jax
.
tree
.
map
(
mask
=
jax
.
tree
.
map
(
lambda
x
:
x
.
reshape
(
num_steps
,
num_envs
),
mask
)
lambda
x
:
x
.
reshape
(
num_steps
,
-
1
),
mask
)
burn_in_mask
=
jnp
.
arange
(
num_steps
)
<
args
.
burn_in_steps
burn_in_mask
=
jnp
.
arange
(
num_steps
)
<
args
.
burn_in_steps
mask
=
jnp
.
where
(
burn_in_mask
[:,
None
],
0.0
,
mask
)
mask
=
jnp
.
where
(
burn_in_mask
[:,
None
],
0.0
,
mask
)
mask
=
jnp
.
reshape
(
mask
,
(
-
1
,))
mask
=
jnp
.
reshape
(
mask
,
(
-
1
,))
...
@@ -754,7 +802,57 @@ def main():
...
@@ -754,7 +802,57 @@ def main():
loss
=
pg_loss
-
args
.
ent_coef
*
ent_loss
+
v_loss
*
args
.
vf_coef
loss
=
pg_loss
-
args
.
ent_coef
*
ent_loss
+
v_loss
*
args
.
vf_coef
loss
=
jnp
.
where
(
jnp
.
isnan
(
loss
)
|
jnp
.
isinf
(
loss
),
0.0
,
loss
)
loss
=
jnp
.
where
(
jnp
.
isnan
(
loss
)
|
jnp
.
isinf
(
loss
),
0.0
,
loss
)
return
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
jax
.
lax
.
stop_gradient
(
approx_kl
))
return
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
def
apply_fn
(
params
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
):
if
args
.
switch
:
dones
=
dones
|
next_dones
(
rstate1
,
rstate2
),
new_logits
,
new_values
=
create_agent
(
args
)
.
apply
(
params
,
obs
,
(
rstate1
,
rstate2
),
dones
,
switch_or_mains
)[:
3
]
new_values
=
new_values
.
squeeze
(
-
1
)
return
(
rstate1
,
rstate2
),
new_logits
,
new_values
def
compute_advantage
(
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
next_value
):
new_logits
,
new_values
=
apply_fn
(
params
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
)[
1
:
3
]
target_values
,
advantages
=
advantage_fn
(
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
next_value
)
return
target_values
,
advantages
def
compute_loss
(
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
target_values
,
advantages
,
mask
):
(
rstate1
,
rstate2
),
new_logits
,
new_values
=
apply_fn
(
params
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
)
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
=
loss_fn
(
new_logits
,
new_values
,
actions
,
logits
,
target_values
,
advantages
,
mask
,
num_steps
=
None
)
approx_kl
,
rstate1
,
rstate2
=
jax
.
tree
.
map
(
jax
.
lax
.
stop_gradient
,
(
approx_kl
,
rstate1
,
rstate2
))
return
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
,
rstate1
,
rstate2
)
def
compute_advantage_loss
(
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
next_value
,
mask
):
new_logits
,
new_values
=
apply_fn
(
params
,
obs
,
rstate1
,
rstate2
,
dones
,
next_dones
,
switch_or_mains
)[
1
:
3
]
target_values
,
advantages
=
advantage_fn
(
new_logits
,
new_values
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
next_value
)
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
=
loss_fn
(
new_logits
,
new_values
,
actions
,
logits
,
target_values
,
advantages
,
mask
,
num_steps
=
dones
.
shape
[
0
]
//
next_value
.
shape
[
0
])
approx_kl
=
jax
.
lax
.
stop_gradient
(
approx_kl
)
return
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
def
single_device_update
(
def
single_device_update
(
agent_state
:
TrainState
,
agent_state
:
TrainState
,
...
@@ -785,7 +883,45 @@ def main():
...
@@ -785,7 +883,45 @@ def main():
switch
=
T
[:,
None
]
==
(
switch_steps
[
None
,
:]
-
1
)
switch
=
T
[:,
None
]
==
(
switch_steps
[
None
,
:]
-
1
)
storage
=
jax
.
tree
.
map
(
lambda
x
:
x
[
indices
,
B
[
None
,
:]],
storage
)
storage
=
jax
.
tree
.
map
(
lambda
x
:
x
[
indices
,
B
[
None
,
:]],
storage
)
loss_grad_fn
=
jax
.
value_and_grad
(
loss_fn
,
has_aux
=
True
)
if
args
.
segment_length
is
None
:
loss_grad_fn
=
jax
.
value_and_grad
(
compute_advantage_loss
,
has_aux
=
True
)
else
:
loss_grad_fn
=
jax
.
value_and_grad
(
compute_loss
,
has_aux
=
True
)
def
compute_advantage_t
(
next_value
):
N
=
args
.
num_minibatches
//
4
def
convert_data1
(
x
:
jnp
.
ndarray
,
multi_step
=
True
):
return
reshape_minibatch
(
x
,
multi_step
,
N
,
num_steps
)
b_init_rstate1
,
b_init_rstate2
,
b_next_value
=
jax
.
tree
.
map
(
partial
(
convert_data1
,
multi_step
=
False
),
(
init_rstate1
,
init_rstate2
,
next_value
))
b_storage
=
jax
.
tree
.
map
(
convert_data1
,
storage
)
if
args
.
switch
:
b_switch_or_mains
=
convert_data1
(
switch
)
else
:
b_switch_or_mains
=
b_storage
.
mains
target_values
,
advantages
=
jax
.
lax
.
scan
(
lambda
x
,
y
:
(
x
,
compute_advantage
(
x
,
*
y
)),
agent_state
.
params
,
(
b_init_rstate1
,
b_init_rstate2
,
b_storage
.
obs
,
b_storage
.
dones
,
b_storage
.
next_dones
,
b_switch_or_mains
,
b_storage
.
actions
,
b_storage
.
logits
,
b_storage
.
rewards
,
b_next_value
,
))[
1
]
print
(
jax
.
tree
.
map
(
lambda
x
:
x
.
shape
,
(
b_storage
.
dones
,
target_values
,
advantages
)))
target_values
,
advantages
=
jax
.
tree
.
map
(
partial
(
reshape_batch
,
num_minibatches
=
N
,
num_steps
=
num_steps
),
(
target_values
,
advantages
))
return
target_values
,
advantages
def
update_epoch
(
carry
,
_
):
def
update_epoch
(
carry
,
_
):
agent_state
,
key
=
carry
agent_state
,
key
=
carry
...
@@ -798,35 +934,50 @@ def main():
...
@@ -798,35 +934,50 @@ def main():
else
:
else
:
next_value
=
jnp
.
where
(
next_main
,
next_value
,
-
next_value
)
next_value
=
jnp
.
where
(
next_main
,
next_value
,
-
next_value
)
def
convert_data
(
x
:
jnp
.
ndarray
,
num_steps
):
def
convert_data
(
x
:
jnp
.
ndarray
,
multi_step
=
True
):
if
args
.
update_epochs
>
1
:
key
=
subkey
if
args
.
update_epochs
>
1
else
None
x
=
jax
.
random
.
permutation
(
subkey
,
x
,
axis
=
1
if
num_steps
>
1
else
0
)
return
reshape_minibatch
(
N
=
args
.
num_minibatches
x
,
multi_step
,
args
.
num_minibatches
,
num_steps
,
args
.
segment_length
,
key
=
key
)
if
num_steps
>
1
:
x
=
jnp
.
reshape
(
x
,
(
num_steps
,
N
,
-
1
)
+
x
.
shape
[
2
:])
x
=
x
.
transpose
(
1
,
0
,
*
range
(
2
,
x
.
ndim
))
x
=
x
.
reshape
(
N
,
-
1
,
*
x
.
shape
[
3
:])
else
:
x
=
jnp
.
reshape
(
x
,
(
N
,
-
1
)
+
x
.
shape
[
1
:])
return
x
shuffled_init_rstate1
,
shuffled_init_rstate2
,
\
shuffled_init_rstate1
,
shuffled_init_rstate2
=
jax
.
tree
.
map
(
shuffled_next_value
=
jax
.
tree
.
map
(
partial
(
convert_data
,
multi_step
=
False
),
(
init_rstate1
,
init_rstate2
))
partial
(
convert_data
,
num_steps
=
1
),
(
init_rstate1
,
init_rstate2
,
next_value
))
shuffled_storage
=
jax
.
tree
.
map
(
convert_data
,
storage
)
shuffled_storage
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
num_steps
),
storage
)
if
args
.
switch
:
if
args
.
switch
:
switch_or_mains
=
convert_data
(
switch
,
num_steps
)
switch_or_mains
=
convert_data
(
switch
)
else
:
else
:
switch_or_mains
=
shuffled_storage
.
mains
switch_or_mains
=
shuffled_storage
.
mains
shuffled_mask
=
jnp
.
ones_like
(
shuffled_storage
.
mains
)
shuffled_mask
=
~
shuffled_storage
.
dones
if
args
.
segment_length
is
None
:
shuffled_next_value
=
convert_data
(
next_value
,
multi_step
=
False
)
others
=
shuffled_storage
.
rewards
,
shuffled_next_value
,
shuffled_mask
def
update_minibatch
(
agent_state
,
minibatch
):
def
update_minibatch
(
agent_state
,
minibatch
):
(
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)),
grads
=
loss_grad_fn
(
(
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)),
grads
=
loss_grad_fn
(
agent_state
.
params
,
*
minibatch
)
agent_state
.
params
,
*
minibatch
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
return
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
return
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
else
:
target_values
,
advantages
=
compute_advantage_t
(
next_value
)
shuffled_target_values
,
shuffled_advantages
=
jax
.
tree
.
map
(
convert_data
,
(
target_values
,
advantages
))
others
=
shuffled_target_values
,
shuffled_advantages
,
shuffled_mask
def
update_minibatch
(
agent_state
,
minibatch
):
def
update_minibatch_t
(
carry
,
minibatch_t
):
agent_state
,
rstate1
,
rstate2
=
carry
minibatch_t
=
rstate1
,
rstate2
,
*
minibatch_t
(
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
,
rstate1
,
rstate2
)),
\
grads
=
loss_grad_fn
(
agent_state
.
params
,
*
minibatch_t
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
return
(
agent_state
,
rstate1
,
rstate2
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
rstate1
,
rstate2
,
*
minibatch_t
=
minibatch
(
agent_state
,
_rstate1
,
_rstate2
),
\
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_minibatch_t
,
(
agent_state
,
rstate1
,
rstate2
),
minibatch_t
)
return
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_minibatch
,
update_minibatch
,
...
@@ -840,9 +991,7 @@ def main():
...
@@ -840,9 +991,7 @@ def main():
switch_or_mains
,
switch_or_mains
,
shuffled_storage
.
actions
,
shuffled_storage
.
actions
,
shuffled_storage
.
logits
,
shuffled_storage
.
logits
,
shuffled_storage
.
rewards
,
*
others
,
shuffled_mask
,
shuffled_next_value
,
),
),
)
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
...
...
scripts/eval.py
View file @
330ee6af
...
@@ -3,7 +3,7 @@ import time
...
@@ -3,7 +3,7 @@ import time
import
os
import
os
import
random
import
random
from
typing
import
Optional
,
Literal
from
typing
import
Optional
,
Literal
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
,
asdict
import
ygoenv
import
ygoenv
import
numpy
as
np
import
numpy
as
np
...
@@ -12,6 +12,7 @@ import tyro
...
@@ -12,6 +12,7 @@ import tyro
from
ygoai.utils
import
init_ygopro
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.jax.agent
import
RNNAgent
,
ModelArgs
@
dataclass
@
dataclass
...
@@ -57,14 +58,8 @@ class Args:
...
@@ -57,14 +58,8 @@ class Args:
strategy
:
Literal
[
"random"
,
"greedy"
]
=
"greedy"
strategy
:
Literal
[
"random"
,
"greedy"
]
=
"greedy"
"""the strategy to use if agent is not used"""
"""the strategy to use if agent is not used"""
num_layers
:
int
=
2
m
:
ModelArgs
=
field
(
default_factory
=
lambda
:
ModelArgs
())
"""the number of layers for the agent"""
"""the model arguments for the agent1"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
rnn_channels
:
Optional
[
int
]
=
512
"""the number of rnn channels for the agent"""
rnn_type
:
Optional
[
str
]
=
"lstm"
"""the type of RNN to use for agent, None for no RNN"""
checkpoint
:
Optional
[
str
]
=
None
checkpoint
:
Optional
[
str
]
=
None
"""the checkpoint to load, must be a `flax_model` file"""
"""the checkpoint to load, must be a `flax_model` file"""
...
@@ -78,11 +73,8 @@ class Args:
...
@@ -78,11 +73,8 @@ class Args:
def
create_agent
(
args
):
def
create_agent
(
args
):
return
RNNAgent
(
return
RNNAgent
(
channels
=
args
.
num_channels
,
**
asdict
(
args
.
m
),
num_layers
=
args
.
num_layers
,
rnn_channels
=
args
.
rnn_channels
,
embedding_shape
=
args
.
num_embeddings
,
embedding_shape
=
args
.
num_embeddings
,
rnn_type
=
args
.
rnn_type
,
)
)
...
@@ -97,12 +89,14 @@ if __name__ == "__main__":
...
@@ -97,12 +89,14 @@ if __name__ == "__main__":
args
.
env_threads
=
min
(
args
.
env_threads
or
args
.
num_envs
,
args
.
num_envs
)
args
.
env_threads
=
min
(
args
.
env_threads
or
args
.
num_envs
,
args
.
num_envs
)
deck
=
init_ygopro
(
args
.
env_id
,
args
.
lang
,
args
.
deck
,
args
.
code_list_fil
e
)
deck
,
deck_names
=
init_ygopro
(
args
.
env_id
,
args
.
lang
,
args
.
deck
,
args
.
code_list_file
,
return_deck_names
=
Tru
e
)
args
.
deck1
=
args
.
deck1
or
deck
args
.
deck1
=
args
.
deck1
or
deck
args
.
deck2
=
args
.
deck2
or
deck
args
.
deck2
=
args
.
deck2
or
deck
seed
=
args
.
seed
seed
=
args
.
seed
+
100000
random
.
seed
(
seed
)
seed
=
random
.
randint
(
0
,
1e8
)
random
.
seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
...
@@ -135,22 +129,21 @@ if __name__ == "__main__":
...
@@ -135,22 +129,21 @@ if __name__ == "__main__":
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
flax
import
flax
from
ygoai.rl.jax.agent
import
RNNAgent
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
agent
=
create_agent
(
args
)
agent
=
create_agent
(
args
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
=
jax
.
random
.
PRNGKey
(
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
rstate
=
agent
.
init_rnn_state
(
1
)
rstate
=
agent
.
init_rnn_state
(
1
)
params
=
jax
.
jit
(
agent
.
init
)(
agent_
key
,
sample_obs
,
rstate
)
params
=
jax
.
jit
(
agent
.
init
)(
key
,
sample_obs
,
rstate
)
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params
=
jax
.
device_put
(
params
)
params
=
jax
.
device_put
(
params
)
rstate
=
agent
.
init_rnn_state
(
num_envs
)
@
jax
.
jit
@
jax
.
jit
def
get_probs_and_value
(
params
,
rstate
,
obs
,
done
):
def
get_probs_and_value
(
params
,
rstate
,
obs
,
done
):
...
@@ -180,6 +173,10 @@ if __name__ == "__main__":
...
@@ -180,6 +173,10 @@ if __name__ == "__main__":
start
=
time
.
time
()
start
=
time
.
time
()
start_step
=
step
start_step
=
step
deck_names
=
sorted
(
deck_names
)
deck_times
=
{
name
:
0
for
name
in
deck_names
}
deck_time_count
=
{
name
:
0
for
name
in
deck_names
}
model_time
=
env_time
=
0
model_time
=
env_time
=
0
while
True
:
while
True
:
if
start_step
==
0
and
len
(
episode_lengths
)
>
int
(
args
.
num_episodes
*
0.1
):
if
start_step
==
0
and
len
(
episode_lengths
)
>
int
(
args
.
num_episodes
*
0.1
):
...
@@ -211,7 +208,20 @@ if __name__ == "__main__":
...
@@ -211,7 +208,20 @@ if __name__ == "__main__":
step
+=
1
step
+=
1
for
idx
,
d
in
enumerate
(
dones
):
for
idx
,
d
in
enumerate
(
dones
):
if
d
:
if
not
d
:
continue
for
i
in
range
(
2
):
deck_time
=
infos
[
'step_time'
][
idx
][
i
]
deck_name
=
deck_names
[
infos
[
'deck'
][
idx
][
i
]]
time_count
=
deck_time_count
[
deck_name
]
avg_time
=
deck_times
[
deck_name
]
avg_time
=
avg_time
*
(
time_count
/
(
time_count
+
1
))
+
deck_time
/
(
time_count
+
1
)
deck_times
[
deck_name
]
=
avg_time
deck_time_count
[
deck_name
]
+=
1
if
deck_time_count
[
deck_name
]
%
100
==
0
:
print
(
f
"Deck {deck_name}: {avg_time:.4f}"
)
win_reason
=
infos
[
'win_reason'
][
idx
]
win_reason
=
infos
[
'win_reason'
][
idx
]
episode_length
=
infos
[
'l'
][
idx
]
episode_length
=
infos
[
'l'
][
idx
]
episode_reward
=
infos
[
'r'
][
idx
]
episode_reward
=
infos
[
'r'
][
idx
]
...
...
scripts/ppo.py
deleted
100644 → 0
View file @
cd59a6e9
import
os
import
shutil
import
queue
import
random
import
threading
import
time
from
datetime
import
datetime
,
timedelta
,
timezone
from
collections
import
deque
from
dataclasses
import
dataclass
,
field
from
types
import
SimpleNamespace
from
typing
import
List
,
NamedTuple
,
Optional
from
functools
import
partial
import
ygoenv
import
flax
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
import
optax
import
distrax
import
tyro
from
flax.training.train_state
import
TrainState
from
rich.pretty
import
pprint
from
tensorboardX
import
SummaryWriter
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.ckpt
import
ModelCheckpoint
,
sync_to_gcs
,
zip_files
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
ach_loss
from
ygoai.rl.jax.switch
import
truncated_gae_2p0s
os
.
environ
[
"XLA_FLAGS"
]
=
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@
dataclass
class
Args
:
exp_name
:
str
=
os
.
path
.
basename
(
__file__
)
.
rstrip
(
".py"
)
"""the name of this experiment"""
seed
:
int
=
1
"""seed of the experiment"""
log_frequency
:
int
=
10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval
:
int
=
400
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to load"""
debug
:
bool
=
False
"""whether to run the script in debug mode"""
tb_dir
:
str
=
"runs"
"""the directory to save the tensorboard logs"""
ckpt_dir
:
str
=
"checkpoints"
"""the directory to save the model checkpoints"""
gcs_bucket
:
Optional
[
str
]
=
None
"""the GCS bucket to save the model checkpoints"""
# Algorithm specific arguments
env_id
:
str
=
"YGOPro-v0"
"""the id of the environment"""
deck
:
str
=
"../assets/deck"
"""the deck file to use"""
deck1
:
Optional
[
str
]
=
None
"""the deck file for the first player"""
deck2
:
Optional
[
str
]
=
None
"""the deck file for the second player"""
code_list_file
:
str
=
"code_list.txt"
"""the code list file for card embeddings"""
embedding_file
:
Optional
[
str
]
=
None
"""the embedding file for card embeddings"""
max_options
:
int
=
24
"""the maximum number of options"""
n_history_actions
:
int
=
32
"""the number of history actions to use"""
greedy_reward
:
bool
=
False
"""whether to use greedy reward (faster kill higher reward)"""
total_timesteps
:
int
=
50000000000
"""total timesteps of the experiments"""
learning_rate
:
float
=
3e-4
"""the learning rate of the optimizer"""
local_num_envs
:
int
=
128
"""the number of parallel game environments"""
local_env_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for environment"""
num_actor_threads
:
int
=
2
"""the number of actor threads to use"""
num_steps
:
int
=
128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr
:
bool
=
False
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
1.0
"""the discount factor gamma"""
num_minibatches
:
int
=
64
"""the number of mini-batches"""
update_epochs
:
int
=
2
"""the K epochs to update the policy"""
norm_adv
:
bool
=
False
"""Toggles advantages normalization"""
upgo
:
bool
=
True
"""Toggle the use of UPGO for advantages"""
gae_lambda
:
float
=
0.95
"""the lambda for the general advantage estimation"""
clip_coef
:
float
=
0.25
"""the surrogate clipping coefficient"""
dual_clip_coef
:
Optional
[
float
]
=
3.0
"""the dual surrogate clipping coefficient, typically 3.0"""
spo_kld_max
:
Optional
[
float
]
=
None
"""the maximum KLD for the SPO policy, typically 0.02"""
logits_threshold
:
Optional
[
float
]
=
None
"""the logits threshold for NeuRD and ACH, typically 2.0-6.0"""
ent_coef
:
float
=
0.01
"""coefficient of the entropy"""
vf_coef
:
float
=
1.0
"""coefficient of the value function"""
max_grad_norm
:
float
=
1.0
"""the maximum norm for the gradient clipping"""
num_layers
:
int
=
2
"""the number of layers for the agent"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
rnn_channels
:
int
=
512
"""the number of channels for the RNN in the agent"""
actor_device_ids
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
0
,
1
])
"""the device ids that actor workers will use"""
learner_device_ids
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
2
,
3
])
"""the device ids that learner workers will use"""
distributed
:
bool
=
False
"""whether to use `jax.distirbuted`"""
concurrency
:
bool
=
True
"""whether to run the actor and learner concurrently"""
bfloat16
:
bool
=
False
"""whether to use bfloat16 for the agent"""
thread_affinity
:
bool
=
False
"""whether to use thread affinity for the environment"""
eval_checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes
:
int
=
128
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
100
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size
:
int
=
0
local_minibatch_size
:
int
=
0
world_size
:
int
=
0
local_rank
:
int
=
0
num_envs
:
int
=
0
batch_size
:
int
=
0
minibatch_size
:
int
=
0
num_updates
:
int
=
0
global_learner_decices
:
Optional
[
List
[
str
]]
=
None
actor_devices
:
Optional
[
List
[
str
]]
=
None
learner_devices
:
Optional
[
List
[
str
]]
=
None
num_embeddings
:
Optional
[
int
]
=
None
freeze_id
:
Optional
[
bool
]
=
None
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
,
eval
=
False
):
if
not
args
.
thread_affinity
:
thread_affinity_offset
=
-
1
if
thread_affinity_offset
>=
0
:
print
(
"Binding to thread offset"
,
thread_affinity_offset
)
envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
num_envs
,
num_threads
=
num_threads
,
thread_affinity_offset
=
thread_affinity_offset
,
seed
=
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
async_reset
=
False
,
greedy_reward
=
args
.
greedy_reward
if
not
eval
else
True
,
play_mode
=
mode
,
)
envs
.
num_envs
=
num_envs
return
envs
class
Transition
(
NamedTuple
):
obs
:
list
dones
:
list
actions
:
list
logits
:
list
rewards
:
list
mains
:
list
next_dones
:
list
def
create_agent
(
args
,
multi_step
=
False
):
return
PPOLSTMAgent
(
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
embedding_shape
=
args
.
num_embeddings
,
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
lstm_channels
=
args
.
rnn_channels
,
switch
=
True
,
multi_step
=
multi_step
,
freeze_id
=
args
.
freeze_id
,
)
def
init_rnn_state
(
num_envs
,
rnn_channels
):
return
(
np
.
zeros
((
num_envs
,
rnn_channels
)),
np
.
zeros
((
num_envs
,
rnn_channels
)),
)
def
rollout
(
key
:
jax
.
random
.
PRNGKey
,
args
:
Args
,
rollout_queue
,
params_queue
,
writer
,
learner_devices
,
device_thread_id
,
):
eval_mode
=
'self'
if
args
.
eval_checkpoint
else
'bot'
if
eval_mode
!=
'bot'
:
eval_params
=
params_queue
.
get
()
local_seed
=
args
.
seed
+
device_thread_id
np
.
random
.
seed
(
local_seed
)
envs
=
make_env
(
args
,
local_seed
,
args
.
local_num_envs
,
args
.
local_env_threads
,
thread_affinity_offset
=
device_thread_id
*
args
.
local_env_threads
,
)
envs
=
RecordEpisodeStatistics
(
envs
)
eval_envs
=
make_env
(
args
,
local_seed
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
,
eval
=
True
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
n_actors
=
args
.
num_actor_threads
*
len_actor_device_ids
global_step
=
0
start_time
=
time
.
time
()
warmup_step
=
0
other_time
=
0
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
@
jax
.
jit
def
get_logits
(
params
:
flax
.
core
.
FrozenDict
,
inputs
):
rstate
,
logits
=
create_agent
(
args
)
.
apply
(
params
,
inputs
)[:
2
]
return
rstate
,
logits
@
jax
.
jit
def
get_action
(
params
:
flax
.
core
.
FrozenDict
,
inputs
):
rstate
,
logits
=
get_logits
(
params
,
inputs
)
return
rstate
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
get_action_battle
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
))
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
))
logits
=
jnp
.
where
(
main
[:,
None
],
logits1
,
logits2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
next_rstate2
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
return
rstate1
,
rstate2
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
sample_action
(
params
:
flax
.
core
.
FrozenDict
,
next_obs
,
rstate1
,
rstate2
,
main
,
done
,
key
):
next_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
(
x
),
next_obs
)
done
=
jnp
.
array
(
done
)
main
=
jnp
.
array
(
main
)
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
logits
=
get_logits
(
params
,
(
rstate
,
next_obs
))
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
action
,
key
=
categorical_sample
(
logits
,
key
)
return
next_obs
,
done
,
main
,
rstate1
,
rstate2
,
action
,
logits
,
key
# put data in the last index
params_queue_get_time
=
deque
(
maxlen
=
10
)
rollout_time
=
deque
(
maxlen
=
10
)
actor_policy_version
=
0
next_obs
,
info
=
envs
.
reset
()
next_to_play
=
info
[
"to_play"
]
next_done
=
np
.
zeros
(
args
.
local_num_envs
,
dtype
=
np
.
bool_
)
next_rstate1
=
next_rstate2
=
init_rnn_state
(
args
.
local_num_envs
,
args
.
rnn_channels
)
eval_rstate
=
init_rnn_state
(
args
.
local_eval_episodes
,
args
.
rnn_channels
)
main_player
=
np
.
concatenate
([
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
np
.
random
.
shuffle
(
main_player
)
storage
=
[]
@
jax
.
jit
def
prepare_data
(
storage
:
List
[
Transition
])
->
Transition
:
return
jax
.
tree
.
map
(
lambda
*
xs
:
jnp
.
split
(
jnp
.
stack
(
xs
),
len
(
learner_devices
),
axis
=
1
),
*
storage
)
for
update
in
range
(
1
,
args
.
num_updates
+
2
):
if
update
==
10
:
start_time
=
time
.
time
()
warmup_step
=
global_step
update_time_start
=
time
.
time
()
inference_time
=
0
env_time
=
0
params_queue_get_time_start
=
time
.
time
()
if
args
.
concurrency
:
if
update
!=
2
:
params
=
params_queue
.
get
()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version
+=
1
else
:
params
=
params_queue
.
get
()
actor_policy_version
+=
1
params_queue_get_time
.
append
(
time
.
time
()
-
params_queue_get_time_start
)
rollout_time_start
=
time
.
time
()
init_rstate1
,
init_rstate2
=
jax
.
tree
.
map
(
lambda
x
:
x
.
copy
(),
(
next_rstate1
,
next_rstate2
))
for
_
in
range
(
args
.
num_steps
):
global_step
+=
args
.
local_num_envs
*
n_actors
*
args
.
world_size
cached_next_obs
=
next_obs
cached_next_done
=
next_done
main
=
next_to_play
==
main_player
inference_time_start
=
time
.
time
()
cached_next_obs
,
cached_next_done
,
cached_main
,
\
next_rstate1
,
next_rstate2
,
action
,
logits
,
key
=
sample_action
(
params
,
cached_next_obs
,
next_rstate1
,
next_rstate2
,
main
,
cached_next_done
,
key
)
cpu_action
=
np
.
array
(
action
)
inference_time
+=
time
.
time
()
-
inference_time_start
_start
=
time
.
time
()
next_obs
,
next_reward
,
next_done
,
info
=
envs
.
step
(
cpu_action
)
next_to_play
=
info
[
"to_play"
]
env_time
+=
time
.
time
()
-
_start
storage
.
append
(
Transition
(
obs
=
cached_next_obs
,
dones
=
cached_next_done
,
mains
=
cached_main
,
actions
=
action
,
logits
=
logits
,
rewards
=
next_reward
,
next_dones
=
next_done
,
)
)
for
idx
,
d
in
enumerate
(
next_done
):
if
not
d
:
continue
cur_main
=
main
[
idx
]
for
j
in
reversed
(
range
(
len
(
storage
)
-
1
)):
t
=
storage
[
j
]
if
t
.
next_dones
[
idx
]:
# For OTK where player may not switch
break
if
t
.
mains
[
idx
]
!=
cur_main
:
t
.
next_dones
[
idx
]
=
True
t
.
rewards
[
idx
]
=
-
next_reward
[
idx
]
break
episode_reward
=
info
[
'r'
][
idx
]
*
(
1
if
cur_main
else
-
1
)
win
=
1
if
episode_reward
>
0
else
0
avg_ep_returns
.
append
(
episode_reward
)
avg_win_rates
.
append
(
win
)
rollout_time
.
append
(
time
.
time
()
-
rollout_time_start
)
partitioned_storage
=
prepare_data
(
storage
)
storage
=
[]
sharded_storage
=
[]
for
x
in
partitioned_storage
:
if
isinstance
(
x
,
dict
):
x
=
{
k
:
jax
.
device_put_sharded
(
v
,
devices
=
learner_devices
)
for
k
,
v
in
x
.
items
()
}
else
:
x
=
jax
.
device_put_sharded
(
x
,
devices
=
learner_devices
)
sharded_storage
.
append
(
x
)
sharded_storage
=
Transition
(
*
sharded_storage
)
next_main
=
main_player
==
next_to_play
next_rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
next_rstate1
,
next_rstate2
)
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
(
init_rstate1
,
init_rstate2
,
(
next_rstate
,
next_obs
),
next_main
))
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
_start
=
time
.
time
()
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
eval_time
=
time
.
time
()
-
_start
other_time
+=
eval_time
eval_stats
=
np
.
array
([
eval_time
,
eval_return
,
eval_win_rate
],
dtype
=
np
.
float32
)
else
:
eval_stats
=
None
learn_opponent
=
False
payload
=
(
global_step
,
update
,
sharded_storage
,
*
sharded_data
,
np
.
mean
(
params_queue_get_time
),
learn_opponent
,
eval_stats
,
)
rollout_queue
.
put
(
payload
)
if
update
%
args
.
log_frequency
==
0
:
avg_episodic_return
=
np
.
mean
(
avg_ep_returns
)
avg_episodic_length
=
np
.
mean
(
envs
.
returned_episode_lengths
)
SPS
=
int
((
global_step
-
warmup_step
)
/
(
time
.
time
()
-
start_time
-
other_time
))
SPS_update
=
int
(
args
.
batch_size
/
(
time
.
time
()
-
update_time_start
))
if
device_thread_id
==
0
:
print
(
f
"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
)
time_now
=
datetime
.
now
(
timezone
(
timedelta
(
hours
=
8
)))
.
strftime
(
"
%
H:
%
M:
%
S"
)
print
(
f
"{time_now} SPS: {SPS}, update: {SPS_update}, "
f
"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
writer
.
add_scalar
(
"stats/rollout_time"
,
np
.
mean
(
rollout_time
),
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_return"
,
avg_episodic_return
,
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_length"
,
avg_episodic_length
,
global_step
)
writer
.
add_scalar
(
"stats/params_queue_get_time"
,
np
.
mean
(
params_queue_get_time
),
global_step
)
writer
.
add_scalar
(
"stats/inference_time"
,
inference_time
,
global_step
)
writer
.
add_scalar
(
"stats/env_time"
,
env_time
,
global_step
)
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
writer
.
add_scalar
(
"charts/SPS_update"
,
SPS_update
,
global_step
)
if
__name__
==
"__main__"
:
args
=
tyro
.
cli
(
Args
)
args
.
local_batch_size
=
int
(
args
.
local_num_envs
*
args
.
num_steps
*
args
.
num_actor_threads
*
len
(
args
.
actor_device_ids
))
args
.
local_minibatch_size
=
int
(
args
.
local_batch_size
//
args
.
num_minibatches
)
assert
(
args
.
local_num_envs
%
len
(
args
.
learner_device_ids
)
==
0
),
"local_num_envs must be divisible by len(learner_device_ids)"
assert
(
int
(
args
.
local_num_envs
/
len
(
args
.
learner_device_ids
))
*
args
.
num_actor_threads
%
args
.
num_minibatches
==
0
),
"int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if
args
.
distributed
:
jax
.
distributed
.
initialize
(
local_device_ids
=
range
(
len
(
args
.
learner_device_ids
)
+
len
(
args
.
actor_device_ids
)),
)
print
(
list
(
range
(
len
(
args
.
learner_device_ids
)
+
len
(
args
.
actor_device_ids
))))
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
args
.
world_size
=
jax
.
process_count
()
args
.
local_rank
=
jax
.
process_index
()
args
.
num_envs
=
args
.
local_num_envs
*
args
.
world_size
*
args
.
num_actor_threads
*
len
(
args
.
actor_device_ids
)
args
.
batch_size
=
args
.
local_batch_size
*
args
.
world_size
args
.
minibatch_size
=
args
.
local_minibatch_size
*
args
.
world_size
args
.
num_updates
=
args
.
total_timesteps
//
(
args
.
local_batch_size
*
args
.
world_size
)
args
.
local_env_threads
=
args
.
local_env_threads
or
args
.
local_num_envs
if
args
.
embedding_file
:
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
embedding_shape
=
embeddings
.
shape
args
.
num_embeddings
=
embedding_shape
args
.
freeze_id
=
True
if
args
.
freeze_id
is
None
else
args
.
freeze_id
else
:
embeddings
=
None
embedding_shape
=
None
local_devices
=
jax
.
local_devices
()
global_devices
=
jax
.
devices
()
learner_devices
=
[
local_devices
[
d_id
]
for
d_id
in
args
.
learner_device_ids
]
actor_devices
=
[
local_devices
[
d_id
]
for
d_id
in
args
.
actor_device_ids
]
global_learner_decices
=
[
global_devices
[
d_id
+
process_index
*
len
(
local_devices
)]
for
process_index
in
range
(
args
.
world_size
)
for
d_id
in
args
.
learner_device_ids
]
global_main_devices
=
[
global_devices
[
process_index
*
len
(
local_devices
)]
for
process_index
in
range
(
args
.
world_size
)
]
print
(
"global_learner_decices"
,
global_learner_decices
)
args
.
global_learner_decices
=
[
str
(
item
)
for
item
in
global_learner_decices
]
args
.
actor_devices
=
[
str
(
item
)
for
item
in
actor_devices
]
args
.
learner_devices
=
[
str
(
item
)
for
item
in
learner_devices
]
pprint
(
args
)
timestamp
=
int
(
time
.
time
())
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
dummy_writer
=
SimpleNamespace
()
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
tb_log_dir
=
f
"{args.tb_dir}/{run_name}"
if
args
.
local_rank
==
0
and
not
args
.
debug
:
writer
=
SummaryWriter
(
tb_log_dir
)
writer
.
add_text
(
"hyperparameters"
,
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
)
else
:
writer
=
dummy_writer
def
save_fn
(
obj
,
path
):
with
open
(
path
,
"wb"
)
as
f
:
f
.
write
(
flax
.
serialization
.
to_bytes
(
obj
))
ckpt_maneger
=
ModelCheckpoint
(
args
.
ckpt_dir
,
save_fn
,
n_saved
=
2
)
# seeding
seed_offset
=
args
.
local_rank
*
10000
args
.
seed
+=
seed_offset
random
.
seed
(
args
.
seed
)
init_key
=
jax
.
random
.
PRNGKey
(
args
.
seed
-
seed_offset
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
,
*
learner_keys
=
jax
.
random
.
split
(
key
,
len
(
learner_devices
)
+
1
)
learner_keys
=
jax
.
device_put_sharded
(
learner_keys
,
devices
=
learner_devices
)
actor_keys
=
jax
.
random
.
split
(
key
,
len
(
actor_devices
)
*
args
.
num_actor_threads
)
deck
=
init_ygopro
(
args
.
env_id
,
"english"
,
args
.
deck
,
args
.
code_list_file
)
args
.
deck1
=
args
.
deck1
or
deck
args
.
deck2
=
args
.
deck2
or
deck
# env setup
envs
=
make_env
(
args
,
args
.
seed
,
8
,
1
)
obs_space
=
envs
.
observation_space
action_shape
=
envs
.
action_space
.
shape
print
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
envs
.
close
()
del
envs
def
linear_schedule
(
count
):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac
=
1.0
-
(
count
//
(
args
.
num_minibatches
*
args
.
update_epochs
))
/
args
.
num_updates
return
args
.
learning_rate
*
frac
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
agent
=
create_agent
(
args
)
params
=
agent
.
init
(
init_key
,
(
rstate
,
sample_obs
))
if
embeddings
is
not
None
:
unknown_embed
=
embeddings
.
mean
(
axis
=
0
)
embeddings
=
np
.
concatenate
([
unknown_embed
[
None
,
:],
embeddings
],
axis
=
0
)
params
=
flax
.
core
.
unfreeze
(
params
)
params
[
'params'
][
'Encoder_0'
][
'Embed_0'
][
'embedding'
]
=
jax
.
device_put
(
embeddings
)
params
=
flax
.
core
.
freeze
(
params
)
tx
=
optax
.
MultiSteps
(
optax
.
chain
(
optax
.
clip_by_global_norm
(
args
.
max_grad_norm
),
optax
.
inject_hyperparams
(
optax
.
adam
)(
learning_rate
=
linear_schedule
if
args
.
anneal_lr
else
args
.
learning_rate
,
eps
=
1e-5
),
),
every_k_schedule
=
1
,
)
tx
=
optax
.
apply_if_finite
(
tx
,
max_consecutive_errors
=
10
)
agent_state
=
TrainState
.
create
(
apply_fn
=
None
,
params
=
params
,
tx
=
tx
,
)
if
args
.
checkpoint
:
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
agent_state
=
agent_state
.
replace
(
params
=
params
)
print
(
f
"loaded checkpoint from {args.checkpoint}"
)
agent_state
=
flax
.
jax_utils
.
replicate
(
agent_state
,
devices
=
learner_devices
)
# print(agent.tabulate(agent_key, sample_obs))
if
args
.
eval_checkpoint
:
with
open
(
args
.
eval_checkpoint
,
"rb"
)
as
f
:
eval_params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
print
(
f
"loaded eval checkpoint from {args.eval_checkpoint}"
)
else
:
eval_params
=
None
@
jax
.
jit
def
get_logits_and_value
(
params
:
flax
.
core
.
FrozenDict
,
inputs
,
):
rstate
,
logits
,
value
,
valid
=
create_agent
(
args
,
multi_step
=
True
)
.
apply
(
params
,
inputs
)
return
logits
,
value
.
squeeze
(
-
1
)
def
loss_fn
(
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
switch
,
actions
,
logits
,
rewards
,
mask
,
next_value
):
# (num_steps * local_num_envs // n_mb))
num_envs
=
next_value
.
shape
[
0
]
num_steps
=
dones
.
shape
[
0
]
//
num_envs
def
reshape_time_series
(
x
):
return
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)
+
x
.
shape
[
1
:])
mask
=
mask
*
(
1.0
-
dones
)
n_valids
=
jnp
.
sum
(
mask
)
dones
=
dones
|
next_dones
inputs
=
(
rstate1
,
rstate2
,
obs
,
dones
,
switch
)
new_logits
,
new_values
=
get_logits_and_value
(
params
,
inputs
)
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
logratio
=
jnp
.
log
(
ratios
)
approx_kl
=
(((
ratios
-
1
)
-
logratio
)
*
mask
)
.
sum
()
/
n_valids
new_values_
,
rewards
,
next_dones
,
switch
=
jax
.
tree
.
map
(
reshape_time_series
,
(
new_values
,
rewards
,
next_dones
,
switch
),
)
target_values
,
advantages
=
truncated_gae_2p0s
(
next_value
,
new_values_
,
rewards
,
next_dones
,
switch
,
args
.
gamma
,
args
.
gae_lambda
,
args
.
upgo
)
target_values
,
advantages
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)),
(
target_values
,
advantages
))
if
args
.
norm_adv
:
advantages
=
masked_normalize
(
advantages
,
mask
,
eps
=
1e-8
)
# Policy loss
if
args
.
spo_kld_max
is
not
None
:
pg_loss
=
simple_policy_loss
(
ratios
,
logits
,
new_logits
,
advantages
,
args
.
spo_kld_max
)
elif
args
.
logits_threshold
is
not
None
:
pg_loss
=
ach_loss
(
actions
,
logits
,
new_logits
,
advantages
,
args
.
logits_threshold
,
args
.
clip_coef
,
args
.
dual_clip_coef
)
else
:
pg_loss
=
clipped_surrogate_pg_loss
(
ratios
,
advantages
,
args
.
clip_coef
,
args
.
dual_clip_coef
)
pg_loss
=
jnp
.
sum
(
pg_loss
*
mask
)
v_loss
=
mse_loss
(
new_values
,
target_values
)
v_loss
=
jnp
.
sum
(
v_loss
*
mask
)
ent_loss
=
entropy_loss
(
new_logits
)
ent_loss
=
jnp
.
sum
(
ent_loss
*
mask
)
pg_loss
=
pg_loss
/
n_valids
v_loss
=
v_loss
/
n_valids
ent_loss
=
ent_loss
/
n_valids
loss
=
pg_loss
-
args
.
ent_coef
*
ent_loss
+
v_loss
*
args
.
vf_coef
return
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
jax
.
lax
.
stop_gradient
(
approx_kl
))
def
single_device_update
(
agent_state
:
TrainState
,
sharded_storages
:
List
,
sharded_init_rstate1
:
List
,
sharded_init_rstate2
:
List
,
sharded_next_inputs
:
List
,
sharded_next_main
:
List
,
key
:
jax
.
random
.
PRNGKey
,
learn_opponent
:
bool
=
False
,
):
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
# TODO: rstate will be out-date after the first update, maybe consider R2D2
next_inputs
,
init_rstate1
,
init_rstate2
=
[
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
for
x
in
[
sharded_next_inputs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
]
next_main
=
jnp
.
concatenate
(
sharded_next_main
)
# reorder storage of individual players
# main first, opponent second
num_steps
,
num_envs
=
storage
.
rewards
.
shape
T
=
jnp
.
arange
(
num_steps
,
dtype
=
jnp
.
int32
)
B
=
jnp
.
arange
(
num_envs
,
dtype
=
jnp
.
int32
)
mains
=
storage
.
mains
.
astype
(
jnp
.
int32
)
indices
=
jnp
.
argsort
(
T
[:,
None
]
-
mains
*
num_steps
,
axis
=
0
)
switch_steps
=
jnp
.
sum
(
mains
,
axis
=
0
)
switch
=
T
[:,
None
]
==
(
switch_steps
[
None
,
:]
-
1
)
storage
=
jax
.
tree
.
map
(
lambda
x
:
x
[
indices
,
B
[
None
,
:]],
storage
)
loss_grad_fn
=
jax
.
value_and_grad
(
loss_fn
,
has_aux
=
True
)
def
update_epoch
(
carry
,
_
):
agent_state
,
key
=
carry
key
,
subkey
=
jax
.
random
.
split
(
key
)
next_value
=
create_agent
(
args
)
.
apply
(
agent_state
.
params
,
next_inputs
)[
2
]
.
squeeze
(
-
1
)
next_value
=
jnp
.
where
(
next_main
,
-
next_value
,
next_value
)
def
convert_data
(
x
:
jnp
.
ndarray
,
num_steps
):
if
args
.
update_epochs
>
1
:
x
=
jax
.
random
.
permutation
(
subkey
,
x
,
axis
=
1
if
num_steps
>
1
else
0
)
N
=
args
.
num_minibatches
if
num_steps
>
1
:
x
=
jnp
.
reshape
(
x
,
(
num_steps
,
N
,
-
1
)
+
x
.
shape
[
2
:])
x
=
x
.
transpose
(
1
,
0
,
*
range
(
2
,
x
.
ndim
))
x
=
x
.
reshape
(
N
,
-
1
,
*
x
.
shape
[
3
:])
else
:
x
=
jnp
.
reshape
(
x
,
(
N
,
-
1
)
+
x
.
shape
[
1
:])
return
x
shuffled_init_rstate1
,
shuffled_init_rstate2
,
\
shuffled_next_value
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
1
),
(
init_rstate1
,
init_rstate2
,
next_value
))
shuffled_storage
,
shuffled_switch
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
num_steps
),
(
storage
,
switch
))
shuffled_mask
=
jnp
.
ones_like
(
shuffled_storage
.
mains
)
def
update_minibatch
(
agent_state
,
minibatch
):
(
loss
,
(
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)),
grads
=
loss_grad_fn
(
agent_state
.
params
,
*
minibatch
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
return
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_minibatch
,
agent_state
,
(
shuffled_init_rstate1
,
shuffled_init_rstate2
,
shuffled_storage
.
obs
,
shuffled_storage
.
dones
,
shuffled_storage
.
next_dones
,
shuffled_switch
,
shuffled_storage
.
actions
,
shuffled_storage
.
logits
,
shuffled_storage
.
rewards
,
shuffled_mask
,
shuffled_next_value
,
),
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_epoch
,
(
agent_state
,
key
),
(),
length
=
args
.
update_epochs
)
loss
=
jax
.
lax
.
pmean
(
loss
,
axis_name
=
"local_devices"
)
.
mean
()
pg_loss
=
jax
.
lax
.
pmean
(
pg_loss
,
axis_name
=
"local_devices"
)
.
mean
()
v_loss
=
jax
.
lax
.
pmean
(
v_loss
,
axis_name
=
"local_devices"
)
.
mean
()
entropy_loss
=
jax
.
lax
.
pmean
(
entropy_loss
,
axis_name
=
"local_devices"
)
.
mean
()
approx_kl
=
jax
.
lax
.
pmean
(
approx_kl
,
axis_name
=
"local_devices"
)
.
mean
()
return
agent_state
,
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
,
key
all_reduce_value
=
jax
.
pmap
(
lambda
x
:
jax
.
lax
.
pmean
(
x
,
axis_name
=
"main_devices"
),
axis_name
=
"main_devices"
,
devices
=
global_main_devices
,
)
multi_device_update
=
jax
.
pmap
(
single_device_update
,
axis_name
=
"local_devices"
,
devices
=
global_learner_decices
,
static_broadcasted_argnums
=
(
7
,),
)
params_queues
=
[]
rollout_queues
=
[]
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
device_params
=
jax
.
device_put
(
unreplicated_params
,
local_devices
[
d_id
])
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
rollout_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
if
eval_params
:
params_queues
[
-
1
]
.
put
(
jax
.
device_put
(
eval_params
,
local_devices
[
d_id
]))
actor_thread_id
=
d_idx
*
args
.
num_actor_threads
+
thread_id
threading
.
Thread
(
target
=
rollout
,
args
=
(
jax
.
device_put
(
actor_keys
[
actor_thread_id
],
local_devices
[
d_id
]),
args
,
rollout_queues
[
-
1
],
params_queues
[
-
1
],
writer
if
d_idx
==
0
and
thread_id
==
0
else
dummy_writer
,
learner_devices
,
actor_thread_id
,
),
)
.
start
()
params_queues
[
-
1
]
.
put
(
device_params
)
rollout_queue_get_time
=
deque
(
maxlen
=
10
)
data_transfer_time
=
deque
(
maxlen
=
10
)
learner_policy_version
=
0
while
True
:
learner_policy_version
+=
1
rollout_queue_get_time_start
=
time
.
time
()
sharded_data_list
=
[]
eval_stat_list
=
[]
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
for
thread_id
in
range
(
args
.
num_actor_threads
):
(
global_step
,
update
,
*
sharded_data
,
avg_params_queue_get_time
,
learn_opponent
,
eval_stats
,
)
=
rollout_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
get
()
sharded_data_list
.
append
(
sharded_data
)
if
eval_stats
is
not
None
:
eval_stat_list
.
append
(
eval_stats
)
if
update
%
args
.
eval_interval
==
0
:
eval_stats
=
np
.
mean
(
eval_stat_list
,
axis
=
0
)
eval_stats
=
jax
.
device_put
(
eval_stats
,
local_devices
[
0
])
eval_stats
=
np
.
array
(
all_reduce_value
(
eval_stats
[
None
])[
0
])
eval_time
,
eval_return
,
eval_win_rate
=
eval_stats
writer
.
add_scalar
(
f
"charts/eval_return"
,
eval_return
,
global_step
)
writer
.
add_scalar
(
f
"charts/eval_win_rate"
,
eval_win_rate
,
global_step
)
print
(
f
"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}"
)
rollout_queue_get_time
.
append
(
time
.
time
()
-
rollout_queue_get_time_start
)
training_time_start
=
time
.
time
()
(
agent_state
,
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
,
learner_keys
)
=
multi_device_update
(
agent_state
,
*
list
(
zip
(
*
sharded_data_list
)),
learner_keys
,
learn_opponent
,
)
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
device_params
=
jax
.
device_put
(
unreplicated_params
,
local_devices
[
d_id
])
device_params
[
"params"
][
"Encoder_0"
][
'Embed_0'
][
"embedding"
]
.
block_until_ready
()
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
put
(
device_params
)
loss
=
loss
[
-
1
]
.
item
()
if
np
.
isnan
(
loss
)
or
np
.
isinf
(
loss
):
raise
ValueError
(
f
"loss is {loss}"
)
# record rewards for plotting purposes
if
learner_policy_version
%
args
.
log_frequency
==
0
:
writer
.
add_scalar
(
"stats/rollout_queue_get_time"
,
np
.
mean
(
rollout_queue_get_time
),
global_step
)
writer
.
add_scalar
(
"stats/rollout_params_queue_get_time_diff"
,
np
.
mean
(
rollout_queue_get_time
)
-
avg_params_queue_get_time
,
global_step
,
)
writer
.
add_scalar
(
"stats/training_time"
,
time
.
time
()
-
training_time_start
,
global_step
)
writer
.
add_scalar
(
"stats/rollout_queue_size"
,
rollout_queues
[
-
1
]
.
qsize
(),
global_step
)
writer
.
add_scalar
(
"stats/params_queue_size"
,
params_queues
[
-
1
]
.
qsize
(),
global_step
)
print
(
f
"{global_step} actor_update={update}, "
f
"train_time={time.time() - training_time_start:.2f}, "
f
"data_time={rollout_queue_get_time[-1]:.2f}"
)
writer
.
add_scalar
(
"charts/learning_rate"
,
agent_state
.
opt_state
[
3
][
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/policy_loss"
,
pg_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/entropy"
,
entropy_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/approx_kl"
,
approx_kl
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/loss"
,
loss
,
global_step
)
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
and
not
args
.
debug
:
M_steps
=
args
.
batch_size
*
learner_policy_version
//
2
**
20
ckpt_name
=
f
"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger
.
save
(
unreplicated_params
,
ckpt_name
)
if
args
.
gcs_bucket
is
not
None
:
lastest_path
=
ckpt_maneger
.
get_latest
()
copy_path
=
lastest_path
.
with_name
(
"latest"
+
lastest_path
.
suffix
)
shutil
.
copyfile
(
lastest_path
,
copy_path
)
zip_file_path
=
"latest.zip"
zip_files
(
zip_file_path
,
[
str
(
copy_path
),
tb_log_dir
])
sync_to_gcs
(
args
.
gcs_bucket
,
zip_file_path
)
if
learner_policy_version
>=
args
.
num_updates
:
break
if
args
.
distributed
:
jax
.
distributed
.
shutdown
()
writer
.
close
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment