Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
03ffe718
Commit
03ffe718
authored
Feb 21, 2024
by
biluo.shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Improve critic
parent
e7a19464
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
50 deletions
+20
-50
scripts/ppo.py
scripts/ppo.py
+4
-4
ygoai/rl/agent.py
ygoai/rl/agent.py
+14
-44
ygoenv/ygoenv/ygopro/ygopro.h
ygoenv/ygoenv/ygopro/ygopro.h
+2
-2
No files found.
scripts/ppo.py
View file @
03ffe718
...
@@ -50,7 +50,7 @@ class Args:
...
@@ -50,7 +50,7 @@ class Args:
"""the embedding file for card embeddings"""
"""the embedding file for card embeddings"""
max_options
:
int
=
24
max_options
:
int
=
24
"""the maximum number of options"""
"""the maximum number of options"""
n_history_actions
:
int
=
8
n_history_actions
:
int
=
16
"""the number of history actions to use"""
"""the number of history actions to use"""
play_mode
:
str
=
"self"
play_mode
:
str
=
"self"
"""the play mode, can be combination of 'self', 'bot', 'random', like 'self+bot'"""
"""the play mode, can be combination of 'self', 'bot', 'random', like 'self+bot'"""
...
@@ -60,7 +60,7 @@ class Args:
...
@@ -60,7 +60,7 @@ class Args:
num_channels
:
int
=
128
num_channels
:
int
=
128
"""the number of channels for the agent"""
"""the number of channels for the agent"""
total_timesteps
:
int
=
100000000
total_timesteps
:
int
=
100000000
0
"""total timesteps of the experiments"""
"""total timesteps of the experiments"""
learning_rate
:
float
=
2.5e-4
learning_rate
:
float
=
2.5e-4
"""the learning rate of the optimizer"""
"""the learning rate of the optimizer"""
...
@@ -76,7 +76,7 @@ class Args:
...
@@ -76,7 +76,7 @@ class Args:
"""the lambda for the general advantage estimation"""
"""the lambda for the general advantage estimation"""
minibatch_size
:
int
=
256
minibatch_size
:
int
=
256
"""the mini-batch size"""
"""the mini-batch size"""
update_epochs
:
int
=
4
update_epochs
:
int
=
2
"""the K epochs to update the policy"""
"""the K epochs to update the policy"""
norm_adv
:
bool
=
True
norm_adv
:
bool
=
True
"""Toggles advantages normalization"""
"""Toggles advantages normalization"""
...
@@ -219,7 +219,7 @@ def run(local_rank, world_size):
...
@@ -219,7 +219,7 @@ def run(local_rank, world_size):
# agent.get_action_and_value = torch.compile(agent.get_action_and_value, mode=args.compile_mode)
# agent.get_action_and_value = torch.compile(agent.get_action_and_value, mode=args.compile_mode)
optimizer
=
optim
.
Adam
(
agent
.
parameters
(),
lr
=
args
.
learning_rate
,
eps
=
1e-5
)
optimizer
=
optim
.
Adam
(
agent
.
parameters
(),
lr
=
args
.
learning_rate
,
eps
=
1e-5
)
scaler
=
GradScaler
(
enabled
=
args
.
fp16_train
)
scaler
=
GradScaler
(
enabled
=
args
.
fp16_train
,
init_scale
=
2
**
8
)
def
masked_mean
(
x
,
valid
):
def
masked_mean
(
x
,
valid
):
x
=
x
.
masked_fill
(
~
valid
,
0
)
x
=
x
.
masked_fill
(
~
valid
,
0
)
...
...
ygoai/rl/agent.py
View file @
03ffe718
...
@@ -304,7 +304,12 @@ class Encoder(nn.Module):
...
@@ -304,7 +304,12 @@ class Encoder(nn.Module):
f_actions
=
layer
(
f_actions
,
f_h_actions
)
f_actions
=
layer
(
f_actions
,
f_h_actions
)
f_actions
=
self
.
action_norm
(
f_actions
)
f_actions
=
self
.
action_norm
(
f_actions
)
return
f_actions
,
mask
,
valid
f_s_cards_global
=
f_cards
.
mean
(
dim
=
1
)
c_mask
=
1
-
mask
.
unsqueeze
(
-
1
)
.
float
()
f_s_actions_ha
=
(
f_actions
*
c_mask
)
.
sum
(
dim
=
1
)
/
c_mask
.
sum
(
dim
=
1
)
f_state
=
torch
.
cat
([
f_s_cards_global
,
f_s_actions_ha
],
dim
=-
1
)
return
f_actions
,
f_state
,
mask
,
valid
class
PPOAgent
(
nn
.
Module
):
class
PPOAgent
(
nn
.
Module
):
...
@@ -324,67 +329,32 @@ class PPOAgent(nn.Module):
...
@@ -324,67 +329,32 @@ class PPOAgent(nn.Module):
)
)
self
.
critic
=
nn
.
Sequential
(
self
.
critic
=
nn
.
Sequential
(
nn
.
Linear
(
c
,
c
//
4
),
nn
.
Linear
(
c
*
2
,
c
//
2
),
nn
.
ReLU
(),
nn
.
ReLU
(),
nn
.
Linear
(
c
//
4
,
1
),
nn
.
Linear
(
c
//
2
,
1
),
)
)
def
load_embeddings
(
self
,
embeddings
,
freeze
=
True
):
def
load_embeddings
(
self
,
embeddings
,
freeze
=
True
):
self
.
encoder
.
load_embeddings
(
embeddings
,
freeze
)
self
.
encoder
.
load_embeddings
(
embeddings
,
freeze
)
def
get_value
(
self
,
x
):
def
get_value
(
self
,
x
):
f_actions
,
mask
,
valid
=
self
.
encoder
(
x
)
f_actions
,
f_state
,
mask
,
valid
=
self
.
encoder
(
x
)
c_mask
=
1
-
mask
.
unsqueeze
(
-
1
)
.
float
()
return
self
.
critic
(
f_state
)
f
=
(
f_actions
*
c_mask
)
.
sum
(
dim
=
1
)
/
c_mask
.
sum
(
dim
=
1
)
return
self
.
critic
(
f
)
def
get_action_and_value
(
self
,
x
,
action
):
def
get_action_and_value
(
self
,
x
,
action
):
f_actions
,
mask
,
valid
=
self
.
encoder
(
x
)
f_actions
,
f_state
,
mask
,
valid
=
self
.
encoder
(
x
)
c_mask
=
1
-
mask
.
unsqueeze
(
-
1
)
.
float
()
f
=
(
f_actions
*
c_mask
)
.
sum
(
dim
=
1
)
/
c_mask
.
sum
(
dim
=
1
)
logits
=
self
.
actor
(
f_actions
)[
...
,
0
]
logits
=
self
.
actor
(
f_actions
)[
...
,
0
]
logits
=
logits
.
float
()
logits
=
logits
.
float
()
logits
=
logits
.
masked_fill
(
mask
,
float
(
"-inf"
))
logits
=
logits
.
masked_fill
(
mask
,
float
(
"-inf"
))
probs
=
Categorical
(
logits
=
logits
)
probs
=
Categorical
(
logits
=
logits
)
return
action
,
probs
.
log_prob
(
action
),
probs
.
entropy
(),
self
.
critic
(
f
),
valid
return
action
,
probs
.
log_prob
(
action
),
probs
.
entropy
(),
self
.
critic
(
f
_state
),
valid
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
f_actions
,
mask
,
valid
=
self
.
encoder
(
x
)
f_actions
,
f_state
,
mask
,
valid
=
self
.
encoder
(
x
)
c_mask
=
1
-
mask
.
unsqueeze
(
-
1
)
.
float
()
f
=
(
f_actions
*
c_mask
)
.
sum
(
dim
=
1
)
/
c_mask
.
sum
(
dim
=
1
)
logits
=
self
.
actor
(
f_actions
)[
...
,
0
]
logits
=
self
.
actor
(
f_actions
)[
...
,
0
]
logits
=
logits
.
float
()
logits
=
logits
.
float
()
logits
=
logits
.
masked_fill
(
mask
,
float
(
"-inf"
))
logits
=
logits
.
masked_fill
(
mask
,
float
(
"-inf"
))
return
logits
,
self
.
critic
(
f
)
return
logits
,
self
.
critic
(
f_state
)
class
DMCAgent
(
nn
.
Module
):
def
__init__
(
self
,
channels
=
128
,
num_card_layers
=
2
,
num_action_layers
=
2
,
num_history_action_layers
=
2
,
embedding_shape
=
None
,
bias
=
False
,
affine
=
True
):
super
(
DMCAgent
,
self
)
.
__init__
()
self
.
encoder
=
Encoder
(
channels
,
num_card_layers
,
num_action_layers
,
num_history_action_layers
,
embedding_shape
,
bias
,
affine
)
c
=
channels
self
.
value_head
=
nn
.
Sequential
(
nn
.
Linear
(
c
,
c
//
4
),
nn
.
ReLU
(),
nn
.
Linear
(
c
//
4
,
1
),
)
def
load_embeddings
(
self
,
embeddings
,
freeze
=
True
):
self
.
encoder
.
load_embeddings
(
embeddings
,
freeze
)
def
forward
(
self
,
x
):
f_actions
,
mask
,
valid
=
self
.
encoder
(
f_actions
)
values
=
self
.
value_head
(
f_actions
)[
...
,
0
]
# values = torch.tanh(values)
values
=
torch
.
where
(
mask
,
torch
.
full_like
(
values
,
-
10
),
values
)
return
values
,
valid
\ No newline at end of file
ygoenv/ygoenv/ygopro/ygopro.h
View file @
03ffe718
...
@@ -1516,9 +1516,9 @@ public:
...
@@ -1516,9 +1516,9 @@ public:
if (win_turn <= 5) {
if (win_turn <= 5) {
base_reward = 2.0;
base_reward = 2.0;
} else if (win_turn <= 3) {
} else if (win_turn <= 3) {
base_reward
=
4
.0
;
base_reward =
3
.0;
} else if (win_turn <= 1) {
} else if (win_turn <= 1) {
base_reward
=
8
.0
;
base_reward =
4
.0;
}
}
if (play_mode_ == kSelfPlay) {
if (play_mode_ == kSelfPlay) {
// to_play_ is the previous player
// to_play_ is the previous player
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment