Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
ec9f3e0c
Commit
ec9f3e0c
authored
May 12, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
support different rnn
parent
b7d52f29
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
150 additions
and
108 deletions
+150
-108
scripts/battle.py
scripts/battle.py
+44
-22
scripts/cleanba.py
scripts/cleanba.py
+31
-33
scripts/eval.py
scripts/eval.py
+9
-12
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+65
-39
ygoai/rl/jax/eval.py
ygoai/rl/jax/eval.py
+1
-2
No files found.
scripts/battle.py
View file @
ec9f3e0c
...
...
@@ -2,9 +2,10 @@ import sys
import
time
import
os
import
random
from
typing
import
Optional
,
Literal
from
typing
import
Optional
from
dataclasses
import
dataclass
from
tqdm
import
tqdm
from
functools
import
partial
import
ygoenv
import
numpy
as
np
...
...
@@ -17,7 +18,7 @@ import flax
from
ygoai.utils
import
init_ygopro
from
ygoai.rl.utils
import
RecordEpisodeStatistics
from
ygoai.rl.jax.agent2
import
PPOLSTM
Agent
from
ygoai.rl.jax.agent2
import
RNN
Agent
@
dataclass
...
...
@@ -43,6 +44,10 @@ class Args:
"""the number of history actions to use"""
num_embeddings
:
Optional
[
int
]
=
None
"""the number of embeddings of the agent"""
use_history1
:
bool
=
True
"""whether to use history actions as input for agent1"""
use_history2
:
bool
=
True
"""whether to use history actions as input for agent2"""
verbose
:
bool
=
False
"""whether to print debug information"""
...
...
@@ -60,6 +65,10 @@ class Args:
"""the number of channels for the agent"""
rnn_channels
:
Optional
[
int
]
=
512
"""the number of rnn channels for the agent"""
rnn_type1
:
Optional
[
str
]
=
"lstm"
"""the type of RNN to use for agent1, None for no RNN"""
rnn_type2
:
Optional
[
str
]
=
"lstm"
"""the type of RNN to use for agent2, None for no RNN"""
checkpoint1
:
str
=
"checkpoints/agent.pt"
"""the checkpoint to load for the first agent, must be a `flax_model` file"""
checkpoint2
:
str
=
"checkpoints/agent.pt"
...
...
@@ -72,19 +81,25 @@ class Args:
"""the number of threads to use for envpool, defaults to `num_envs`"""
def
create_agent
(
args
):
return
PPOLSTM
Agent
(
def
create_agent
1
(
args
):
return
RNN
Agent
(
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
lstm
_channels
=
args
.
rnn_channels
,
rnn
_channels
=
args
.
rnn_channels
,
embedding_shape
=
args
.
num_embeddings
,
use_history
=
args
.
use_history1
,
rnn_type
=
args
.
rnn_type1
,
)
def
init_rnn_state
(
num_envs
,
rnn_channels
):
return
(
np
.
zeros
((
num_envs
,
rnn_channels
)),
np
.
zeros
((
num_envs
,
rnn_channels
)),
def
create_agent2
(
args
):
return
RNNAgent
(
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
rnn_channels
=
args
.
rnn_channels
,
embedding_shape
=
args
.
num_embeddings
,
use_history
=
args
.
use_history2
,
rnn_type
=
args
.
rnn_type2
,
)
...
...
@@ -137,28 +152,34 @@ if __name__ == "__main__":
envs
.
num_envs
=
num_envs
envs
=
RecordEpisodeStatistics
(
envs
)
agent
=
create_agent
(
args
)
key
=
jax
.
random
.
PRNGKey
(
args
.
seed
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
params
=
jax
.
jit
(
agent
.
init
)(
agent_key
,
(
rstate
,
sample_obs
))
agent1
=
create_agent1
(
args
)
rstate
=
agent1
.
init_rnn_state
(
1
)
params1
=
jax
.
jit
(
agent1
.
init
)(
agent_key
,
(
rstate
,
sample_obs
))
with
open
(
args
.
checkpoint1
,
"rb"
)
as
f
:
params1
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params1
=
flax
.
serialization
.
from_bytes
(
params
1
,
f
.
read
())
if
args
.
checkpoint1
==
args
.
checkpoint2
:
params2
=
params1
else
:
agent2
=
create_agent2
(
args
)
rstate
=
agent2
.
init_rnn_state
(
1
)
params2
=
jax
.
jit
(
agent2
.
init
)(
agent_key
,
(
rstate
,
sample_obs
))
with
open
(
args
.
checkpoint2
,
"rb"
)
as
f
:
params2
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
params2
=
flax
.
serialization
.
from_bytes
(
params
2
,
f
.
read
())
params1
=
jax
.
device_put
(
params1
)
params2
=
jax
.
device_put
(
params2
)
@
jax
.
jit
def
get_probs
(
params
,
rstate
,
obs
,
done
=
None
):
agent
=
create_agent
(
args
)
@
partial
(
jax
.
jit
,
static_argnums
=
(
4
,))
def
get_probs
(
params
,
rstate
,
obs
,
done
=
None
,
model_id
=
1
):
if
model_id
==
1
:
agent
=
create_agent1
(
args
)
else
:
agent
=
create_agent2
(
args
)
next_rstate
,
logits
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
2
]
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
if
done
is
not
None
:
...
...
@@ -168,8 +189,8 @@ if __name__ == "__main__":
if
args
.
num_envs
!=
1
:
@
jax
.
jit
def
get_probs2
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
probs1
=
get_probs
(
params1
,
rstate1
,
obs
)
next_rstate2
,
probs2
=
get_probs
(
params2
,
rstate2
,
obs
)
next_rstate1
,
probs1
=
get_probs
(
params1
,
rstate1
,
obs
,
None
,
1
)
next_rstate2
,
probs2
=
get_probs
(
params2
,
rstate2
,
obs
,
None
,
2
)
probs
=
jnp
.
where
(
main
[:,
None
],
probs1
,
probs2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
...
...
@@ -185,9 +206,9 @@ if __name__ == "__main__":
else
:
def
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
done
):
if
main
[
0
]:
rstate1
,
probs
=
get_probs
(
params1
,
rstate1
,
obs
,
done
)
rstate1
,
probs
=
get_probs
(
params1
,
rstate1
,
obs
,
done
,
1
)
else
:
rstate2
,
probs
=
get_probs
(
params2
,
rstate2
,
obs
,
done
)
rstate2
,
probs
=
get_probs
(
params2
,
rstate2
,
obs
,
done
,
2
)
return
rstate1
,
rstate2
,
np
.
array
(
probs
)
obs
,
infos
=
envs
.
reset
()
...
...
@@ -209,7 +230,8 @@ if __name__ == "__main__":
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
num_envs
-
num_envs
//
2
,
dtype
=
np
.
int64
)
])
rstate1
=
rstate2
=
init_rnn_state
(
num_envs
,
args
.
rnn_channels
)
rstate1
=
agent1
.
init_rnn_state
(
num_envs
)
rstate2
=
agent2
.
init_rnn_state
(
num_envs
)
if
not
args
.
verbose
:
pbar
=
tqdm
(
total
=
args
.
num_episodes
)
...
...
scripts/cleanba.py
View file @
ec9f3e0c
...
...
@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.ckpt
import
ModelCheckpoint
,
sync_to_gcs
,
zip_files
from
ygoai.rl.jax.agent2
import
LSTM
Agent
from
ygoai.rl.jax.agent2
import
RNN
Agent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
vtrace_2p0s
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
ach_loss
,
policy_gradient_loss
...
...
@@ -80,8 +80,6 @@ class Args:
"""whether to use history actions as input for agent"""
eval_use_history
:
bool
=
True
"""whether to use history actions as input for eval agent"""
use_rnn
:
bool
=
True
"""whether to use RNN for the agent"""
total_timesteps
:
int
=
50000000000
"""total timesteps of the experiments"""
...
...
@@ -150,6 +148,10 @@ class Args:
"""the number of channels for the agent"""
rnn_channels
:
int
=
512
"""the number of channels for the RNN in the agent"""
rnn_type
:
Optional
[
str
]
=
"lstm"
"""the type of RNN to use, None for no RNN"""
eval_rnn_type
:
Optional
[
str
]
=
"lstm"
"""the type of RNN to use for evaluation, None for no RNN"""
actor_device_ids
:
List
[
int
]
=
field
(
default_factory
=
lambda
:
[
0
,
1
])
"""the device ids that actor workers will use"""
...
...
@@ -222,18 +224,18 @@ class Transition(NamedTuple):
def
create_agent
(
args
,
multi_step
=
False
,
eval
=
False
):
return
LSTM
Agent
(
return
RNN
Agent
(
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
embedding_shape
=
args
.
num_embeddings
,
dtype
=
jnp
.
bfloat16
if
args
.
bfloat16
else
jnp
.
float32
,
param_dtype
=
jnp
.
float32
,
lstm
_channels
=
args
.
rnn_channels
,
rnn
_channels
=
args
.
rnn_channels
,
switch
=
args
.
switch
,
multi_step
=
multi_step
,
freeze_id
=
args
.
freeze_id
,
use_history
=
args
.
use_history
if
not
eval
else
args
.
eval_use_history
,
no_rnn
=
(
not
args
.
use_rnn
)
if
not
eval
else
False
rnn_type
=
args
.
rnn_type
if
not
eval
else
args
.
eval_rnn_type
,
)
...
...
@@ -284,23 +286,20 @@ def rollout(
other_time
=
0
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
@
partial
(
jax
.
jit
,
static_argnums
=
(
2
,))
def
get_logits
(
params
:
flax
.
core
.
FrozenDict
,
inputs
,
eval
=
False
):
rstate
,
logits
=
create_agent
(
args
,
eval
=
eval
)
.
apply
(
params
,
inputs
)[:
2
]
return
rstate
,
logits
agent
=
create_agent
(
args
)
eval_agent
=
create_agent
(
args
,
eval
=
True
)
@
jax
.
jit
def
get_action
(
params
:
flax
.
core
.
FrozenDict
,
inputs
):
rstate
,
logits
=
get_logits
(
params
,
inputs
)
rstate
,
logits
=
eval_agent
.
apply
(
params
,
inputs
)[:
2
]
return
rstate
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
get_action_battle
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
))
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
),
True
)
next_rstate1
,
logits1
=
agent
.
apply
(
params1
,
(
rstate1
,
obs
))[:
2
]
next_rstate2
,
logits2
=
eval_agent
.
apply
(
params2
,
(
rstate2
,
obs
))[:
2
]
logits
=
jnp
.
where
(
main
[:,
None
],
logits1
,
logits2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
...
...
@@ -320,7 +319,7 @@ def rollout(
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
logits
=
get_logits
(
params
,
(
rstate
,
next_obs
))
rstate
,
logits
=
agent
.
apply
(
params
,
(
rstate
,
next_obs
))[:
2
]
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
rstate
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
...
...
@@ -335,10 +334,11 @@ def rollout(
next_obs
,
info
=
envs
.
reset
()
next_to_play
=
info
[
"to_play"
]
next_done
=
np
.
zeros
(
args
.
local_num_envs
,
dtype
=
np
.
bool_
)
next_rstate1
=
next_rstate2
=
init_rnn_state
(
args
.
local_num_envs
,
args
.
rnn_channels
)
eval_rstate
=
init_rnn_state
(
args
.
local_eval_episodes
,
args
.
rnn_channels
)
next_rstate1
=
next_rstate2
=
agent
.
init_rnn_state
(
args
.
local_num_envs
)
eval_rstate1
=
agent
.
init_rnn_state
(
args
.
local_eval_episodes
)
eval_rstate2
=
eval_agent
.
init_rnn_state
(
args
.
local_eval_episodes
)
main_player
=
np
.
concatenate
([
np
.
zeros
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
args
.
local_num_envs
//
2
,
dtype
=
np
.
int64
)
...
...
@@ -452,11 +452,11 @@ def rollout(
if
eval_mode
==
'bot'
:
predict_fn
=
lambda
x
:
get_action
(
params
,
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
evaluate
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
2
)
else
:
predict_fn
=
lambda
*
x
:
get_action_battle
(
params
,
eval_params
,
*
x
)
eval_return
,
eval_ep_len
,
eval_win_rate
=
battle
(
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
)
eval_envs
,
args
.
local_eval_episodes
,
predict_fn
,
eval_rstate
1
,
eval_rstate2
)
eval_time
=
time
.
time
()
-
_start
other_time
+=
eval_time
eval_stats
=
np
.
array
([
eval_time
,
eval_return
,
eval_win_rate
],
dtype
=
np
.
float32
)
...
...
@@ -606,8 +606,9 @@ if __name__ == "__main__":
frac
=
1.0
-
(
count
//
(
args
.
num_minibatches
*
args
.
update_epochs
))
/
args
.
num_updates
return
args
.
learning_rate
*
frac
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
#
rstate = init_rnn_state(1, args.rnn_channels)
agent
=
create_agent
(
args
)
rstate
=
agent
.
init_rnn_state
(
1
)
params
=
agent
.
init
(
init_key
,
(
rstate
,
sample_obs
))
if
embeddings
is
not
None
:
unknown_embed
=
embeddings
.
mean
(
axis
=
0
)
...
...
@@ -641,20 +642,15 @@ if __name__ == "__main__":
# print(agent.tabulate(agent_key, sample_obs))
if
args
.
eval_checkpoint
:
eval_agent
=
create_agent
(
args
,
eval
=
True
)
eval_rstate
=
eval_agent
.
init_rnn_state
(
1
)
eval_params
=
eval_agent
.
init
(
init_key
,
(
eval_rstate
,
sample_obs
))
with
open
(
args
.
eval_checkpoint
,
"rb"
)
as
f
:
eval_params
=
flax
.
serialization
.
from_bytes
(
params
,
f
.
read
())
eval_params
=
flax
.
serialization
.
from_bytes
(
eval_
params
,
f
.
read
())
print
(
f
"loaded eval checkpoint from {args.eval_checkpoint}"
)
else
:
eval_params
=
None
@
jax
.
jit
def
get_logits_and_value
(
params
:
flax
.
core
.
FrozenDict
,
inputs
,
):
rstate
,
logits
,
value
,
valid
=
create_agent
(
args
,
multi_step
=
True
)
.
apply
(
params
,
inputs
)
return
logits
,
value
.
squeeze
(
-
1
)
def
loss_fn
(
params
,
rstate1
,
rstate2
,
obs
,
dones
,
next_dones
,
switch_or_mains
,
actions
,
logits
,
rewards
,
mask
,
next_value
):
...
...
@@ -671,7 +667,9 @@ if __name__ == "__main__":
dones
=
dones
|
next_dones
inputs
=
(
rstate1
,
rstate2
,
obs
,
dones
,
switch_or_mains
)
new_logits
,
new_values
=
get_logits_and_value
(
params
,
inputs
)
_rstate
,
new_logits
,
new_values
,
_valid
=
create_agent
(
args
,
multi_step
=
True
)
.
apply
(
params
,
inputs
)
new_values
=
new_values
.
squeeze
(
-
1
)
ratios
=
distrax
.
importance_sampling_ratios
(
distrax
.
Categorical
(
new_logits
),
distrax
.
Categorical
(
logits
),
actions
)
...
...
scripts/eval.py
View file @
ec9f3e0c
...
...
@@ -63,6 +63,8 @@ class Args:
"""the number of channels for the agent"""
rnn_channels
:
Optional
[
int
]
=
512
"""the number of rnn channels for the agent"""
rnn_type
:
Optional
[
str
]
=
"lstm"
"""the type of RNN to use for agent, None for no RNN"""
checkpoint
:
Optional
[
str
]
=
None
"""the checkpoint to load, must be a `flax_model` file"""
...
...
@@ -75,18 +77,12 @@ class Args:
def
create_agent
(
args
):
return
PPOLSTM
Agent
(
return
RNN
Agent
(
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
lstm
_channels
=
args
.
rnn_channels
,
rnn
_channels
=
args
.
rnn_channels
,
embedding_shape
=
args
.
num_embeddings
,
)
def
init_rnn_state
(
num_envs
,
rnn_channels
):
return
(
np
.
zeros
((
num_envs
,
rnn_channels
)),
np
.
zeros
((
num_envs
,
rnn_channels
)),
rnn_type
=
args
.
rnn_type
,
)
...
...
@@ -139,7 +135,7 @@ if __name__ == "__main__":
import
jax
import
jax.numpy
as
jnp
import
flax
from
ygoai.rl.jax.agent2
import
PPOLSTM
Agent
from
ygoai.rl.jax.agent2
import
RNN
Agent
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
cc
.
set_cache_dir
(
os
.
path
.
expanduser
(
"~/.cache/jax"
))
...
...
@@ -148,7 +144,7 @@ if __name__ == "__main__":
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
sample_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
([
x
]),
obs_space
.
sample
())
rstate
=
init_rnn_state
(
1
,
args
.
rnn_channels
)
rstate
=
agent
.
init_rnn_state
(
1
)
params
=
jax
.
jit
(
agent
.
init
)(
agent_key
,
(
rstate
,
sample_obs
))
with
open
(
args
.
checkpoint
,
"rb"
)
as
f
:
...
...
@@ -158,7 +154,7 @@ if __name__ == "__main__":
@
jax
.
jit
def
get_probs_and_value
(
params
,
rstate
,
obs
,
done
):
agent
=
create_agent
(
args
)
agent
=
agent
next_rstate
,
logits
,
value
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
3
]
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
next_rstate
=
jax
.
tree
.
map
(
...
...
@@ -173,6 +169,7 @@ if __name__ == "__main__":
obs
,
infos
=
envs
.
reset
()
print
(
obs
)
next_to_play
=
infos
[
'to_play'
]
dones
=
np
.
zeros
(
num_envs
,
dtype
=
np
.
bool_
)
...
...
ygoai/rl/jax/agent2.py
View file @
ec9f3e0c
from
typing
import
Tuple
,
Union
,
Optional
,
Sequence
from
functools
import
partial
import
numpy
as
np
import
jax
import
jax.numpy
as
jnp
import
flax.linen
as
nn
...
...
@@ -272,8 +273,6 @@ class Encoder(nn.Module):
f_state
=
jnp
.
concatenate
([
f_g_card
,
f_global
,
f_g_h_actions
,
f_g_actions
],
axis
=-
1
)
f_state
=
MLP
((
c
*
2
,
c
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
f_state
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_state
)
# TODO: LSTM
return
f_actions
,
f_state
,
a_mask
,
valid
...
...
@@ -309,10 +308,34 @@ class Critic(nn.Module):
return
x
class
LSTMAgent
(
nn
.
Module
):
def
rnn_forward_2p
(
rnn_layer
,
rstate1
,
rstate2
,
f_state
,
done
,
switch_or_main
,
switch
=
True
):
if
switch
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
switch
):
rstate
,
init_rstate2
=
carry
rstate
,
y
=
cell
(
rstate
,
x
)
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
rstate
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
switch
[:,
None
],
x
,
y
),
init_rstate2
,
rstate
)
return
(
rstate
,
init_rstate2
),
y
else
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
main
):
rstate1
,
rstate2
=
carry
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
y
=
cell
(
rstate
,
x
)
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
return
(
rstate1
,
rstate2
),
y
scan
=
nn
.
scan
(
body_fn
,
variable_broadcast
=
'params'
,
split_rngs
=
{
'params'
:
False
})
rstate
,
f_state
=
scan
(
rnn_layer
,
(
rstate1
,
rstate2
),
f_state
,
done
,
switch_or_main
)
return
rstate
,
f_state
class
RNNAgent
(
nn
.
Module
):
channels
:
int
=
128
num_layers
:
int
=
2
lstm
_channels
:
int
=
512
rnn
_channels
:
int
=
512
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
...
...
@@ -320,15 +343,13 @@ class LSTMAgent(nn.Module):
switch
:
bool
=
True
freeze_id
:
bool
=
False
use_history
:
bool
=
True
no_rnn
:
bool
=
False
rnn_type
:
str
=
'lstm'
@
nn
.
compact
def
__call__
(
self
,
inputs
):
if
self
.
multi_step
:
# (num_steps * batch_size, ...)
rstate1
,
rstate2
,
x
,
done
,
switch_or_main
=
inputs
batch_size
=
rstate1
[
0
]
.
shape
[
0
]
num_steps
=
done
.
shape
[
0
]
//
batch_size
*
rstate
,
x
,
done
,
switch_or_main
=
inputs
else
:
rstate
,
x
=
inputs
...
...
@@ -345,43 +366,48 @@ class LSTMAgent(nn.Module):
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
lstm_layer
=
nn
.
OptimizedLSTMCell
(
self
.
lstm_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
if
self
.
multi_step
:
if
self
.
switch
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
switch
):
rstate
,
init_rstate2
=
carry
rstate
,
y
=
cell
(
rstate
,
x
)
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
rstate
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
switch
[:,
None
],
x
,
y
),
init_rstate2
,
rstate
)
return
(
rstate
,
init_rstate2
),
y
else
:
def
body_fn
(
cell
,
carry
,
x
,
done
,
main
):
rstate1
,
rstate2
=
carry
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
y
=
cell
(
rstate
,
x
)
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
return
(
rstate1
,
rstate2
),
y
scan
=
nn
.
scan
(
body_fn
,
variable_broadcast
=
'params'
,
split_rngs
=
{
'params'
:
False
})
f_state_r
,
done
,
switch_or_main
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
batch_size
)
+
x
.
shape
[
1
:]),
(
f_state
,
done
,
switch_or_main
))
rstate
,
f_state_r
=
scan
(
lstm_layer
,
(
rstate1
,
rstate2
),
f_state_r
,
done
,
switch_or_main
)
f_state_r
=
f_state_r
.
reshape
((
-
1
,
f_state_r
.
shape
[
-
1
]))
if
self
.
rnn_type
in
[
'lstm'
,
'none'
]:
rnn_layer
=
nn
.
OptimizedLSTMCell
(
self
.
rnn_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
elif
self
.
rnn_type
==
'gru'
:
rnn_layer
=
nn
.
GRUCell
(
self
.
rnn_channels
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
nn
.
initializers
.
orthogonal
(
1.0
))
elif
self
.
rnn_type
is
None
:
rnn_layer
=
None
if
rnn_layer
is
None
:
f_state_r
=
f_state
elif
self
.
rnn_type
==
'none'
:
f_state_r
=
jnp
.
concatenate
([
f_state
for
i
in
range
(
self
.
rnn_channels
//
c
)],
axis
=-
1
)
else
:
rstate
,
f_state_r
=
lstm_layer
(
rstate
,
f_state
)
if
self
.
multi_step
:
rstate1
,
rstate2
=
rstate
batch_size
=
jax
.
tree
.
leaves
(
rstate1
)[
0
]
.
shape
[
0
]
num_steps
=
done
.
shape
[
0
]
//
batch_size
f_state_r
,
done
,
switch_or_main
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
batch_size
)
+
x
.
shape
[
1
:]),
(
f_state
,
done
,
switch_or_main
))
rstate
,
f_state_r
=
rnn_forward_2p
(
rnn_layer
,
rstate1
,
rstate2
,
f_state_r
,
done
,
switch_or_main
,
self
.
switch
)
f_state_r
=
f_state_r
.
reshape
((
-
1
,
f_state_r
.
shape
[
-
1
]))
else
:
rstate
,
f_state_r
=
rnn_layer
(
rstate
,
f_state
)
actor
=
Actor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
critic
=
Critic
(
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
if
self
.
no_rnn
:
f_state_r
=
jnp
.
concatenate
([
f_state
for
i
in
range
(
self
.
lstm_channels
//
c
)],
axis
=-
1
)
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
value
=
critic
(
f_state_r
)
return
rstate
,
logits
,
value
,
valid
def
init_rnn_state
(
self
,
batch_size
):
if
self
.
rnn_type
in
[
'lstm'
,
'none'
]:
return
(
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)),
np
.
zeros
((
batch_size
,
self
.
rnn_channels
)),
)
elif
self
.
rnn_type
==
'gru'
:
return
np
.
zeros
((
batch_size
,
self
.
rnn_channels
))
else
:
return
None
\ No newline at end of file
ygoai/rl/jax/eval.py
View file @
ec9f3e0c
...
...
@@ -36,7 +36,7 @@ def evaluate(envs, num_episodes, predict_fn, rnn_state=None):
return
eval_return
,
eval_ep_len
,
eval_win_rate
def
battle
(
envs
,
num_episodes
,
predict_fn
,
init_rnn_state
=
None
):
def
battle
(
envs
,
num_episodes
,
predict_fn
,
rstate1
=
None
,
rstate2
=
None
):
num_envs
=
envs
.
num_envs
episode_rewards
=
[]
episode_lengths
=
[]
...
...
@@ -50,7 +50,6 @@ def battle(envs, num_episodes, predict_fn, init_rnn_state=None):
np
.
zeros
(
num_envs
//
2
,
dtype
=
np
.
int64
),
np
.
ones
(
num_envs
-
num_envs
//
2
,
dtype
=
np
.
int64
)
])
rstate1
=
rstate2
=
init_rnn_state
while
True
:
main
=
next_to_play
==
main_player
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment