Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
81d80f7f
Commit
81d80f7f
authored
Apr 30, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix done for rstate
parent
3bf0bc91
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
66 additions
and
42 deletions
+66
-42
scripts/battle.py
scripts/battle.py
+7
-5
scripts/impala.py
scripts/impala.py
+31
-20
scripts/ppo.py
scripts/ppo.py
+16
-12
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+7
-4
ygoai/utils.py
ygoai/utils.py
+5
-1
No files found.
scripts/battle.py
View file @
81d80f7f
...
...
@@ -157,24 +157,26 @@ if __name__ == "__main__":
params2
=
jax
.
device_put
(
params2
)
@
jax
.
jit
def
get_probs
(
params
,
rstate
,
obs
,
done
):
def
get_probs
(
params
,
rstate
,
obs
,
done
=
None
):
agent
=
create_agent
(
args
)
next_rstate
,
logits
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
2
]
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
next_rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
)
,
next_rstate
)
if
done
is
not
None
:
next_rstate
=
jnp
.
where
(
done
[:,
None
],
0
,
next_rstate
)
return
next_rstate
,
probs
if
args
.
num_envs
!=
1
:
@
jax
.
jit
def
get_probs2
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
probs1
=
get_probs
(
params1
,
rstate1
,
obs
,
done
)
next_rstate2
,
probs2
=
get_probs
(
params2
,
rstate2
,
obs
,
done
)
next_rstate1
,
probs1
=
get_probs
(
params1
,
rstate1
,
obs
)
next_rstate2
,
probs2
=
get_probs
(
params2
,
rstate2
,
obs
)
probs
=
jnp
.
where
(
main
[:,
None
],
probs1
,
probs2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
next_rstate2
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
return
rstate1
,
rstate2
,
probs
def
predict_fn
(
rstate1
,
rstate2
,
obs
,
main
,
done
):
...
...
scripts/impala.py
View file @
81d80f7f
import
os
import
shutil
import
queue
import
random
import
threading
...
...
@@ -45,6 +46,8 @@ class Args:
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to load"""
debug
:
bool
=
False
"""whether to run the script in debug mode"""
tb_dir
:
str
=
"runs"
"""the directory to save the tensorboard logs"""
...
...
@@ -156,7 +159,7 @@ class Args:
actor_devices
:
Optional
[
List
[
str
]]
=
None
learner_devices
:
Optional
[
List
[
str
]]
=
None
num_embeddings
:
Optional
[
int
]
=
None
freeze_id
:
bool
=
Fals
e
freeze_id
:
Optional
[
bool
]
=
Non
e
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
,
eval
=
False
):
...
...
@@ -253,28 +256,27 @@ def rollout(
@
jax
.
jit
def
get_logits
(
params
:
flax
.
core
.
FrozenDict
,
inputs
,
done
):
params
:
flax
.
core
.
FrozenDict
,
inputs
):
rstate
,
logits
=
create_agent
(
args
)
.
apply
(
params
,
inputs
)[:
2
]
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
return
rstate
,
logits
@
jax
.
jit
def
get_action
(
params
:
flax
.
core
.
FrozenDict
,
inputs
):
batch_size
=
jax
.
tree
.
leaves
(
inputs
)[
0
]
.
shape
[
0
]
done
=
jnp
.
zeros
(
batch_size
,
dtype
=
jnp
.
bool_
)
rstate
,
logits
=
get_logits
(
params
,
inputs
,
done
)
rstate
,
logits
=
get_logits
(
params
,
inputs
)
return
rstate
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
get_action_battle
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
)
,
done
)
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
)
,
done
)
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
))
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
))
logits
=
jnp
.
where
(
main
[:,
None
],
logits1
,
logits2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
next_rstate2
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
return
rstate1
,
rstate2
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
...
...
@@ -284,12 +286,14 @@ def rollout(
next_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
(
x
),
next_obs
)
done
=
jnp
.
array
(
done
)
main
=
jnp
.
array
(
main
)
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
logits
=
get_logits
(
params
,
(
rstate
,
next_obs
)
,
done
)
rstate
,
logits
=
get_logits
(
params
,
(
rstate
,
next_obs
))
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
action
,
key
=
categorical_sample
(
logits
,
key
)
return
next_obs
,
done
,
main
,
rstate1
,
rstate2
,
action
,
logits
,
key
...
...
@@ -517,13 +521,18 @@ if __name__ == "__main__":
timestamp
=
int
(
time
.
time
())
run_name
=
f
"{args.env_id}__{args.exp_name}__{args.seed}__{timestamp}"
tb_log_dir
=
f
"{args.tb_dir}/{run_name}"
writer
=
SummaryWriter
(
tb_log_dir
)
writer
.
add_text
(
"hyperparameters"
,
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
)
dummy_writer
=
SimpleNamespace
()
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
tb_log_dir
=
f
"{args.tb_dir}/{run_name}"
if
args
.
local_rank
==
0
and
not
args
.
debug
:
writer
=
SummaryWriter
(
tb_log_dir
)
writer
.
add_text
(
"hyperparameters"
,
"|param|value|
\n
|-|-|
\n
%
s"
%
(
"
\n
"
.
join
([
f
"|{key}|{value}|"
for
key
,
value
in
vars
(
args
)
.
items
()])),
)
else
:
writer
=
dummy_writer
def
save_fn
(
obj
,
path
):
with
open
(
path
,
"wb"
)
as
f
:
...
...
@@ -669,6 +678,7 @@ if __name__ == "__main__":
learn_opponent
:
bool
=
False
,
):
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
# TODO: rstate will be out-date after the first update, maybe consider R2D2
next_inputs
,
init_rstate1
,
init_rstate2
=
[
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
for
x
in
[
sharded_next_inputs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
...
...
@@ -756,8 +766,6 @@ if __name__ == "__main__":
params_queues
=
[]
rollout_queues
=
[]
eval_queue
=
queue
.
Queue
()
dummy_writer
=
SimpleNamespace
()
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
unreplicated_params
=
flax
.
jax_utils
.
unreplicate
(
agent_state
.
params
)
for
d_idx
,
d_id
in
enumerate
(
args
.
actor_device_ids
):
...
...
@@ -844,13 +852,16 @@ if __name__ == "__main__":
writer
.
add_scalar
(
"losses/approx_kl"
,
approx_kl
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/loss"
,
loss
,
global_step
)
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
:
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
and
not
args
.
debug
:
M_steps
=
args
.
batch_size
*
learner_policy_version
//
(
2
**
20
)
ckpt_name
=
f
"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger
.
save
(
unreplicated_params
,
ckpt_name
)
if
args
.
gcs_bucket
is
not
None
:
lastest_path
=
ckpt_maneger
.
get_latest
()
copy_path
=
lastest_path
.
with_name
(
"latest"
+
lastest_path
.
suffix
)
shutil
.
copyfile
(
lastest_path
,
copy_path
)
zip_file_path
=
"latest.zip"
zip_files
(
zip_file_path
,
[
ckpt_maneger
.
get_latest
(
),
tb_log_dir
])
zip_files
(
zip_file_path
,
[
str
(
copy_path
),
tb_log_dir
])
sync_to_gcs
(
args
.
gcs_bucket
,
zip_file_path
)
if
learner_policy_version
>=
args
.
num_updates
:
...
...
scripts/ppo.py
View file @
81d80f7f
...
...
@@ -47,6 +47,8 @@ class Args:
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint
:
Optional
[
str
]
=
None
"""the path to the model checkpoint to load"""
debug
:
bool
=
False
"""whether to run the script in debug mode"""
tb_dir
:
str
=
"runs"
"""the directory to save the tensorboard logs"""
...
...
@@ -156,7 +158,7 @@ class Args:
actor_devices
:
Optional
[
List
[
str
]]
=
None
learner_devices
:
Optional
[
List
[
str
]]
=
None
num_embeddings
:
Optional
[
int
]
=
None
freeze_id
:
bool
=
Fals
e
freeze_id
:
Optional
[
bool
]
=
Non
e
def
make_env
(
args
,
seed
,
num_envs
,
num_threads
,
mode
=
'self'
,
thread_affinity_offset
=-
1
,
eval
=
False
):
...
...
@@ -254,28 +256,27 @@ def rollout(
@
jax
.
jit
def
get_logits
(
params
:
flax
.
core
.
FrozenDict
,
inputs
,
done
):
params
:
flax
.
core
.
FrozenDict
,
inputs
):
rstate
,
logits
=
create_agent
(
args
)
.
apply
(
params
,
inputs
)[:
2
]
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
return
rstate
,
logits
@
jax
.
jit
def
get_action
(
params
:
flax
.
core
.
FrozenDict
,
inputs
):
batch_size
=
jax
.
tree
.
leaves
(
inputs
)[
0
]
.
shape
[
0
]
done
=
jnp
.
zeros
(
batch_size
,
dtype
=
jnp
.
bool_
)
rstate
,
logits
=
get_logits
(
params
,
inputs
,
done
)
rstate
,
logits
=
get_logits
(
params
,
inputs
)
return
rstate
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
def
get_action_battle
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
)
,
done
)
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
)
,
done
)
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
))
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
))
logits
=
jnp
.
where
(
main
[:,
None
],
logits1
,
logits2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
next_rstate2
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
return
rstate1
,
rstate2
,
logits
.
argmax
(
axis
=
1
)
@
jax
.
jit
...
...
@@ -285,12 +286,14 @@ def rollout(
next_obs
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
array
(
x
),
next_obs
)
done
=
jnp
.
array
(
done
)
main
=
jnp
.
array
(
main
)
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
logits
=
get_logits
(
params
,
(
rstate
,
next_obs
)
,
done
)
rstate
,
logits
=
get_logits
(
params
,
(
rstate
,
next_obs
))
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
action
,
key
=
categorical_sample
(
logits
,
key
)
return
next_obs
,
done
,
main
,
rstate1
,
rstate2
,
action
,
logits
,
key
...
...
@@ -532,7 +535,7 @@ if __name__ == "__main__":
dummy_writer
.
add_scalar
=
lambda
x
,
y
,
z
:
None
tb_log_dir
=
f
"{args.tb_dir}/{run_name}"
if
args
.
local_rank
==
0
:
if
args
.
local_rank
==
0
and
not
args
.
debug
:
writer
=
SummaryWriter
(
tb_log_dir
)
writer
.
add_text
(
"hyperparameters"
,
...
...
@@ -692,6 +695,7 @@ if __name__ == "__main__":
learn_opponent
:
bool
=
False
,
):
storage
=
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
hstack
(
x
),
*
sharded_storages
)
# TODO: rstate will be out-date after the first update, maybe consider R2D2
next_inputs
,
init_rstate1
,
init_rstate2
=
[
jax
.
tree
.
map
(
lambda
*
x
:
jnp
.
concatenate
(
x
),
*
x
)
for
x
in
[
sharded_next_inputs
,
sharded_init_rstate1
,
sharded_init_rstate2
]
...
...
@@ -874,7 +878,7 @@ if __name__ == "__main__":
writer
.
add_scalar
(
"losses/approx_kl"
,
approx_kl
[
-
1
]
.
item
(),
global_step
)
writer
.
add_scalar
(
"losses/loss"
,
loss
,
global_step
)
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
:
if
args
.
local_rank
==
0
and
learner_policy_version
%
args
.
save_interval
==
0
and
not
args
.
debug
:
M_steps
=
args
.
batch_size
*
learner_policy_version
//
(
2
**
20
)
ckpt_name
=
f
"{timestamp}_{M_steps}M.flax_model"
ckpt_maneger
.
save
(
unreplicated_params
,
ckpt_name
)
...
...
ygoai/rl/jax/agent2.py
View file @
81d80f7f
...
...
@@ -169,8 +169,6 @@ class Encoder(nn.Module):
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
param_dtype
=
self
.
param_dtype
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
if
self
.
freeze_id
:
id_embed
=
lambda
x
:
jax
.
lax
.
stop_gradient
(
id_embed
(
x
))
action_encoder
=
ActionEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
...
...
@@ -184,6 +182,8 @@ class Encoder(nn.Module):
x_id
=
decode_id
(
x_cards
[:,
:,
:
2
]
.
astype
(
jnp
.
int32
))
x_id
=
id_embed
(
x_id
)
if
self
.
freeze_id
:
x_id
=
jax
.
lax
.
stop_gradient
(
x_id
)
# Cards
f_cards
=
CardEncoder
(
...
...
@@ -215,9 +215,12 @@ class Encoder(nn.Module):
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
x_h_id
=
decode_id
(
x_h_actions
[
...
,
:
2
])
x_h_id
=
id_embed
(
x_h_id
)
if
self
.
freeze_id
:
x_h_id
=
jax
.
lax
.
stop_gradient
(
x_h_id
)
x_h_id
=
MLP
(
(
c
,
c
),
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
default_fc_init2
)(
id_embed
(
x_h_id
)
)
kernel_init
=
default_fc_init2
)(
x_h_id
)
x_h_a_feats1
=
action_encoder
(
x_h_actions
[:,
:,
2
:
13
])
...
...
@@ -379,9 +382,9 @@ class PPOLSTMAgent(nn.Module):
rstate1
,
rstate2
=
carry
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
y
=
cell
(
rstate
,
x
)
rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
rstate
)
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
return
(
rstate1
,
rstate2
),
y
scan
=
nn
.
scan
(
body_fn
,
variable_broadcast
=
'params'
,
...
...
ygoai/utils.py
View file @
81d80f7f
...
...
@@ -48,7 +48,7 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
return
deck_name
def
load_embeddings
(
embedding_file
,
code_list_file
):
def
load_embeddings
(
embedding_file
,
code_list_file
,
pad_to
=
999
):
with
open
(
embedding_file
,
"rb"
)
as
f
:
embeddings
=
pickle
.
load
(
f
)
with
open
(
code_list_file
,
"r"
)
as
f
:
...
...
@@ -56,4 +56,8 @@ def load_embeddings(embedding_file, code_list_file):
code_list
=
[
int
(
code
.
strip
())
for
code
in
code_list
]
assert
len
(
embeddings
)
==
len
(
code_list
),
f
"len(embeddings)={len(embeddings)}, len(code_list)={len(code_list)}"
embeddings
=
np
.
array
([
embeddings
[
code
]
for
code
in
code_list
],
dtype
=
np
.
float32
)
if
pad_to
is
not
None
:
assert
pad_to
>=
len
(
embeddings
),
f
"pad_to={pad_to} < len(embeddings)={len(embeddings)}"
pad
=
np
.
zeros
((
pad_to
-
len
(
embeddings
),
embeddings
.
shape
[
1
]),
dtype
=
np
.
float32
)
embeddings
=
np
.
concatenate
([
embeddings
,
pad
],
axis
=
0
)
return
embeddings
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment