Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
c02fbd19
Commit
c02fbd19
authored
May 08, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add option to use history
parent
0e9969c5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
36 deletions
+18
-36
scripts/cleanba.py
scripts/cleanba.py
+12
-7
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+6
-29
No files found.
scripts/cleanba.py
View file @
c02fbd19
...
...
@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.ckpt
import
ModelCheckpoint
,
sync_to_gcs
,
zip_files
from
ygoai.rl.jax.agent2
import
PPO
LSTMAgent
from
ygoai.rl.jax.agent2
import
LSTMAgent
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
vtrace_2p0s
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
ach_loss
,
policy_gradient_loss
...
...
@@ -76,6 +76,10 @@ class Args:
"""the number of history actions to use"""
greedy_reward
:
bool
=
False
"""whether to use greedy reward (faster kill higher reward)"""
use_history
:
bool
=
True
"""whether to use history actions as input for agent"""
eval_use_history
:
bool
=
True
"""whether to use history actions as input for eval agent"""
total_timesteps
:
int
=
50000000000
"""total timesteps of the experiments"""
...
...
@@ -210,8 +214,8 @@ class Transition(NamedTuple):
next_dones
:
list
def
create_agent
(
args
,
multi_step
=
False
):
return
PPO
LSTMAgent
(
def
create_agent
(
args
,
multi_step
=
False
,
eval
=
False
):
return
LSTMAgent
(
channels
=
args
.
num_channels
,
num_layers
=
args
.
num_layers
,
embedding_shape
=
args
.
num_embeddings
,
...
...
@@ -221,6 +225,7 @@ def create_agent(args, multi_step=False):
switch
=
args
.
switch
,
multi_step
=
multi_step
,
freeze_id
=
args
.
freeze_id
,
use_history
=
args
.
use_history
if
not
eval
else
args
.
eval_use_history
,
)
...
...
@@ -272,10 +277,10 @@ def rollout(
avg_ep_returns
=
deque
(
maxlen
=
1000
)
avg_win_rates
=
deque
(
maxlen
=
1000
)
@
jax
.
jit
@
partial
(
jax
.
jit
,
static_argnums
=
(
2
,))
def
get_logits
(
params
:
flax
.
core
.
FrozenDict
,
inputs
):
rstate
,
logits
=
create_agent
(
args
)
.
apply
(
params
,
inputs
)[:
2
]
params
:
flax
.
core
.
FrozenDict
,
inputs
,
eval
=
False
):
rstate
,
logits
=
create_agent
(
args
,
eval
=
eval
)
.
apply
(
params
,
inputs
)[:
2
]
return
rstate
,
logits
@
jax
.
jit
...
...
@@ -287,7 +292,7 @@ def rollout(
@
jax
.
jit
def
get_action_battle
(
params1
,
params2
,
rstate1
,
rstate2
,
obs
,
main
,
done
):
next_rstate1
,
logits1
=
get_logits
(
params1
,
(
rstate1
,
obs
))
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
))
next_rstate2
,
logits2
=
get_logits
(
params2
,
(
rstate2
,
obs
)
,
True
)
logits
=
jnp
.
where
(
main
[:,
None
],
logits1
,
logits2
)
rstate1
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
next_rstate1
,
rstate1
)
...
...
ygoai/rl/jax/agent2.py
View file @
c02fbd19
...
...
@@ -151,6 +151,7 @@ class Encoder(nn.Module):
dtype
:
Optional
[
jnp
.
dtype
]
=
None
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
freeze_id
:
bool
=
False
use_history
:
bool
=
True
@
nn
.
compact
def
__call__
(
self
,
x
):
...
...
@@ -266,6 +267,8 @@ class Encoder(nn.Module):
f_g_actions
=
f_g_actions
/
a_mask_
.
sum
(
axis
=
1
,
keepdims
=
True
)
# State
if
not
self
.
use_history
:
f_g_h_actions
=
jnp
.
zeros_like
(
f_g_h_actions
)
f_state
=
jnp
.
concatenate
([
f_g_card
,
f_global
,
f_g_h_actions
,
f_g_actions
],
axis
=-
1
)
f_state
=
MLP
((
c
*
2
,
c
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
f_state
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_state
)
...
...
@@ -306,35 +309,7 @@ class Critic(nn.Module):
return
x
class
PPOAgent
(
nn
.
Module
):
channels
:
int
=
128
num_layers
:
int
=
2
embedding_shape
:
Optional
[
Union
[
int
,
Tuple
[
int
,
int
]]]
=
None
dtype
:
jnp
.
dtype
=
jnp
.
float32
param_dtype
:
jnp
.
dtype
=
jnp
.
float32
@
nn
.
compact
def
__call__
(
self
,
x
):
c
=
self
.
channels
encoder
=
Encoder
(
channels
=
c
,
num_layers
=
self
.
num_layers
,
embedding_shape
=
self
.
embedding_shape
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
)
actor
=
Actor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
critic
=
Critic
(
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
logits
=
actor
(
f_state
,
f_actions
,
mask
)
value
=
critic
(
f_state
)
return
logits
,
value
,
valid
class
PPOLSTMAgent
(
nn
.
Module
):
class
LSTMAgent
(
nn
.
Module
):
channels
:
int
=
128
num_layers
:
int
=
2
lstm_channels
:
int
=
512
...
...
@@ -344,6 +319,7 @@ class PPOLSTMAgent(nn.Module):
multi_step
:
bool
=
False
switch
:
bool
=
True
freeze_id
:
bool
=
False
use_history
:
bool
=
True
@
nn
.
compact
def
__call__
(
self
,
inputs
):
...
...
@@ -363,6 +339,7 @@ class PPOLSTMAgent(nn.Module):
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
freeze_id
=
self
.
freeze_id
,
use_history
=
self
.
use_history
,
)
f_actions
,
f_state
,
mask
,
valid
=
encoder
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment