Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
bbd36d86
Commit
bbd36d86
authored
Apr 18, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
replace torch with jax as default
parent
43ca871e
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1482 additions
and
1845 deletions
+1482
-1845
scripts/battle.py
scripts/battle.py
+126
-137
scripts/eval.py
scripts/eval.py
+44
-113
scripts/impala.py
scripts/impala.py
+0
-0
scripts/jax/battle.py
scripts/jax/battle.py
+0
-283
scripts/jax/ppo.py
scripts/jax/ppo.py
+0
-868
scripts/ppo.py
scripts/ppo.py
+760
-444
scripts/torch/ppo.py
scripts/torch/ppo.py
+552
-0
scripts/torch/ppo_c.py
scripts/torch/ppo_c.py
+0
-0
scripts/torch/ppo_osfp.py
scripts/torch/ppo_osfp.py
+0
-0
scripts/torch/ppo_xla.py
scripts/torch/ppo_xla.py
+0
-0
No files found.
scripts/battle.py
View file @
bbd36d86
...
@@ -4,16 +4,20 @@ import os
...
@@ -4,16 +4,20 @@ import os
import
random
import
random
from
typing
import
Optional
,
Literal
from
typing
import
Optional
,
Literal
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
tqdm
import
tqdm
import
ygoenv
import
ygoenv
import
numpy
as
np
import
numpy
as
np
import
optree
import
tyro
import
tyro
import
jax
import
jax.numpy
as
jnp
import
flax
from
ygoai.utils
import
init_ygopro
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
@
dataclass
@
dataclass
...
@@ -54,6 +58,8 @@ class Args:
...
@@ -54,6 +58,8 @@ class Args:
"""the number of layers for the agent"""
"""the number of layers for the agent"""
num_channels
:
int
=
128
num_channels
:
int
=
128
"""the number of channels for the agent"""
"""the number of channels for the agent"""
rnn_channels
:
Optional
[
int
]
=
512
"""the number of rnn channels for the agent"""
checkpoint1
:
str
=
"checkpoints/agent.pt"
checkpoint1
:
str
=
"checkpoints/agent.pt"
"""the checkpoint to load for the first agent, `pt` or `flax_model` file"""
"""the checkpoint to load for the first agent, `pt` or `flax_model` file"""
checkpoint2
:
str
=
"checkpoints/agent.pt"
checkpoint2
:
str
=
"checkpoints/agent.pt"
...
@@ -63,25 +69,30 @@ class Args:
...
@@ -63,25 +69,30 @@ class Args:
xla_device
:
Optional
[
str
]
=
None
xla_device
:
Optional
[
str
]
=
None
"""the XLA device to use, defaults to `None`"""
"""the XLA device to use, defaults to `None`"""
# PyTorch specific
torch_deterministic
:
bool
=
True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda
:
bool
=
True
"""if toggled, cuda will be enabled by default"""
compile
:
bool
=
False
"""if toggled, the model will be compiled"""
optimize
:
bool
=
False
"""if toggled, the model will be optimized"""
torch_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads
:
Optional
[
int
]
=
16
env_threads
:
Optional
[
int
]
=
16
"""the number of threads to use for envpool, defaults to `num_envs`"""
"""the number of threads to use for envpool, defaults to `num_envs`"""
framework
:
Optional
[
Literal
[
"torch"
,
"jax"
]]
=
None
def
create_agent
(
args
):
return
PPOLSTMAgent
(
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
lstm_channels
=
args
.
rnn_channels
,
embedding_shape
=
args
.
num_embeddings
,
)
def
init_rnn_state
(
num_envs
,
rnn_channels
):
return
(
np
.
zeros
((
num_envs
,
rnn_channels
)),
np
.
zeros
((
num_envs
,
rnn_channels
)),
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
args
=
tyro
.
cli
(
Args
)
args
=
tyro
.
cli
(
Args
)
if
args
.
record
:
if
args
.
record
:
...
@@ -101,18 +112,6 @@ if __name__ == "__main__":
...
@@ -101,18 +112,6 @@ if __name__ == "__main__":
random
.
seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
if
args
.
framework
is
None
:
args
.
framework
=
"jax"
if
"flax_model"
in
args
.
checkpoint1
else
"torch"
if
args
.
framework
==
"torch"
:
import
torch
torch
.
manual_seed
(
args
.
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
args
.
torch_deterministic
args
.
torch_threads
=
args
.
torch_threads
or
int
(
os
.
getenv
(
"OMP_NUM_THREADS"
,
"4"
))
torch
.
set_num_threads
(
args
.
torch_threads
)
torch
.
set_float32_matmul_precision
(
'high'
)
else
:
if
args
.
xla_device
is
not
None
:
if
args
.
xla_device
is
not
None
:
os
.
environ
.
setdefault
(
"JAX_PLATFORMS"
,
args
.
xla_device
)
os
.
environ
.
setdefault
(
"JAX_PLATFORMS"
,
args
.
xla_device
)
...
@@ -138,78 +137,14 @@ if __name__ == "__main__":
...
@@ -138,78 +137,14 @@ if __name__ == "__main__":
envs
.
num_envs
=
num_envs
envs
.
num_envs
=
num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
envs
=
RecordEpisodeStatistics
(
envs
)
if
args
.
framework
==
'torch'
:
from
ygoai.rl.agent
import
PPOAgent
as
Agent
from
ygoai.rl.buffer
import
create_obs
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
args
.
cuda
else
"cpu"
)
if
args
.
checkpoint1
.
endswith
(
".ptj"
):
agent1
=
torch
.
jit
.
load
(
args
.
checkpoint1
)
agent2
=
torch
.
jit
.
load
(
args
.
checkpoint2
)
else
:
# count lines of code_list
embedding_shape
=
args
.
num_embeddings
if
embedding_shape
is
None
:
with
open
(
args
.
code_list_file
,
"r"
)
as
f
:
code_list
=
f
.
readlines
()
embedding_shape
=
len
(
code_list
)
L
=
args
.
num_layers
agent1
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
agent2
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
for
agent
,
ckpt
in
zip
([
agent1
,
agent2
],
[
args
.
checkpoint1
,
args
.
checkpoint2
]):
state_dict
=
torch
.
load
(
ckpt
,
map_location
=
device
)
if
not
args
.
compile
:
prefix
=
"_orig_mod."
state_dict
=
{
k
[
len
(
prefix
):]
if
k
.
startswith
(
prefix
)
else
k
:
v
for
k
,
v
in
state_dict
.
items
()}
print
(
agent
.
load_state_dict
(
state_dict
))
def
get_probs
(
agent
,
obs
):
with
torch
.
no_grad
():
return
torch
.
softmax
(
agent
(
obs
)[
0
],
dim
=-
1
)
if
args
.
compile
:
get_probs
=
torch
.
compile
(
get_probs
,
mode
=
'reduce-overhead'
)
elif
args
.
optimize
:
obs
=
create_obs
(
envs
.
observation_space
,
(
num_envs
,),
device
=
device
)
def
optimize_for_inference
(
agent
):
with
torch
.
no_grad
():
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
return
torch
.
jit
.
optimize_for_inference
(
traced_model
)
agent1
=
optimize_for_inference
(
agent1
)
agent2
=
optimize_for_inference
(
agent2
)
def
predict_fn
(
obs
,
main
):
obs
=
optree
.
tree_map
(
lambda
x
:
torch
.
from_numpy
(
x
)
.
to
(
device
=
device
),
obs
)
if
num_envs
!=
1
:
probs1
=
get_probs
(
agent
,
obs
)
probs2
=
get_probs
(
agent
,
obs
)
probs
=
torch
.
where
(
main
[:,
None
],
probs1
,
probs2
)
else
:
if
main
[
0
]:
probs
=
get_probs
(
agent1
,
obs
)
else
:
probs
=
get_probs
(
agent2
,
obs
)
probs
=
probs
.
cpu
()
.
numpy
()
return
probs
else
:
import
jax
import
jax.numpy
as
jnp
import
flax
from
ygoai.rl.jax.agent2
import
PPOAgent
def
create_agent
(
args
):
return
PPOAgent
(
channels
=
128
,
num_layers
=
2
,
embedding_shape
=
args
.
num_embeddings
,
)
agent
=
create_agent
(
args
)
agent
=
create_agent
(
args
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
params
=
agent
.
init
(
agent_key
,
sample_obs
)
print
(
jax
.
tree
.
leaves
(
params
)[
0
]
.
devices
())
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
params
=
jax
.
jit
(
agent
.
init
)(
agent_key
,
(
rstate
,
sample_obs
))
with
open
(
args
.
checkpoint1
,
"rb"
)
as
f
:
with
open
(
args
.
checkpoint1
,
"rb"
)
as
f
:
params1
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params1
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
if
args
.
checkpoint1
==
args
.
checkpoint2
:
if
args
.
checkpoint1
==
args
.
checkpoint2
:
...
@@ -218,38 +153,51 @@ if __name__ == "__main__":
...
@@ -218,38 +153,51 @@ if __name__ == "__main__":
with
open
(
args
.
checkpoint2
,
"rb"
)
as
f
:
with
open
(
args
.
checkpoint2
,
"rb"
)
as
f
:
params2
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params2
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params1
=
jax
.
device_put
(
params1
)
params2
=
jax
.
device_put
(
params2
)
@
jax
.
jit
@
jax
.
jit
def
get_probs
(
params
,
obs
):
def
get_probs
(
params
,
rstate
,
obs
,
done
):
agent
=
create_agent
(
args
)
agent
=
create_agent
(
args
)
logits
=
agent
.
apply
(
params
,
obs
)[
0
]
next_rstate
,
logits
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
2
]
return
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
next_rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
next_rstate
)
return
next_rstate
,
probs
if
args
.
num_envs
!=
1
:
if
args
.
num_envs
!=
1
:
@
jax
.
jit
@
jax
.
jit
def
get_probs2
(
params1
,
params2
,
obs
,
main
):
def
get_probs2
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
probs1
=
get_probs
(
params1
,
obs
)
next_rstate1
,
probs1
=
get_probs
(
params1
,
rstate1
,
obs
,
done
)
probs2
=
get_probs
(
params2
,
obs
)
next_rstate2
,
probs2
=
get_probs
(
params2
,
rstate2
,
obs
,
done
)
probs
=
jnp
.
where
(
main
[:,
None
],
probs1
,
probs2
)
probs
=
jnp
.
where
(
main
[:,
None
],
probs1
,
probs2
)
return
probs
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
def
predict_fn
(
obs
,
main
):
rstate2
=
jax
.
tree
.
map
(
probs
=
get_probs2
(
params1
,
params2
,
obs
,
main
)
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
next_rstate2
,
rstate2
)
return
np
.
array
(
probs
)
return
rstate1
,
rstate2
,
probs
def
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
done
):
rstate1
,
rstate2
,
probs
=
get_probs2
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
)
return
rstate1
,
rstate2
,
np
.
array
(
probs
)
else
:
else
:
def
predict_fn
(
obs
,
main
):
def
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
done
):
if
main
[
0
]:
if
main
[
0
]:
probs
=
get_probs
(
params1
,
obs
)
rstate1
,
probs
=
get_probs
(
params1
,
rstate1
,
obs
,
done
)
else
:
else
:
probs
=
get_probs
(
params2
,
obs
)
rstate2
,
probs
=
get_probs
(
params2
,
rstate2
,
obs
,
done
)
return
np
.
array
(
probs
)
return
rstate1
,
rstate2
,
np
.
array
(
probs
)
obs
,
infos
=
envs
.
reset
()
obs
,
infos
=
envs
.
reset
()
next_to_play
=
infos
[
'to_play'
]
next_to_play
=
infos
[
'to_play'
]
dones
=
np
.
zeros
(
num_envs
,
dtype
=
np
.
bool_
)
episode_rewards
=
[]
episode_rewards
=
[]
episode_lengths
=
[]
episode_lengths
=
[]
win_rates
=
[]
win_rates
=
[]
win_reasons
=
[]
win_reasons
=
[]
win_players
=
[]
win_agents
=
[]
step
=
0
step
=
0
start
=
time
.
time
()
start
=
time
.
time
()
...
@@ -259,6 +207,10 @@ if __name__ == "__main__":
...
@@ -259,6 +207,10 @@ if __name__ == "__main__":
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
num_envs
-
num_envs
//
2
,
dtype
=
np
.
int64
)
np
.
ones
(
num_envs
-
num_envs
//
2
,
dtype
=
np
.
int64
)
])
])
rstate1
=
rstate2
=
init_rnn_state
(
num_envs
,
args
.
rnn_channels
)
if
not
args
.
verbose
:
pbar
=
tqdm
(
total
=
args
.
num_episodes
)
model_time
=
env_time
=
0
model_time
=
env_time
=
0
while
True
:
while
True
:
...
@@ -268,9 +220,8 @@ if __name__ == "__main__":
...
@@ -268,9 +220,8 @@ if __name__ == "__main__":
model_time
=
env_time
=
0
model_time
=
env_time
=
0
_start
=
time
.
time
()
_start
=
time
.
time
()
main
=
next_to_play
==
main_player
main
=
next_to_play
==
main_player
probs
=
predict_fn
(
obs
,
main
)
rstate1
,
rstate2
,
probs
=
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
dones
)
actions
=
probs
.
argmax
(
axis
=
1
)
actions
=
probs
.
argmax
(
axis
=
1
)
model_time
+=
time
.
time
()
-
_start
model_time
+=
time
.
time
()
-
_start
...
@@ -289,14 +240,24 @@ if __name__ == "__main__":
...
@@ -289,14 +240,24 @@ if __name__ == "__main__":
win_reason
=
infos
[
'win_reason'
][
idx
]
win_reason
=
infos
[
'win_reason'
][
idx
]
pl
=
1
if
to_play
[
idx
]
==
main_player
[
idx
]
else
-
1
pl
=
1
if
to_play
[
idx
]
==
main_player
[
idx
]
else
-
1
episode_length
=
infos
[
'l'
][
idx
]
episode_length
=
infos
[
'l'
][
idx
]
episode_reward
=
infos
[
'r'
][
idx
]
*
pl
episode_reward
=
infos
[
'r'
][
idx
]
win
=
int
(
episode_reward
>
0
)
main_reward
=
episode_reward
*
pl
win
=
int
(
main_reward
>
0
)
win_player
=
0
if
(
to_play
[
idx
]
==
0
and
episode_reward
>
0
)
or
(
to_play
[
idx
]
==
1
and
episode_reward
<
0
)
else
1
win_players
.
append
(
win_player
)
win_agent
=
1
if
main_reward
>
0
else
2
win_agents
.
append
(
win_agent
)
episode_lengths
.
append
(
episode_length
)
episode_lengths
.
append
(
episode_length
)
episode_rewards
.
append
(
episode
_reward
)
episode_rewards
.
append
(
main
_reward
)
win_rates
.
append
(
win
)
win_rates
.
append
(
win
)
win_reasons
.
append
(
1
if
win_reason
==
1
else
0
)
win_reasons
.
append
(
1
if
win_reason
==
1
else
0
)
sys
.
stderr
.
write
(
f
"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}
\n
"
)
if
args
.
verbose
:
print
(
f
"Episode {len(episode_lengths)}: length={episode_length}, reward={main_reward}, win={win}, win_reason={win_reason}
\n
"
)
else
:
pbar
.
set_postfix
(
len
=
np
.
mean
(
episode_lengths
),
reward
=
np
.
mean
(
episode_rewards
),
win_rate
=
np
.
mean
(
win_rates
))
pbar
.
update
(
1
)
# Only when num_envs=1, we switch the player here
# Only when num_envs=1, we switch the player here
if
args
.
verbose
:
if
args
.
verbose
:
...
@@ -305,10 +266,38 @@ if __name__ == "__main__":
...
@@ -305,10 +266,38 @@ if __name__ == "__main__":
if
len
(
episode_lengths
)
>=
args
.
num_episodes
:
if
len
(
episode_lengths
)
>=
args
.
num_episodes
:
break
break
if
not
args
.
verbose
:
pbar
.
close
()
print
(
f
"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}"
)
print
(
f
"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}"
)
win_players
=
np
.
array
(
win_players
)
win_agents
=
np
.
array
(
win_agents
)
N
=
len
(
win_players
)
N1
=
np
.
sum
((
win_players
==
0
)
&
(
win_agents
==
1
))
N2
=
np
.
sum
((
win_players
==
0
)
&
(
win_agents
==
2
))
N3
=
np
.
sum
((
win_players
==
1
)
&
(
win_agents
==
1
))
N4
=
np
.
sum
((
win_players
==
1
)
&
(
win_agents
==
2
))
print
(
f
"Payoff matrix:"
)
w1
=
N1
/
N
w2
=
N2
/
N
w3
=
N3
/
N
w4
=
N4
/
N
print
(
f
" agent1 agent2"
)
print
(
f
"0 {w1:.4f} {w2:.4f}"
)
print
(
f
"1 {w3:.4f} {w4:.4f}"
)
print
(
f
"0/1 matrix, win rates of agentX as playerY"
)
w1
=
N1
/
(
N1
+
N4
)
w2
=
N2
/
(
N2
+
N3
)
w3
=
1
-
w2
w4
=
1
-
w1
print
(
f
" agent1 agent2"
)
print
(
f
"0 {w1:.4f} {w2:.4f}"
)
print
(
f
"1 {w3:.4f} {w4:.4f}"
)
total_time
=
time
.
time
()
-
start
total_time
=
time
.
time
()
-
start
total_steps
=
(
step
-
start_step
)
*
num_envs
total_steps
=
(
step
-
start_step
)
*
num_envs
print
(
f
"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}"
)
print
(
f
"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}"
)
print
(
f
"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}"
)
print
(
f
"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}"
)
\ No newline at end of file
scripts/eval.py
View file @
bbd36d86
...
@@ -8,8 +8,6 @@ from dataclasses import dataclass
...
@@ -8,8 +8,6 @@ from dataclasses import dataclass
import
ygoenv
import
ygoenv
import
numpy
as
np
import
numpy
as
np
import
optree
import
tyro
import
tyro
from
ygoai.utils
import
init_ygopro
from
ygoai.utils
import
init_ygopro
...
@@ -63,6 +61,8 @@ class Args:
...
@@ -63,6 +61,8 @@ class Args:
"""the number of layers for the agent"""
"""the number of layers for the agent"""
num_channels
:
int
=
128
num_channels
:
int
=
128
"""the number of channels for the agent"""
"""the number of channels for the agent"""
rnn_channels
:
Optional
[
int
]
=
512
"""the number of rnn channels for the agent"""
checkpoint
:
Optional
[
str
]
=
None
checkpoint
:
Optional
[
str
]
=
None
"""the checkpoint to load, `pt` or `flax_model` file"""
"""the checkpoint to load, `pt` or `flax_model` file"""
...
@@ -70,24 +70,24 @@ class Args:
...
@@ -70,24 +70,24 @@ class Args:
xla_device
:
Optional
[
str
]
=
None
xla_device
:
Optional
[
str
]
=
None
"""the XLA device to use, defaults to `None`"""
"""the XLA device to use, defaults to `None`"""
# PyTorch specific
torch_deterministic
:
bool
=
True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda
:
bool
=
True
"""if toggled, cuda will be enabled by default"""
compile
:
bool
=
False
"""if toggled, the model will be compiled"""
optimize
:
bool
=
True
"""if toggled, the model will be optimized"""
convert
:
bool
=
False
"""if toggled, the model will be converted to a jit model and the program will exit"""
torch_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads
:
Optional
[
int
]
=
16
env_threads
:
Optional
[
int
]
=
16
"""the number of threads to use for envpool, defaults to `num_envs`"""
"""the number of threads to use for envpool, defaults to `num_envs`"""
framework
:
Optional
[
Literal
[
"torch"
,
"jax"
]]
=
None
def
create_agent
(
args
):
return
PPOLSTMAgent
(
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
lstm_channels
=
args
.
rnn_channels
,
embedding_shape
=
args
.
num_embeddings
,
)
def
init_rnn_state
(
num_envs
,
rnn_channels
):
return
(
np
.
zeros
((
num_envs
,
rnn_channels
)),
np
.
zeros
((
num_envs
,
rnn_channels
)),
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -113,18 +113,6 @@ if __name__ == "__main__":
...
@@ -113,18 +113,6 @@ if __name__ == "__main__":
random
.
seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
if
args
.
checkpoint
and
args
.
framework
is
None
:
args
.
framework
=
"jax"
if
"flax_model"
in
args
.
checkpoint
else
"torch"
if
args
.
framework
==
"torch"
:
import
torch
torch
.
manual_seed
(
args
.
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
args
.
torch_deterministic
args
.
torch_threads
=
args
.
torch_threads
or
int
(
os
.
getenv
(
"OMP_NUM_THREADS"
,
"4"
))
torch
.
set_num_threads
(
args
.
torch_threads
)
torch
.
set_float32_matmul_precision
(
'high'
)
else
:
if
args
.
xla_device
is
not
None
:
if
args
.
xla_device
is
not
None
:
os
.
environ
.
setdefault
(
"JAX_PLATFORMS"
,
args
.
xla_device
)
os
.
environ
.
setdefault
(
"JAX_PLATFORMS"
,
args
.
xla_device
)
...
@@ -150,102 +138,46 @@ if __name__ == "__main__":
...
@@ -150,102 +138,46 @@ if __name__ == "__main__":
envs
.
num_envs
=
num_envs
envs
.
num_envs
=
num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
envs
=
RecordEpisodeStatistics
(
envs
)
if
args
.
framework
==
'torch'
:
if
args
.
checkpoint
:
from
ygoai.rl.agent
import
PPOAgent
as
Agent
from
ygoai.rl.buffer
import
create_obs
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
args
.
cuda
else
"cpu"
)
if
args
.
checkpoint
.
endswith
(
".ptj"
):
agent
=
torch
.
jit
.
load
(
args
.
checkpoint
)
else
:
# count lines of code_list
embedding_shape
=
args
.
num_embeddings
if
embedding_shape
is
None
:
with
open
(
args
.
code_list_file
,
"r"
)
as
f
:
code_list
=
f
.
readlines
()
embedding_shape
=
len
(
code_list
)
L
=
args
.
num_layers
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
state_dict
=
torch
.
load
(
args
.
checkpoint
,
map_location
=
device
)
if
not
args
.
compile
:
prefix
=
"_orig_mod."
state_dict
=
{
k
[
len
(
prefix
):]
if
k
.
startswith
(
prefix
)
else
k
:
v
for
k
,
v
in
state_dict
.
items
()}
print
(
agent
.
load_state_dict
(
state_dict
))
if
args
.
compile
:
if
args
.
convert
:
# Don't support dynamic shapes and very slow inference
raise
NotImplementedError
# obs = create_obs(envs.observation_space, (num_envs,), device=device)
# dynamic_shapes = {"x": {}}
# # batch_dim = torch.export.Dim("batch", min=1, max=64)
# batch_dim = None
# for k, v in obs.items():
# dynamic_shapes["x"][k] = {0: batch_dim}
# program = torch.export.export(
# agent, (obs,),
# dynamic_shapes=dynamic_shapes,
# )
# torch.export.save(program, args.checkpoint + "2")
# exit(0)
agent
=
torch
.
compile
(
agent
,
mode
=
'reduce-overhead'
)
elif
args
.
optimize
:
obs
=
create_obs
(
envs
.
observation_space
,
(
num_envs
,),
device
=
device
)
def
optimize_for_inference
(
agent
):
with
torch
.
no_grad
():
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
return
torch
.
jit
.
optimize_for_inference
(
traced_model
)
agent
=
optimize_for_inference
(
agent
)
if
args
.
convert
:
torch
.
jit
.
save
(
agent
,
args
.
checkpoint
+
"j"
)
print
(
f
"Optimized model saved to {args.checkpoint}j"
)
exit
(
0
)
def
predict_fn
(
obs
):
obs
=
optree
.
tree_map
(
lambda
x
:
torch
.
from_numpy
(
x
)
.
to
(
device
=
device
),
obs
)
with
torch
.
no_grad
():
logits
=
agent
(
obs
)[
0
]
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs
=
probs
.
cpu
()
.
numpy
()
return
probs
else
:
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
flax
import
flax
from
ygoai.rl.jax.agent2
import
PPOAgent
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
def
create_agent
(
args
):
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
return
PPOAgent
(
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
channels
=
128
,
num_layers
=
2
,
embedding_shape
=
args
.
num_embeddings
,
)
agent
=
create_agent
(
args
)
agent
=
create_agent
(
args
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
params
=
agent
.
init
(
agent_key
,
sample_obs
)
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
params
=
jax
.
jit
(
agent
.
init
)(
agent_key
,
(
rstate
,
sample_obs
))
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params
=
jax
.
device_put
(
params
)
@
jax
.
jit
@
jax
.
jit
def
get_probs
(
def
get_probs
(
params
,
rstate
,
obs
,
done
):
params
:
flax
.
core
.
FrozenDict
,
agent
=
create_agent
(
args
)
next_obs
,
next_rstate
,
logits
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
2
]
):
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
logits
=
create_agent
(
args
)
.
apply
(
params
,
next_obs
)[
0
]
next_rstate
=
jax
.
tree
.
map
(
return
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
next_rstate
)
return
next_rstate
,
probs
def
predict_fn
(
obs
):
def
predict_fn
(
rstate
,
obs
,
done
):
probs
=
get_probs
(
params
,
obs
)
rstate
,
probs
=
get_probs
(
params
,
rstate
,
obs
,
done
)
return
np
.
array
(
probs
)
return
rstate
,
np
.
array
(
probs
)
print
(
f
"loaded checkpoint from {args.checkpoint}"
)
print
(
f
"loaded checkpoint from {args.checkpoint}"
)
obs
,
infos
=
envs
.
reset
()
obs
,
infos
=
envs
.
reset
()
next_to_play
=
infos
[
'to_play'
]
next_to_play
=
infos
[
'to_play'
]
dones
=
np
.
zeros
(
num_envs
,
dtype
=
np
.
bool_
)
episode_rewards
=
[]
episode_rewards
=
[]
episode_lengths
=
[]
episode_lengths
=
[]
...
@@ -263,9 +195,9 @@ if __name__ == "__main__":
...
@@ -263,9 +195,9 @@ if __name__ == "__main__":
start_step
=
step
start_step
=
step
model_time
=
env_time
=
0
model_time
=
env_time
=
0
if
args
.
framework
:
if
args
.
checkpoint
:
_start
=
time
.
time
()
_start
=
time
.
time
()
probs
=
predict_fn
(
ob
s
)
rstate
,
probs
=
predict_fn
(
rstate
,
obs
,
done
s
)
if
args
.
verbose
:
if
args
.
verbose
:
print
([
f
"{p:.4f}"
for
p
in
probs
[
probs
!=
0
]
.
tolist
()])
print
([
f
"{p:.4f}"
for
p
in
probs
[
probs
!=
0
]
.
tolist
()])
actions
=
probs
.
argmax
(
axis
=
1
)
actions
=
probs
.
argmax
(
axis
=
1
)
...
@@ -306,4 +238,3 @@ if __name__ == "__main__":
...
@@ -306,4 +238,3 @@ if __name__ == "__main__":
total_steps
=
(
step
-
start_step
)
*
num_envs
total_steps
=
(
step
-
start_step
)
*
num_envs
print
(
f
"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}"
)
print
(
f
"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}"
)
print
(
f
"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}"
)
print
(
f
"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}"
)
\ No newline at end of file
scripts/
jax/
impala.py
→
scripts/impala.py
View file @
bbd36d86
File moved
scripts/jax/battle.py
deleted
100644 → 0
View file @
43ca871e
import
sys
import
time
import
os
import
random
from
typing
import
Optional
,
Literal
from
dataclasses
import
dataclass
from
tqdm
import
tqdm
import
ygoenv
import
numpy
as
np
import
tyro
import
jax
import
jax.numpy
as
jnp
import
flax
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
@
dataclass
class
Args
:
seed
:
int
=
1
"""the random seed"""
env_id
:
str
=
"YGOPro-v0"
"""the id of the environment"""
deck
:
str
=
"../assets/deck"
"""the deck file to use"""
deck1
:
Optional
[
str
]
=
None
"""the deck file for the first player"""
deck2
:
Optional
[
str
]
=
None
"""the deck file for the second player"""
code_list_file
:
str
=
"code_list.txt"
"""the code list file for card embeddings"""
lang
:
str
=
"english"
"""the language to use"""
max_options
:
int
=
24
"""the maximum number of options"""
n_history_actions
:
int
=
32
"""the number of history actions to use"""
num_embeddings
:
Optional
[
int
]
=
None
"""the number of embeddings of the agent"""
record
:
bool
=
False
"""whether to record the game as YGOPro replays"""
num_episodes
:
int
=
1024
"""the number of episodes to run"""
num_envs
:
int
=
64
"""the number of parallel game environments"""
verbose
:
bool
=
False
"""whether to print debug information"""
num_layers
:
int
=
2
"""the number of layers for the agent"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
rnn_channels
:
Optional
[
int
]
=
512
"""the number of rnn channels for the agent"""
checkpoint1
:
str
=
"checkpoints/agent.pt"
"""the checkpoint to load for the first agent, `pt` or `flax_model` file"""
checkpoint2
:
str
=
"checkpoints/agent.pt"
"""the checkpoint to load for the second agent, `pt` or `flax_model` file"""
# Jax specific
xla_device
:
Optional
[
str
]
=
None
"""the XLA device to use, defaults to `None`"""
# PyTorch specific
torch_deterministic
:
bool
=
True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda
:
bool
=
True
"""if toggled, cuda will be enabled by default"""
compile
:
bool
=
False
"""if toggled, the model will be compiled"""
optimize
:
bool
=
False
"""if toggled, the model will be optimized"""
torch_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads
:
Optional
[
int
]
=
16
"""the number of threads to use for envpool, defaults to `num_envs`"""
framework
:
Optional
[
Literal
[
"torch"
,
"jax"
]]
=
None
def
create_agent
(
args
):
return
PPOLSTMAgent
(
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
lstm_channels
=
args
.
rnn_channels
,
embedding_shape
=
args
.
num_embeddings
,
)
def
init_rnn_state
(
num_envs
,
rnn_channels
):
return
(
np
.
zeros
((
num_envs
,
rnn_channels
)),
np
.
zeros
((
num_envs
,
rnn_channels
)),
)
if
__name__
==
"__main__"
:
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
args
=
tyro
.
cli
(
Args
)
if
args
.
record
:
assert
args
.
num_envs
==
1
,
"Recording only works with a single environment"
assert
args
.
verbose
,
"Recording only works with verbose mode"
if
not
os
.
path
.
exists
(
"replay"
):
os
.
makedirs
(
"replay"
)
args
.
env_threads
=
min
(
args
.
env_threads
or
args
.
num_envs
,
args
.
num_envs
)
deck
=
init_ygopro
(
args
.
env_id
,
args
.
lang
,
args
.
deck
,
args
.
code_list_file
)
args
.
deck1
=
args
.
deck1
or
deck
args
.
deck2
=
args
.
deck2
or
deck
seed
=
args
.
seed
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
if
args
.
xla_device
is
not
None
:
os
.
environ
.
setdefault
(
"JAX_PLATFORMS"
,
args
.
xla_device
)
num_envs
=
args
.
num_envs
envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
num_envs
,
num_threads
=
args
.
env_threads
,
seed
=
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
player
=-
1
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
'self'
,
async_reset
=
False
,
verbose
=
args
.
verbose
,
record
=
args
.
record
,
)
obs_space
=
envs
.
observation_space
envs
.
num_envs
=
num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
agent
=
create_agent
(
args
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
params
=
jax
.
jit
(
agent
.
init
)(
agent_key
,
(
rstate
,
sample_obs
))
with
open
(
args
.
checkpoint1
,
"rb"
)
as
f
:
params1
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
if
args
.
checkpoint1
==
args
.
checkpoint2
:
params2
=
params1
else
:
with
open
(
args
.
checkpoint2
,
"rb"
)
as
f
:
params2
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params1
=
jax
.
device_put
(
params1
)
params2
=
jax
.
device_put
(
params2
)
@
jax
.
jit
def
get_probs
(
params
,
rstate
,
obs
,
done
):
agent
=
create_agent
(
args
)
next_rstate
,
logits
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
2
]
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
next_rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
next_rstate
)
return
next_rstate
,
probs
if
args
.
num_envs
!=
1
:
@
jax
.
jit
def
get_probs2
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
probs1
=
get_probs
(
params1
,
rstate1
,
obs
,
done
)
next_rstate2
,
probs2
=
get_probs
(
params2
,
rstate2
,
obs
,
done
)
probs
=
jnp
.
where
(
main
[:,
None
],
probs1
,
probs2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
next_rstate2
,
rstate2
)
return
rstate1
,
rstate2
,
probs
def
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
done
):
rstate1
,
rstate2
,
probs
=
get_probs2
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
)
return
rstate1
,
rstate2
,
np
.
array
(
probs
)
else
:
def
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
done
):
if
main
[
0
]:
rstate1
,
probs
=
get_probs
(
params1
,
rstate1
,
obs
,
done
)
else
:
rstate2
,
probs
=
get_probs
(
params2
,
rstate2
,
obs
,
done
)
return
rstate1
,
rstate2
,
np
.
array
(
probs
)
obs
,
infos
=
envs
.
reset
()
next_to_play
=
infos
[
'to_play'
]
dones
=
np
.
zeros
(
num_envs
,
dtype
=
np
.
bool_
)
episode_rewards
=
[]
episode_lengths
=
[]
win_rates
=
[]
win_reasons
=
[]
step
=
0
start
=
time
.
time
()
start_step
=
step
main_player
=
np
.
concatenate
([
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
num_envs
-
num_envs
//
2
,
dtype
=
np
.
int64
)
])
rstate1
=
rstate2
=
init_rnn_state
(
num_envs
,
args
.
rnn_channels
)
if
not
args
.
verbose
:
pbar
=
tqdm
(
total
=
args
.
num_episodes
)
model_time
=
env_time
=
0
while
True
:
if
start_step
==
0
and
len
(
episode_lengths
)
>
int
(
args
.
num_episodes
*
0.1
):
start
=
time
.
time
()
start_step
=
step
model_time
=
env_time
=
0
_start
=
time
.
time
()
main
=
next_to_play
==
main_player
rstate1
,
rstate2
,
probs
=
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
dones
)
actions
=
probs
.
argmax
(
axis
=
1
)
model_time
+=
time
.
time
()
-
_start
to_play
=
next_to_play
_start
=
time
.
time
()
obs
,
rewards
,
dones
,
infos
=
envs
.
step
(
actions
)
next_to_play
=
infos
[
'to_play'
]
env_time
+=
time
.
time
()
-
_start
step
+=
1
for
idx
,
d
in
enumerate
(
dones
):
if
d
:
win_reason
=
infos
[
'win_reason'
][
idx
]
pl
=
1
if
to_play
[
idx
]
==
main_player
[
idx
]
else
-
1
episode_length
=
infos
[
'l'
][
idx
]
episode_reward
=
infos
[
'r'
][
idx
]
*
pl
win
=
int
(
episode_reward
>
0
)
episode_lengths
.
append
(
episode_length
)
episode_rewards
.
append
(
episode_reward
)
win_rates
.
append
(
win
)
win_reasons
.
append
(
1
if
win_reason
==
1
else
0
)
if
args
.
verbose
:
print
(
f
"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}
\n
"
)
else
:
pbar
.
set_postfix
(
len
=
np
.
mean
(
episode_lengths
),
reward
=
np
.
mean
(
episode_rewards
),
win_rate
=
np
.
mean
(
win_rates
))
pbar
.
update
(
1
)
# Only when num_envs=1, we switch the player here
if
args
.
verbose
:
main_player
=
1
-
main_player
if
len
(
episode_lengths
)
>=
args
.
num_episodes
:
break
if
not
args
.
verbose
:
pbar
.
close
()
print
(
f
"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}"
)
total_time
=
time
.
time
()
-
start
total_steps
=
(
step
-
start_step
)
*
num_envs
print
(
f
"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}"
)
print
(
f
"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}"
)
\ No newline at end of file
scripts/jax/ppo.py
deleted
100644 → 0
View file @
43ca871e
import
os
import
queue
import
random
import
threading
import
time
from
datetime
import
datetime
,
timedelta
,
timezone
from
collections
import
deque
from
dataclasses
import
dataclass
,
field
from
types
import
SimpleNamespace
from
typing
import
List
,
NamedTuple
,
Optional
from
functools
import
partial
import
ygoenv
import
flax
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
import
optax
import
distrax
import
tyro
from
flax.training.train_state
import
TrainState
from
rich.pretty
import
pprint
from
tensorboardX
import
SummaryWriter
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
from
ygoai.rl.jax.switch
import
truncated_gae_2p0s
os
.
environ
[
"XLA_FLAGS"
]
=
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
@
dataclass
class
Args
:
exp_name
:
str
=
os
.
path
.
basename
(
__file__
)
.
rstrip
(
".py"
)
"""the name of this experiment"""
seed
:
int
=
1
"""seed of the experiment"""
log_frequency
:
int
=
10
"""the logging frequency of the model performance (in terms of `updates`)"""
save_interval
:
int
=
400
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments
env_id
:
str
=
"YGOPro-v0"
"""the id of the environment"""
deck
:
str
=
"../assets/deck"
"""the deck file to use"""
deck1
:
Optional
[
str
]
=
None
"""the deck file for the first player"""
deck2
:
Optional
[
str
]
=
None
"""the deck file for the second player"""
code_list_file
:
str
=
"code_list.txt"
"""the code list file for card embeddings"""
embedding_file
:
Optional
[
str
]
=
None
"""the embedding file for card embeddings"""
max_options
:
int
=
24
"""the maximum number of options"""
n_history_actions
:
int
=
32
"""the number of history actions to use"""
greedy_reward
:
bool
=
True
"""whether to use greedy reward (faster kill higher reward)"""
total_timesteps
:
int
=
5000000000
"""total timesteps of the experiments"""
learning_rate
:
float
=
1e-3
"""the learning rate of the optimizer"""
local_num_envs
:
int
=
128
"""the number of parallel game environments"""
local_env_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for environment"""
num_actor_threads
:
int
=
2
"""the number of actor threads to use"""
num_steps
:
int
=
128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr
:
bool
=
False
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
1.0
"""the discount factor gamma"""
gae_lambda
:
float
=
0.95
"""the lambda for the general advantage estimation"""
upgo
:
bool
=
False
"""Toggle the use of UPGO for advantages"""
num_minibatches
:
int
=
8
"""the number of mini-batches"""
update_epochs
:
int
=
2
"""the K epochs to update the policy"""
norm_adv
:
bool
=
False
"""Toggles advantages normalization"""
clip_coef
:
float
=
0.25
"""the surrogate clipping coefficient"""
dual_clip_coef
:
Optional
[
float
]
=
None
"""the dual surrogate clipping coefficient, typically 3.0"""
spo_kld_max
:
Optional
[
float
]
=
None
"""the maximum KLD for the SPO policy, typically 0.02"""
ent_coef
:
float
=
0.01
"""coefficient of the entropy"""
vf_coef
:
float
=
0.5
"""coefficient of the value function"""
max_grad_norm
:
float
=
1.0
"""the maximum norm for the gradient clipping"""
num_layers
:
int
=
2
"""the number of layers for the agent"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
rnn_channels
:
int
=
512
"""the number of channels for the RNN in the agent"""
actor_device_ids
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
0
,
1
])
"""the device ids that actor workers will use"""
learner_device_ids
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
2
,
3
])
"""the device ids that learner workers will use"""
distributed
:
bool
=
False
"""whether to use `jax.distirbuted`"""
concurrency
:
bool
=
True
"""whether to run the actor and learner concurrently"""
bfloat16
:
bool
=
False
"""whether to use bfloat16 for the agent"""
thread_affinity
:
bool
=
False
"""whether to use thread affinity for the environment"""
eval_checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to evaluate"""
local_eval_episodes
:
int
=
32
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
50
"""the number of iterations to evaluate the model"""
# runtime arguments to be filled in
local_batch_size
:
int
=
0
local_minibatch_size
:
int
=
0
world_size
:
int
=
0
local_rank
:
int
=
0
num_envs
:
int
=
0
batch_size
:
int
=
0
minibatch_size
:
int
=
0
num_updates
:
int
=
0
global_learner_decices
:
Optional
[
List
[
str
]]
=
None
actor_devices
:
Optional
[
List
[
str
]]
=
None
learner_devices
:
Optional
[
List
[
str
]]
=
None
num_embeddings
:
Optional
[
int
]
=
None
freeze_id
:
bool
=
False
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
,
eval
=
False
):
if
not
args
.
thread_affinity
:
thread_affinity_offset
=
-
1
if
thread_affinity_offset
>=
0
:
print
(
"Binding to thread offset"
,
thread_affinity_offset
)
envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
num_envs
,
num_threads
=
num_threads
,
thread_affinity_offset
=
thread_affinity_offset
,
seed
=
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
async_reset
=
False
,
greedy_reward
=
args
.
greedy_reward
if
not
eval
else
True
,
play_mode
=
mode
,
)
envs
.
num_envs
=
num_envs
return
envs
class
Transition
(
NamedTuple
):
obs
:
list
dones
:
list
actions
:
list
logits
:
list
rewards
:
list
mains
:
list
next_dones
:
list
def
create_agent
(
args
,
multi_step
=
False
):
return
PPOLSTMAgent
(
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
embedding_shape
=
args
.
num_embeddings
,
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
lstm_channels
=
args
.
rnn_channels
,
multi_step
=
multi_step
,
freeze_id
=
args
.
freeze_id
,
)
def
init_rnn_state
(
num_envs
,
rnn_channels
):
return
(
np
.
zeros
((
num_envs
,
rnn_channels
)),
np
.
zeros
((
num_envs
,
rnn_channels
)),
)
def
rollout
(
key
:
jax
.
random
.
PRNGKey
,
args
:
Args
,
rollout_queue
,
params_queue
,
eval_queue
,
writer
,
learner_devices
,
device_thread_id
,
):
eval_mode
=
'self'
if
args
.
eval_checkpoint
else
'bot'
if
eval_mode
!=
'bot'
:
eval_params
=
params_queue
.
get
()
envs
=
make_env
(
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
local_num_envs
,
args
.
local_env_threads
,
thread_affinity_offset
=
device_thread_id
*
args
.
local_env_threads
,
)
envs
=
RecordEpisodeStatistics
(
envs
)
eval_envs
=
make_env
(
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
local_eval_episodes
,
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
,
eval
=
True
)
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
n_actors
=
args
.
num_actor_threads
*
len_actor_device_ids
global_step
=
0
start_time
=
time
.
time
()
warmup_step
=
0
other_time
=
0
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
@
jax
.
jit
def
get_logits
(
params
:
flax
.
core
.
FrozenDict
,
inputs
,
done
):
rstate
,
logits
=
create_agent
(
args
)
.
apply
(
params
,
inputs
)[:
2
]
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
return
rstate
,
logits
@
jax
.
jit
def
get_action
(
params
:
flax
.
core
.
FrozenDict
,
inputs
):
batch_size
=
jax
.
tree
.
leaves
(
inputs
)[
0
]
.
shape
[
0
]
done
=
jnp
.
zeros
(
batch_size
,
dtype
=
jnp
.
bool_
)
rstate
,
logits
=
get_logits
(
params
,
inputs
,
done
)
return
rstate
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
get_action_battle
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
),
done
)
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
),
done
)
logits
=
jnp
.
where
(
main
[:,
None
],
logits1
,
logits2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
next_rstate2
,
rstate2
)
return
rstate1
,
rstate2
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
sample_action
(
params
:
flax
.
core
.
FrozenDict
,
next_obs
,
rstate1
,
rstate2
,
main
,
done
,
key
):
next_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
(
x
),
next_obs
)
done
=
jnp
.
array
(
done
)
main
=
jnp
.
array
(
main
)
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
logits
=
get_logits
(
params
,
(
rstate
,
next_obs
),
done
)
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
action
,
key
=
categorical_sample
(
logits
,
key
)
return
next_obs
,
done
,
main
,
rstate1
,
rstate2
,
action
,
logits
,
key
# put data in the last index
params_queue_get_time
=
deque
(
maxlen
=
10
)
rollout_time
=
deque
(
maxlen
=
10
)
actor_policy_version
=
0
next_obs
,
info
=
envs
.
reset
()
next_to_play
=
info
[
"to_play"
]
next_done
=
np
.
zeros
(
args
.
local_num_envs
,
dtype
=
np
.
bool_
)
next_rstate1
=
next_rstate2
=
init_rnn_state
(
args
.
local_num_envs
,
args
.
rnn_channels
)
eval_rstate
=
init_rnn_state
(
args
.
local_eval_episodes
,
args
.
rnn_channels
)
main_player
=
np
.
concatenate
([
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
np
.
random
.
shuffle
(
main_player
)
storage
=
[]
@
jax
.
jit
def
prepare_data
(
storage
:
List
[
Transition
])
->
Transition
:
return
jax
.
tree
.
map
(
lambda
*
xs
:
jnp
.
split
(
jnp
.
stack
(
xs
),
len
(
learner_devices
),
axis
=
1
),
*
storage
)
for
update
in
range
(
1
,
args
.
num_updates
+
2
):
if
update
==
10
:
start_time
=
time
.
time
()
warmup_step
=
global_step
update_time_start
=
time
.
time
()
inference_time
=
0
env_time
=
0
params_queue_get_time_start
=
time
.
time
()
if
args
.
concurrency
:
if
update
!=
2
:
params
=
params_queue
.
get
()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version
+=
1
else
:
params
=
params_queue
.
get
()
actor_policy_version
+=
1
params_queue_get_time
.
append
(
time
.
time
()
-
params_queue_get_time_start
)
rollout_time_start
=
time
.
time
()
init_rstate1
,
init_rstate2
=
jax
.
tree
.
map
(
lambda
x
:
x
.
copy
(),
(
next_rstate1
,
next_rstate2
))
for
_
in
range
(
args
.
num_steps
):
global_step
+=
args
.
local_num_envs
*
n_actors
*
args
.
world_size
cached_next_obs
=
next_obs
cached_next_done
=
next_done
main
=
next_to_play
==
main_player
inference_time_start
=
time
.
time
()
cached_next_obs
,
cached_next_done
,
cached_main
,
\
next_rstate1
,
next_rstate2
,
action
,
logits
,
key
=
sample_action
(
params
,
cached_next_obs
,
next_rstate1
,
next_rstate2
,
main
,
cached_next_done
,
key
)
cpu_action
=
np
.
array
(
action
)
inference_time
+=
time
.
time
()
-
inference_time_start
_start
=
time
.
time
()
next_obs
,
next_reward
,
next_done
,
info
=
envs
.
step
(
cpu_action
)
next_to_play
=
info
[
"to_play"
]
env_time
+=
time
.
time
()
-
_start
storage
.
append
(
Transition
(
obs
=
cached_next_obs
,
dones
=
cached_next_done
,
mains
=
cached_main
,
actions
=
action
,
logits
=
logits
,
rewards
=
next_reward
,
next_dones
=
next_done
,
)
)
for
idx
,
d
in
enumerate
(
next_done
):
if
not
d
:
continue
cur_main
=
main
[
idx
]
for
j
in
reversed
(
range
(
len
(
storage
)
-
1
)):
t
=
storage
[
j
]
if
t
.
next_dones
[
idx
]:
# For OTK where player may not switch
break
if
t
.
mains
[
idx
]
!=
cur_main
:
t
.
next_dones
[
idx
]
=
True
t
.
rewards
[
idx
]
=
-
next_reward
[
idx
]
break
episode_reward
=
info
[
'r'
][
idx
]
*
(
1
if
cur_main
else
-
1
)
win
=
1
if
episode_reward
>
0
else
0
avg_ep_returns
.
append
(
episode_reward
)
avg_win_rates
.
append
(
win
)
rollout_time
.
append
(
time
.
time
()
-
rollout_time_start
)
partitioned_storage
=
prepare_data
(
storage
)
storage
=
[]
sharded_storage
=
[]
for
x
in
partitioned_storage
:
if
isinstance
(
x
,
dict
):
x
=
{
k
:
jax
.
device_put_sharded
(
v
,
devices
=
learner_devices
)
for
k
,
v
in
x
.
items
()
}
else
:
x
=
jax
.
device_put_sharded
(
x
,
devices
=
learner_devices
)
sharded_storage
.
append
(
x
)
sharded_storage
=
Transition
(
*
sharded_storage
)
next_main
=
main_player
==
next_to_play
next_rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
next_rstate1
,
next_rstate2
)
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
(
init_rstate1
,
init_rstate2
,
(
next_rstate
,
next_obs
),
next_main
))
learn_opponent
=
False
payload
=
(
global_step
,
update
,
sharded_storage
,
*
sharded_data
,
np
.
mean
(
params_queue_get_time
),
learn_opponent
,
)
rollout_queue
.
put
(
payload
)
if
update
%
args
.
log_frequency
==
0
:
avg_episodic_return
=
np
.
mean
(
avg_ep_returns
)
avg_episodic_length
=
np
.
mean
(
envs
.
returned_episode_lengths
)
SPS
=
int
((
global_step
-
warmup_step
)
/
(
time
.
time
()
-
start_time
-
other_time
))
SPS_update
=
int
(
args
.
batch_size
/
(
time
.
time
()
-
update_time_start
))
if
device_thread_id
==
0
:
print
(
f
"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
)
time_now
=
datetime
.
now
(
timezone
(
timedelta
(
hours
=
8
)))
.
strftime
(
"
%
H:
%
M:
%
S"
)
print
(
f
"{time_now} SPS: {SPS}, update: {SPS_update}, "
f
"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
writer
.
add_scalar
(
"stats/rollout_time"
,
np
.
mean
(
rollout_time
),
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_return"
,
avg_episodic_return
,
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_length"
,
avg_episodic_length
,
global_step
)
writer
.
add_scalar
(
"stats/params_queue_get_time"
,
np
.
mean
(
params_queue_get_time
),
global_step
)
writer
.
add_scalar
(
"stats/inference_time"
,
inference_time
,
global_step
)
writer
.
add_scalar
(
"stats/env_time"
,
env_time
,
global_step
)
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
writer
.
add_scalar
(
"charts/SPS_update"
,
SPS_update
,
global_step
)
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
# Eval with rule-based policy
_start
=
time
.
time
()
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
eval_stat
=
np
.
array
([
eval_return
,
eval_win_rate
])
if
device_thread_id
!=
0
:
eval_queue
.
put
(
eval_stat
)
else
:
eval_stats
=
[]
eval_stats
.
append
(
eval_stat
)
for
_
in
range
(
1
,
n_actors
):
eval_stats
.
append
(
eval_queue
.
get
())
eval_stats
=
np
.
stack
(
eval_stats
)
eval_return
,
eval_win_rate
=
np
.
mean
(
eval_stats
,
axis
=
0
)
writer
.
add_scalar
(
f
"charts/eval_return"
,
eval_return
,
global_step
)
writer
.
add_scalar
(
f
"charts/eval_win_rate"
,
eval_win_rate
,
global_step
)
if
device_thread_id
==
0
:
eval_time
=
time
.
time
()
-
_start
other_time
+=
eval_time
print
(
f
"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}"
)
if
__name__
==
"__main__"
:
args
=
tyro
.
cli
(
Args
)
args
.
local_batch_size
=
int
(
args
.
local_num_envs
*
args
.
num_steps
*
args
.
num_actor_threads
*
len
(
args
.
actor_device_ids
))
args
.
local_minibatch_size
=
int
(
args
.
local_batch_size
//
args
.
num_minibatches
)
assert
(
args
.
local_num_envs
%
len
(
args
.
learner_device_ids
)
==
0
),
"local_num_envs must be divisible by len(learner_device_ids)"
assert
(
int
(
args
.
local_num_envs
/
len
(
args
.
learner_device_ids
))
*
args
.
num_actor_threads
%
args
.
num_minibatches
==
0
),
"int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if
args
.
distributed
:
jax
.
distributed
.
initialize
(
local_device_ids
=
range
(
len
(
args
.
learner_device_ids
)
+
len
(
args
.
actor_device_ids
)),
)
print
(
list
(
range
(
len
(
args
.
learner_device_ids
)
+
len
(
args
.
actor_device_ids
))))
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
args
.
world_size
=
jax
.
process_count
()
args
.
local_rank
=
jax
.
process_index
()
args
.
num_envs
=
args
.
local_num_envs
*
args
.
world_size
*
args
.
num_actor_threads
*
len
(
args
.
actor_device_ids
)
args
.
batch_size
=
args
.
local_batch_size
*
args
.
world_size
args
.
minibatch_size
=
args
.
local_minibatch_size
*
args
.
world_size
args
.
num_updates
=
args
.
total_timesteps
//
(
args
.
local_batch_size
*
args
.
world_size
)
args
.
local_env_threads
=
args
.
local_env_threads
or
args
.
local_num_envs
if
args
.
embedding_file
:
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
embedding_shape
=
embeddings
.
shape
args
.
num_embeddings
=
embedding_shape
args
.
freeze_id
=
True
if
args
.
freeze_id
is
None
else
args
.
freeze_id
else
:
embeddings
=
None
embedding_shape
=
None
local_devices
=
jax
.
local_devices
()
global_devices
=
jax
.
devices
()
learner_devices
=
[
local_devices
[
d_id
]
for
d_id
in
args
.
learner_device_ids
]
actor_devices
=
[
local_devices
[
d_id
]
for
d_id
in
args
.
actor_device_ids
]
global_learner_decices
=
[
global_devices
[
d_id
+
process_index
*
len
(
local_devices
)]
for
process_index
in
range
(
args
.
world_size
)
for
d_id
in
args
.
learner_device_ids
]
print
(
"global_learner_decices"
,
global_learner_decices
)
args
.
global_learner_decices
=
[
str
(
item
)
for
item
in
global_learner_decices
]
args
.
actor_devices
=
[
str
(
item
)
for
item
in
actor_devices
]
args
.
learner_devices
=
[
str
(
item
)
for
item
in
learner_devices
]
pprint
(
args
)
timestamp
=
int
(
time
.
time
())
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer
=
SummaryWriter
(
f
"runs/{run_name}"
)
writer
.
add_text
(
"hyperparameters"
,
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
)
# seeding
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
learner_keys
=
jax
.
device_put_replicated
(
key
,
learner_devices
)
deck
=
init_ygopro
(
args
.
env_id
,
"english"
,
args
.
deck
,
args
.
code_list_file
)
args
.
deck1
=
args
.
deck1
or
deck
args
.
deck2
=
args
.
deck2
or
deck
# env setup
envs
=
make_env
(
args
,
args
.
seed
,
8
,
1
)
obs_space
=
envs
.
observation_space
action_shape
=
envs
.
action_space
.
shape
print
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
envs
.
close
()
del
envs
def
linear_schedule
(
count
):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac
=
1.0
-
(
count
//
(
args
.
num_minibatches
*
args
.
update_epochs
))
/
args
.
num_updates
return
args
.
learning_rate
*
frac
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
agent
=
create_agent
(
args
)
params
=
agent
.
init
(
agent_key
,
(
rstate
,
sample_obs
))
if
embeddings
is
not
None
:
unknown_embed
=
embeddings
.
mean
(
axis
=
0
)
embeddings
=
np
.
concatenate
([
unknown_embed
[
None
,
:],
embeddings
],
axis
=
0
)
params
=
flax
.
core
.
unfreeze
(
params
)
params
[
'params'
][
'Encoder_0'
][
'Embed_0'
][
'embedding'
]
=
jax
.
device_put
(
embeddings
)
params
=
flax
.
core
.
freeze
(
params
)
tx
=
optax
.
MultiSteps
(
optax
.
chain
(
optax
.
clip_by_global_norm
(
args
.
max_grad_norm
),
optax
.
inject_hyperparams
(
optax
.
adam
)(
learning_rate
=
linear_schedule
if
args
.
anneal_lr
else
args
.
learning_rate
,
eps
=
1e-5
),
),
every_k_schedule
=
1
,
)
agent_state
=
TrainState
.
create
(
apply_fn
=
None
,
params
=
params
,
tx
=
tx
,
)
if
args
.
checkpoint
:
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
agent_state
=
agent_state
.
replace
(
params
=
params
)
print
(
f
"loaded checkpoint from {args.checkpoint}"
)
agent_state
=
flax
.
jax_utils
.
replicate
(
agent_state
,
devices
=
learner_devices
)
# print(agent.tabulate(agent_key, sample_obs))
if
args
.
eval_checkpoint
:
with
open
(
args
.
eval_checkpoint
,
"rb"
)
as
f
:
eval_params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
print
(
f
"loaded eval checkpoint from {args.eval_checkpoint}"
)
else
:
eval_params
=
None
@
jax
.
jit
def
get_logits_and_value
(
params
:
flax
.
core
.
FrozenDict
,
inputs
,
):
rstate
,
logits
,
value
,
valid
=
create_agent
(
args
,
multi_step
=
True
)
.
apply
(
params
,
inputs
)
return
logits
,
value
.
squeeze
(
-
1
)
def
ppo_loss
(
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
switch
,
actions
,
logits
,
rewards
,
mask
,
next_value
):
# (num_steps * local_num_envs // n_mb))
num_envs
=
next_value
.
shape
[
0
]
num_steps
=
dones
.
shape
[
0
]
//
num_envs
mask
=
mask
*
(
1.0
-
dones
)
n_valids
=
jnp
.
sum
(
mask
)
real_dones
=
dones
|
next_dones
inputs
=
(
rstate1
,
rstate2
,
obs
,
real_dones
,
switch
)
new_logits
,
new_values
=
get_logits_and_value
(
params
,
inputs
)
new_values_
,
rewards
,
next_dones
,
switch
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)
+
x
.
shape
[
1
:]),
(
new_values
,
rewards
,
next_dones
,
switch
),
)
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
target_values
,
advantages
=
truncated_gae_2p0s
(
next_value
,
new_values_
,
rewards
,
next_dones
,
switch
,
args
.
gamma
,
args
.
gae_lambda
,
args
.
upgo
)
target_values
,
advantages
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)),
(
target_values
,
advantages
))
logratio
=
jnp
.
log
(
ratios
)
approx_kl
=
(((
ratios
-
1
)
-
logratio
)
*
mask
)
.
sum
()
/
n_valids
if
args
.
norm_adv
:
advantages
=
masked_normalize
(
advantages
,
mask
,
eps
=
1e-8
)
# Policy loss
if
args
.
spo_kld_max
is
not
None
:
pg_loss
=
simple_policy_loss
(
ratios
,
logits
,
new_logits
,
advantages
,
args
.
spo_kld_max
)
else
:
pg_loss
=
clipped_surrogate_pg_loss
(
ratios
,
advantages
,
args
.
clip_coef
,
args
.
dual_clip_coef
)
pg_loss
=
jnp
.
sum
(
pg_loss
*
mask
)
v_loss
=
mse_loss
(
new_values
,
target_values
)
v_loss
=
jnp
.
sum
(
v_loss
*
mask
)
ent_loss
=
entropy_loss
(
new_logits
)
ent_loss
=
jnp
.
sum
(
ent_loss
*
mask
)
pg_loss
=
pg_loss
/
n_valids
v_loss
=
v_loss
/
n_valids
ent_loss
=
ent_loss
/
n_valids
loss
=
pg_loss
-
args
.
ent_coef
*
ent_loss
+
v_loss
*
args
.
vf_coef
return
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
jax
.
lax
.
stop_gradient
(
approx_kl
))
def
single_device_update
(
agent_state
:
TrainState
,
sharded_storages
:
List
,
sharded_init_rstate1
:
List
,
sharded_init_rstate2
:
List
,
sharded_next_inputs
:
List
,
sharded_next_main
:
List
,
key
:
jax
.
random
.
PRNGKey
,
learn_opponent
:
bool
=
False
,
):
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
next_inputs
,
init_rstate1
,
init_rstate2
=
[
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
for
x
in
[
sharded_next_inputs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
]
next_main
,
=
[
jnp
.
concatenate
(
x
)
for
x
in
[
sharded_next_main
]
]
# reorder storage of individual players
# main first, opponent second
num_steps
,
num_envs
=
storage
.
rewards
.
shape
T
=
jnp
.
arange
(
num_steps
,
dtype
=
jnp
.
int32
)
B
=
jnp
.
arange
(
num_envs
,
dtype
=
jnp
.
int32
)
mains
=
storage
.
mains
.
astype
(
jnp
.
int32
)
indices
=
jnp
.
argsort
(
T
[:,
None
]
-
mains
*
num_steps
,
axis
=
0
)
switch_steps
=
jnp
.
sum
(
mains
,
axis
=
0
)
switch
=
T
[:,
None
]
==
(
switch_steps
[
None
,
:]
-
1
)
storage
=
jax
.
tree
.
map
(
lambda
x
:
x
[
indices
,
B
[
None
,
:]],
storage
)
ppo_loss_grad_fn
=
jax
.
value_and_grad
(
ppo_loss
,
has_aux
=
True
)
def
update_epoch
(
carry
,
_
):
agent_state
,
key
=
carry
key
,
subkey
=
jax
.
random
.
split
(
key
)
next_value
=
create_agent
(
args
)
.
apply
(
agent_state
.
params
,
next_inputs
)[
2
]
.
squeeze
(
-
1
)
# TODO: check if this is correct
sign
=
jnp
.
where
(
switch_steps
<=
num_steps
,
1.0
,
-
1.0
)
next_value
=
jnp
.
where
(
next_main
,
-
sign
*
next_value
,
sign
*
next_value
)
def
convert_data
(
x
:
jnp
.
ndarray
,
num_steps
):
if
args
.
update_epochs
>
1
:
x
=
jax
.
random
.
permutation
(
subkey
,
x
,
axis
=
1
if
num_steps
>
1
else
0
)
N
=
args
.
num_minibatches
if
num_steps
>
1
:
x
=
jnp
.
reshape
(
x
,
(
num_steps
,
N
,
-
1
)
+
x
.
shape
[
2
:])
x
=
x
.
transpose
(
1
,
0
,
*
range
(
2
,
x
.
ndim
))
x
=
x
.
reshape
(
N
,
-
1
,
*
x
.
shape
[
3
:])
else
:
x
=
jnp
.
reshape
(
x
,
(
N
,
-
1
)
+
x
.
shape
[
1
:])
return
x
shuffled_init_rstate1
,
shuffled_init_rstate2
,
\
shuffled_next_value
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
1
),
(
init_rstate1
,
init_rstate2
,
next_value
))
shuffled_storage
,
shuffled_switch
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
num_steps
),
(
storage
,
switch
))
shuffled_mask
=
jnp
.
ones_like
(
shuffled_storage
.
mains
)
def
update_minibatch
(
agent_state
,
minibatch
):
(
loss
,
(
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)),
grads
=
ppo_loss_grad_fn
(
agent_state
.
params
,
*
minibatch
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
return
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_minibatch
,
agent_state
,
(
shuffled_init_rstate1
,
shuffled_init_rstate2
,
shuffled_storage
.
obs
,
shuffled_storage
.
dones
,
shuffled_storage
.
next_dones
,
shuffled_switch
,
shuffled_storage
.
actions
,
shuffled_storage
.
logits
,
shuffled_storage
.
rewards
,
shuffled_mask
,
shuffled_next_value
,
),
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_epoch
,
(
agent_state
,
key
),
(),
length
=
args
.
update_epochs
)
loss
=
jax
.
lax
.
pmean
(
loss
,
axis_name
=
"local_devices"
)
.
mean
()
pg_loss
=
jax
.
lax
.
pmean
(
pg_loss
,
axis_name
=
"local_devices"
)
.
mean
()
v_loss
=
jax
.
lax
.
pmean
(
v_loss
,
axis_name
=
"local_devices"
)
.
mean
()
entropy_loss
=
jax
.
lax
.
pmean
(
entropy_loss
,
axis_name
=
"local_devices"
)
.
mean
()
approx_kl
=
jax
.
lax
.
pmean
(
approx_kl
,
axis_name
=
"local_devices"
)
.
mean
()
return
agent_state
,
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
,
key
multi_device_update
=
jax
.
pmap
(
single_device_update
,
axis_name
=
"local_devices"
,
devices
=
global_learner_decices
,
static_broadcasted_argnums
=
(
7
,),
)
params_queues
=
[]
rollout_queues
=
[]
eval_queue
=
queue
.
Queue
()
dummy_writer
=
SimpleNamespace
()
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
device_params
=
jax
.
device_put
(
unreplicated_params
,
local_devices
[
d_id
])
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
rollout_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
if
eval_params
:
params_queues
[
-
1
]
.
put
(
jax
.
device_put
(
eval_params
,
local_devices
[
d_id
]))
threading
.
Thread
(
target
=
rollout
,
args
=
(
jax
.
device_put
(
key
,
local_devices
[
d_id
]),
args
,
rollout_queues
[
-
1
],
params_queues
[
-
1
],
eval_queue
,
writer
if
d_idx
==
0
and
thread_id
==
0
else
dummy_writer
,
learner_devices
,
d_idx
*
args
.
num_actor_threads
+
thread_id
,
),
)
.
start
()
params_queues
[
-
1
]
.
put
(
device_params
)
rollout_queue_get_time
=
deque
(
maxlen
=
10
)
data_transfer_time
=
deque
(
maxlen
=
10
)
learner_policy_version
=
0
while
True
:
learner_policy_version
+=
1
rollout_queue_get_time_start
=
time
.
time
()
sharded_data_list
=
[]
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
for
thread_id
in
range
(
args
.
num_actor_threads
):
(
global_step
,
update
,
*
sharded_data
,
avg_params_queue_get_time
,
learn_opponent
,
)
=
rollout_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
get
()
sharded_data_list
.
append
(
sharded_data
)
rollout_queue_get_time
.
append
(
time
.
time
()
-
rollout_queue_get_time_start
)
training_time_start
=
time
.
time
()
(
agent_state
,
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
,
learner_keys
)
=
multi_device_update
(
agent_state
,
*
list
(
zip
(
*
sharded_data_list
)),
learner_keys
,
learn_opponent
,
)
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
device_params
=
jax
.
device_put
(
unreplicated_params
,
local_devices
[
d_id
])
device_params
[
"params"
][
"Encoder_0"
][
'Embed_0'
][
"embedding"
]
.
block_until_ready
()
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
put
(
device_params
)
loss
=
loss
[
-
1
]
.
item
()
if
np
.
isnan
(
loss
)
or
np
.
isinf
(
loss
):
raise
ValueError
(
f
"loss is {loss}"
)
# record rewards for plotting purposes
if
learner_policy_version
%
args
.
log_frequency
==
0
:
writer
.
add_scalar
(
"stats/rollout_queue_get_time"
,
np
.
mean
(
rollout_queue_get_time
),
global_step
)
writer
.
add_scalar
(
"stats/rollout_params_queue_get_time_diff"
,
np
.
mean
(
rollout_queue_get_time
)
-
avg_params_queue_get_time
,
global_step
,
)
writer
.
add_scalar
(
"stats/training_time"
,
time
.
time
()
-
training_time_start
,
global_step
)
writer
.
add_scalar
(
"stats/rollout_queue_size"
,
rollout_queues
[
-
1
]
.
qsize
(),
global_step
)
writer
.
add_scalar
(
"stats/params_queue_size"
,
params_queues
[
-
1
]
.
qsize
(),
global_step
)
print
(
f
"{global_step} actor_update={update}, "
f
"train_time={time.time() - training_time_start:.2f}, "
f
"data_time={rollout_queue_get_time[-1]:.2f}"
)
writer
.
add_scalar
(
"charts/learning_rate"
,
agent_state
.
opt_state
[
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/policy_loss"
,
pg_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/entropy"
,
entropy_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/approx_kl"
,
approx_kl
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/loss"
,
loss
,
global_step
)
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
:
ckpt_dir
=
f
"checkpoints"
os
.
makedirs
(
ckpt_dir
,
exist_ok
=
True
)
M_steps
=
args
.
batch_size
*
learner_policy_version
//
(
2
**
20
)
model_path
=
os
.
path
.
join
(
ckpt_dir
,
f
"{timestamp}_{M_steps}M.flax_model"
)
with
open
(
model_path
,
"wb"
)
as
f
:
f
.
write
(
flax
.
serialization
.
to_bytes
(
unreplicated_params
)
)
print
(
f
"model saved to {model_path}"
)
if
learner_policy_version
>=
args
.
num_updates
:
break
if
args
.
distributed
:
jax
.
distributed
.
shutdown
()
writer
.
close
()
\ No newline at end of file
scripts/ppo.py
View file @
bbd36d86
import
os
import
os
import
queue
import
random
import
random
import
threading
import
time
import
time
from
datetime
import
datetime
,
timedelta
,
timezone
from
collections
import
deque
from
collections
import
deque
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
types
import
SimpleNamespace
from
typing
import
List
,
NamedTuple
,
Optional
from
functools
import
partial
import
ygoenv
import
ygoenv
import
flax
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
import
numpy
as
np
import
optax
import
distrax
import
tyro
import
tyro
from
flax.training.train_state
import
TrainState
from
rich.pretty
import
pprint
from
tensorboardX
import
SummaryWriter
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
mse_loss
,
entropy_loss
,
simple_policy_loss
from
ygoai.rl.jax.switch
import
truncated_gae_2p0s
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
torch.distributions
import
Categorical
import
torch.distributed
as
dist
from
torch.cuda.amp
import
GradScaler
,
autocast
from
ygoai.utils
import
init_ygopro
os
.
environ
[
"XLA_FLAGS"
]
=
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"
from
ygoai.rl.utils
import
RecordEpisodeStatistics
,
to_tensor
,
load_embeddings
from
ygoai.rl.agent
import
PPOAgent
as
Agent
from
ygoai.rl.dist
import
reduce_gradidents
,
torchrun_setup
,
fprint
from
ygoai.rl.buffer
import
create_obs
from
ygoai.rl.ppo
import
bootstrap_value_selfplay
,
train_step
as
train_step_
from
ygoai.rl.eval
import
evaluate
@
dataclass
@
dataclass
class
Args
:
class
Args
:
exp_name
:
str
=
os
.
path
.
basename
(
__file__
)
[:
-
len
(
".py"
)]
exp_name
:
str
=
os
.
path
.
basename
(
__file__
)
.
rstrip
(
".py"
)
"""the name of this experiment"""
"""the name of this experiment"""
seed
:
int
=
1
seed
:
int
=
1
"""seed of the experiment"""
"""seed of the experiment"""
torch_deterministic
:
bool
=
False
log_frequency
:
int
=
10
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
"""the logging frequency of the model performance (in terms of `updates`)"""
cuda
:
bool
=
True
save_interval
:
int
=
400
"""if toggled, cuda will be enabled by default"""
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to load"""
# Algorithm specific arguments
# Algorithm specific arguments
env_id
:
str
=
"YGOPro-v0"
env_id
:
str
=
"YGOPro-v0"
...
@@ -54,499 +63,806 @@ class Args:
...
@@ -54,499 +63,806 @@ class Args:
"""the maximum number of options"""
"""the maximum number of options"""
n_history_actions
:
int
=
32
n_history_actions
:
int
=
32
"""the number of history actions to use"""
"""the number of history actions to use"""
greedy_reward
:
bool
=
True
"""whether to use greedy reward (faster kill higher reward)"""
num_layers
:
int
=
2
total_timesteps
:
int
=
5000000000
"""the number of layers for the agent"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
checkpoint
:
Optional
[
str
]
=
None
"""the checkpoint to load the model from"""
total_timesteps
:
int
=
2000000000
"""total timesteps of the experiments"""
"""total timesteps of the experiments"""
learning_rate
:
float
=
2.5e-4
learning_rate
:
float
=
1e-3
"""the learning rate of the optimizer"""
"""the learning rate of the optimizer"""
num_envs
:
int
=
8
local_num_envs
:
int
=
12
8
"""the number of parallel game environments"""
"""the number of parallel game environments"""
local_env_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for environment"""
num_actor_threads
:
int
=
2
"""the number of actor threads to use"""
num_steps
:
int
=
128
num_steps
:
int
=
128
"""the number of steps to run in each environment per policy rollout"""
"""the number of steps to run in each environment per policy rollout"""
anneal_lr
:
bool
=
Tru
e
anneal_lr
:
bool
=
Fals
e
"""Toggle learning rate annealing for policy and value networks"""
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
1.0
gamma
:
float
=
1.0
"""the discount factor gamma"""
"""the discount factor gamma"""
gae_lambda
:
float
=
0.95
gae_lambda
:
float
=
0.95
"""the lambda for the general advantage estimation"""
"""the lambda for the general advantage estimation"""
upgo
:
bool
=
False
fix_target
:
bool
=
False
"""Toggle the use of UPGO for advantages"""
"""if toggled, the target network will be fixed"""
num_minibatches
:
int
=
8
update_win_rate
:
float
=
0.55
"""the number of mini-batches"""
"""the required win rate to update the agent"""
update_return
:
float
=
0.1
"""the required return to update the agent"""
minibatch_size
:
int
=
256
"""the mini-batch size"""
update_epochs
:
int
=
2
update_epochs
:
int
=
2
"""the K epochs to update the policy"""
"""the K epochs to update the policy"""
norm_adv
:
bool
=
Tru
e
norm_adv
:
bool
=
Fals
e
"""Toggles advantages normalization"""
"""Toggles advantages normalization"""
clip_coef
:
float
=
0.2
clip_coef
:
float
=
0.2
5
"""the surrogate clipping coefficient"""
"""the surrogate clipping coefficient"""
clip_vloss
:
bool
=
True
dual_clip_coef
:
Optional
[
float
]
=
None
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
"""the dual surrogate clipping coefficient, typically 3.0"""
spo_kld_max
:
Optional
[
float
]
=
None
"""the maximum KLD for the SPO policy, typically 0.02"""
ent_coef
:
float
=
0.01
ent_coef
:
float
=
0.01
"""coefficient of the entropy"""
"""coefficient of the entropy"""
vf_coef
:
float
=
0.5
vf_coef
:
float
=
0.5
"""coefficient of the value function"""
"""coefficient of the value function"""
max_grad_norm
:
float
=
1.0
max_grad_norm
:
float
=
1.0
"""the maximum norm for the gradient clipping"""
"""the maximum norm for the gradient clipping"""
collect_length
:
Optional
[
int
]
=
None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
num_layers
:
int
=
2
"""the number of layers for the agent"""
compile
:
Optional
[
str
]
=
None
num_channels
:
int
=
128
"""Compile mode of torch.compile, None for no compilation"""
"""the number of channels for the agent"""
torch_threads
:
Optional
[
int
]
=
None
rnn_channels
:
int
=
512
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
"""the number of channels for the RNN in the agent"""
env_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for envpool, defaults to `num_envs`"""
actor_device_ids
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
0
,
1
])
fp16_train
:
bool
=
False
"""the device ids that actor workers will use"""
"""if toggled, training will be done in fp16 precision"""
learner_device_ids
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
2
,
3
])
fp16_eval
:
bool
=
False
"""the device ids that learner workers will use"""
"""if toggled, evaluation will be done in fp16 precision"""
distributed
:
bool
=
False
"""whether to use `jax.distirbuted`"""
tb_dir
:
str
=
"./runs"
concurrency
:
bool
=
True
"""tensorboard log directory"""
"""whether to run the actor and learner concurrently"""
ckpt_dir
:
str
=
"./checkpoints"
bfloat16
:
bool
=
False
"""checkpoint directory"""
"""whether to use bfloat16 for the agent"""
save_interval
:
int
=
500
thread_affinity
:
bool
=
False
"""the number of iterations to save the model"""
"""whether to use thread affinity for the environment"""
log_p
:
float
=
1.0
"""the probability of logging"""
eval_checkpoint
:
Optional
[
str
]
=
None
eval_episodes
:
int
=
128
"""the path to the model checkpoint to evaluate"""
local_eval_episodes
:
int
=
32
"""the number of episodes to evaluate the model"""
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
50
eval_interval
:
int
=
50
"""the number of iterations to evaluate the model"""
"""the number of iterations to evaluate the model"""
#
to be filled in runtime
#
runtime arguments to be filled in
local_batch_size
:
int
=
0
local_batch_size
:
int
=
0
"""the local batch size in the local rank (computed in runtime)"""
local_minibatch_size
:
int
=
0
local_minibatch_size
:
int
=
0
"""the local mini-batch size in the local rank (computed in runtime)"""
local_num_envs
:
int
=
0
"""the number of parallel game environments (in the local rank, computed in runtime)"""
batch_size
:
int
=
0
"""the batch size (computed in runtime)"""
num_iterations
:
int
=
0
"""the number of iterations (computed in runtime)"""
world_size
:
int
=
0
world_size
:
int
=
0
"""the number of processes (computed in runtime)"""
local_rank
:
int
=
0
num_envs
:
int
=
0
batch_size
:
int
=
0
minibatch_size
:
int
=
0
num_updates
:
int
=
0
global_learner_decices
:
Optional
[
List
[
str
]]
=
None
actor_devices
:
Optional
[
List
[
str
]]
=
None
learner_devices
:
Optional
[
List
[
str
]]
=
None
num_embeddings
:
Optional
[
int
]
=
None
num_embeddings
:
Optional
[
int
]
=
None
"""the number of embeddings (computed in runtime)"""
freeze_id
:
bool
=
False
def
make_env
(
args
,
num_envs
,
num_threads
,
mode
=
'self'
):
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
,
eval
=
False
):
if
not
args
.
thread_affinity
:
thread_affinity_offset
=
-
1
if
thread_affinity_offset
>=
0
:
print
(
"Binding to thread offset"
,
thread_affinity_offset
)
envs
=
ygoenv
.
make
(
envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
env_type
=
"gymnasium"
,
num_envs
=
num_envs
,
num_envs
=
num_envs
,
num_threads
=
num_threads
,
num_threads
=
num_threads
,
seed
=
args
.
seed
,
thread_affinity_offset
=
thread_affinity_offset
,
seed
=
seed
,
deck1
=
args
.
deck1
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
n_history_actions
=
args
.
n_history_actions
,
async_reset
=
False
,
greedy_reward
=
args
.
greedy_reward
if
not
eval
else
True
,
play_mode
=
mode
,
play_mode
=
mode
,
)
)
envs
.
num_envs
=
num_envs
envs
.
num_envs
=
num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
return
envs
return
envs
def
main
():
rank
=
int
(
os
.
environ
.
get
(
"RANK"
,
0
))
local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
world_size
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
print
(
f
"rank={rank}, local_rank={local_rank}, world_size={world_size}"
)
args
=
tyro
.
cli
(
Args
)
args
.
world_size
=
world_size
args
.
local_num_envs
=
args
.
num_envs
//
args
.
world_size
args
.
local_batch_size
=
int
(
args
.
local_num_envs
*
args
.
num_steps
)
args
.
local_minibatch_size
=
int
(
args
.
minibatch_size
//
args
.
world_size
)
args
.
batch_size
=
int
(
args
.
num_envs
*
args
.
num_steps
)
args
.
num_iterations
=
args
.
total_timesteps
//
args
.
batch_size
args
.
num_minibatches
=
args
.
local_batch_size
//
args
.
local_minibatch_size
args
.
env_threads
=
args
.
env_threads
or
args
.
num_envs
args
.
torch_threads
=
args
.
torch_threads
or
(
int
(
os
.
getenv
(
"OMP_NUM_THREADS"
,
"2"
))
*
args
.
world_size
)
args
.
collect_length
=
args
.
collect_length
or
args
.
num_steps
assert
args
.
collect_length
>=
args
.
num_steps
,
"collect_length must be greater than or equal to num_steps"
local_torch_threads
=
args
.
torch_threads
//
args
.
world_size
local_env_threads
=
args
.
env_threads
//
args
.
world_size
torch
.
set_num_threads
(
local_torch_threads
)
class
Transition
(
NamedTuple
):
torch
.
set_float32_matmul_precision
(
'high'
)
obs
:
list
dones
:
list
if
args
.
world_size
>
1
:
actions
:
list
torchrun_setup
(
'nccl'
,
local_rank
)
logits
:
list
rewards
:
list
timestamp
=
int
(
time
.
time
())
mains
:
list
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
next_dones
:
list
writer
=
None
if
rank
==
0
:
from
torch.utils.tensorboard
import
SummaryWriter
def
create_agent
(
args
,
multi_step
=
False
):
writer
=
SummaryWriter
(
os
.
path
.
join
(
args
.
tb_dir
,
run_name
))
return
PPOLSTMAgent
(
writer
.
add_text
(
channels
=
args
.
num_channels
,
"hyperparameters"
,
num_layers
=
args
.
num_layers
,
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
embedding_shape
=
args
.
num_embeddings
,
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
lstm_channels
=
args
.
rnn_channels
,
multi_step
=
multi_step
,
freeze_id
=
args
.
freeze_id
,
)
)
ckpt_dir
=
os
.
path
.
join
(
args
.
ckpt_dir
,
run_name
)
os
.
makedirs
(
ckpt_dir
,
exist_ok
=
True
)
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args
.
seed
+=
rank
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
-
rank
)
if
args
.
torch_deterministic
:
torch
.
backends
.
cudnn
.
deterministic
=
True
else
:
torch
.
backends
.
cudnn
.
benchmark
=
True
device
=
torch
.
device
(
f
"cuda:{local_rank}"
if
torch
.
cuda
.
is_available
()
and
args
.
cuda
else
"cpu"
)
def
init_rnn_state
(
num_envs
,
rnn_channels
):
return
(
deck
=
init_ygopro
(
args
.
env_id
,
"english"
,
args
.
deck
,
args
.
code_list_file
)
np
.
zeros
((
num_envs
,
rnn_channels
)),
args
.
deck1
=
args
.
deck1
or
deck
np
.
zeros
((
num_envs
,
rnn_channels
)),
args
.
deck2
=
args
.
deck2
or
deck
)
# env setup
envs
=
make_env
(
args
,
args
.
local_num_envs
,
local_env_threads
)
obs_space
=
envs
.
env
.
observation_space
action_shape
=
envs
.
env
.
action_space
.
shape
if
local_rank
==
0
:
fprint
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
envs_per_thread
=
args
.
local_num_envs
//
local_env_threads
local_eval_episodes
=
args
.
eval_episodes
//
args
.
world_size
local_eval_num_envs
=
local_eval_episodes
local_eval_num_threads
=
max
(
1
,
local_eval_num_envs
//
envs_per_thread
)
eval_envs
=
make_env
(
args
,
local_eval_num_envs
,
local_eval_num_threads
,
mode
=
'bot'
)
if
args
.
embedding_file
:
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
embedding_shape
=
embeddings
.
shape
else
:
embedding_shape
=
None
L
=
args
.
num_layers
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
torch
.
manual_seed
(
args
.
seed
)
agent
.
eval
()
if
args
.
checkpoint
:
def
rollout
(
agent
.
load_state_dict
(
torch
.
load
(
args
.
checkpoint
,
map_location
=
device
))
key
:
jax
.
random
.
PRNGKey
,
fprint
(
f
"Loaded checkpoint from {args.checkpoint}"
)
args
:
Args
,
elif
args
.
embedding_file
:
rollout_queue
,
agent
.
load_embeddings
(
embeddings
)
params_queue
,
fprint
(
f
"Loaded embeddings from {args.embedding_file}"
)
eval_queue
,
if
args
.
embedding_file
:
writer
,
agent
.
freeze_embeddings
()
learner_devices
,
device_thread_id
,
):
eval_mode
=
'self'
if
args
.
eval_checkpoint
else
'bot'
if
eval_mode
!=
'bot'
:
eval_params
=
params_queue
.
get
()
envs
=
make_env
(
args
,
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
args
.
local_num_envs
,
args
.
local_env_threads
,
thread_affinity_offset
=
device_thread_id
*
args
.
local_env_threads
,
)
envs
=
RecordEpisodeStatistics
(
envs
)
if
args
.
fix_target
:
eval_envs
=
make_env
(
agent_t
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
args
,
agent_t
.
eval
()
args
.
seed
+
jax
.
process_index
()
+
device_thread_id
,
agent_t
.
load_state_dict
(
agent
.
state_dict
())
args
.
local_eval_episodes
,
else
:
args
.
local_eval_episodes
//
4
,
mode
=
eval_mode
,
eval
=
True
)
agent_t
=
agent
eval_envs
=
RecordEpisodeStatistics
(
eval_envs
)
optim_params
=
list
(
agent
.
parameters
())
optimizer
=
optim
.
Adam
(
optim_params
,
lr
=
args
.
learning_rate
,
eps
=
1e-5
)
scaler
=
GradScaler
(
enabled
=
args
.
fp16_train
,
init_scale
=
2
**
8
)
def
predict_step
(
agent
:
Agent
,
next_obs
):
with
torch
.
no_grad
():
with
autocast
(
enabled
=
args
.
fp16_eval
):
logits
,
value
,
valid
=
agent
(
next_obs
)
return
logits
,
value
if
args
.
compile
:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
example_obs
=
create_obs
(
envs
.
observation_space
,
(
args
.
local_num_envs
,),
device
=
device
)
with
torch
.
no_grad
():
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
if
args
.
fix_target
:
traced_model_t
=
torch
.
jit
.
trace
(
agent_t
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model_t
=
torch
.
jit
.
optimize_for_inference
(
traced_model_t
)
else
:
traced_model_t
=
traced_model
train_step
=
torch
.
compile
(
train_step_
,
mode
=
args
.
compile
)
len_actor_device_ids
=
len
(
args
.
actor_device_ids
)
else
:
n_actors
=
args
.
num_actor_threads
*
len_actor_device_ids
traced_model
=
agent
global_step
=
0
traced_model_t
=
agent_t
start_time
=
time
.
time
()
train_step
=
train_step_
warmup_step
=
0
other_time
=
0
# ALGO Logic: Storage setup
obs
=
create_obs
(
obs_space
,
(
args
.
collect_length
,
args
.
local_num_envs
),
device
)
actions
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
)
+
action_shape
)
.
to
(
device
)
logprobs
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
rewards
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
dones
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
values
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
learns
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
version
=
0
# TRY NOT TO MODIFY: start the game
@
jax
.
jit
global_step
=
0
def
get_logits
(
warmup_steps
=
0
params
:
flax
.
core
.
FrozenDict
,
inputs
,
done
):
start_time
=
time
.
time
()
rstate
,
logits
=
create_agent
(
args
)
.
apply
(
params
,
inputs
)[:
2
]
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
return
rstate
,
logits
@
jax
.
jit
def
get_action
(
params
:
flax
.
core
.
FrozenDict
,
inputs
):
batch_size
=
jax
.
tree
.
leaves
(
inputs
)[
0
]
.
shape
[
0
]
done
=
jnp
.
zeros
(
batch_size
,
dtype
=
jnp
.
bool_
)
rstate
,
logits
=
get_logits
(
params
,
inputs
,
done
)
return
rstate
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
get_action_battle
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
),
done
)
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
),
done
)
logits
=
jnp
.
where
(
main
[:,
None
],
logits1
,
logits2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
next_rstate2
,
rstate2
)
return
rstate1
,
rstate2
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
sample_action
(
params
:
flax
.
core
.
FrozenDict
,
next_obs
,
rstate1
,
rstate2
,
main
,
done
,
key
):
next_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
(
x
),
next_obs
)
done
=
jnp
.
array
(
done
)
main
=
jnp
.
array
(
main
)
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
logits
=
get_logits
(
params
,
(
rstate
,
next_obs
),
done
)
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
action
,
key
=
categorical_sample
(
logits
,
key
)
return
next_obs
,
done
,
main
,
rstate1
,
rstate2
,
action
,
logits
,
key
# put data in the last index
params_queue_get_time
=
deque
(
maxlen
=
10
)
rollout_time
=
deque
(
maxlen
=
10
)
actor_policy_version
=
0
next_obs
,
info
=
envs
.
reset
()
next_obs
,
info
=
envs
.
reset
()
next_obs
=
to_tensor
(
next_obs
,
device
,
dtype
=
torch
.
uint8
)
next_to_play
=
info
[
"to_play"
]
next_to_play_
=
info
[
"to_play"
]
next_done
=
np
.
zeros
(
args
.
local_num_envs
,
dtype
=
np
.
bool_
)
next_to_play
=
to_tensor
(
next_to_play_
,
device
)
next_rstate1
=
next_rstate2
=
init_rnn_state
(
next_done
=
torch
.
zeros
(
args
.
local_num_envs
,
device
=
device
,
dtype
=
torch
.
bool
)
args
.
local_num_envs
,
args
.
rnn_channels
)
main_player_
=
np
.
concatenate
([
eval_rstate
=
init_rnn_state
(
args
.
local_eval_episodes
,
args
.
rnn_channels
)
main_player
=
np
.
concatenate
([
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
])
np
.
random
.
shuffle
(
main_player_
)
np
.
random
.
shuffle
(
main_player
)
main_player
=
to_tensor
(
main_player_
,
device
,
dtype
=
next_to_play
.
dtype
)
storage
=
[]
step
=
0
@
jax
.
jit
for
iteration
in
range
(
args
.
num_iterations
):
def
prepare_data
(
storage
:
List
[
Transition
])
->
Transition
:
# Annealing the rate if instructed to do so.
return
jax
.
tree
.
map
(
lambda
*
xs
:
jnp
.
split
(
jnp
.
stack
(
xs
),
len
(
learner_devices
),
axis
=
1
),
*
storage
)
if
args
.
anneal_lr
:
frac
=
1.0
-
iteration
/
args
.
num_iterations
for
update
in
range
(
1
,
args
.
num_updates
+
2
):
lrnow
=
frac
*
args
.
learning_rate
if
update
==
10
:
optimizer
.
param_groups
[
0
][
"lr"
]
=
lrnow
start_time
=
time
.
time
()
warmup_step
=
global_step
model_time
=
0
update_time_start
=
time
.
time
()
inference_time
=
0
env_time
=
0
env_time
=
0
collect_start
=
time
.
time
()
params_queue_get_time_start
=
time
.
time
()
while
step
<
args
.
collect_length
:
if
args
.
concurrency
:
global_step
+=
args
.
num_envs
if
update
!=
2
:
params
=
params_queue
.
get
()
# params["params"]["Encoder_0"]['Embed_0'][
# "embedding"
# ].block_until_ready()
actor_policy_version
+=
1
else
:
params
=
params_queue
.
get
()
actor_policy_version
+=
1
params_queue_get_time
.
append
(
time
.
time
()
-
params_queue_get_time_start
)
for
key
in
obs
:
rollout_time_start
=
time
.
time
()
obs
[
key
][
step
]
=
next_obs
[
key
]
init_rstate1
,
init_rstate2
=
jax
.
tree
.
map
(
dones
[
step
]
=
next_done
lambda
x
:
x
.
copy
(),
(
next_rstate1
,
next_rstate2
))
learn
=
next_to_play
==
main_player
for
_
in
range
(
args
.
num_steps
):
learns
[
step
]
=
learn
global_step
+=
args
.
local_num_envs
*
n_actors
*
args
.
world_size
_start
=
time
.
time
()
cached_next_obs
=
next_obs
logits
,
value
=
predict_step
(
traced_model
,
next_obs
)
cached_next_done
=
next_done
if
args
.
fix_target
:
main
=
next_to_play
==
main_player
logits_t
,
value_t
=
predict_step
(
traced_model_t
,
next_obs
)
logits
=
torch
.
where
(
learn
[:,
None
],
logits
,
logits_t
)
inference_time_start
=
time
.
time
()
value
=
torch
.
where
(
learn
[:,
None
],
value
,
value_t
)
cached_next_obs
,
cached_next_done
,
cached_main
,
\
value
=
value
.
flatten
()
next_rstate1
,
next_rstate2
,
action
,
logits
,
key
=
sample_action
(
probs
=
Categorical
(
logits
=
logits
)
params
,
cached_next_obs
,
next_rstate1
,
next_rstate2
,
main
,
cached_next_done
,
key
)
action
=
probs
.
sample
()
logprob
=
probs
.
log_prob
(
action
)
cpu_action
=
np
.
array
(
action
)
inference_time
+=
time
.
time
()
-
inference_time_start
values
[
step
]
=
value
actions
[
step
]
=
action
logprobs
[
step
]
=
logprob
action
=
action
.
cpu
()
.
numpy
()
model_time
+=
time
.
time
()
-
_start
_start
=
time
.
time
()
_start
=
time
.
time
()
to_play
=
next_to_play_
next_obs
,
next_reward
,
next_done
,
info
=
envs
.
step
(
cpu_action
)
next_obs
,
reward
,
next_done_
,
info
=
envs
.
step
(
action
)
next_to_play
=
info
[
"to_play"
]
next_to_play_
=
info
[
"to_play"
]
next_to_play
=
to_tensor
(
next_to_play_
,
device
)
env_time
+=
time
.
time
()
-
_start
env_time
+=
time
.
time
()
-
_start
rewards
[
step
]
=
to_tensor
(
reward
,
device
)
next_obs
,
next_done
=
to_tensor
(
next_obs
,
device
,
torch
.
uint8
),
to_tensor
(
next_done_
,
device
,
torch
.
bool
)
step
+=
1
if
not
writer
:
storage
.
append
(
continue
Transition
(
obs
=
cached_next_obs
,
dones
=
cached_next_done
,
mains
=
cached_main
,
actions
=
action
,
logits
=
logits
,
rewards
=
next_reward
,
next_dones
=
next_done
,
)
)
for
idx
,
d
in
enumerate
(
next_done_
):
for
idx
,
d
in
enumerate
(
next_done
):
if
d
:
if
not
d
:
pl
=
1
if
to_play
[
idx
]
==
main_player_
[
idx
]
else
-
1
continue
episode_length
=
info
[
'l'
][
idx
]
cur_main
=
main
[
idx
]
episode_reward
=
info
[
'r'
][
idx
]
*
pl
for
j
in
reversed
(
range
(
len
(
storage
)
-
1
)):
t
=
storage
[
j
]
if
t
.
next_dones
[
idx
]:
# For OTK where player may not switch
break
if
t
.
mains
[
idx
]
!=
cur_main
:
t
.
next_dones
[
idx
]
=
True
t
.
rewards
[
idx
]
=
-
next_reward
[
idx
]
break
episode_reward
=
info
[
'r'
][
idx
]
*
(
1
if
cur_main
else
-
1
)
win
=
1
if
episode_reward
>
0
else
0
win
=
1
if
episode_reward
>
0
else
0
avg_ep_returns
.
append
(
episode_reward
)
avg_ep_returns
.
append
(
episode_reward
)
avg_win_rates
.
append
(
win
)
avg_win_rates
.
append
(
win
)
if
random
.
random
()
<
args
.
log_p
:
rollout_time
.
append
(
time
.
time
()
-
rollout_time_start
)
n
=
100
if
random
.
random
()
<
10
/
n
or
iteration
<=
1
:
writer
.
add_scalar
(
"charts/episodic_return"
,
info
[
"r"
][
idx
],
global_step
)
writer
.
add_scalar
(
"charts/episodic_length"
,
info
[
"l"
][
idx
],
global_step
)
fprint
(
f
"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}"
)
if
random
.
random
()
<
1
/
n
:
partitioned_storage
=
prepare_data
(
storage
)
writer
.
add_scalar
(
"charts/avg_ep_return"
,
np
.
mean
(
avg_ep_returns
),
global_step
)
storage
=
[]
writer
.
add_scalar
(
"charts/avg_win_rate"
,
np
.
mean
(
avg_win_rates
),
global_step
)
sharded_storage
=
[]
for
x
in
partitioned_storage
:
if
isinstance
(
x
,
dict
):
x
=
{
k
:
jax
.
device_put_sharded
(
v
,
devices
=
learner_devices
)
for
k
,
v
in
x
.
items
()
}
else
:
x
=
jax
.
device_put_sharded
(
x
,
devices
=
learner_devices
)
sharded_storage
.
append
(
x
)
sharded_storage
=
Transition
(
*
sharded_storage
)
next_main
=
main_player
==
next_to_play
next_rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
next_rstate1
,
next_rstate2
)
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
(
init_rstate1
,
init_rstate2
,
(
next_rstate
,
next_obs
),
next_main
))
learn_opponent
=
False
payload
=
(
global_step
,
update
,
sharded_storage
,
*
sharded_data
,
np
.
mean
(
params_queue_get_time
),
learn_opponent
,
)
rollout_queue
.
put
(
payload
)
if
update
%
args
.
log_frequency
==
0
:
avg_episodic_return
=
np
.
mean
(
avg_ep_returns
)
avg_episodic_length
=
np
.
mean
(
envs
.
returned_episode_lengths
)
SPS
=
int
((
global_step
-
warmup_step
)
/
(
time
.
time
()
-
start_time
-
other_time
))
SPS_update
=
int
(
args
.
batch_size
/
(
time
.
time
()
-
update_time_start
))
if
device_thread_id
==
0
:
print
(
f
"global_step={global_step}, avg_return={avg_episodic_return:.4f}, avg_length={avg_episodic_length:.0f}"
)
time_now
=
datetime
.
now
(
timezone
(
timedelta
(
hours
=
8
)))
.
strftime
(
"
%
H:
%
M:
%
S"
)
print
(
f
"{time_now} SPS: {SPS}, update: {SPS_update}, "
f
"rollout_time={rollout_time[-1]:.2f}, params_time={params_queue_get_time[-1]:.2f}"
)
writer
.
add_scalar
(
"stats/rollout_time"
,
np
.
mean
(
rollout_time
),
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_return"
,
avg_episodic_return
,
global_step
)
writer
.
add_scalar
(
"charts/avg_episodic_length"
,
avg_episodic_length
,
global_step
)
writer
.
add_scalar
(
"stats/params_queue_get_time"
,
np
.
mean
(
params_queue_get_time
),
global_step
)
writer
.
add_scalar
(
"stats/inference_time"
,
inference_time
,
global_step
)
writer
.
add_scalar
(
"stats/env_time"
,
env_time
,
global_step
)
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
writer
.
add_scalar
(
"charts/SPS_update"
,
SPS_update
,
global_step
)
if
args
.
eval_interval
and
update
%
args
.
eval_interval
==
0
:
# Eval with rule-based policy
_start
=
time
.
time
()
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
eval_stat
=
np
.
array
([
eval_return
,
eval_win_rate
])
if
device_thread_id
!=
0
:
eval_queue
.
put
(
eval_stat
)
else
:
eval_stats
=
[]
eval_stats
.
append
(
eval_stat
)
for
_
in
range
(
1
,
n_actors
):
eval_stats
.
append
(
eval_queue
.
get
())
eval_stats
=
np
.
stack
(
eval_stats
)
eval_return
,
eval_win_rate
=
np
.
mean
(
eval_stats
,
axis
=
0
)
writer
.
add_scalar
(
f
"charts/eval_return"
,
eval_return
,
global_step
)
writer
.
add_scalar
(
f
"charts/eval_win_rate"
,
eval_win_rate
,
global_step
)
if
device_thread_id
==
0
:
eval_time
=
time
.
time
()
-
_start
other_time
+=
eval_time
print
(
f
"eval_time={eval_time:.4f}, eval_return={eval_return:.4f}, eval_win_rate={eval_win_rate:.4f}"
)
collect_time
=
time
.
time
()
-
collect_start
if
local_rank
==
0
:
fprint
(
f
"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}"
)
step
=
args
.
collect_length
-
args
.
num_steps
if
__name__
==
"__main__"
:
args
=
tyro
.
cli
(
Args
)
args
.
local_batch_size
=
int
(
args
.
local_num_envs
*
args
.
num_steps
*
args
.
num_actor_threads
*
len
(
args
.
actor_device_ids
))
args
.
local_minibatch_size
=
int
(
args
.
local_batch_size
//
args
.
num_minibatches
)
assert
(
args
.
local_num_envs
%
len
(
args
.
learner_device_ids
)
==
0
),
"local_num_envs must be divisible by len(learner_device_ids)"
assert
(
int
(
args
.
local_num_envs
/
len
(
args
.
learner_device_ids
))
*
args
.
num_actor_threads
%
args
.
num_minibatches
==
0
),
"int(local_num_envs / len(learner_device_ids)) must be divisible by num_minibatches"
if
args
.
distributed
:
jax
.
distributed
.
initialize
(
local_device_ids
=
range
(
len
(
args
.
learner_device_ids
)
+
len
(
args
.
actor_device_ids
)),
)
print
(
list
(
range
(
len
(
args
.
learner_device_ids
)
+
len
(
args
.
actor_device_ids
))))
_start
=
time
.
time
()
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
# bootstrap value if not done
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
value
=
predict_step
(
traced_model
,
next_obs
)[
1
]
.
reshape
(
-
1
)
nextvalues1
=
torch
.
where
(
next_to_play
==
main_player
,
value
,
-
value
)
args
.
world_size
=
jax
.
process_count
()
if
args
.
fix_target
:
args
.
local_rank
=
jax
.
process_index
()
value_t
=
predict_step
(
traced_model_t
,
next_obs
)[
1
]
.
reshape
(
-
1
)
args
.
num_envs
=
args
.
local_num_envs
*
args
.
world_size
*
args
.
num_actor_threads
*
len
(
args
.
actor_device_ids
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
main_player
,
value_t
,
-
value_t
)
args
.
batch_size
=
args
.
local_batch_size
*
args
.
world_size
args
.
minibatch_size
=
args
.
local_minibatch_size
*
args
.
world_size
args
.
num_updates
=
args
.
total_timesteps
//
(
args
.
local_batch_size
*
args
.
world_size
)
args
.
local_env_threads
=
args
.
local_env_threads
or
args
.
local_num_envs
if
args
.
embedding_file
:
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
embedding_shape
=
embeddings
.
shape
args
.
num_embeddings
=
embedding_shape
args
.
freeze_id
=
True
if
args
.
freeze_id
is
None
else
args
.
freeze_id
else
:
else
:
nextvalues2
=
-
nextvalues1
embeddings
=
None
embedding_shape
=
None
if
step
>
0
and
iteration
!=
0
:
# recalculate the values for the first few steps
v_steps
=
args
.
local_minibatch_size
*
4
//
args
.
local_num_envs
for
v_start
in
range
(
0
,
step
,
v_steps
):
v_end
=
min
(
v_start
+
v_steps
,
step
)
v_obs
=
{
k
:
v
[
v_start
:
v_end
]
.
flatten
(
0
,
1
)
for
k
,
v
in
obs
.
items
()
}
with
torch
.
no_grad
():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value
=
predict_step
(
traced_model
,
v_obs
)[
1
]
.
reshape
(
v_end
-
v_start
,
-
1
)
values
[
v_start
:
v_end
]
=
value
advantages
=
bootstrap_value_selfplay
(
local_devices
=
jax
.
local_devices
()
values
,
rewards
,
dones
,
learns
,
nextvalues1
,
nextvalues2
,
next_done
,
args
.
gamma
,
args
.
gae_lambda
)
global_devices
=
jax
.
devices
()
bootstrap_time
=
time
.
time
()
-
_start
learner_devices
=
[
local_devices
[
d_id
]
for
d_id
in
args
.
learner_device_ids
]
actor_devices
=
[
local_devices
[
d_id
]
for
d_id
in
args
.
actor_device_ids
]
global_learner_decices
=
[
global_devices
[
d_id
+
process_index
*
len
(
local_devices
)]
for
process_index
in
range
(
args
.
world_size
)
for
d_id
in
args
.
learner_device_ids
]
print
(
"global_learner_decices"
,
global_learner_decices
)
args
.
global_learner_decices
=
[
str
(
item
)
for
item
in
global_learner_decices
]
args
.
actor_devices
=
[
str
(
item
)
for
item
in
actor_devices
]
args
.
learner_devices
=
[
str
(
item
)
for
item
in
learner_devices
]
pprint
(
args
)
_start
=
time
.
time
()
timestamp
=
int
(
time
.
time
())
# flatten the batch
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
b_obs
=
{
k
:
v
[:
args
.
num_steps
]
.
reshape
((
-
1
,)
+
v
.
shape
[
2
:])
writer
=
SummaryWriter
(
f
"runs/{run_name}"
)
for
k
,
v
in
obs
.
items
()
writer
.
add_text
(
}
"hyperparameters"
,
b_actions
=
actions
[:
args
.
num_steps
]
.
reshape
((
-
1
,)
+
action_shape
)
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
b_logprobs
=
logprobs
[:
args
.
num_steps
]
.
reshape
(
-
1
)
)
b_advantages
=
advantages
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_values
=
values
[:
args
.
num_steps
]
.
reshape
(
-
1
)
# seeding
b_returns
=
b_advantages
+
b_values
random
.
seed
(
args
.
seed
)
if
args
.
fix_target
:
np
.
random
.
seed
(
args
.
seed
)
b_learns
=
learns
[:
args
.
num_steps
]
.
reshape
(
-
1
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
learner_keys
=
jax
.
device_put_replicated
(
key
,
learner_devices
)
deck
=
init_ygopro
(
args
.
env_id
,
"english"
,
args
.
deck
,
args
.
code_list_file
)
args
.
deck1
=
args
.
deck1
or
deck
args
.
deck2
=
args
.
deck2
or
deck
# env setup
envs
=
make_env
(
args
,
args
.
seed
,
8
,
1
)
obs_space
=
envs
.
observation_space
action_shape
=
envs
.
action_space
.
shape
print
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
envs
.
close
()
del
envs
def
linear_schedule
(
count
):
# anneal learning rate linearly after one training iteration which contains
# (args.num_minibatches) gradient updates
frac
=
1.0
-
(
count
//
(
args
.
num_minibatches
*
args
.
update_epochs
))
/
args
.
num_updates
return
args
.
learning_rate
*
frac
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
agent
=
create_agent
(
args
)
params
=
agent
.
init
(
agent_key
,
(
rstate
,
sample_obs
))
if
embeddings
is
not
None
:
unknown_embed
=
embeddings
.
mean
(
axis
=
0
)
embeddings
=
np
.
concatenate
([
unknown_embed
[
None
,
:],
embeddings
],
axis
=
0
)
params
=
flax
.
core
.
unfreeze
(
params
)
params
[
'params'
][
'Encoder_0'
][
'Embed_0'
][
'embedding'
]
=
jax
.
device_put
(
embeddings
)
params
=
flax
.
core
.
freeze
(
params
)
tx
=
optax
.
MultiSteps
(
optax
.
chain
(
optax
.
clip_by_global_norm
(
args
.
max_grad_norm
),
optax
.
inject_hyperparams
(
optax
.
adam
)(
learning_rate
=
linear_schedule
if
args
.
anneal_lr
else
args
.
learning_rate
,
eps
=
1e-5
),
),
every_k_schedule
=
1
,
)
agent_state
=
TrainState
.
create
(
apply_fn
=
None
,
params
=
params
,
tx
=
tx
,
)
if
args
.
checkpoint
:
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
agent_state
=
agent_state
.
replace
(
params
=
params
)
print
(
f
"loaded checkpoint from {args.checkpoint}"
)
agent_state
=
flax
.
jax_utils
.
replicate
(
agent_state
,
devices
=
learner_devices
)
# print(agent.tabulate(agent_key, sample_obs))
if
args
.
eval_checkpoint
:
with
open
(
args
.
eval_checkpoint
,
"rb"
)
as
f
:
eval_params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
print
(
f
"loaded eval checkpoint from {args.eval_checkpoint}"
)
else
:
else
:
b_learns
=
torch
.
ones_like
(
b_values
,
dtype
=
torch
.
bool
)
eval_params
=
None
# Optimizing the policy and value network
@
jax
.
jit
b_inds
=
np
.
arange
(
args
.
local_batch_size
)
def
get_logits_and_value
(
clipfracs
=
[]
params
:
flax
.
core
.
FrozenDict
,
inputs
,
for
epoch
in
range
(
args
.
update_epochs
):
):
np
.
random
.
shuffle
(
b_inds
)
rstate
,
logits
,
value
,
valid
=
create_agent
(
for
start
in
range
(
0
,
args
.
local_batch_size
,
args
.
local_minibatch_size
):
args
,
multi_step
=
True
)
.
apply
(
params
,
inputs
)
end
=
start
+
args
.
local_minibatch_size
return
logits
,
value
.
squeeze
(
-
1
)
mb_inds
=
b_inds
[
start
:
end
]
mb_obs
=
{
def
ppo_loss
(
k
:
v
[
mb_inds
]
for
k
,
v
in
b_obs
.
items
()
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
}
switch
,
actions
,
logits
,
rewards
,
mask
,
next_value
):
old_approx_kl
,
approx_kl
,
clipfrac
,
pg_loss
,
v_loss
,
entropy_loss
=
\
# (num_steps * local_num_envs // n_mb))
train_step
(
agent
,
optimizer
,
scaler
,
mb_obs
,
b_actions
[
mb_inds
],
b_logprobs
[
mb_inds
],
b_advantages
[
mb_inds
],
num_envs
=
next_value
.
shape
[
0
]
b_returns
[
mb_inds
],
b_values
[
mb_inds
],
b_learns
[
mb_inds
],
args
)
num_steps
=
dones
.
shape
[
0
]
//
num_envs
reduce_gradidents
(
optim_params
,
args
.
world_size
)
nn
.
utils
.
clip_grad_norm_
(
optim_params
,
args
.
max_grad_norm
)
mask
=
mask
*
(
1.0
-
dones
)
scaler
.
step
(
optimizer
)
n_valids
=
jnp
.
sum
(
mask
)
scaler
.
update
()
real_dones
=
dones
|
next_dones
clipfracs
.
append
(
clipfrac
.
item
())
inputs
=
(
rstate1
,
rstate2
,
obs
,
real_dones
,
switch
)
if
step
>
0
:
new_logits
,
new_values
=
get_logits_and_value
(
params
,
inputs
)
# TODO: use cyclic buffer to avoid copying
for
v
in
obs
.
values
():
new_values_
,
rewards
,
next_dones
,
switch
=
jax
.
tree
.
map
(
v
[:
step
]
=
v
[
args
.
num_steps
:]
.
clone
()
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
num_envs
)
+
x
.
shape
[
1
:]),
for
v
in
[
actions
,
logprobs
,
rewards
,
dones
,
values
,
learns
]:
(
new_values
,
rewards
,
next_dones
,
switch
),
v
[:
step
]
=
v
[
args
.
num_steps
:]
.
clone
()
)
train_time
=
time
.
time
()
-
_start
if
local_rank
==
0
:
fprint
(
f
"train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}"
)
y_pred
,
y_true
=
b_values
.
cpu
()
.
numpy
(),
b_returns
.
cpu
()
.
numpy
()
var_y
=
np
.
var
(
y_true
)
explained_var
=
np
.
nan
if
var_y
==
0
else
1
-
np
.
var
(
y_true
-
y_pred
)
/
var_y
if
rank
==
0
:
if
iteration
%
args
.
save_interval
==
0
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent.pt"
))
writer
.
add_scalar
(
"charts/learning_rate"
,
optimizer
.
param_groups
[
0
][
"lr"
],
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/policy_loss"
,
pg_loss
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/entropy"
,
entropy_loss
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/old_approx_kl"
,
old_approx_kl
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/approx_kl"
,
approx_kl
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/clipfrac"
,
np
.
mean
(
clipfracs
),
global_step
)
writer
.
add_scalar
(
"losses/explained_variance"
,
explained_var
,
global_step
)
SPS
=
int
((
global_step
-
warmup_steps
)
/
(
time
.
time
()
-
start_time
))
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters
=
10
if
iteration
==
SPS_warmup_iters
:
start_time
=
time
.
time
()
warmup_steps
=
global_step
if
iteration
>
SPS_warmup_iters
:
if
local_rank
==
0
:
fprint
(
f
"SPS: {SPS}"
)
if
rank
==
0
:
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
if
args
.
fix_target
:
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
if
rank
==
0
:
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
should_update
=
len
(
avg_win_rates
)
==
1000
and
np
.
mean
(
avg_win_rates
)
>
args
.
update_win_rate
and
np
.
mean
(
avg_ep_returns
)
>
args
.
update_return
should_update
=
torch
.
tensor
(
int
(
should_update
),
dtype
=
torch
.
int64
,
device
=
device
)
target_values
,
advantages
=
truncated_gae_2p0s
(
next_value
,
new_values_
,
rewards
,
next_dones
,
switch
,
args
.
gamma
,
args
.
gae_lambda
,
args
.
upgo
)
target_values
,
advantages
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
-
1
,)),
(
target_values
,
advantages
))
logratio
=
jnp
.
log
(
ratios
)
approx_kl
=
(((
ratios
-
1
)
-
logratio
)
*
mask
)
.
sum
()
/
n_valids
if
args
.
norm_adv
:
advantages
=
masked_normalize
(
advantages
,
mask
,
eps
=
1e-8
)
# Policy loss
if
args
.
spo_kld_max
is
not
None
:
pg_loss
=
simple_policy_loss
(
ratios
,
logits
,
new_logits
,
advantages
,
args
.
spo_kld_max
)
else
:
else
:
should_update
=
torch
.
zeros
((),
dtype
=
torch
.
int64
,
device
=
device
)
pg_loss
=
clipped_surrogate_pg_loss
(
if
args
.
world_size
>
1
:
ratios
,
advantages
,
args
.
clip_coef
,
args
.
dual_clip_coef
)
dist
.
all_reduce
(
should_update
,
op
=
dist
.
ReduceOp
.
SUM
)
pg_loss
=
jnp
.
sum
(
pg_loss
*
mask
)
should_update
=
should_update
.
item
()
>
0
if
should_update
:
v_loss
=
mse_loss
(
new_values
,
target_values
)
agent_t
.
load_state_dict
(
agent
.
state_dict
())
v_loss
=
jnp
.
sum
(
v_loss
*
mask
)
with
torch
.
no_grad
():
traced_model_t
=
torch
.
jit
.
trace
(
agent_t
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
ent_loss
=
entropy_loss
(
new_logits
)
traced_model_t
=
torch
.
jit
.
optimize_for_inference
(
traced_model_t
)
ent_loss
=
jnp
.
sum
(
ent_loss
*
mask
)
version
+=
1
pg_loss
=
pg_loss
/
n_valids
if
rank
==
0
:
v_loss
=
v_loss
/
n_valids
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_v{version}.pt"
))
ent_loss
=
ent_loss
/
n_valids
print
(
f
"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}"
)
avg_win_rates
.
clear
()
loss
=
pg_loss
-
args
.
ent_coef
*
ent_loss
+
v_loss
*
args
.
vf_coef
avg_ep_returns
.
clear
()
return
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
jax
.
lax
.
stop_gradient
(
approx_kl
))
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
def
single_device_update
(
# Eval with rule-based policy
agent_state
:
TrainState
,
_start
=
time
.
time
()
sharded_storages
:
List
,
eval_return
=
evaluate
(
sharded_init_rstate1
:
List
,
eval_envs
,
traced_model
,
local_eval_episodes
,
device
,
args
.
fp16_eval
)[
0
]
sharded_init_rstate2
:
List
,
eval_stats
=
torch
.
tensor
(
eval_return
,
dtype
=
torch
.
float32
,
device
=
device
)
sharded_next_inputs
:
List
,
sharded_next_main
:
List
,
# sync the statistics
key
:
jax
.
random
.
PRNGKey
,
if
args
.
world_size
>
1
:
learn_opponent
:
bool
=
False
,
dist
.
all_reduce
(
eval_stats
,
op
=
dist
.
ReduceOp
.
AVG
)
):
eval_return
=
eval_stats
.
cpu
()
.
numpy
()
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
if
rank
==
0
:
next_inputs
,
init_rstate1
,
init_rstate2
=
[
writer
.
add_scalar
(
"charts/eval_return"
,
eval_return
,
global_step
)
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
if
local_rank
==
0
:
for
x
in
[
sharded_next_inputs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
eval_time
=
time
.
time
()
-
_start
]
fprint
(
f
"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}"
)
next_main
,
=
[
jnp
.
concatenate
(
x
)
for
x
in
[
sharded_next_main
]
]
# reorder storage of individual players
# main first, opponent second
num_steps
,
num_envs
=
storage
.
rewards
.
shape
T
=
jnp
.
arange
(
num_steps
,
dtype
=
jnp
.
int32
)
B
=
jnp
.
arange
(
num_envs
,
dtype
=
jnp
.
int32
)
mains
=
storage
.
mains
.
astype
(
jnp
.
int32
)
indices
=
jnp
.
argsort
(
T
[:,
None
]
-
mains
*
num_steps
,
axis
=
0
)
switch_steps
=
jnp
.
sum
(
mains
,
axis
=
0
)
switch
=
T
[:,
None
]
==
(
switch_steps
[
None
,
:]
-
1
)
storage
=
jax
.
tree
.
map
(
lambda
x
:
x
[
indices
,
B
[
None
,
:]],
storage
)
ppo_loss_grad_fn
=
jax
.
value_and_grad
(
ppo_loss
,
has_aux
=
True
)
def
update_epoch
(
carry
,
_
):
agent_state
,
key
=
carry
key
,
subkey
=
jax
.
random
.
split
(
key
)
next_value
=
create_agent
(
args
)
.
apply
(
agent_state
.
params
,
next_inputs
)[
2
]
.
squeeze
(
-
1
)
# TODO: check if this is correct
sign
=
jnp
.
where
(
switch_steps
<=
num_steps
,
1.0
,
-
1.0
)
next_value
=
jnp
.
where
(
next_main
,
-
sign
*
next_value
,
sign
*
next_value
)
def
convert_data
(
x
:
jnp
.
ndarray
,
num_steps
):
if
args
.
update_epochs
>
1
:
x
=
jax
.
random
.
permutation
(
subkey
,
x
,
axis
=
1
if
num_steps
>
1
else
0
)
N
=
args
.
num_minibatches
if
num_steps
>
1
:
x
=
jnp
.
reshape
(
x
,
(
num_steps
,
N
,
-
1
)
+
x
.
shape
[
2
:])
x
=
x
.
transpose
(
1
,
0
,
*
range
(
2
,
x
.
ndim
))
x
=
x
.
reshape
(
N
,
-
1
,
*
x
.
shape
[
3
:])
else
:
x
=
jnp
.
reshape
(
x
,
(
N
,
-
1
)
+
x
.
shape
[
1
:])
return
x
shuffled_init_rstate1
,
shuffled_init_rstate2
,
\
shuffled_next_value
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
1
),
(
init_rstate1
,
init_rstate2
,
next_value
))
shuffled_storage
,
shuffled_switch
=
jax
.
tree
.
map
(
partial
(
convert_data
,
num_steps
=
num_steps
),
(
storage
,
switch
))
shuffled_mask
=
jnp
.
ones_like
(
shuffled_storage
.
mains
)
def
update_minibatch
(
agent_state
,
minibatch
):
(
loss
,
(
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)),
grads
=
ppo_loss_grad_fn
(
agent_state
.
params
,
*
minibatch
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
return
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_minibatch
,
agent_state
,
(
shuffled_init_rstate1
,
shuffled_init_rstate2
,
shuffled_storage
.
obs
,
shuffled_storage
.
dones
,
shuffled_storage
.
next_dones
,
shuffled_switch
,
shuffled_storage
.
actions
,
shuffled_storage
.
logits
,
shuffled_storage
.
rewards
,
shuffled_mask
,
shuffled_next_value
,
),
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
# Eval with old model
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_epoch
,
(
agent_state
,
key
),
(),
length
=
args
.
update_epochs
)
loss
=
jax
.
lax
.
pmean
(
loss
,
axis_name
=
"local_devices"
)
.
mean
()
pg_loss
=
jax
.
lax
.
pmean
(
pg_loss
,
axis_name
=
"local_devices"
)
.
mean
()
v_loss
=
jax
.
lax
.
pmean
(
v_loss
,
axis_name
=
"local_devices"
)
.
mean
()
entropy_loss
=
jax
.
lax
.
pmean
(
entropy_loss
,
axis_name
=
"local_devices"
)
.
mean
()
approx_kl
=
jax
.
lax
.
pmean
(
approx_kl
,
axis_name
=
"local_devices"
)
.
mean
()
return
agent_state
,
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
,
key
multi_device_update
=
jax
.
pmap
(
single_device_update
,
axis_name
=
"local_devices"
,
devices
=
global_learner_decices
,
static_broadcasted_argnums
=
(
7
,),
)
if
args
.
world_size
>
1
:
params_queues
=
[]
dist
.
destroy_process_group
()
rollout_queues
=
[]
envs
.
close
()
eval_queue
=
queue
.
Queue
()
if
rank
==
0
:
dummy_writer
=
SimpleNamespace
()
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_final.pt"
))
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
writer
.
close
()
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
device_params
=
jax
.
device_put
(
unreplicated_params
,
local_devices
[
d_id
])
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
rollout_queues
.
append
(
queue
.
Queue
(
maxsize
=
1
))
if
eval_params
:
params_queues
[
-
1
]
.
put
(
jax
.
device_put
(
eval_params
,
local_devices
[
d_id
]))
threading
.
Thread
(
target
=
rollout
,
args
=
(
jax
.
device_put
(
key
,
local_devices
[
d_id
]),
args
,
rollout_queues
[
-
1
],
params_queues
[
-
1
],
eval_queue
,
writer
if
d_idx
==
0
and
thread_id
==
0
else
dummy_writer
,
learner_devices
,
d_idx
*
args
.
num_actor_threads
+
thread_id
,
),
)
.
start
()
params_queues
[
-
1
]
.
put
(
device_params
)
rollout_queue_get_time
=
deque
(
maxlen
=
10
)
data_transfer_time
=
deque
(
maxlen
=
10
)
learner_policy_version
=
0
while
True
:
learner_policy_version
+=
1
rollout_queue_get_time_start
=
time
.
time
()
sharded_data_list
=
[]
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
for
thread_id
in
range
(
args
.
num_actor_threads
):
(
global_step
,
update
,
*
sharded_data
,
avg_params_queue_get_time
,
learn_opponent
,
)
=
rollout_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
get
()
sharded_data_list
.
append
(
sharded_data
)
rollout_queue_get_time
.
append
(
time
.
time
()
-
rollout_queue_get_time_start
)
training_time_start
=
time
.
time
()
(
agent_state
,
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
,
learner_keys
)
=
multi_device_update
(
agent_state
,
*
list
(
zip
(
*
sharded_data_list
)),
learner_keys
,
learn_opponent
,
)
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
device_params
=
jax
.
device_put
(
unreplicated_params
,
local_devices
[
d_id
])
device_params
[
"params"
][
"Encoder_0"
][
'Embed_0'
][
"embedding"
]
.
block_until_ready
()
for
thread_id
in
range
(
args
.
num_actor_threads
):
params_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
put
(
device_params
)
loss
=
loss
[
-
1
]
.
item
()
if
np
.
isnan
(
loss
)
or
np
.
isinf
(
loss
):
raise
ValueError
(
f
"loss is {loss}"
)
# record rewards for plotting purposes
if
learner_policy_version
%
args
.
log_frequency
==
0
:
writer
.
add_scalar
(
"stats/rollout_queue_get_time"
,
np
.
mean
(
rollout_queue_get_time
),
global_step
)
writer
.
add_scalar
(
"stats/rollout_params_queue_get_time_diff"
,
np
.
mean
(
rollout_queue_get_time
)
-
avg_params_queue_get_time
,
global_step
,
)
writer
.
add_scalar
(
"stats/training_time"
,
time
.
time
()
-
training_time_start
,
global_step
)
writer
.
add_scalar
(
"stats/rollout_queue_size"
,
rollout_queues
[
-
1
]
.
qsize
(),
global_step
)
writer
.
add_scalar
(
"stats/params_queue_size"
,
params_queues
[
-
1
]
.
qsize
(),
global_step
)
print
(
f
"{global_step} actor_update={update}, "
f
"train_time={time.time() - training_time_start:.2f}, "
f
"data_time={rollout_queue_get_time[-1]:.2f}"
)
writer
.
add_scalar
(
"charts/learning_rate"
,
agent_state
.
opt_state
[
2
][
1
]
.
hyperparams
[
"learning_rate"
][
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/policy_loss"
,
pg_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/entropy"
,
entropy_loss
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/approx_kl"
,
approx_kl
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/loss"
,
loss
,
global_step
)
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
:
ckpt_dir
=
f
"checkpoints"
os
.
makedirs
(
ckpt_dir
,
exist_ok
=
True
)
M_steps
=
args
.
batch_size
*
learner_policy_version
//
(
2
**
20
)
model_path
=
os
.
path
.
join
(
ckpt_dir
,
f
"{timestamp}_{M_steps}M.flax_model"
)
with
open
(
model_path
,
"wb"
)
as
f
:
f
.
write
(
flax
.
serialization
.
to_bytes
(
unreplicated_params
)
)
print
(
f
"model saved to {model_path}"
)
if
learner_policy_version
>=
args
.
num_updates
:
break
if
__name__
==
"__main__"
:
if
args
.
distributed
:
main
()
jax
.
distributed
.
shutdown
()
writer
.
close
()
\ No newline at end of file
scripts/torch/ppo.py
0 → 100644
View file @
bbd36d86
import
os
import
random
import
time
from
collections
import
deque
from
dataclasses
import
dataclass
from
typing
import
Optional
import
ygoenv
import
numpy
as
np
import
tyro
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
torch.distributions
import
Categorical
import
torch.distributed
as
dist
from
torch.cuda.amp
import
GradScaler
,
autocast
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
,
to_tensor
,
load_embeddings
from
ygoai.rl.agent
import
PPOAgent
as
Agent
from
ygoai.rl.dist
import
reduce_gradidents
,
torchrun_setup
,
fprint
from
ygoai.rl.buffer
import
create_obs
from
ygoai.rl.ppo
import
bootstrap_value_selfplay
,
train_step
as
train_step_
from
ygoai.rl.eval
import
evaluate
@
dataclass
class
Args
:
exp_name
:
str
=
os
.
path
.
basename
(
__file__
)[:
-
len
(
".py"
)]
"""the name of this experiment"""
seed
:
int
=
1
"""seed of the experiment"""
torch_deterministic
:
bool
=
False
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda
:
bool
=
True
"""if toggled, cuda will be enabled by default"""
# Algorithm specific arguments
env_id
:
str
=
"YGOPro-v0"
"""the id of the environment"""
deck
:
str
=
"../assets/deck"
"""the deck file to use"""
deck1
:
Optional
[
str
]
=
None
"""the deck file for the first player"""
deck2
:
Optional
[
str
]
=
None
"""the deck file for the second player"""
code_list_file
:
str
=
"code_list.txt"
"""the code list file for card embeddings"""
embedding_file
:
Optional
[
str
]
=
None
"""the embedding file for card embeddings"""
max_options
:
int
=
24
"""the maximum number of options"""
n_history_actions
:
int
=
32
"""the number of history actions to use"""
num_layers
:
int
=
2
"""the number of layers for the agent"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
checkpoint
:
Optional
[
str
]
=
None
"""the checkpoint to load the model from"""
total_timesteps
:
int
=
2000000000
"""total timesteps of the experiments"""
learning_rate
:
float
=
2.5e-4
"""the learning rate of the optimizer"""
num_envs
:
int
=
8
"""the number of parallel game environments"""
num_steps
:
int
=
128
"""the number of steps to run in each environment per policy rollout"""
anneal_lr
:
bool
=
True
"""Toggle learning rate annealing for policy and value networks"""
gamma
:
float
=
1.0
"""the discount factor gamma"""
gae_lambda
:
float
=
0.95
"""the lambda for the general advantage estimation"""
fix_target
:
bool
=
False
"""if toggled, the target network will be fixed"""
update_win_rate
:
float
=
0.55
"""the required win rate to update the agent"""
update_return
:
float
=
0.1
"""the required return to update the agent"""
minibatch_size
:
int
=
256
"""the mini-batch size"""
update_epochs
:
int
=
2
"""the K epochs to update the policy"""
norm_adv
:
bool
=
True
"""Toggles advantages normalization"""
clip_coef
:
float
=
0.2
"""the surrogate clipping coefficient"""
clip_vloss
:
bool
=
True
"""Toggles whether or not to use a clipped loss for the value function, as per the paper."""
ent_coef
:
float
=
0.01
"""coefficient of the entropy"""
vf_coef
:
float
=
0.5
"""coefficient of the value function"""
max_grad_norm
:
float
=
1.0
"""the maximum norm for the gradient clipping"""
collect_length
:
Optional
[
int
]
=
None
"""the length of the buffer, only the first `num_steps` will be used for training (partial GAE)"""
compile
:
Optional
[
str
]
=
None
"""Compile mode of torch.compile, None for no compilation"""
torch_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for envpool, defaults to `num_envs`"""
fp16_train
:
bool
=
False
"""if toggled, training will be done in fp16 precision"""
fp16_eval
:
bool
=
False
"""if toggled, evaluation will be done in fp16 precision"""
tb_dir
:
str
=
"./runs"
"""tensorboard log directory"""
ckpt_dir
:
str
=
"./checkpoints"
"""checkpoint directory"""
save_interval
:
int
=
500
"""the number of iterations to save the model"""
log_p
:
float
=
1.0
"""the probability of logging"""
eval_episodes
:
int
=
128
"""the number of episodes to evaluate the model"""
eval_interval
:
int
=
50
"""the number of iterations to evaluate the model"""
# to be filled in runtime
local_batch_size
:
int
=
0
"""the local batch size in the local rank (computed in runtime)"""
local_minibatch_size
:
int
=
0
"""the local mini-batch size in the local rank (computed in runtime)"""
local_num_envs
:
int
=
0
"""the number of parallel game environments (in the local rank, computed in runtime)"""
batch_size
:
int
=
0
"""the batch size (computed in runtime)"""
num_iterations
:
int
=
0
"""the number of iterations (computed in runtime)"""
world_size
:
int
=
0
"""the number of processes (computed in runtime)"""
num_embeddings
:
Optional
[
int
]
=
None
"""the number of embeddings (computed in runtime)"""
def
make_env
(
args
,
num_envs
,
num_threads
,
mode
=
'self'
):
envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
num_envs
,
num_threads
=
num_threads
,
seed
=
args
.
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
mode
,
)
envs
.
num_envs
=
num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
return
envs
def
main
():
rank
=
int
(
os
.
environ
.
get
(
"RANK"
,
0
))
local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
0
))
world_size
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
print
(
f
"rank={rank}, local_rank={local_rank}, world_size={world_size}"
)
args
=
tyro
.
cli
(
Args
)
args
.
world_size
=
world_size
args
.
local_num_envs
=
args
.
num_envs
//
args
.
world_size
args
.
local_batch_size
=
int
(
args
.
local_num_envs
*
args
.
num_steps
)
args
.
local_minibatch_size
=
int
(
args
.
minibatch_size
//
args
.
world_size
)
args
.
batch_size
=
int
(
args
.
num_envs
*
args
.
num_steps
)
args
.
num_iterations
=
args
.
total_timesteps
//
args
.
batch_size
args
.
num_minibatches
=
args
.
local_batch_size
//
args
.
local_minibatch_size
args
.
env_threads
=
args
.
env_threads
or
args
.
num_envs
args
.
torch_threads
=
args
.
torch_threads
or
(
int
(
os
.
getenv
(
"OMP_NUM_THREADS"
,
"2"
))
*
args
.
world_size
)
args
.
collect_length
=
args
.
collect_length
or
args
.
num_steps
assert
args
.
collect_length
>=
args
.
num_steps
,
"collect_length must be greater than or equal to num_steps"
local_torch_threads
=
args
.
torch_threads
//
args
.
world_size
local_env_threads
=
args
.
env_threads
//
args
.
world_size
torch
.
set_num_threads
(
local_torch_threads
)
torch
.
set_float32_matmul_precision
(
'high'
)
if
args
.
world_size
>
1
:
torchrun_setup
(
'nccl'
,
local_rank
)
timestamp
=
int
(
time
.
time
())
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
writer
=
None
if
rank
==
0
:
from
torch.utils.tensorboard
import
SummaryWriter
writer
=
SummaryWriter
(
os
.
path
.
join
(
args
.
tb_dir
,
run_name
))
writer
.
add_text
(
"hyperparameters"
,
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
)
ckpt_dir
=
os
.
path
.
join
(
args
.
ckpt_dir
,
run_name
)
os
.
makedirs
(
ckpt_dir
,
exist_ok
=
True
)
# TRY NOT TO MODIFY: seeding
# CRUCIAL: note that we needed to pass a different seed for each data parallelism worker
args
.
seed
+=
rank
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
-
rank
)
if
args
.
torch_deterministic
:
torch
.
backends
.
cudnn
.
deterministic
=
True
else
:
torch
.
backends
.
cudnn
.
benchmark
=
True
device
=
torch
.
device
(
f
"cuda:{local_rank}"
if
torch
.
cuda
.
is_available
()
and
args
.
cuda
else
"cpu"
)
deck
=
init_ygopro
(
args
.
env_id
,
"english"
,
args
.
deck
,
args
.
code_list_file
)
args
.
deck1
=
args
.
deck1
or
deck
args
.
deck2
=
args
.
deck2
or
deck
# env setup
envs
=
make_env
(
args
,
args
.
local_num_envs
,
local_env_threads
)
obs_space
=
envs
.
env
.
observation_space
action_shape
=
envs
.
env
.
action_space
.
shape
if
local_rank
==
0
:
fprint
(
f
"obs_space={obs_space}, action_shape={action_shape}"
)
envs_per_thread
=
args
.
local_num_envs
//
local_env_threads
local_eval_episodes
=
args
.
eval_episodes
//
args
.
world_size
local_eval_num_envs
=
local_eval_episodes
local_eval_num_threads
=
max
(
1
,
local_eval_num_envs
//
envs_per_thread
)
eval_envs
=
make_env
(
args
,
local_eval_num_envs
,
local_eval_num_threads
,
mode
=
'bot'
)
if
args
.
embedding_file
:
embeddings
=
load_embeddings
(
args
.
embedding_file
,
args
.
code_list_file
)
embedding_shape
=
embeddings
.
shape
else
:
embedding_shape
=
None
L
=
args
.
num_layers
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
torch
.
manual_seed
(
args
.
seed
)
agent
.
eval
()
if
args
.
checkpoint
:
agent
.
load_state_dict
(
torch
.
load
(
args
.
checkpoint
,
map_location
=
device
))
fprint
(
f
"Loaded checkpoint from {args.checkpoint}"
)
elif
args
.
embedding_file
:
agent
.
load_embeddings
(
embeddings
)
fprint
(
f
"Loaded embeddings from {args.embedding_file}"
)
if
args
.
embedding_file
:
agent
.
freeze_embeddings
()
if
args
.
fix_target
:
agent_t
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
agent_t
.
eval
()
agent_t
.
load_state_dict
(
agent
.
state_dict
())
else
:
agent_t
=
agent
optim_params
=
list
(
agent
.
parameters
())
optimizer
=
optim
.
Adam
(
optim_params
,
lr
=
args
.
learning_rate
,
eps
=
1e-5
)
scaler
=
GradScaler
(
enabled
=
args
.
fp16_train
,
init_scale
=
2
**
8
)
def
predict_step
(
agent
:
Agent
,
next_obs
):
with
torch
.
no_grad
():
with
autocast
(
enabled
=
args
.
fp16_eval
):
logits
,
value
,
valid
=
agent
(
next_obs
)
return
logits
,
value
if
args
.
compile
:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
example_obs
=
create_obs
(
envs
.
observation_space
,
(
args
.
local_num_envs
,),
device
=
device
)
with
torch
.
no_grad
():
traced_model
=
torch
.
jit
.
trace
(
agent
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
if
args
.
fix_target
:
traced_model_t
=
torch
.
jit
.
trace
(
agent_t
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model_t
=
torch
.
jit
.
optimize_for_inference
(
traced_model_t
)
else
:
traced_model_t
=
traced_model
train_step
=
torch
.
compile
(
train_step_
,
mode
=
args
.
compile
)
else
:
traced_model
=
agent
traced_model_t
=
agent_t
train_step
=
train_step_
# ALGO Logic: Storage setup
obs
=
create_obs
(
obs_space
,
(
args
.
collect_length
,
args
.
local_num_envs
),
device
)
actions
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
)
+
action_shape
)
.
to
(
device
)
logprobs
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
rewards
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
dones
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
values
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
))
.
to
(
device
)
learns
=
torch
.
zeros
((
args
.
collect_length
,
args
.
local_num_envs
),
dtype
=
torch
.
bool
)
.
to
(
device
)
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
version
=
0
# TRY NOT TO MODIFY: start the game
global_step
=
0
warmup_steps
=
0
start_time
=
time
.
time
()
next_obs
,
info
=
envs
.
reset
()
next_obs
=
to_tensor
(
next_obs
,
device
,
dtype
=
torch
.
uint8
)
next_to_play_
=
info
[
"to_play"
]
next_to_play
=
to_tensor
(
next_to_play_
,
device
)
next_done
=
torch
.
zeros
(
args
.
local_num_envs
,
device
=
device
,
dtype
=
torch
.
bool
)
main_player_
=
np
.
concatenate
([
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
np
.
random
.
shuffle
(
main_player_
)
main_player
=
to_tensor
(
main_player_
,
device
,
dtype
=
next_to_play
.
dtype
)
step
=
0
for
iteration
in
range
(
args
.
num_iterations
):
# Annealing the rate if instructed to do so.
if
args
.
anneal_lr
:
frac
=
1.0
-
iteration
/
args
.
num_iterations
lrnow
=
frac
*
args
.
learning_rate
optimizer
.
param_groups
[
0
][
"lr"
]
=
lrnow
model_time
=
0
env_time
=
0
collect_start
=
time
.
time
()
while
step
<
args
.
collect_length
:
global_step
+=
args
.
num_envs
for
key
in
obs
:
obs
[
key
][
step
]
=
next_obs
[
key
]
dones
[
step
]
=
next_done
learn
=
next_to_play
==
main_player
learns
[
step
]
=
learn
_start
=
time
.
time
()
logits
,
value
=
predict_step
(
traced_model
,
next_obs
)
if
args
.
fix_target
:
logits_t
,
value_t
=
predict_step
(
traced_model_t
,
next_obs
)
logits
=
torch
.
where
(
learn
[:,
None
],
logits
,
logits_t
)
value
=
torch
.
where
(
learn
[:,
None
],
value
,
value_t
)
value
=
value
.
flatten
()
probs
=
Categorical
(
logits
=
logits
)
action
=
probs
.
sample
()
logprob
=
probs
.
log_prob
(
action
)
values
[
step
]
=
value
actions
[
step
]
=
action
logprobs
[
step
]
=
logprob
action
=
action
.
cpu
()
.
numpy
()
model_time
+=
time
.
time
()
-
_start
_start
=
time
.
time
()
to_play
=
next_to_play_
next_obs
,
reward
,
next_done_
,
info
=
envs
.
step
(
action
)
next_to_play_
=
info
[
"to_play"
]
next_to_play
=
to_tensor
(
next_to_play_
,
device
)
env_time
+=
time
.
time
()
-
_start
rewards
[
step
]
=
to_tensor
(
reward
,
device
)
next_obs
,
next_done
=
to_tensor
(
next_obs
,
device
,
torch
.
uint8
),
to_tensor
(
next_done_
,
device
,
torch
.
bool
)
step
+=
1
if
not
writer
:
continue
for
idx
,
d
in
enumerate
(
next_done_
):
if
d
:
pl
=
1
if
to_play
[
idx
]
==
main_player_
[
idx
]
else
-
1
episode_length
=
info
[
'l'
][
idx
]
episode_reward
=
info
[
'r'
][
idx
]
*
pl
win
=
1
if
episode_reward
>
0
else
0
avg_ep_returns
.
append
(
episode_reward
)
avg_win_rates
.
append
(
win
)
if
random
.
random
()
<
args
.
log_p
:
n
=
100
if
random
.
random
()
<
10
/
n
or
iteration
<=
1
:
writer
.
add_scalar
(
"charts/episodic_return"
,
info
[
"r"
][
idx
],
global_step
)
writer
.
add_scalar
(
"charts/episodic_length"
,
info
[
"l"
][
idx
],
global_step
)
fprint
(
f
"global_step={global_step}, e_ret={episode_reward}, e_len={episode_length}"
)
if
random
.
random
()
<
1
/
n
:
writer
.
add_scalar
(
"charts/avg_ep_return"
,
np
.
mean
(
avg_ep_returns
),
global_step
)
writer
.
add_scalar
(
"charts/avg_win_rate"
,
np
.
mean
(
avg_win_rates
),
global_step
)
collect_time
=
time
.
time
()
-
collect_start
if
local_rank
==
0
:
fprint
(
f
"collect_time={collect_time:.4f}, model_time={model_time:.4f}, env_time={env_time:.4f}"
)
step
=
args
.
collect_length
-
args
.
num_steps
_start
=
time
.
time
()
# bootstrap value if not done
value
=
predict_step
(
traced_model
,
next_obs
)[
1
]
.
reshape
(
-
1
)
nextvalues1
=
torch
.
where
(
next_to_play
==
main_player
,
value
,
-
value
)
if
args
.
fix_target
:
value_t
=
predict_step
(
traced_model_t
,
next_obs
)[
1
]
.
reshape
(
-
1
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
main_player
,
value_t
,
-
value_t
)
else
:
nextvalues2
=
-
nextvalues1
if
step
>
0
and
iteration
!=
0
:
# recalculate the values for the first few steps
v_steps
=
args
.
local_minibatch_size
*
4
//
args
.
local_num_envs
for
v_start
in
range
(
0
,
step
,
v_steps
):
v_end
=
min
(
v_start
+
v_steps
,
step
)
v_obs
=
{
k
:
v
[
v_start
:
v_end
]
.
flatten
(
0
,
1
)
for
k
,
v
in
obs
.
items
()
}
with
torch
.
no_grad
():
# value = traced_get_value(v_obs).reshape(v_end - v_start, -1)
value
=
predict_step
(
traced_model
,
v_obs
)[
1
]
.
reshape
(
v_end
-
v_start
,
-
1
)
values
[
v_start
:
v_end
]
=
value
advantages
=
bootstrap_value_selfplay
(
values
,
rewards
,
dones
,
learns
,
nextvalues1
,
nextvalues2
,
next_done
,
args
.
gamma
,
args
.
gae_lambda
)
bootstrap_time
=
time
.
time
()
-
_start
_start
=
time
.
time
()
# flatten the batch
b_obs
=
{
k
:
v
[:
args
.
num_steps
]
.
reshape
((
-
1
,)
+
v
.
shape
[
2
:])
for
k
,
v
in
obs
.
items
()
}
b_actions
=
actions
[:
args
.
num_steps
]
.
reshape
((
-
1
,)
+
action_shape
)
b_logprobs
=
logprobs
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_advantages
=
advantages
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_values
=
values
[:
args
.
num_steps
]
.
reshape
(
-
1
)
b_returns
=
b_advantages
+
b_values
if
args
.
fix_target
:
b_learns
=
learns
[:
args
.
num_steps
]
.
reshape
(
-
1
)
else
:
b_learns
=
torch
.
ones_like
(
b_values
,
dtype
=
torch
.
bool
)
# Optimizing the policy and value network
b_inds
=
np
.
arange
(
args
.
local_batch_size
)
clipfracs
=
[]
for
epoch
in
range
(
args
.
update_epochs
):
np
.
random
.
shuffle
(
b_inds
)
for
start
in
range
(
0
,
args
.
local_batch_size
,
args
.
local_minibatch_size
):
end
=
start
+
args
.
local_minibatch_size
mb_inds
=
b_inds
[
start
:
end
]
mb_obs
=
{
k
:
v
[
mb_inds
]
for
k
,
v
in
b_obs
.
items
()
}
old_approx_kl
,
approx_kl
,
clipfrac
,
pg_loss
,
v_loss
,
entropy_loss
=
\
train_step
(
agent
,
optimizer
,
scaler
,
mb_obs
,
b_actions
[
mb_inds
],
b_logprobs
[
mb_inds
],
b_advantages
[
mb_inds
],
b_returns
[
mb_inds
],
b_values
[
mb_inds
],
b_learns
[
mb_inds
],
args
)
reduce_gradidents
(
optim_params
,
args
.
world_size
)
nn
.
utils
.
clip_grad_norm_
(
optim_params
,
args
.
max_grad_norm
)
scaler
.
step
(
optimizer
)
scaler
.
update
()
clipfracs
.
append
(
clipfrac
.
item
())
if
step
>
0
:
# TODO: use cyclic buffer to avoid copying
for
v
in
obs
.
values
():
v
[:
step
]
=
v
[
args
.
num_steps
:]
.
clone
()
for
v
in
[
actions
,
logprobs
,
rewards
,
dones
,
values
,
learns
]:
v
[:
step
]
=
v
[
args
.
num_steps
:]
.
clone
()
train_time
=
time
.
time
()
-
_start
if
local_rank
==
0
:
fprint
(
f
"train_time={train_time:.4f}, collect_time={collect_time:.4f}, bootstrap_time={bootstrap_time:.4f}"
)
y_pred
,
y_true
=
b_values
.
cpu
()
.
numpy
(),
b_returns
.
cpu
()
.
numpy
()
var_y
=
np
.
var
(
y_true
)
explained_var
=
np
.
nan
if
var_y
==
0
else
1
-
np
.
var
(
y_true
-
y_pred
)
/
var_y
if
rank
==
0
:
if
iteration
%
args
.
save_interval
==
0
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent.pt"
))
writer
.
add_scalar
(
"charts/learning_rate"
,
optimizer
.
param_groups
[
0
][
"lr"
],
global_step
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/policy_loss"
,
pg_loss
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/entropy"
,
entropy_loss
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/old_approx_kl"
,
old_approx_kl
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/approx_kl"
,
approx_kl
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/clipfrac"
,
np
.
mean
(
clipfracs
),
global_step
)
writer
.
add_scalar
(
"losses/explained_variance"
,
explained_var
,
global_step
)
SPS
=
int
((
global_step
-
warmup_steps
)
/
(
time
.
time
()
-
start_time
))
# Warmup at first few iterations for accurate SPS measurement
SPS_warmup_iters
=
10
if
iteration
==
SPS_warmup_iters
:
start_time
=
time
.
time
()
warmup_steps
=
global_step
if
iteration
>
SPS_warmup_iters
:
if
local_rank
==
0
:
fprint
(
f
"SPS: {SPS}"
)
if
rank
==
0
:
writer
.
add_scalar
(
"charts/SPS"
,
SPS
,
global_step
)
if
args
.
fix_target
:
if
rank
==
0
:
should_update
=
len
(
avg_win_rates
)
==
1000
and
np
.
mean
(
avg_win_rates
)
>
args
.
update_win_rate
and
np
.
mean
(
avg_ep_returns
)
>
args
.
update_return
should_update
=
torch
.
tensor
(
int
(
should_update
),
dtype
=
torch
.
int64
,
device
=
device
)
else
:
should_update
=
torch
.
zeros
((),
dtype
=
torch
.
int64
,
device
=
device
)
if
args
.
world_size
>
1
:
dist
.
all_reduce
(
should_update
,
op
=
dist
.
ReduceOp
.
SUM
)
should_update
=
should_update
.
item
()
>
0
if
should_update
:
agent_t
.
load_state_dict
(
agent
.
state_dict
())
with
torch
.
no_grad
():
traced_model_t
=
torch
.
jit
.
trace
(
agent_t
,
(
example_obs
,),
check_tolerance
=
False
,
check_trace
=
False
)
traced_model_t
=
torch
.
jit
.
optimize_for_inference
(
traced_model_t
)
version
+=
1
if
rank
==
0
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_v{version}.pt"
))
print
(
f
"Updating agent at global_step={global_step} with win_rate={np.mean(avg_win_rates)}"
)
avg_win_rates
.
clear
()
avg_ep_returns
.
clear
()
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
# Eval with rule-based policy
_start
=
time
.
time
()
eval_return
=
evaluate
(
eval_envs
,
traced_model
,
local_eval_episodes
,
device
,
args
.
fp16_eval
)[
0
]
eval_stats
=
torch
.
tensor
(
eval_return
,
dtype
=
torch
.
float32
,
device
=
device
)
# sync the statistics
if
args
.
world_size
>
1
:
dist
.
all_reduce
(
eval_stats
,
op
=
dist
.
ReduceOp
.
AVG
)
eval_return
=
eval_stats
.
cpu
()
.
numpy
()
if
rank
==
0
:
writer
.
add_scalar
(
"charts/eval_return"
,
eval_return
,
global_step
)
if
local_rank
==
0
:
eval_time
=
time
.
time
()
-
_start
fprint
(
f
"eval_time={eval_time:.4f}, eval_ep_return={eval_return:.4f}"
)
# Eval with old model
if
args
.
world_size
>
1
:
dist
.
destroy_process_group
()
envs
.
close
()
if
rank
==
0
:
torch
.
save
(
agent
.
state_dict
(),
os
.
path
.
join
(
ckpt_dir
,
f
"agent_final.pt"
))
writer
.
close
()
if
__name__
==
"__main__"
:
main
()
scripts/ppo_c.py
→
scripts/
torch/
ppo_c.py
View file @
bbd36d86
File moved
scripts/ppo_osfp.py
→
scripts/
torch/
ppo_osfp.py
View file @
bbd36d86
File moved
scripts/ppo_xla.py
→
scripts/
torch/
ppo_xla.py
View file @
bbd36d86
File moved
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment