Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
9670ed68
Commit
9670ed68
authored
Apr 10, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactor to main_player
parent
6b23ca2d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
383 additions
and
139 deletions
+383
-139
scripts/battle.py
scripts/battle.py
+37
-29
scripts/jax/battle.py
scripts/jax/battle.py
+270
-0
scripts/jax/ppo_lstm.py
scripts/jax/ppo_lstm.py
+46
-24
scripts/ppo.py
scripts/ppo.py
+16
-16
scripts/ppo_osfp.py
scripts/ppo_osfp.py
+14
-11
ygoai/rl/env.py
ygoai/rl/env.py
+0
-12
ygoai/rl/ppo.py
ygoai/rl/ppo.py
+0
-47
No files found.
scripts/battle.py
View file @
9670ed68
...
@@ -180,14 +180,19 @@ if __name__ == "__main__":
...
@@ -180,14 +180,19 @@ if __name__ == "__main__":
agent1
=
optimize_for_inference
(
agent1
)
agent1
=
optimize_for_inference
(
agent1
)
agent2
=
optimize_for_inference
(
agent2
)
agent2
=
optimize_for_inference
(
agent2
)
def
predict_fn
(
agent
,
obs
):
def
predict_fn
(
obs
,
main
):
obs
=
optree
.
tree_map
(
lambda
x
:
torch
.
from_numpy
(
x
)
.
to
(
device
=
device
),
obs
)
obs
=
optree
.
tree_map
(
lambda
x
:
torch
.
from_numpy
(
x
)
.
to
(
device
=
device
),
obs
)
probs
=
get_probs
(
agent
,
obs
)
if
num_envs
!=
1
:
probs1
=
get_probs
(
agent
,
obs
)
probs2
=
get_probs
(
agent
,
obs
)
probs
=
torch
.
where
(
main
[:,
None
],
probs1
,
probs2
)
else
:
if
main
[
0
]:
probs
=
get_probs
(
agent1
,
obs
)
else
:
probs
=
get_probs
(
agent2
,
obs
)
probs
=
probs
.
cpu
()
.
numpy
()
probs
=
probs
.
cpu
()
.
numpy
()
return
probs
return
probs
predict_fn1
=
lambda
obs
:
predict_fn
(
agent1
,
obs
)
predict_fn2
=
lambda
obs
:
predict_fn
(
agent2
,
obs
)
else
:
else
:
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
...
@@ -214,21 +219,30 @@ if __name__ == "__main__":
...
@@ -214,21 +219,30 @@ if __name__ == "__main__":
params2
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params2
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
@
jax
.
jit
@
jax
.
jit
def
get_probs
(
def
get_probs
(
params
,
obs
):
params
:
flax
.
core
.
FrozenDict
,
agent
=
create_agent
(
args
)
next_obs
,
logits
=
agent
.
apply
(
params
,
obs
)[
0
]
):
logits
=
create_agent
(
args
)
.
apply
(
params
,
next_obs
)[
0
]
return
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
return
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
def
predict_fn
(
params
,
obs
)
:
if
args
.
num_envs
!=
1
:
probs
=
get_probs
(
params
,
obs
)
@
jax
.
jit
return
np
.
array
(
probs
)
def
get_probs2
(
params1
,
params2
,
obs
,
main
):
probs1
=
get_probs
(
params1
,
obs
)
predict_fn1
=
lambda
obs
:
predict_fn
(
params1
,
obs
)
probs2
=
get_probs
(
params2
,
obs
)
predict_fn2
=
lambda
obs
:
predict_fn
(
params2
,
obs
)
probs
=
jnp
.
where
(
main
[:,
None
],
probs1
,
probs2
)
return
probs
def
predict_fn
(
obs
,
main
):
probs
=
get_probs2
(
params1
,
params2
,
obs
,
main
)
return
np
.
array
(
probs
)
else
:
def
predict_fn
(
obs
,
main
):
if
main
[
0
]:
probs
=
get_probs
(
params1
,
obs
)
else
:
probs
=
get_probs
(
params2
,
obs
)
return
np
.
array
(
probs
)
obs
,
infos
=
envs
.
reset
()
obs
,
infos
=
envs
.
reset
()
next_to_play
=
infos
[
'to_play'
]
next_to_play
=
infos
[
'to_play'
]
...
@@ -241,7 +255,7 @@ if __name__ == "__main__":
...
@@ -241,7 +255,7 @@ if __name__ == "__main__":
start
=
time
.
time
()
start
=
time
.
time
()
start_step
=
step
start_step
=
step
player1
=
np
.
concatenate
([
main_player
=
np
.
concatenate
([
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
num_envs
-
num_envs
//
2
,
dtype
=
np
.
int64
)
np
.
ones
(
num_envs
-
num_envs
//
2
,
dtype
=
np
.
int64
)
])
])
...
@@ -254,15 +268,9 @@ if __name__ == "__main__":
...
@@ -254,15 +268,9 @@ if __name__ == "__main__":
model_time
=
env_time
=
0
model_time
=
env_time
=
0
_start
=
time
.
time
()
_start
=
time
.
time
()
if
args
.
num_envs
!=
1
:
probs1
=
predict_fn1
(
obs
)
main
=
next_to_play
==
main_player
probs2
=
predict_fn2
(
obs
)
probs
=
predict_fn
(
obs
,
main
)
probs
=
np
.
where
((
next_to_play
==
player1
)[:,
None
],
probs1
,
probs2
)
else
:
if
(
next_to_play
==
player1
)
.
all
():
probs
=
predict_fn1
(
obs
)
else
:
probs
=
predict_fn2
(
obs
)
actions
=
probs
.
argmax
(
axis
=
1
)
actions
=
probs
.
argmax
(
axis
=
1
)
model_time
+=
time
.
time
()
-
_start
model_time
+=
time
.
time
()
-
_start
...
@@ -279,7 +287,7 @@ if __name__ == "__main__":
...
@@ -279,7 +287,7 @@ if __name__ == "__main__":
for
idx
,
d
in
enumerate
(
dones
):
for
idx
,
d
in
enumerate
(
dones
):
if
d
:
if
d
:
win_reason
=
infos
[
'win_reason'
][
idx
]
win_reason
=
infos
[
'win_reason'
][
idx
]
pl
=
1
if
to_play
[
idx
]
==
player1
[
idx
]
else
-
1
pl
=
1
if
to_play
[
idx
]
==
main_player
[
idx
]
else
-
1
episode_length
=
infos
[
'l'
][
idx
]
episode_length
=
infos
[
'l'
][
idx
]
episode_reward
=
infos
[
'r'
][
idx
]
*
pl
episode_reward
=
infos
[
'r'
][
idx
]
*
pl
win
=
int
(
episode_reward
>
0
)
win
=
int
(
episode_reward
>
0
)
...
@@ -292,7 +300,7 @@ if __name__ == "__main__":
...
@@ -292,7 +300,7 @@ if __name__ == "__main__":
# Only when num_envs=1, we switch the player here
# Only when num_envs=1, we switch the player here
if
args
.
verbose
:
if
args
.
verbose
:
player1
=
1
-
player1
main_player
=
1
-
main_player
if
len
(
episode_lengths
)
>=
args
.
num_episodes
:
if
len
(
episode_lengths
)
>=
args
.
num_episodes
:
break
break
...
...
scripts/jax/battle.py
0 → 100644
View file @
9670ed68
import
sys
import
time
import
os
import
random
from
typing
import
Optional
,
Literal
from
dataclasses
import
dataclass
import
ygoenv
import
numpy
as
np
import
tyro
import
jax
import
jax.numpy
as
jnp
import
flax
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.jax.agent2
import
PPOLSTMAgent
@
dataclass
class
Args
:
seed
:
int
=
1
"""the random seed"""
env_id
:
str
=
"YGOPro-v0"
"""the id of the environment"""
deck
:
str
=
"../assets/deck"
"""the deck file to use"""
deck1
:
Optional
[
str
]
=
None
"""the deck file for the first player"""
deck2
:
Optional
[
str
]
=
None
"""the deck file for the second player"""
code_list_file
:
str
=
"code_list.txt"
"""the code list file for card embeddings"""
lang
:
str
=
"english"
"""the language to use"""
max_options
:
int
=
24
"""the maximum number of options"""
n_history_actions
:
int
=
32
"""the number of history actions to use"""
num_embeddings
:
Optional
[
int
]
=
None
"""the number of embeddings of the agent"""
record
:
bool
=
False
"""whether to record the game as YGOPro replays"""
num_episodes
:
int
=
1024
"""the number of episodes to run"""
num_envs
:
int
=
64
"""the number of parallel game environments"""
verbose
:
bool
=
False
"""whether to print debug information"""
num_layers
:
int
=
2
"""the number of layers for the agent"""
num_channels
:
int
=
128
"""the number of channels for the agent"""
rnn_channels
:
Optional
[
int
]
=
512
"""the number of rnn channels for the agent"""
checkpoint1
:
str
=
"checkpoints/agent.pt"
"""the checkpoint to load for the first agent, `pt` or `flax_model` file"""
checkpoint2
:
str
=
"checkpoints/agent.pt"
"""the checkpoint to load for the second agent, `pt` or `flax_model` file"""
# Jax specific
xla_device
:
Optional
[
str
]
=
None
"""the XLA device to use, defaults to `None`"""
# PyTorch specific
torch_deterministic
:
bool
=
True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda
:
bool
=
True
"""if toggled, cuda will be enabled by default"""
compile
:
bool
=
False
"""if toggled, the model will be compiled"""
optimize
:
bool
=
False
"""if toggled, the model will be optimized"""
torch_threads
:
Optional
[
int
]
=
None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads
:
Optional
[
int
]
=
16
"""the number of threads to use for envpool, defaults to `num_envs`"""
framework
:
Optional
[
Literal
[
"torch"
,
"jax"
]]
=
None
def
create_agent
(
args
):
return
PPOLSTMAgent
(
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
lstm_channels
=
args
.
rnn_channels
,
embedding_shape
=
args
.
num_embeddings
,
)
def
init_rnn_state
(
num_envs
,
rnn_channels
):
return
(
np
.
zeros
((
num_envs
,
rnn_channels
)),
np
.
zeros
((
num_envs
,
rnn_channels
)),
)
if
__name__
==
"__main__"
:
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
args
=
tyro
.
cli
(
Args
)
if
args
.
record
:
assert
args
.
num_envs
==
1
,
"Recording only works with a single environment"
assert
args
.
verbose
,
"Recording only works with verbose mode"
if
not
os
.
path
.
exists
(
"replay"
):
os
.
makedirs
(
"replay"
)
args
.
env_threads
=
min
(
args
.
env_threads
or
args
.
num_envs
,
args
.
num_envs
)
deck
=
init_ygopro
(
args
.
env_id
,
args
.
lang
,
args
.
deck
,
args
.
code_list_file
)
args
.
deck1
=
args
.
deck1
or
deck
args
.
deck2
=
args
.
deck2
or
deck
seed
=
args
.
seed
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
if
args
.
xla_device
is
not
None
:
os
.
environ
.
setdefault
(
"JAX_PLATFORMS"
,
args
.
xla_device
)
num_envs
=
args
.
num_envs
envs
=
ygoenv
.
make
(
task_id
=
args
.
env_id
,
env_type
=
"gymnasium"
,
num_envs
=
num_envs
,
num_threads
=
args
.
env_threads
,
seed
=
seed
,
deck1
=
args
.
deck1
,
deck2
=
args
.
deck2
,
player
=-
1
,
max_options
=
args
.
max_options
,
n_history_actions
=
args
.
n_history_actions
,
play_mode
=
'self'
,
async_reset
=
False
,
verbose
=
args
.
verbose
,
record
=
args
.
record
,
)
obs_space
=
envs
.
observation_space
envs
.
num_envs
=
num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
agent
=
create_agent
(
args
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
sample_obs
=
jax
.
tree_map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
params
=
jax
.
jit
(
agent
.
init
)(
agent_key
,
(
rstate
,
sample_obs
))
with
open
(
args
.
checkpoint1
,
"rb"
)
as
f
:
params1
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
if
args
.
checkpoint1
==
args
.
checkpoint2
:
params2
=
params1
else
:
with
open
(
args
.
checkpoint2
,
"rb"
)
as
f
:
params2
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
@
jax
.
jit
def
get_probs
(
params
,
rstate
,
obs
,
done
):
agent
=
create_agent
(
args
)
next_rstate
,
logits
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
2
]
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
next_rstate
=
jax
.
tree_map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
next_rstate
)
return
next_rstate
,
probs
if
args
.
num_envs
!=
1
:
@
jax
.
jit
def
get_probs2
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
probs1
=
get_probs
(
params1
,
rstate1
,
obs
,
done
)
next_rstate2
,
probs2
=
get_probs
(
params2
,
rstate2
,
obs
,
done
)
probs
=
jnp
.
where
(
main
[:,
None
],
probs1
,
probs2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
next_rstate2
,
rstate2
)
return
rstate1
,
rstate2
,
probs
def
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
done
):
rstate1
,
rstate2
,
probs
=
get_probs2
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
)
return
rstate1
,
rstate2
,
np
.
array
(
probs
)
else
:
def
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
done
):
if
main
[
0
]:
rstate1
,
probs
=
get_probs
(
params1
,
rstate1
,
obs
,
done
)
else
:
rstate2
,
probs
=
get_probs
(
params2
,
rstate2
,
obs
,
done
)
return
rstate1
,
rstate2
,
np
.
array
(
probs
)
obs
,
infos
=
envs
.
reset
()
next_to_play
=
infos
[
'to_play'
]
dones
=
np
.
zeros
(
num_envs
,
dtype
=
np
.
bool_
)
episode_rewards
=
[]
episode_lengths
=
[]
win_rates
=
[]
win_reasons
=
[]
step
=
0
start
=
time
.
time
()
start_step
=
step
main_player
=
np
.
concatenate
([
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
num_envs
-
num_envs
//
2
,
dtype
=
np
.
int64
)
])
rstate1
=
rstate2
=
init_rnn_state
(
num_envs
,
args
.
rnn_channels
)
model_time
=
env_time
=
0
while
True
:
if
start_step
==
0
and
len
(
episode_lengths
)
>
int
(
args
.
num_episodes
*
0.1
):
start
=
time
.
time
()
start_step
=
step
model_time
=
env_time
=
0
_start
=
time
.
time
()
main
=
next_to_play
==
main_player
rstate1
,
rstate2
,
probs
=
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
dones
)
actions
=
probs
.
argmax
(
axis
=
1
)
model_time
+=
time
.
time
()
-
_start
to_play
=
next_to_play
_start
=
time
.
time
()
obs
,
rewards
,
dones
,
infos
=
envs
.
step
(
actions
)
next_to_play
=
infos
[
'to_play'
]
env_time
+=
time
.
time
()
-
_start
step
+=
1
for
idx
,
d
in
enumerate
(
dones
):
if
d
:
win_reason
=
infos
[
'win_reason'
][
idx
]
pl
=
1
if
to_play
[
idx
]
==
main_player
[
idx
]
else
-
1
episode_length
=
infos
[
'l'
][
idx
]
episode_reward
=
infos
[
'r'
][
idx
]
*
pl
win
=
int
(
episode_reward
>
0
)
episode_lengths
.
append
(
episode_length
)
episode_rewards
.
append
(
episode_reward
)
win_rates
.
append
(
win
)
win_reasons
.
append
(
1
if
win_reason
==
1
else
0
)
sys
.
stderr
.
write
(
f
"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}
\n
"
)
# Only when num_envs=1, we switch the player here
if
args
.
verbose
:
main_player
=
1
-
main_player
if
len
(
episode_lengths
)
>=
args
.
num_episodes
:
break
print
(
f
"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}"
)
total_time
=
time
.
time
()
-
start
total_steps
=
(
step
-
start_step
)
*
num_envs
print
(
f
"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}"
)
print
(
f
"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}"
)
\ No newline at end of file
scripts/jax/ppo_lstm.py
View file @
9670ed68
...
@@ -312,7 +312,7 @@ def rollout(
...
@@ -312,7 +312,7 @@ def rollout(
params_queue_get_time
.
append
(
time
.
time
()
-
params_queue_get_time_start
)
params_queue_get_time
.
append
(
time
.
time
()
-
params_queue_get_time_start
)
rollout_time_start
=
time
.
time
()
rollout_time_start
=
time
.
time
()
init
ial_rstate1
,
initial
_rstate2
=
jax
.
tree
.
map
(
init
_rstate1
,
init
_rstate2
=
jax
.
tree
.
map
(
lambda
x
:
x
.
copy
(),
(
next_rstate1
,
next_rstate2
))
lambda
x
:
x
.
copy
(),
(
next_rstate1
,
next_rstate2
))
for
_
in
range
(
start_step
,
args
.
collect_length
):
for
_
in
range
(
start_step
,
args
.
collect_length
):
global_step
+=
args
.
local_num_envs
*
n_actors
*
args
.
world_size
global_step
+=
args
.
local_num_envs
*
n_actors
*
args
.
world_size
...
@@ -385,15 +385,10 @@ def rollout(
...
@@ -385,15 +385,10 @@ def rollout(
next_main
=
main_player
==
next_to_play
next_main
=
main_player
==
next_to_play
next_rstate
=
jax
.
tree
.
map
(
next_rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
next_rstate1
,
next_rstate2
)
lambda
x1
,
x2
:
jnp
.
where
(
next_main
[:,
None
],
x1
,
x2
),
next_rstate1
,
next_rstate2
)
# initial_rstate1: main, initial_rstate2: opponent
# init rstate1: == next_main, init rstate2: != next_main
init_rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
next_main
[:,
None
],
x
,
y
),
initial_rstate1
,
initial_rstate2
)
init_rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
next_main
[:,
None
],
y
,
x
),
initial_rstate1
,
initial_rstate2
)
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
sharded_data
=
jax
.
tree
.
map
(
lambda
x
:
jax
.
device_put_sharded
(
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
np
.
split
(
x
,
len
(
learner_devices
)),
devices
=
learner_devices
),
(
next_obs
,
next_rstate
,
init_rstate1
,
init_rstate2
,
next_done
,
next_main
))
(
next_obs
,
next_rstate
,
init_rstate1
,
init_rstate2
,
next_done
,
next_main
))
learn_opponent
=
False
payload
=
(
payload
=
(
global_step
,
global_step
,
actor_policy_version
,
actor_policy_version
,
...
@@ -402,6 +397,7 @@ def rollout(
...
@@ -402,6 +397,7 @@ def rollout(
*
sharded_data
,
*
sharded_data
,
np
.
mean
(
params_queue_get_time
),
np
.
mean
(
params_queue_get_time
),
device_thread_id
,
device_thread_id
,
learn_opponent
,
)
)
rollout_queue
.
put
(
payload
)
rollout_queue
.
put
(
payload
)
...
@@ -565,9 +561,10 @@ if __name__ == "__main__":
...
@@ -565,9 +561,10 @@ if __name__ == "__main__":
return
logprob
,
probs
,
entropy
,
value
.
squeeze
(),
valid
return
logprob
,
probs
,
entropy
,
value
.
squeeze
(),
valid
def
ppo_loss
(
def
ppo_loss
(
params
,
inputs
,
actions
,
logprobs
,
probs
,
advantages
,
target_values
):
params
,
inputs
,
actions
,
logprobs
,
probs
,
advantages
,
target_values
,
mask
):
newlogprob
,
newprobs
,
entropy
,
newvalue
,
valid
=
\
newlogprob
,
newprobs
,
entropy
,
newvalue
,
valid
=
\
get_logprob_entropy_value
(
params
,
inputs
,
actions
)
get_logprob_entropy_value
(
params
,
inputs
,
actions
)
valid
=
valid
&
mask
logratio
=
newlogprob
-
logprobs
logratio
=
newlogprob
-
logprobs
ratio
=
jnp
.
exp
(
logratio
)
ratio
=
jnp
.
exp
(
logratio
)
approx_kl
=
((
ratio
-
1
)
-
logratio
)
.
mean
()
approx_kl
=
((
ratio
-
1
)
-
logratio
)
.
mean
()
...
@@ -600,7 +597,6 @@ if __name__ == "__main__":
...
@@ -600,7 +597,6 @@ if __name__ == "__main__":
loss
=
pg_loss
-
args
.
ent_coef
*
entropy_loss
+
v_loss
*
args
.
vf_coef
loss
=
pg_loss
-
args
.
ent_coef
*
entropy_loss
+
v_loss
*
args
.
vf_coef
return
loss
,
(
pg_loss
,
v_loss
,
entropy_loss
,
jax
.
lax
.
stop_gradient
(
approx_kl
))
return
loss
,
(
pg_loss
,
v_loss
,
entropy_loss
,
jax
.
lax
.
stop_gradient
(
approx_kl
))
@
jax
.
jit
def
single_device_update
(
def
single_device_update
(
agent_state
:
TrainState
,
agent_state
:
TrainState
,
sharded_storages
:
List
,
sharded_storages
:
List
,
...
@@ -611,11 +607,12 @@ if __name__ == "__main__":
...
@@ -611,11 +607,12 @@ if __name__ == "__main__":
sharded_next_done
:
List
,
sharded_next_done
:
List
,
sharded_next_main
:
List
,
sharded_next_main
:
List
,
key
:
jax
.
random
.
PRNGKey
,
key
:
jax
.
random
.
PRNGKey
,
learn_opponent
:
bool
=
False
,
):
):
def
reshape_minibatch
(
x
,
num_minibatches
,
multi_step
=
False
):
def
reshape_minibatch
(
x
,
num_minibatches
,
num_steps
=
1
):
N
=
num_minibatches
N
=
num_minibatches
if
multi_step
:
if
num_steps
>
1
:
x
=
jnp
.
reshape
(
x
,
(
args
.
num_steps
,
N
,
-
1
)
+
x
.
shape
[
2
:])
x
=
jnp
.
reshape
(
x
,
(
num_steps
,
N
,
-
1
)
+
x
.
shape
[
2
:])
x
=
x
.
transpose
(
1
,
0
,
*
range
(
2
,
x
.
ndim
))
x
=
x
.
transpose
(
1
,
0
,
*
range
(
2
,
x
.
ndim
))
x
=
x
.
reshape
(
N
,
-
1
,
*
x
.
shape
[
3
:])
x
=
x
.
reshape
(
N
,
-
1
,
*
x
.
shape
[
3
:])
else
:
else
:
...
@@ -632,21 +629,37 @@ if __name__ == "__main__":
...
@@ -632,21 +629,37 @@ if __name__ == "__main__":
]
]
# reorder storage of individual players
# reorder storage of individual players
# main first, opponent second
num_steps
,
num_envs
=
storage
.
rewards
.
shape
num_steps
,
num_envs
=
storage
.
rewards
.
shape
T
=
jnp
.
arange
(
num_steps
,
dtype
=
jnp
.
int32
)
T
=
jnp
.
arange
(
num_steps
,
dtype
=
jnp
.
int32
)
B
=
jnp
.
arange
(
num_envs
,
dtype
=
jnp
.
int32
)
B
=
jnp
.
arange
(
num_envs
,
dtype
=
jnp
.
int32
)
mains
=
(
storage
.
mains
==
next_main
)
.
astype
(
jnp
.
int32
)
mains
=
storage
.
mains
.
astype
(
jnp
.
int32
)
indices
=
jnp
.
argsort
(
T
[:,
None
]
+
mains
*
num_steps
,
axis
=
0
)
indices
=
jnp
.
argsort
(
T
[:,
None
]
-
mains
*
num_steps
,
axis
=
0
)
switch
=
T
[:,
None
]
==
(
num_steps
-
1
-
jnp
.
sum
(
mains
,
axis
=
0
))
switch_steps
=
jnp
.
sum
(
mains
,
axis
=
0
)
switch
=
T
[:,
None
]
==
(
switch_steps
[
None
,
:]
-
1
)
if
not
learn_opponent
:
num_steps
=
int
(
num_steps
*
0.75
)
indices
=
indices
[:
num_steps
+
1
]
switch
=
switch
[:
num_steps
]
storage
=
jax
.
tree
.
map
(
lambda
x
:
x
[
indices
,
B
[
None
,
:]],
storage
)
storage
=
jax
.
tree
.
map
(
lambda
x
:
x
[
indices
,
B
[
None
,
:]],
storage
)
if
not
learn_opponent
:
next_obs
=
jax
.
tree
.
map
(
lambda
x
:
x
[
num_steps
],
storage
.
obs
)
next_done
=
storage
.
dones
[
num_steps
]
next_main
=
storage
.
mains
[
num_steps
]
storage
=
jax
.
tree
.
map
(
lambda
x
:
x
[:
num_steps
],
storage
)
# split minibatches for recompute values
# split minibatches for recompute values
n_mbs
=
args
.
num_minibatches
//
4
num_minibatches
=
args
.
num_minibatches
if
not
learn_opponent
:
num_minibatches
=
num_minibatches
//
2
n_mbs
=
num_minibatches
//
4
split_init_rstate
=
jax
.
tree
.
map
(
split_init_rstate
=
jax
.
tree
.
map
(
partial
(
reshape_minibatch
,
num_minibatches
=
n_mbs
),
partial
(
reshape_minibatch
,
num_minibatches
=
n_mbs
),
(
init_rstate1
,
init_rstate2
))
(
init_rstate1
,
init_rstate2
))
split_inputs
=
jax
.
tree
.
map
(
split_inputs
=
jax
.
tree
.
map
(
partial
(
reshape_minibatch
,
num_minibatches
=
n_mbs
,
multi_step
=
True
),
partial
(
reshape_minibatch
,
num_minibatches
=
n_mbs
,
num_steps
=
num_steps
),
(
storage
.
obs
,
storage
.
dones
,
switch
))
(
storage
.
obs
,
storage
.
dones
,
switch
))
split_inputs
=
split_init_rstate
+
split_inputs
split_inputs
=
split_init_rstate
+
split_inputs
...
@@ -663,27 +676,32 @@ if __name__ == "__main__":
...
@@ -663,27 +676,32 @@ if __name__ == "__main__":
_
,
values
=
jax
.
lax
.
scan
(
_
,
values
=
jax
.
lax
.
scan
(
get_value_minibatch
,
agent_state
,
split_inputs
)
get_value_minibatch
,
agent_state
,
split_inputs
)
values
=
values
.
reshape
((
n_mbs
,
args
.
num_steps
,
-
1
))
.
transpose
(
1
,
0
,
2
)
values
=
values
.
reshape
((
n_mbs
,
num_steps
,
-
1
))
.
transpose
(
1
,
0
,
2
)
values
=
values
.
reshape
(
storage
.
rewards
.
shape
)
values
=
values
.
reshape
(
storage
.
rewards
.
shape
)
next_value
=
create_agent
(
args
)
.
apply
(
next_value
=
create_agent
(
args
)
.
apply
(
agent_state
.
params
,
(
next_rstate
,
next_obs
))[
2
]
.
squeeze
(
-
1
)
agent_state
.
params
,
(
next_rstate
,
next_obs
))[
2
]
.
squeeze
(
-
1
)
# TODO: check if this is correct
sign
=
jnp
.
where
(
switch_steps
<=
num_steps
,
1.0
,
-
1.0
)
next_value
=
jnp
.
where
(
next_main
,
-
sign
*
next_value
,
sign
*
next_value
)
compute_gae_fn
=
compute_gae_upgo_2p0s
if
args
.
upgo
else
compute_gae_2p0s
compute_gae_fn
=
compute_gae_upgo_2p0s
if
args
.
upgo
else
compute_gae_2p0s
advantages
,
target_values
=
compute_gae_fn
(
advantages
,
target_values
=
compute_gae_fn
(
next_value
,
next_done
,
values
,
storage
.
rewards
,
storage
.
dones
,
switch
,
next_value
,
next_done
,
values
,
storage
.
rewards
,
storage
.
dones
,
switch
,
args
.
gamma
,
args
.
gae_lambda
)
args
.
gamma
,
args
.
gae_lambda
)
advantages
=
advantages
[:
args
.
num_steps
]
target_values
=
target_values
[:
args
.
num_steps
]
def
convert_data
(
x
:
jnp
.
ndarray
,
multi_step
):
def
convert_data
(
x
:
jnp
.
ndarray
,
num_steps
):
x
=
jax
.
random
.
permutation
(
subkey
,
x
,
axis
=
1
)
x
=
jax
.
random
.
permutation
(
subkey
,
x
,
axis
=
1
)
return
reshape_minibatch
(
x
,
args
.
num_minibatches
,
multi_step
)
return
reshape_minibatch
(
x
,
num_minibatches
,
num_steps
)
shuffled_init_rstate1
,
shuffled_init_rstate2
=
jax
.
tree
.
map
(
shuffled_init_rstate1
,
shuffled_init_rstate2
=
jax
.
tree
.
map
(
partial
(
convert_data
,
multi_step
=
False
),
(
init_rstate1
,
init_rstate2
))
partial
(
convert_data
,
num_steps
=
1
),
(
init_rstate1
,
init_rstate2
))
shuffled_storage
,
shuffled_switch
,
shuffled_advantages
,
shuffled_target_values
=
jax
.
tree
.
map
(
shuffled_storage
,
shuffled_switch
,
shuffled_advantages
,
shuffled_target_values
=
jax
.
tree
.
map
(
partial
(
convert_data
,
multi_step
=
True
),
(
storage
,
switch
,
advantages
,
target_values
))
partial
(
convert_data
,
num_steps
=
num_steps
),
(
storage
,
switch
,
advantages
,
target_values
))
if
learn_opponent
:
shuffled_mask
=
jnp
.
ones_like
(
shuffled_storage
.
mains
)
else
:
shuffled_mask
=
shuffled_storage
.
mains
def
update_minibatch
(
agent_state
,
minibatch
):
def
update_minibatch
(
agent_state
,
minibatch
):
(
loss
,
(
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)),
grads
=
ppo_loss_grad_fn
(
(
loss
,
(
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)),
grads
=
ppo_loss_grad_fn
(
...
@@ -708,6 +726,7 @@ if __name__ == "__main__":
...
@@ -708,6 +726,7 @@ if __name__ == "__main__":
shuffled_storage
.
probs
,
shuffled_storage
.
probs
,
shuffled_advantages
,
shuffled_advantages
,
shuffled_target_values
,
shuffled_target_values
,
shuffled_mask
,
),
),
)
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
entropy_loss
,
approx_kl
)
...
@@ -726,6 +745,7 @@ if __name__ == "__main__":
...
@@ -726,6 +745,7 @@ if __name__ == "__main__":
single_device_update
,
single_device_update
,
axis_name
=
"local_devices"
,
axis_name
=
"local_devices"
,
devices
=
global_learner_decices
,
devices
=
global_learner_decices
,
static_broadcasted_argnums
=
(
9
,),
)
)
params_queues
=
[]
params_queues
=
[]
...
@@ -771,6 +791,7 @@ if __name__ == "__main__":
...
@@ -771,6 +791,7 @@ if __name__ == "__main__":
*
sharded_data
,
*
sharded_data
,
avg_params_queue_get_time
,
avg_params_queue_get_time
,
device_thread_id
,
device_thread_id
,
learn_opponent
,
)
=
rollout_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
get
()
)
=
rollout_queues
[
d_idx
*
args
.
num_actor_threads
+
thread_id
]
.
get
()
sharded_data_list
.
append
(
sharded_data
)
sharded_data_list
.
append
(
sharded_data
)
rollout_queue_get_time
.
append
(
time
.
time
()
-
rollout_queue_get_time_start
)
rollout_queue_get_time
.
append
(
time
.
time
()
-
rollout_queue_get_time_start
)
...
@@ -779,6 +800,7 @@ if __name__ == "__main__":
...
@@ -779,6 +800,7 @@ if __name__ == "__main__":
agent_state
,
agent_state
,
*
list
(
zip
(
*
sharded_data_list
)),
*
list
(
zip
(
*
sharded_data_list
)),
learner_keys
,
learner_keys
,
learn_opponent
,
)
)
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
...
...
scripts/ppo.py
View file @
9670ed68
...
@@ -22,7 +22,7 @@ from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
...
@@ -22,7 +22,7 @@ from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
from
ygoai.rl.agent
import
PPOAgent
as
Agent
from
ygoai.rl.agent
import
PPOAgent
as
Agent
from
ygoai.rl.dist
import
reduce_gradidents
,
torchrun_setup
,
fprint
from
ygoai.rl.dist
import
reduce_gradidents
,
torchrun_setup
,
fprint
from
ygoai.rl.buffer
import
create_obs
from
ygoai.rl.buffer
import
create_obs
from
ygoai.rl.ppo
import
bootstrap_value_selfplay
from
ygoai.rl.ppo
import
bootstrap_value_selfplay
,
train_step
as
train_step_
from
ygoai.rl.eval
import
evaluate
from
ygoai.rl.eval
import
evaluate
...
@@ -242,6 +242,7 @@ def main():
...
@@ -242,6 +242,7 @@ def main():
embedding_shape
=
None
embedding_shape
=
None
L
=
args
.
num_layers
L
=
args
.
num_layers
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
torch
.
manual_seed
(
args
.
seed
)
agent
.
eval
()
agent
.
eval
()
if
args
.
checkpoint
:
if
args
.
checkpoint
:
...
@@ -271,7 +272,6 @@ def main():
...
@@ -271,7 +272,6 @@ def main():
logits
,
value
,
valid
=
agent
(
next_obs
)
logits
,
value
,
valid
=
agent
(
next_obs
)
return
logits
,
value
return
logits
,
value
from
ygoai.rl.ppo
import
train_step
if
args
.
compile
:
if
args
.
compile
:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
# predict_step = torch.compile(predict_step, mode=args.compile)
...
@@ -284,10 +284,11 @@ def main():
...
@@ -284,10 +284,11 @@ def main():
else
:
else
:
traced_model_t
=
traced_model
traced_model_t
=
traced_model
train_step
=
torch
.
compile
(
train_step
,
mode
=
args
.
compile
)
train_step
=
torch
.
compile
(
train_step
_
,
mode
=
args
.
compile
)
else
:
else
:
traced_model
=
agent
traced_model
=
agent
traced_model_t
=
agent_t
traced_model_t
=
agent_t
train_step
=
train_step_
# ALGO Logic: Storage setup
# ALGO Logic: Storage setup
obs
=
create_obs
(
obs_space
,
(
args
.
collect_length
,
args
.
local_num_envs
),
device
)
obs
=
create_obs
(
obs_space
,
(
args
.
collect_length
,
args
.
local_num_envs
),
device
)
...
@@ -310,12 +311,12 @@ def main():
...
@@ -310,12 +311,12 @@ def main():
next_to_play_
=
info
[
"to_play"
]
next_to_play_
=
info
[
"to_play"
]
next_to_play
=
to_tensor
(
next_to_play_
,
device
)
next_to_play
=
to_tensor
(
next_to_play_
,
device
)
next_done
=
torch
.
zeros
(
args
.
local_num_envs
,
device
=
device
,
dtype
=
torch
.
bool
)
next_done
=
torch
.
zeros
(
args
.
local_num_envs
,
device
=
device
,
dtype
=
torch
.
bool
)
ai_player1
_
=
np
.
concatenate
([
main_player
_
=
np
.
concatenate
([
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
])
np
.
random
.
shuffle
(
ai_player1
_
)
np
.
random
.
shuffle
(
main_player
_
)
ai_player1
=
to_tensor
(
ai_player1
_
,
device
,
dtype
=
next_to_play
.
dtype
)
main_player
=
to_tensor
(
main_player
_
,
device
,
dtype
=
next_to_play
.
dtype
)
step
=
0
step
=
0
for
iteration
in
range
(
args
.
num_iterations
):
for
iteration
in
range
(
args
.
num_iterations
):
...
@@ -334,7 +335,7 @@ def main():
...
@@ -334,7 +335,7 @@ def main():
for
key
in
obs
:
for
key
in
obs
:
obs
[
key
][
step
]
=
next_obs
[
key
]
obs
[
key
][
step
]
=
next_obs
[
key
]
dones
[
step
]
=
next_done
dones
[
step
]
=
next_done
learn
=
next_to_play
==
ai_player1
learn
=
next_to_play
==
main_player
learns
[
step
]
=
learn
learns
[
step
]
=
learn
_start
=
time
.
time
()
_start
=
time
.
time
()
...
@@ -369,7 +370,7 @@ def main():
...
@@ -369,7 +370,7 @@ def main():
for
idx
,
d
in
enumerate
(
next_done_
):
for
idx
,
d
in
enumerate
(
next_done_
):
if
d
:
if
d
:
pl
=
1
if
to_play
[
idx
]
==
ai_player1
_
[
idx
]
else
-
1
pl
=
1
if
to_play
[
idx
]
==
main_player
_
[
idx
]
else
-
1
episode_length
=
info
[
'l'
][
idx
]
episode_length
=
info
[
'l'
][
idx
]
episode_reward
=
info
[
'r'
][
idx
]
*
pl
episode_reward
=
info
[
'r'
][
idx
]
*
pl
win
=
1
if
episode_reward
>
0
else
0
win
=
1
if
episode_reward
>
0
else
0
...
@@ -395,14 +396,13 @@ def main():
...
@@ -395,14 +396,13 @@ def main():
_start
=
time
.
time
()
_start
=
time
.
time
()
# bootstrap value if not done
# bootstrap value if not done
with
torch
.
no_grad
():
value
=
predict_step
(
traced_model
,
next_obs
)[
1
]
.
reshape
(
-
1
)
value
=
predict_step
(
traced_model
,
next_obs
)[
1
]
.
reshape
(
-
1
)
nextvalues1
=
torch
.
where
(
next_to_play
==
main_player
,
value
,
-
value
)
nextvalues1
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
-
value
)
if
args
.
fix_target
:
if
args
.
fix_target
:
value_t
=
predict_step
(
traced_model_t
,
next_obs
)[
1
]
.
reshape
(
-
1
)
value_t
=
predict_step
(
traced_model_t
,
next_obs
)[
1
]
.
reshape
(
-
1
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
main_player
,
value_t
,
-
value_t
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
ai_player1
,
value_t
,
-
value_t
)
else
:
else
:
nextvalues2
=
-
nextvalues1
nextvalues2
=
-
nextvalues1
if
step
>
0
and
iteration
!=
0
:
if
step
>
0
and
iteration
!=
0
:
# recalculate the values for the first few steps
# recalculate the values for the first few steps
...
...
scripts/ppo_osfp.py
View file @
9670ed68
...
@@ -21,7 +21,7 @@ from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
...
@@ -21,7 +21,7 @@ from ygoai.rl.utils import RecordEpisodeStatistics, to_tensor, load_embeddings
from
ygoai.rl.agent
import
PPOAgent
as
Agent
from
ygoai.rl.agent
import
PPOAgent
as
Agent
from
ygoai.rl.dist
import
reduce_gradidents
,
torchrun_setup
,
fprint
from
ygoai.rl.dist
import
reduce_gradidents
,
torchrun_setup
,
fprint
from
ygoai.rl.buffer
import
create_obs
from
ygoai.rl.buffer
import
create_obs
from
ygoai.rl.ppo
import
bootstrap_value_selfplay
from
ygoai.rl.ppo
import
bootstrap_value_selfplay
,
train_step
as
train_step_
from
ygoai.rl.eval
import
evaluate
from
ygoai.rl.eval
import
evaluate
...
@@ -261,6 +261,7 @@ def main():
...
@@ -261,6 +261,7 @@ def main():
embedding_shape
=
None
embedding_shape
=
None
L
=
args
.
num_layers
L
=
args
.
num_layers
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
agent
=
Agent
(
args
.
num_channels
,
L
,
L
,
embedding_shape
)
.
to
(
device
)
torch
.
manual_seed
(
args
.
seed
)
agent
.
eval
()
agent
.
eval
()
if
args
.
checkpoint
:
if
args
.
checkpoint
:
...
@@ -289,7 +290,6 @@ def main():
...
@@ -289,7 +290,6 @@ def main():
history
=
[]
history
=
[]
from
ygoai.rl.ppo
import
train_step
if
args
.
compile
:
if
args
.
compile
:
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# It seems that using torch.compile twice cause segfault at start, so we use torch.jit.trace here
# predict_step = torch.compile(predict_step, mode=args.compile)
# predict_step = torch.compile(predict_step, mode=args.compile)
...
@@ -302,7 +302,10 @@ def main():
...
@@ -302,7 +302,10 @@ def main():
traced_model_t
=
torch
.
jit
.
optimize_for_inference
(
traced_model_t
)
traced_model_t
=
torch
.
jit
.
optimize_for_inference
(
traced_model_t
)
history
.
append
(
traced_model_t
)
history
.
append
(
traced_model_t
)
train_step
=
torch
.
compile
(
train_step
,
mode
=
args
.
compile
)
train_step
=
torch
.
compile
(
train_step_
,
mode
=
args
.
compile
)
else
:
train_step
=
train_step_
def
sample_target
(
history
):
def
sample_target
(
history
):
ts
=
[]
ts
=
[]
...
@@ -331,12 +334,12 @@ def main():
...
@@ -331,12 +334,12 @@ def main():
warmup_steps
=
0
warmup_steps
=
0
start_time
=
time
.
time
()
start_time
=
time
.
time
()
next_done
=
torch
.
zeros
(
args
.
local_num_envs
,
device
=
device
,
dtype
=
torch
.
bool
)
next_done
=
torch
.
zeros
(
args
.
local_num_envs
,
device
=
device
,
dtype
=
torch
.
bool
)
ai_player1
_
=
np
.
concatenate
([
main_player
_
=
np
.
concatenate
([
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
])
])
np
.
random
.
shuffle
(
ai_player1
_
)
np
.
random
.
shuffle
(
main_player
_
)
ai_player1
=
to_tensor
(
ai_player1
_
,
device
)
main_player
=
to_tensor
(
main_player
_
,
device
)
next_value1
=
next_value2
=
0
next_value1
=
next_value2
=
0
step
=
0
step
=
0
ts
=
[]
ts
=
[]
...
@@ -374,7 +377,7 @@ def main():
...
@@ -374,7 +377,7 @@ def main():
for
key
in
obs
:
for
key
in
obs
:
obs
[
key
][
step
]
=
next_obs
[
key
]
obs
[
key
][
step
]
=
next_obs
[
key
]
dones
[
step
]
=
next_done
dones
[
step
]
=
next_done
learn
=
next_to_play
==
ai_player1
learn
=
next_to_play
==
main_player
learns
[
step
]
=
learn
learns
[
step
]
=
learn
_start
=
time
.
time
()
_start
=
time
.
time
()
...
@@ -410,7 +413,7 @@ def main():
...
@@ -410,7 +413,7 @@ def main():
for
idx
,
d
in
enumerate
(
next_done_
):
for
idx
,
d
in
enumerate
(
next_done_
):
if
d
:
if
d
:
pl
=
1
if
to_play
[
idx
]
==
ai_player1
_
[
idx
]
else
-
1
pl
=
1
if
to_play
[
idx
]
==
main_player
_
[
idx
]
else
-
1
episode_length
=
info
[
'l'
][
idx
]
episode_length
=
info
[
'l'
][
idx
]
episode_reward
=
info
[
'r'
][
idx
]
*
pl
episode_reward
=
info
[
'r'
][
idx
]
*
pl
win
=
1
if
episode_reward
>
0
else
0
win
=
1
if
episode_reward
>
0
else
0
...
@@ -442,9 +445,9 @@ def main():
...
@@ -442,9 +445,9 @@ def main():
value
=
predict_step
(
traced_model
,
next_obs
)[
1
]
.
reshape
(
-
1
)
value
=
predict_step
(
traced_model
,
next_obs
)[
1
]
.
reshape
(
-
1
)
if
not
selfplay
:
if
not
selfplay
:
value_t
=
predict_step
(
traced_model_t
,
next_obs
)[
1
]
.
reshape
(
-
1
)
value_t
=
predict_step
(
traced_model_t
,
next_obs
)[
1
]
.
reshape
(
-
1
)
value
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
value_t
)
value
=
torch
.
where
(
next_to_play
==
main_player
,
value
,
value_t
)
nextvalues1
=
torch
.
where
(
next_to_play
==
ai_player1
,
value
,
next_value1
)
nextvalues1
=
torch
.
where
(
next_to_play
==
main_player
,
value
,
next_value1
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
ai_player1
,
value
,
next_value2
)
nextvalues2
=
torch
.
where
(
next_to_play
!=
main_player
,
value
,
next_value2
)
if
step
>
0
and
iteration
!=
0
:
if
step
>
0
and
iteration
!=
0
:
# recalculate the values for the first few steps
# recalculate the values for the first few steps
...
...
ygoai/rl/env.py
View file @
9670ed68
...
@@ -35,18 +35,6 @@ class RecordEpisodeStatistics(gym.Wrapper):
...
@@ -35,18 +35,6 @@ class RecordEpisodeStatistics(gym.Wrapper):
self
.
episode_lengths
*=
1
-
dones
self
.
episode_lengths
*=
1
-
dones
infos
[
"r"
]
=
self
.
returned_episode_returns
infos
[
"r"
]
=
self
.
returned_episode_returns
infos
[
"l"
]
=
self
.
returned_episode_lengths
infos
[
"l"
]
=
self
.
returned_episode_lengths
# env_id = infos["env_id"]
# self.env_id = env_id
# self.episode_returns[env_id] += infos["reward"]
# self.returned_episode_returns[env_id] = np.where(
# infos["terminated"] + truncated, self.episode_returns[env_id], self.returned_episode_returns[env_id]
# )
# self.episode_returns[env_id] *= (1 - infos["terminated"]) * (1 - truncated)
# self.episode_lengths[env_id] += 1
# self.returned_episode_lengths[env_id] = np.where(
# infos["terminated"] + truncated, self.episode_lengths[env_id], self.returned_episode_lengths[env_id]
# )
# self.episode_lengths[env_id] *= (1 - infos["terminated"]) * (1 - truncated)
return
(
return
(
observations
,
observations
,
...
...
ygoai/rl/ppo.py
View file @
9670ed68
...
@@ -121,53 +121,6 @@ def train_step_t(agent, optimizer, b_obs, b_actions, b_logprobs, b_advantages, b
...
@@ -121,53 +121,6 @@ def train_step_t(agent, optimizer, b_obs, b_actions, b_logprobs, b_advantages, b
return
old_approx_kl
,
approx_kl
,
clipfrac
,
pg_loss
,
v_loss
,
entropy_loss
return
old_approx_kl
,
approx_kl
,
clipfrac
,
pg_loss
,
v_loss
,
entropy_loss
# def train_step_t(agent, optimizer, mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns, mb_values, mb_learns, args):
# logits, newvalue, valid = agent(mb_obs)
# logits = logits - logits.logsumexp(dim=-1, keepdim=True)
# newlogprob = logits.gather(-1, mb_actions[:, None]).squeeze(-1)
# entropy = entropy_from_logits(logits)
# valid = torch.logical_and(valid, mb_learns)
# logratio = newlogprob - mb_logprobs
# ratio = logratio.exp()
# with torch.no_grad():
# # calculate approx_kl http://joschu.net/blog/kl-approx.html
# old_approx_kl = (-logratio).mean()
# approx_kl = ((ratio - 1) - logratio).mean()
# clipfrac = ((ratio - 1.0).abs() > args.clip_coef).float().mean()
# if args.norm_adv:
# mb_advantages = masked_normalize(mb_advantages, valid, eps=1e-8)
# # Policy loss
# pg_loss1 = -mb_advantages * ratio
# pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
# pg_loss = torch.max(pg_loss1, pg_loss2)
# pg_loss = masked_mean(pg_loss, valid)
# # Value loss
# newvalue = newvalue.view(-1)
# if args.clip_vloss:
# v_loss_unclipped = (newvalue - mb_returns) ** 2
# v_clipped = mb_values + torch.clamp(
# newvalue - mb_values,
# -args.clip_coef,
# args.clip_coef,
# )
# v_loss_clipped = (v_clipped - mb_returns) ** 2
# v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
# v_loss = 0.5 * v_loss_max
# else:
# v_loss = 0.5 * ((newvalue - mb_returns) ** 2)
# v_loss = masked_mean(v_loss, valid)
# entropy_loss = masked_mean(entropy, valid)
# loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
# loss.backward()
# optimizer.step()
# return old_approx_kl, approx_kl, clipfrac, pg_loss, v_loss, entropy_loss
def
bootstrap_value
(
values
,
rewards
,
dones
,
nextvalues
,
next_done
,
gamma
,
gae_lambda
):
def
bootstrap_value
(
values
,
rewards
,
dones
,
nextvalues
,
next_done
,
gamma
,
gae_lambda
):
num_steps
=
rewards
.
size
(
0
)
num_steps
=
rewards
.
size
(
0
)
advantages
=
torch
.
zeros_like
(
rewards
)
advantages
=
torch
.
zeros_like
(
rewards
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment