Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
0559e98c
Commit
0559e98c
authored
Feb 25, 2024
by
biluo.shen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Update agent
parent
b1054104
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
7 deletions
+6
-7
scripts/ppo.py
scripts/ppo.py
+1
-1
ygoai/rl/agent.py
ygoai/rl/agent.py
+5
-5
ygoai/rl/dist.py
ygoai/rl/dist.py
+0
-1
No files found.
scripts/ppo.py
View file @
0559e98c
...
@@ -30,7 +30,7 @@ class Args:
...
@@ -30,7 +30,7 @@ class Args:
"""the name of this experiment"""
"""the name of this experiment"""
seed
:
int
=
1
seed
:
int
=
1
"""seed of the experiment"""
"""seed of the experiment"""
torch_deterministic
:
bool
=
Tru
e
torch_deterministic
:
bool
=
Fals
e
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda
:
bool
=
True
cuda
:
bool
=
True
"""if toggled, cuda will be enabled by default"""
"""if toggled, cuda will be enabled by default"""
...
...
ygoai/rl/agent.py
View file @
0559e98c
...
@@ -29,7 +29,7 @@ class Encoder(nn.Module):
...
@@ -29,7 +29,7 @@ class Encoder(nn.Module):
c
=
channels
c
=
channels
self
.
loc_embed
=
nn
.
Embedding
(
9
,
c
)
self
.
loc_embed
=
nn
.
Embedding
(
9
,
c
)
self
.
loc_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
affine
)
self
.
loc_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
affine
)
self
.
seq_embed
=
nn
.
Embedding
(
61
,
c
)
self
.
seq_embed
=
nn
.
Embedding
(
76
,
c
)
self
.
seq_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
affine
)
self
.
seq_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
affine
)
linear
=
lambda
in_features
,
out_features
:
nn
.
Linear
(
in_features
,
out_features
,
bias
=
bias
)
linear
=
lambda
in_features
,
out_features
:
nn
.
Linear
(
in_features
,
out_features
,
bias
=
bias
)
...
@@ -83,7 +83,7 @@ class Encoder(nn.Module):
...
@@ -83,7 +83,7 @@ class Encoder(nn.Module):
self
.
lp_fc_emb
=
linear
(
c_num
,
c
//
4
)
self
.
lp_fc_emb
=
linear
(
c_num
,
c
//
4
)
self
.
oppo_lp_fc_emb
=
linear
(
c_num
,
c
//
4
)
self
.
oppo_lp_fc_emb
=
linear
(
c_num
,
c
//
4
)
self
.
turn_embed
=
nn
.
Embedding
(
20
,
c
//
8
)
self
.
turn_embed
=
nn
.
Embedding
(
20
,
c
//
8
)
self
.
phase_embed
=
nn
.
Embedding
(
1
0
,
c
//
8
)
self
.
phase_embed
=
nn
.
Embedding
(
1
1
,
c
//
8
)
self
.
if_first_embed
=
nn
.
Embedding
(
2
,
c
//
8
)
self
.
if_first_embed
=
nn
.
Embedding
(
2
,
c
//
8
)
self
.
is_my_turn_embed
=
nn
.
Embedding
(
2
,
c
//
8
)
self
.
is_my_turn_embed
=
nn
.
Embedding
(
2
,
c
//
8
)
...
@@ -97,15 +97,15 @@ class Encoder(nn.Module):
...
@@ -97,15 +97,15 @@ class Encoder(nn.Module):
divisor
=
8
divisor
=
8
self
.
a_msg_embed
=
nn
.
Embedding
(
30
,
c
//
divisor
)
self
.
a_msg_embed
=
nn
.
Embedding
(
30
,
c
//
divisor
)
self
.
a_act_embed
=
nn
.
Embedding
(
1
1
,
c
//
divisor
)
self
.
a_act_embed
=
nn
.
Embedding
(
1
3
,
c
//
divisor
)
self
.
a_yesno_embed
=
nn
.
Embedding
(
3
,
c
//
divisor
)
self
.
a_yesno_embed
=
nn
.
Embedding
(
3
,
c
//
divisor
)
self
.
a_phase_embed
=
nn
.
Embedding
(
4
,
c
//
divisor
)
self
.
a_phase_embed
=
nn
.
Embedding
(
4
,
c
//
divisor
)
self
.
a_cancel_finish_embed
=
nn
.
Embedding
(
3
,
c
//
divisor
)
self
.
a_cancel_finish_embed
=
nn
.
Embedding
(
3
,
c
//
divisor
)
self
.
a_position_embed
=
nn
.
Embedding
(
9
,
c
//
divisor
)
self
.
a_position_embed
=
nn
.
Embedding
(
9
,
c
//
divisor
)
self
.
a_option_embed
=
nn
.
Embedding
(
4
,
c
//
divisor
//
2
)
self
.
a_option_embed
=
nn
.
Embedding
(
6
,
c
//
divisor
//
2
)
self
.
a_number_embed
=
nn
.
Embedding
(
13
,
c
//
divisor
//
2
)
self
.
a_number_embed
=
nn
.
Embedding
(
13
,
c
//
divisor
//
2
)
self
.
a_place_embed
=
nn
.
Embedding
(
31
,
c
//
divisor
//
2
)
self
.
a_place_embed
=
nn
.
Embedding
(
31
,
c
//
divisor
//
2
)
self
.
a_attrib_embed
=
nn
.
Embedding
(
31
,
c
//
divisor
//
2
)
self
.
a_attrib_embed
=
nn
.
Embedding
(
10
,
c
//
divisor
//
2
)
self
.
a_feat_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
affine
)
self
.
a_feat_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
affine
)
self
.
a_card_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
False
)
self
.
a_card_norm
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
False
)
...
...
ygoai/rl/dist.py
View file @
0559e98c
...
@@ -36,7 +36,6 @@ def setup(backend, rank, world_size, port):
...
@@ -36,7 +36,6 @@ def setup(backend, rank, world_size, port):
def
mp_start
(
run
):
def
mp_start
(
run
):
mp
.
set_start_method
(
'forkserver'
)
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
if
world_size
==
1
:
if
world_size
==
1
:
run
(
local_rank
=
0
,
world_size
=
world_size
)
run
(
local_rank
=
0
,
world_size
=
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment