Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
ec9f3e0c
Commit
ec9f3e0c
authored
May 12, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
support different rnn
parent
b7d52f29
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
150 additions
and
108 deletions
+150
-108
scripts/battle.py
scripts/battle.py
+44
-22
scripts/cleanba.py
scripts/cleanba.py
+31
-33
scripts/eval.py
scripts/eval.py
+9
-12
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+65
-39
ygoai/rl/jax/eval.py
ygoai/rl/jax/eval.py
+1
-2
No files found.
scripts/battle.py
View file @
ec9f3e0c
...
@@ -2,9 +2,10 @@ import sys
...
@@ -2,9 +2,10 @@ import sys
import
time
import
time
import
os
import
os
import
random
import
random
from
typing
import
Optional
,
Literal
from
typing
import
Optional
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
functools
import
partial
import
ygoenv
import
ygoenv
import
numpy
as
np
import
numpy
as
np
...
@@ -17,7 +18,7 @@ import flax
...
@@ -17,7 +18,7 @@ import flax
from
ygoai.utils
import
init_ygopro
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.jax.agent2
import
PPOLSTM
Agent
from
ygoai.rl.jax.agent2
import
RNN
Agent
@
dataclass
@
dataclass
...
@@ -43,6 +44,10 @@ class Args:
...
@@ -43,6 +44,10 @@ class Args:
"""the number of history actions to use"""
"""the number of history actions to use"""
num_embeddings
:
Optional
[
int
]
=
None
num_embeddings
:
Optional
[
int
]
=
None
"""the number of embeddings of the agent"""
"""the number of embeddings of the agent"""
use_history1
:
bool
=
True
"""whether to use history actions as input for agent1"""
use_history2
:
bool
=
True
"""whether to use history actions as input for agent2"""
verbose
:
bool
=
False
verbose
:
bool
=
False
"""whether to print debug information"""
"""whether to print debug information"""
...
@@ -60,6 +65,10 @@ class Args:
...
@@ -60,6 +65,10 @@ class Args:
"""the number of channels for the agent"""
"""the number of channels for the agent"""
rnn_channels
:
Optional
[
int
]
=
512
rnn_channels
:
Optional
[
int
]
=
512
"""the number of rnn channels for the agent"""
"""the number of rnn channels for the agent"""
rnn_type1
:
Optional
[
str
]
=
"lstm"
"""the type of RNN to use for agent1, None for no RNN"""
rnn_type2
:
Optional
[
str
]
=
"lstm"
"""the type of RNN to use for agent2, None for no RNN"""
checkpoint1
:
str
=
"checkpoints/agent.pt"
checkpoint1
:
str
=
"checkpoints/agent.pt"
"""the checkpoint to load for the first agent, must be a `flax_model` file"""
"""the checkpoint to load for the first agent, must be a `flax_model` file"""
checkpoint2
:
str
=
"checkpoints/agent.pt"
checkpoint2
:
str
=
"checkpoints/agent.pt"
...
@@ -72,19 +81,25 @@ class Args:
...
@@ -72,19 +81,25 @@ class Args:
"""the number of threads to use for envpool, defaults to `num_envs`"""
"""the number of threads to use for envpool, defaults to `num_envs`"""
def
create_agent
(
args
):
def
create_agent
1
(
args
):
return
PPOLSTM
Agent
(
return
RNN
Agent
(
channels
=
args
.
num_channels
,
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
num_layers
=
args
.
num_layers
,
lstm
_channels
=
args
.
rnn_channels
,
rnn
_channels
=
args
.
rnn_channels
,
embedding_shape
=
args
.
num_embeddings
,
embedding_shape
=
args
.
num_embeddings
,
use_history
=
args
.
use_history1
,
rnn_type
=
args
.
rnn_type1
,
)
)
def
init_rnn_state
(
num_envs
,
rnn_channels
):
def
create_agent2
(
args
):
return
(
return
RNNAgent
(
np
.
zeros
((
num_envs
,
rnn_channels
)),
channels
=
args
.
num_channels
,
np
.
zeros
((
num_envs
,
rnn_channels
)),
num_layers
=
args
.
num_layers
,
rnn_channels
=
args
.
rnn_channels
,
embedding_shape
=
args
.
num_embeddings
,
use_history
=
args
.
use_history2
,
rnn_type
=
args
.
rnn_type2
,
)
)
...
@@ -137,28 +152,34 @@ if __name__ == "__main__":
...
@@ -137,28 +152,34 @@ if __name__ == "__main__":
envs
.
num_envs
=
num_envs
envs
.
num_envs
=
num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
envs
=
RecordEpisodeStatistics
(
envs
)
agent
=
create_agent
(
args
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
agent1
=
create_agent1
(
args
)
params
=
jax
.
jit
(
agent
.
init
)(
agent_key
,
(
rstate
,
sample_obs
))
rstate
=
agent1
.
init_rnn_state
(
1
)
params1
=
jax
.
jit
(
agent1
.
init
)(
agent_key
,
(
rstate
,
sample_obs
))
with
open
(
args
.
checkpoint1
,
"rb"
)
as
f
:
with
open
(
args
.
checkpoint1
,
"rb"
)
as
f
:
params1
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params1
=
flax
.
serialization
.
from_bytes
(
params
1
,
f
.
read
())
if
args
.
checkpoint1
==
args
.
checkpoint2
:
if
args
.
checkpoint1
==
args
.
checkpoint2
:
params2
=
params1
params2
=
params1
else
:
else
:
agent2
=
create_agent2
(
args
)
rstate
=
agent2
.
init_rnn_state
(
1
)
params2
=
jax
.
jit
(
agent2
.
init
)(
agent_key
,
(
rstate
,
sample_obs
))
with
open
(
args
.
checkpoint2
,
"rb"
)
as
f
:
with
open
(
args
.
checkpoint2
,
"rb"
)
as
f
:
params2
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params2
=
flax
.
serialization
.
from_bytes
(
params
2
,
f
.
read
())
params1
=
jax
.
device_put
(
params1
)
params1
=
jax
.
device_put
(
params1
)
params2
=
jax
.
device_put
(
params2
)
params2
=
jax
.
device_put
(
params2
)
@
jax
.
jit
@
partial
(
jax
.
jit
,
static_argnums
=
(
4
,))
def
get_probs
(
params
,
rstate
,
obs
,
done
=
None
):
def
get_probs
(
params
,
rstate
,
obs
,
done
=
None
,
model_id
=
1
):
agent
=
create_agent
(
args
)
if
model_id
==
1
:
agent
=
create_agent1
(
args
)
else
:
agent
=
create_agent2
(
args
)
next_rstate
,
logits
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
2
]
next_rstate
,
logits
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
2
]
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
if
done
is
not
None
:
if
done
is
not
None
:
...
@@ -168,8 +189,8 @@ if __name__ == "__main__":
...
@@ -168,8 +189,8 @@ if __name__ == "__main__":
if
args
.
num_envs
!=
1
:
if
args
.
num_envs
!=
1
:
@
jax
.
jit
@
jax
.
jit
def
get_probs2
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
def
get_probs2
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
probs1
=
get_probs
(
params1
,
rstate1
,
obs
)
next_rstate1
,
probs1
=
get_probs
(
params1
,
rstate1
,
obs
,
None
,
1
)
next_rstate2
,
probs2
=
get_probs
(
params2
,
rstate2
,
obs
)
next_rstate2
,
probs2
=
get_probs
(
params2
,
rstate2
,
obs
,
None
,
2
)
probs
=
jnp
.
where
(
main
[:,
None
],
probs1
,
probs2
)
probs
=
jnp
.
where
(
main
[:,
None
],
probs1
,
probs2
)
rstate1
=
jax
.
tree
.
map
(
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
...
@@ -185,9 +206,9 @@ if __name__ == "__main__":
...
@@ -185,9 +206,9 @@ if __name__ == "__main__":
else
:
else
:
def
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
done
):
def
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
done
):
if
main
[
0
]:
if
main
[
0
]:
rstate1
,
probs
=
get_probs
(
params1
,
rstate1
,
obs
,
done
)
rstate1
,
probs
=
get_probs
(
params1
,
rstate1
,
obs
,
done
,
1
)
else
:
else
:
rstate2
,
probs
=
get_probs
(
params2
,
rstate2
,
obs
,
done
)
rstate2
,
probs
=
get_probs
(
params2
,
rstate2
,
obs
,
done
,
2
)
return
rstate1
,
rstate2
,
np
.
array
(
probs
)
return
rstate1
,
rstate2
,
np
.
array
(
probs
)
obs
,
infos
=
envs
.
reset
()
obs
,
infos
=
envs
.
reset
()
...
@@ -209,7 +230,8 @@ if __name__ == "__main__":
...
@@ -209,7 +230,8 @@ if __name__ == "__main__":
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
num_envs
-
num_envs
//
2
,
dtype
=
np
.
int64
)
np
.
ones
(
num_envs
-
num_envs
//
2
,
dtype
=
np
.
int64
)
])
])
rstate1
=
rstate2
=
init_rnn_state
(
num_envs
,
args
.
rnn_channels
)
rstate1
=
agent1
.
init_rnn_state
(
num_envs
)
rstate2
=
agent2
.
init_rnn_state
(
num_envs
)
if
not
args
.
verbose
:
if
not
args
.
verbose
:
pbar
=
tqdm
(
total
=
args
.
num_episodes
)
pbar
=
tqdm
(
total
=
args
.
num_episodes
)
...
...
scripts/cleanba.py
View file @
ec9f3e0c
...
@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
...
@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.ckpt
import
ModelCheckpoint
,
sync_to_gcs
,
zip_files
from
ygoai.rl.ckpt
import
ModelCheckpoint
,
sync_to_gcs
,
zip_files
from
ygoai.rl.jax.agent2
import
LSTM
Agent
from
ygoai.rl.jax.agent2
import
RNN
Agent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
vtrace_2p0s
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
ach_loss
,
policy_gradient_loss
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
vtrace_2p0s
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
ach_loss
,
policy_gradient_loss
...
@@ -80,8 +80,6 @@ class Args:
...
@@ -80,8 +80,6 @@ class Args:
"""whether to use history actions as input for agent"""
"""whether to use history actions as input for agent"""
eval_use_history
:
bool
=
True
eval_use_history
:
bool
=
True
"""whether to use history actions as input for eval agent"""
"""whether to use history actions as input for eval agent"""
use_rnn
:
bool
=
True
"""whether to use RNN for the agent"""
total_timesteps
:
int
=
50000000000
total_timesteps
:
int
=
50000000000
"""total timesteps of the experiments"""
"""total timesteps of the experiments"""
...
@@ -150,6 +148,10 @@ class Args:
...
@@ -150,6 +148,10 @@ class Args:
"""the number of channels for the agent"""
"""the number of channels for the agent"""
rnn_channels
:
int
=
512
rnn_channels
:
int
=
512
"""the number of channels for the RNN in the agent"""
"""the number of channels for the RNN in the agent"""
rnn_type
:
Optional
[
str
]
=
"lstm"
"""the type of RNN to use, None for no RNN"""
eval_rnn_type
:
Optional
[
str
]
=
"lstm"
"""the type of RNN to use for evaluation, None for no RNN"""
actor_device_ids
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
0
,
1
])
actor_device_ids
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
0
,
1
])
"""the device ids that actor workers will use"""
"""the device ids that actor workers will use"""
...
@@ -222,18 +224,18 @@ class Transition(NamedTuple):
...
@@ -222,18 +224,18 @@ class Transition(NamedTuple):
def
create_agent
(
args
,
multi_step
=
False
,
eval
=
False
):
def
create_agent
(
args
,
multi_step
=
False
,
eval
=
False
):
return
LSTM
Agent
(
return
RNN
Agent
(
channels
=
args
.
num_channels
,
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
num_layers
=
args
.
num_layers
,
embedding_shape
=
args
.
num_embeddings
,
embedding_shape
=
args
.
num_embeddings
,
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
lstm
_channels
=
args
.
rnn_channels
,
rnn
_channels
=
args
.
rnn_channels
,
switch
=
args
.
switch
,
switch
=
args
.
switch
,
multi_step
=
multi_step
,
multi_step
=
multi_step
,
freeze_id
=
args
.
freeze_id
,
freeze_id
=
args
.
freeze_id
,
use_history
=
args
.
use_history
if
not
eval
else
args
.
eval_use_history
,
use_history
=
args
.
use_history
if
not
eval
else
args
.
eval_use_history
,
no_rnn
=
(
not
args
.
use_rnn
)
if
not
eval
else
False
rnn_type
=
args
.
rnn_type
if
not
eval
else
args
.
eval_rnn_type
,
)
)
...
@@ -284,23 +286,20 @@ def rollout(
...
@@ -284,23 +286,20 @@ def rollout(
other_time
=
0
other_time
=
0
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
@
partial
(
jax
.
jit
,
static_argnums
=
(
2
,))
agent
=
create_agent
(
args
)
def
get_logits
(
eval_agent
=
create_agent
(
args
,
eval
=
True
)
params
:
flax
.
core
.
FrozenDict
,
inputs
,
eval
=
False
):
rstate
,
logits
=
create_agent
(
args
,
eval
=
eval
)
.
apply
(
params
,
inputs
)[:
2
]
return
rstate
,
logits
@
jax
.
jit
@
jax
.
jit
def
get_action
(
def
get_action
(
params
:
flax
.
core
.
FrozenDict
,
inputs
):
params
:
flax
.
core
.
FrozenDict
,
inputs
):
rstate
,
logits
=
get_logits
(
params
,
inputs
)
rstate
,
logits
=
eval_agent
.
apply
(
params
,
inputs
)[:
2
]
return
rstate
,
logits
.
argmax
(
axis
=
1
)
return
rstate
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
@
jax
.
jit
def
get_action_battle
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
def
get_action_battle
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
))
next_rstate1
,
logits1
=
agent
.
apply
(
params1
,
(
rstate1
,
obs
))[:
2
]
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
),
True
)
next_rstate2
,
logits2
=
eval_agent
.
apply
(
params2
,
(
rstate2
,
obs
))[:
2
]
logits
=
jnp
.
where
(
main
[:,
None
],
logits1
,
logits2
)
logits
=
jnp
.
where
(
main
[:,
None
],
logits1
,
logits2
)
rstate1
=
jax
.
tree
.
map
(
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
...
@@ -320,7 +319,7 @@ def rollout(
...
@@ -320,7 +319,7 @@ def rollout(
rstate
=
jax
.
tree
.
map
(
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
logits
=
get_logits
(
params
,
(
rstate
,
next_obs
))
rstate
,
logits
=
agent
.
apply
(
params
,
(
rstate
,
next_obs
))[:
2
]
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate
,
rstate1
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
rstate
,
rstate2
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
rstate
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
rstate1
,
rstate2
=
jax
.
tree
.
map
(
...
@@ -335,10 +334,11 @@ def rollout(
...
@@ -335,10 +334,11 @@ def rollout(
next_obs
,
info
=
envs
.
reset
()
next_obs
,
info
=
envs
.
reset
()
next_to_play
=
info
[
"to_play"
]
next_to_play
=
info
[
"to_play"
]
next_done
=
np
.
zeros
(
args
.
local_num_envs
,
dtype
=
np
.
bool_
)
next_done
=
np
.
zeros
(
args
.
local_num_envs
,
dtype
=
np
.
bool_
)
next_rstate1
=
next_rstate2
=
init_rnn_state
(
next_rstate1
=
next_rstate2
=
agent
.
init_rnn_state
(
args
.
local_num_envs
)
args
.
local_num_envs
,
args
.
rnn_channels
)
eval_rstate
=
init_rnn_state
(
eval_rstate1
=
agent
.
init_rnn_state
(
args
.
local_eval_episodes
)
args
.
local_eval_episodes
,
args
.
rnn_channels
)
eval_rstate2
=
eval_agent
.
init_rnn_state
(
args
.
local_eval_episodes
)
main_player
=
np
.
concatenate
([
main_player
=
np
.
concatenate
([
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
...
@@ -452,11 +452,11 @@ def rollout(
...
@@ -452,11 +452,11 @@ def rollout(
if
eval_mode
==
'bot'
:
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
evaluate
(
eval_return
,
eval_ep_len
,
eval_win_rate
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
2
)
else
:
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
battle
(
eval_return
,
eval_ep_len
,
eval_win_rate
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
1
,
eval_rstate2
)
eval_time
=
time
.
time
()
-
_start
eval_time
=
time
.
time
()
-
_start
other_time
+=
eval_time
other_time
+=
eval_time
eval_stats
=
np
.
array
([
eval_time
,
eval_return
,
eval_win_rate
],
dtype
=
np
.
float32
)
eval_stats
=
np
.
array
([
eval_time
,
eval_return
,
eval_win_rate
],
dtype
=
np
.
float32
)
...
@@ -606,8 +606,9 @@ if __name__ == "__main__":
...
@@ -606,8 +606,9 @@ if __name__ == "__main__":
frac
=
1.0
-
(
count
//
(
args
.
num_minibatches
*
args
.
update_epochs
))
/
args
.
num_updates
frac
=
1.0
-
(
count
//
(
args
.
num_minibatches
*
args
.
update_epochs
))
/
args
.
num_updates
return
args
.
learning_rate
*
frac
return
args
.
learning_rate
*
frac
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
#
rstate = init_rnn_state(1, args.rnn_channels)
agent
=
create_agent
(
args
)
agent
=
create_agent
(
args
)
rstate
=
agent
.
init_rnn_state
(
1
)
params
=
agent
.
init
(
init_key
,
(
rstate
,
sample_obs
))
params
=
agent
.
init
(
init_key
,
(
rstate
,
sample_obs
))
if
embeddings
is
not
None
:
if
embeddings
is
not
None
:
unknown_embed
=
embeddings
.
mean
(
axis
=
0
)
unknown_embed
=
embeddings
.
mean
(
axis
=
0
)
...
@@ -641,20 +642,15 @@ if __name__ == "__main__":
...
@@ -641,20 +642,15 @@ if __name__ == "__main__":
# print(agent.tabulate(agent_key, sample_obs))
# print(agent.tabulate(agent_key, sample_obs))
if
args
.
eval_checkpoint
:
if
args
.
eval_checkpoint
:
eval_agent
=
create_agent
(
args
,
eval
=
True
)
eval_rstate
=
eval_agent
.
init_rnn_state
(
1
)
eval_params
=
eval_agent
.
init
(
init_key
,
(
eval_rstate
,
sample_obs
))
with
open
(
args
.
eval_checkpoint
,
"rb"
)
as
f
:
with
open
(
args
.
eval_checkpoint
,
"rb"
)
as
f
:
eval_params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
eval_params
=
flax
.
serialization
.
from_bytes
(
eval_
params
,
f
.
read
())
print
(
f
"loaded eval checkpoint from {args.eval_checkpoint}"
)
print
(
f
"loaded eval checkpoint from {args.eval_checkpoint}"
)
else
:
else
:
eval_params
=
None
eval_params
=
None
@
jax
.
jit
def
get_logits_and_value
(
params
:
flax
.
core
.
FrozenDict
,
inputs
,
):
rstate
,
logits
,
value
,
valid
=
create_agent
(
args
,
multi_step
=
True
)
.
apply
(
params
,
inputs
)
return
logits
,
value
.
squeeze
(
-
1
)
def
loss_fn
(
def
loss_fn
(
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
mask
,
next_value
):
switch_or_mains
,
actions
,
logits
,
rewards
,
mask
,
next_value
):
...
@@ -671,7 +667,9 @@ if __name__ == "__main__":
...
@@ -671,7 +667,9 @@ if __name__ == "__main__":
dones
=
dones
|
next_dones
dones
=
dones
|
next_dones
inputs
=
(
rstate1
,
rstate2
,
obs
,
dones
,
switch_or_mains
)
inputs
=
(
rstate1
,
rstate2
,
obs
,
dones
,
switch_or_mains
)
new_logits
,
new_values
=
get_logits_and_value
(
params
,
inputs
)
_rstate
,
new_logits
,
new_values
,
_valid
=
create_agent
(
args
,
multi_step
=
True
)
.
apply
(
params
,
inputs
)
new_values
=
new_values
.
squeeze
(
-
1
)
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
...
...
scripts/eval.py
View file @
ec9f3e0c
...
@@ -63,6 +63,8 @@ class Args:
...
@@ -63,6 +63,8 @@ class Args:
"""the number of channels for the agent"""
"""the number of channels for the agent"""
rnn_channels
:
Optional
[
int
]
=
512
rnn_channels
:
Optional
[
int
]
=
512
"""the number of rnn channels for the agent"""
"""the number of rnn channels for the agent"""
rnn_type
:
Optional
[
str
]
=
"lstm"
"""the type of RNN to use for agent, None for no RNN"""
checkpoint
:
Optional
[
str
]
=
None
checkpoint
:
Optional
[
str
]
=
None
"""the checkpoint to load, must be a `flax_model` file"""
"""the checkpoint to load, must be a `flax_model` file"""
...
@@ -75,18 +77,12 @@ class Args:
...
@@ -75,18 +77,12 @@ class Args:
def
create_agent
(
args
):
def
create_agent
(
args
):
return
PPOLSTM
Agent
(
return
RNN
Agent
(
channels
=
args
.
num_channels
,
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
num_layers
=
args
.
num_layers
,
lstm
_channels
=
args
.
rnn_channels
,
rnn
_channels
=
args
.
rnn_channels
,
embedding_shape
=
args
.
num_embeddings
,
embedding_shape
=
args
.
num_embeddings
,
)
rnn_type
=
args
.
rnn_type
,
def
init_rnn_state
(
num_envs
,
rnn_channels
):
return
(
np
.
zeros
((
num_envs
,
rnn_channels
)),
np
.
zeros
((
num_envs
,
rnn_channels
)),
)
)
...
@@ -139,7 +135,7 @@ if __name__ == "__main__":
...
@@ -139,7 +135,7 @@ if __name__ == "__main__":
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
flax
import
flax
from
ygoai.rl.jax.agent2
import
PPOLSTM
Agent
from
ygoai.rl.jax.agent2
import
RNN
Agent
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
...
@@ -148,7 +144,7 @@ if __name__ == "__main__":
...
@@ -148,7 +144,7 @@ if __name__ == "__main__":
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
rstate
=
agent
.
init_rnn_state
(
1
)
params
=
jax
.
jit
(
agent
.
init
)(
agent_key
,
(
rstate
,
sample_obs
))
params
=
jax
.
jit
(
agent
.
init
)(
agent_key
,
(
rstate
,
sample_obs
))
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
...
@@ -158,7 +154,7 @@ if __name__ == "__main__":
...
@@ -158,7 +154,7 @@ if __name__ == "__main__":
@
jax
.
jit
@
jax
.
jit
def
get_probs_and_value
(
params
,
rstate
,
obs
,
done
):
def
get_probs_and_value
(
params
,
rstate
,
obs
,
done
):
agent
=
create_agent
(
args
)
agent
=
agent
next_rstate
,
logits
,
value
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
3
]
next_rstate
,
logits
,
value
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
3
]
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
next_rstate
=
jax
.
tree
.
map
(
next_rstate
=
jax
.
tree
.
map
(
...
@@ -173,6 +169,7 @@ if __name__ == "__main__":
...
@@ -173,6 +169,7 @@ if __name__ == "__main__":
obs
,
infos
=
envs
.
reset
()
obs
,
infos
=
envs
.
reset
()
print
(
obs
)
next_to_play
=
infos
[
'to_play'
]
next_to_play
=
infos
[
'to_play'
]
dones
=
np
.
zeros
(
num_envs
,
dtype
=
np
.
bool_
)
dones
=
np
.
zeros
(
num_envs
,
dtype
=
np
.
bool_
)
...
...
ygoai/rl/jax/agent2.py
View file @
ec9f3e0c
from
typing
import
Tuple
,
Union
,
Optional
,
Sequence
from
typing
import
Tuple
,
Union
,
Optional
,
Sequence
from
functools
import
partial
from
functools
import
partial
import
numpy
as
np
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
flax.linen
as
nn
import
flax.linen
as
nn
...
@@ -272,8 +273,6 @@ class Encoder(nn.Module):
...
@@ -272,8 +273,6 @@ class Encoder(nn.Module):
f_state
=
jnp
.
concatenate
([
f_g_card
,
f_global
,
f_g_h_actions
,
f_g_actions
],
axis
=-
1
)
f_state
=
jnp
.
concatenate
([
f_g_card
,
f_global
,
f_g_h_actions
,
f_g_actions
],
axis
=-
1
)
f_state
=
MLP
((
c
*
2
,
c
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
f_state
=
MLP
((
c
*
2
,
c
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
f_state
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_state
)
f_state
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_state
)
# TODO: LSTM
return
f_actions
,
f_state
,
a_mask
,
valid
return
f_actions
,
f_state
,
a_mask
,
valid
...
@@ -309,10 +308,34 @@ class Critic(nn.Module):
...
@@ -309,10 +308,34 @@ class Critic(nn.Module):
return
x
return
x
class
LSTMAgent
(
nn
.
Module
):
def
rnn_forward_2p
(
rnn_layer
,
rstate1
,
rstate2
,
f_state
,
done
,
switch_or_main
,
switch
=
True
):
if
switch
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
switch
):
rstate
,
init_rstate2
=
carry
rstate
,
y
=
cell
(
rstate
,
x
)
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
rstate
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
switch
[:,
None
],
x
,
y
),
init_rstate2
,
rstate
)
return
(
rstate
,
init_rstate2
),
y
else
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
main
):
rstate1
,
rstate2
=
carry
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
y
=
cell
(
rstate
,
x
)
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
return
(
rstate1
,
rstate2
),
y
scan
=
nn
.
scan
(
body_fn
,
variable_broadcast
=
'params'
,
split_rngs
=
{
'params'
:
False
})
rstate
,
f_state
=
scan
(
rnn_layer
,
(
rstate1
,
rstate2
),
f_state
,
done
,
switch_or_main
)
return
rstate
,
f_state
class
RNNAgent
(
nn
.
Module
):
channels
:
int
=
128
channels
:
int
=
128
num_layers
:
int
=
2
num_layers
:
int
=
2
lstm
_channels
:
int
=
512
rnn
_channels
:
int
=
512
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
jnp
.
dtype
=
jnp
.
float32
dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
...
@@ -320,15 +343,13 @@ class LSTMAgent(nn.Module):
...
@@ -320,15 +343,13 @@ class LSTMAgent(nn.Module):
switch
:
bool
=
True
switch
:
bool
=
True
freeze_id
:
bool
=
False
freeze_id
:
bool
=
False
use_history
:
bool
=
True
use_history
:
bool
=
True
no_rnn
:
bool
=
False
rnn_type
:
str
=
'lstm'
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
inputs
):
def
__call__
(
self
,
inputs
):
if
self
.
multi_step
:
if
self
.
multi_step
:
# (num_steps * batch_size, ...)
# (num_steps * batch_size, ...)
rstate1
,
rstate2
,
x
,
done
,
switch_or_main
=
inputs
*
rstate
,
x
,
done
,
switch_or_main
=
inputs
batch_size
=
rstate1
[
0
]
.
shape
[
0
]
num_steps
=
done
.
shape
[
0
]
//
batch_size
else
:
else
:
rstate
,
x
=
inputs
rstate
,
x
=
inputs
...
@@ -345,43 +366,48 @@ class LSTMAgent(nn.Module):
...
@@ -345,43 +366,48 @@ class LSTMAgent(nn.Module):
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
lstm_layer
=
nn
.
OptimizedLSTMCell
(
if
self
.
rnn_type
in
[
'lstm'
,
'none'
]:
self
.
lstm_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
rnn_layer
=
nn
.
OptimizedLSTMCell
(
if
self
.
multi_step
:
self
.
rnn_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
if
self
.
switch
:
elif
self
.
rnn_type
==
'gru'
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
switch
):
rnn_layer
=
nn
.
GRUCell
(
rstate
,
init_rstate2
=
carry
self
.
rnn_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
rstate
,
y
=
cell
(
rstate
,
x
)
elif
self
.
rnn_type
is
None
:
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
rnn_layer
=
None
rstate
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
switch
[:,
None
],
x
,
y
),
init_rstate2
,
rstate
)
return
(
rstate
,
init_rstate2
),
y
if
rnn_layer
is
None
:
else
:
f_state_r
=
f_state
def
body_fn
(
cell
,
carry
,
x
,
done
,
main
):
elif
self
.
rnn_type
==
'none'
:
rstate1
,
rstate2
=
carry
f_state_r
=
jnp
.
concatenate
([
f_state
for
i
in
range
(
self
.
rnn_channels
//
c
)],
axis
=-
1
)
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
y
=
cell
(
rstate
,
x
)
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
return
(
rstate1
,
rstate2
),
y
scan
=
nn
.
scan
(
body_fn
,
variable_broadcast
=
'params'
,
split_rngs
=
{
'params'
:
False
})
f_state_r
,
done
,
switch_or_main
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
batch_size
)
+
x
.
shape
[
1
:]),
(
f_state
,
done
,
switch_or_main
))
rstate
,
f_state_r
=
scan
(
lstm_layer
,
(
rstate1
,
rstate2
),
f_state_r
,
done
,
switch_or_main
)
f_state_r
=
f_state_r
.
reshape
((
-
1
,
f_state_r
.
shape
[
-
1
]))
else
:
else
:
rstate
,
f_state_r
=
lstm_layer
(
rstate
,
f_state
)
if
self
.
multi_step
:
rstate1
,
rstate2
=
rstate
batch_size
=
jax
.
tree
.
leaves
(
rstate1
)[
0
]
.
shape
[
0
]
num_steps
=
done
.
shape
[
0
]
//
batch_size
f_state_r
,
done
,
switch_or_main
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
batch_size
)
+
x
.
shape
[
1
:]),
(
f_state
,
done
,
switch_or_main
))
rstate
,
f_state_r
=
rnn_forward_2p
(
rnn_layer
,
rstate1
,
rstate2
,
f_state_r
,
done
,
switch_or_main
,
self
.
switch
)
f_state_r
=
f_state_r
.
reshape
((
-
1
,
f_state_r
.
shape
[
-
1
]))
else
:
rstate
,
f_state_r
=
rnn_layer
(
rstate
,
f_state
)
actor
=
Actor
(
actor
=
Actor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
critic
=
Critic
(
critic
=
Critic
(
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
if
self
.
no_rnn
:
f_state_r
=
jnp
.
concatenate
([
f_state
for
i
in
range
(
self
.
lstm_channels
//
c
)],
axis
=-
1
)
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
value
=
critic
(
f_state_r
)
value
=
critic
(
f_state_r
)
return
rstate
,
logits
,
value
,
valid
return
rstate
,
logits
,
value
,
valid
def
init_rnn_state
(
self
,
batch_size
):
if
self
.
rnn_type
in
[
'lstm'
,
'none'
]:
return
(
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)),
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)),
)
elif
self
.
rnn_type
==
'gru'
:
return
np
.
zeros
((
batch_size
,
self
.
rnn_channels
))
else
:
return
None
\ No newline at end of file
ygoai/rl/jax/eval.py
View file @
ec9f3e0c
...
@@ -36,7 +36,7 @@ def evaluate(envs, num_episodes, predict_fn, rnn_state=None):
...
@@ -36,7 +36,7 @@ def evaluate(envs, num_episodes, predict_fn, rnn_state=None):
return
eval_return
,
eval_ep_len
,
eval_win_rate
return
eval_return
,
eval_ep_len
,
eval_win_rate
def
battle
(
envs
,
num_episodes
,
predict_fn
,
init_rnn_state
=
None
):
def
battle
(
envs
,
num_episodes
,
predict_fn
,
rstate1
=
None
,
rstate2
=
None
):
num_envs
=
envs
.
num_envs
num_envs
=
envs
.
num_envs
episode_rewards
=
[]
episode_rewards
=
[]
episode_lengths
=
[]
episode_lengths
=
[]
...
@@ -50,7 +50,6 @@ def battle(envs, num_episodes, predict_fn, init_rnn_state=None):
...
@@ -50,7 +50,6 @@ def battle(envs, num_episodes, predict_fn, init_rnn_state=None):
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
num_envs
-
num_envs
//
2
,
dtype
=
np
.
int64
)
np
.
ones
(
num_envs
-
num_envs
//
2
,
dtype
=
np
.
int64
)
])
])
rstate1
=
rstate2
=
init_rnn_state
while
True
:
while
True
:
main
=
next_to_play
==
main_player
main
=
next_to_play
==
main_player
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment