Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
b0d45e40
Commit
b0d45e40
authored
Jun 23, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
More bfloat16
parent
a1329e4f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
34 deletions
+40
-34
ygoai/rl/env.py
ygoai/rl/env.py
+1
-0
ygoai/rl/jax/agent.py
ygoai/rl/jax/agent.py
+39
-34
No files found.
ygoai/rl/env.py
View file @
b0d45e40
...
@@ -78,6 +78,7 @@ class EnvPreprocess(gym.Wrapper):
...
@@ -78,6 +78,7 @@ class EnvPreprocess(gym.Wrapper):
def
__init__
(
self
,
env
,
skip_mask
):
def
__init__
(
self
,
env
,
skip_mask
):
super
()
.
__init__
(
env
)
super
()
.
__init__
(
env
)
self
.
num_envs
=
env
.
num_envs
self
.
skip_mask
=
skip_mask
self
.
skip_mask
=
skip_mask
def
reset
(
self
,
**
kwargs
):
def
reset
(
self
,
**
kwargs
):
...
...
ygoai/rl/jax/agent.py
View file @
b0d45e40
...
@@ -90,7 +90,7 @@ class CardEncoder(nn.Module):
...
@@ -90,7 +90,7 @@ class CardEncoder(nn.Module):
assert
self
.
version
>
0
assert
self
.
version
>
0
c
=
self
.
channels
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
,
dtype
=
self
.
dtype
)
embed
=
partial
(
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_embed
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
fc_embed
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
...
@@ -100,7 +100,7 @@ class CardEncoder(nn.Module):
...
@@ -100,7 +100,7 @@ class CardEncoder(nn.Module):
num_transform
=
lambda
x
:
num_fc
(
bytes_to_bin
(
x
,
bin_points
,
bin_intervals
))
num_transform
=
lambda
x
:
num_fc
(
bytes_to_bin
(
x
,
bin_points
,
bin_intervals
))
x1
=
x
[:,
:,
:
10
]
.
astype
(
jnp
.
int32
)
x1
=
x
[:,
:,
:
10
]
.
astype
(
jnp
.
int32
)
x2
=
x
[:,
:,
10
:]
.
astype
(
jnp
.
float32
)
x2
=
x
[:,
:,
10
:]
.
astype
(
self
.
dtype
)
x_loc
=
x1
[:,
:,
0
]
x_loc
=
x1
[:,
:,
0
]
x_seq
=
x1
[:,
:,
1
]
x_seq
=
x1
[:,
:,
1
]
...
@@ -158,12 +158,16 @@ class CardEncoder(nn.Module):
...
@@ -158,12 +158,16 @@ class CardEncoder(nn.Module):
x_cards
=
jnp
.
concatenate
(
feats
[
1
:],
axis
=-
1
)
x_cards
=
jnp
.
concatenate
(
feats
[
1
:],
axis
=-
1
)
x_cards
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_cards
)
x_cards
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_cards
)
x_cards
=
x_cards
*
feats
[
0
]
x_cards
=
x_cards
*
feats
[
0
]
print
(
"before"
,
x_cards
.
dtype
)
f_cards
=
layer_norm
()(
x_cards
)
f_cards
=
layer_norm
()(
x_cards
)
# f_cards = f_cards.astype(self.dtype)
print
(
"norm"
,
f_cards
.
dtype
)
if
self
.
oppo_info
:
if
self
.
oppo_info
:
x_cards_g
=
jnp
.
concatenate
(
feats_g
[
1
:],
axis
=-
1
)
x_cards_g
=
jnp
.
concatenate
(
feats_g
[
1
:],
axis
=-
1
)
x_cards_g
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_cards_g
)
x_cards_g
=
mlp
((
c
,),
kernel_init
=
default_fc_init2
)(
x_cards_g
)
x_cards_g
=
x_cards_g
*
feats_g
[
0
]
x_cards_g
=
x_cards_g
*
feats_g
[
0
]
f_cards_g
=
layer_norm
()(
x_cards_g
)
f_cards_g
=
layer_norm
()(
x_cards_g
)
# f_cards_g = f_cards_g.astype(self.dtype)
else
:
else
:
f_cards_g
=
None
f_cards_g
=
None
return
f_cards_g
,
f_cards
,
c_mask
return
f_cards_g
,
f_cards
,
c_mask
...
@@ -180,7 +184,7 @@ class GlobalEncoder(nn.Module):
...
@@ -180,7 +184,7 @@ class GlobalEncoder(nn.Module):
batch_size
=
x
.
shape
[
0
]
batch_size
=
x
.
shape
[
0
]
c
=
self
.
channels
c
=
self
.
channels
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
mlp
=
partial
(
MLP
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
,
dtype
=
self
.
dtype
)
embed
=
partial
(
embed
=
partial
(
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_embed
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
fc_embed
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
...
@@ -192,7 +196,7 @@ class GlobalEncoder(nn.Module):
...
@@ -192,7 +196,7 @@ class GlobalEncoder(nn.Module):
bin_points
,
bin_intervals
=
make_bin_params
(
n_bins
=
32
)
bin_points
,
bin_intervals
=
make_bin_params
(
n_bins
=
32
)
num_transform
=
lambda
x
:
num_fc
(
bytes_to_bin
(
x
,
bin_points
,
bin_intervals
))
num_transform
=
lambda
x
:
num_fc
(
bytes_to_bin
(
x
,
bin_points
,
bin_intervals
))
x1
=
x
[:,
:
4
]
.
astype
(
jnp
.
float32
)
x1
=
x
[:,
:
4
]
.
astype
(
self
.
dtype
)
x2
=
x
[:,
4
:
8
]
.
astype
(
jnp
.
int32
)
x2
=
x
[:,
4
:
8
]
.
astype
(
jnp
.
int32
)
x3
=
x
[:,
8
:
22
]
.
astype
(
jnp
.
int32
)
x3
=
x
[:,
8
:
22
]
.
astype
(
jnp
.
int32
)
...
@@ -241,18 +245,18 @@ class Encoder(nn.Module):
...
@@ -241,18 +245,18 @@ class Encoder(nn.Module):
n_embed
,
embed_dim
=
self
.
embedding_shape
n_embed
,
embed_dim
=
self
.
embedding_shape
n_embed
=
1
+
n_embed
# 1 (index 0) for unknown
n_embed
=
1
+
n_embed
# 1 (index 0) for unknown
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
)
layer_norm
=
partial
(
nn
.
LayerNorm
,
use_scale
=
True
,
use_bias
=
True
,
dtype
=
self
.
dtype
)
embed
=
partial
(
embed
=
partial
(
nn
.
Embed
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
nn
.
Embed
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
embedding_init
=
default_embed_init
)
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
param_dtype
=
self
.
param_dtype
)
fc_layer
=
partial
(
nn
.
Dense
,
use_bias
=
False
,
param_dtype
=
self
.
param_dtype
,
dtype
=
self
.
dtype
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
id_embed
=
embed
(
n_embed
,
embed_dim
)
card_encoder
=
CardEncoder
(
card_encoder
=
CardEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
channels
=
c
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
,
oppo_info
=
self
.
oppo_info
)
version
=
self
.
version
,
oppo_info
=
self
.
oppo_info
)
ActionEncoderCls
=
ActionEncoder
if
self
.
version
==
0
else
ActionEncoderV1
ActionEncoderCls
=
ActionEncoder
if
self
.
version
==
0
else
ActionEncoderV1
action_encoder
=
ActionEncoderCls
(
action_encoder
=
ActionEncoderCls
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
channels
=
c
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
x_cards
=
x
[
'cards_'
]
x_cards
=
x
[
'cards_'
]
x_global
=
x
[
'global_'
]
x_global
=
x
[
'global_'
]
...
@@ -288,27 +292,26 @@ class Encoder(nn.Module):
...
@@ -288,27 +292,26 @@ class Encoder(nn.Module):
c_mask
=
None
c_mask
=
None
num_heads
=
max
(
2
,
c
//
128
)
num_heads
=
max
(
2
,
c
//
128
)
for
_
in
range
(
self
.
num_layers
):
for
i
in
range
(
self
.
num_layers
):
f_cards
=
get_encoder_layer_cls
(
f_cards
=
get_encoder_layer_cls
(
self
.
noam
,
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
self
.
noam
,
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_cards
,
src_key_padding_mask
=
c_mask
)
f_cards
,
src_key_padding_mask
=
c_mask
)
f_cards
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_cards
)
f_cards
=
layer_norm
()(
f_cards
)
f_g_card
=
f_cards
[:,
0
]
f_g_card
=
f_cards
[:,
0
]
fs_g_card
.
append
(
f_g_card
)
fs_g_card
.
append
(
f_g_card
)
f_g_g_card
,
f_g_card
=
fs_g_card
f_g_g_card
,
f_g_card
=
fs_g_card
# Global
# Global
x_global
=
GlobalEncoder
(
x_global
=
GlobalEncoder
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
)(
x_global
)
channels
=
c
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
version
=
self
.
version
)(
x_global
)
x_global
=
x_global
.
astype
(
self
.
dtype
)
if
self
.
version
==
2
:
if
self
.
version
==
2
:
x_global
=
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_global
)
x_global
=
fc_layer
(
c
)(
x_global
)
f_global
=
x_global
+
GLUMlp
(
c
*
2
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_global
=
x_global
+
GLUMlp
(
c
*
2
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
layer_norm
(
dtype
=
self
.
dtype
)(
x_global
))
layer_norm
()(
x_global
))
else
:
else
:
f_global
=
x_global
+
MLP
((
c
*
2
,
c
*
2
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
x_global
)
f_global
=
x_global
+
MLP
((
c
*
2
,
c
*
2
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
x_global
)
f_global
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_global
)
f_global
=
fc_layer
(
c
)(
f_global
)
f_global
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_global
)
f_global
=
layer_norm
()(
f_global
)
# History actions
# History actions
x_h_actions
=
x_h_actions
.
astype
(
jnp
.
int32
)
x_h_actions
=
x_h_actions
.
astype
(
jnp
.
int32
)
...
@@ -321,7 +324,7 @@ class Encoder(nn.Module):
...
@@ -321,7 +324,7 @@ class Encoder(nn.Module):
if
self
.
freeze_id
:
if
self
.
freeze_id
:
x_h_id
=
jax
.
lax
.
stop_gradient
(
x_h_id
)
x_h_id
=
jax
.
lax
.
stop_gradient
(
x_h_id
)
x_h_id
=
MLP
(
x_h_id
=
MLP
(
(
c
,
c
),
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
,
(
c
,
c
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
,
kernel_init
=
default_fc_init2
)(
x_h_id
)
kernel_init
=
default_fc_init2
)(
x_h_id
)
x_h_a_feats1
=
action_encoder
(
x_h_actions
[:,
:,
2
:
13
])
x_h_a_feats1
=
action_encoder
(
x_h_actions
[:,
:,
2
:
13
])
...
@@ -331,13 +334,13 @@ class Encoder(nn.Module):
...
@@ -331,13 +334,13 @@ class Encoder(nn.Module):
x_h_a_feats
=
jnp
.
concatenate
([
x_h_a_feats
=
jnp
.
concatenate
([
*
x_h_a_feats1
,
x_h_a_player
,
x_h_a_turn
],
axis
=-
1
)
*
x_h_a_feats1
,
x_h_a_player
,
x_h_a_turn
],
axis
=-
1
)
f_h_actions
=
layer_norm
()(
x_h_id
)
+
layer_norm
()(
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_h_a_feats
))
f_h_actions
=
layer_norm
()(
x_h_id
)
+
layer_norm
()(
fc_layer
(
c
)(
x_h_a_feats
))
f_h_actions
=
PositionalEncoding
()(
f_h_actions
)
f_h_actions
=
PositionalEncoding
()(
f_h_actions
)
for
_
in
range
(
self
.
num_layers
):
for
_
in
range
(
self
.
num_layers
):
f_h_actions
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_h_actions
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_h_actions
,
src_key_padding_mask
=
h_mask
)
f_h_actions
,
src_key_padding_mask
=
h_mask
)
f_g_h_actions
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_h_actions
[:,
0
])
f_g_h_actions
=
layer_norm
()(
f_h_actions
[:,
0
])
else
:
else
:
h_mask
=
x_h_actions
[:,
:,
3
]
==
0
# msg == 0
h_mask
=
x_h_actions
[:,
:,
3
]
==
0
# msg == 0
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
h_mask
=
h_mask
.
at
[:,
0
]
.
set
(
False
)
...
@@ -347,7 +350,7 @@ class Encoder(nn.Module):
...
@@ -347,7 +350,7 @@ class Encoder(nn.Module):
if
self
.
freeze_id
:
if
self
.
freeze_id
:
x_h_id
=
jax
.
lax
.
stop_gradient
(
x_h_id
)
x_h_id
=
jax
.
lax
.
stop_gradient
(
x_h_id
)
x_h_id
=
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_h_id
)
x_h_id
=
fc_layer
(
c
)(
x_h_id
)
x_h_a_feats
=
action_encoder
(
x_h_actions
[:,
:,
3
:
12
])
x_h_a_feats
=
action_encoder
(
x_h_actions
[:,
:,
3
:
12
])
x_h_a_turn
=
embed
(
20
,
c
//
2
)(
x_h_actions
[:,
:,
12
])
x_h_a_turn
=
embed
(
20
,
c
//
2
)(
x_h_actions
[:,
:,
12
])
...
@@ -355,7 +358,7 @@ class Encoder(nn.Module):
...
@@ -355,7 +358,7 @@ class Encoder(nn.Module):
x_h_a_feats
.
extend
([
x_h_id
,
x_h_a_turn
,
x_h_a_phase
])
x_h_a_feats
.
extend
([
x_h_id
,
x_h_a_turn
,
x_h_a_phase
])
x_h_a_feats
=
jnp
.
concatenate
(
x_h_a_feats
,
axis
=-
1
)
x_h_a_feats
=
jnp
.
concatenate
(
x_h_a_feats
,
axis
=-
1
)
x_h_a_feats
=
layer_norm
()(
x_h_a_feats
)
x_h_a_feats
=
layer_norm
()(
x_h_a_feats
)
x_h_a_feats
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
x_h_a_feats
)
x_h_a_feats
=
fc_layer
(
c
)(
x_h_a_feats
)
if
self
.
noam
:
if
self
.
noam
:
f_h_actions
=
LlamaEncoderLayer
(
f_h_actions
=
LlamaEncoderLayer
(
...
@@ -365,7 +368,7 @@ class Encoder(nn.Module):
...
@@ -365,7 +368,7 @@ class Encoder(nn.Module):
x_h_a_feats
=
PositionalEncoding
()(
x_h_a_feats
)
x_h_a_feats
=
PositionalEncoding
()(
x_h_a_feats
)
f_h_actions
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_h_actions
=
EncoderLayer
(
num_heads
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
x_h_a_feats
,
src_key_padding_mask
=
h_mask
)
x_h_a_feats
,
src_key_padding_mask
=
h_mask
)
f_g_h_actions
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_h_actions
[:,
0
])
f_g_h_actions
=
layer_norm
()(
f_h_actions
[:,
0
])
# Actions
# Actions
...
@@ -382,12 +385,12 @@ class Encoder(nn.Module):
...
@@ -382,12 +385,12 @@ class Encoder(nn.Module):
spec_index
=
decode_id
(
x_actions
[
...
,
:
2
])
spec_index
=
decode_id
(
x_actions
[
...
,
:
2
])
B
=
jnp
.
arange
(
batch_size
)
B
=
jnp
.
arange
(
batch_size
)
f_a_cards
=
f_cards
[
B
[:,
None
],
spec_index
]
f_a_cards
=
f_cards
[
B
[:,
None
],
spec_index
]
f_a_cards
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_a_cards
)
f_a_cards
=
fc_layer
(
c
)(
f_a_cards
)
x_a_feats
=
jnp
.
concatenate
(
action_encoder
(
x_actions
[
...
,
2
:]),
axis
=-
1
)
x_a_feats
=
jnp
.
concatenate
(
action_encoder
(
x_actions
[
...
,
2
:]),
axis
=-
1
)
x_a_feats
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
x_a_feats
)
x_a_feats
=
fc_layer
(
c
)(
x_a_feats
)
f_actions
=
jnp
.
concatenate
([
f_a_cards
,
x_a_feats
],
axis
=-
1
)
f_actions
=
jnp
.
concatenate
([
f_a_cards
,
x_a_feats
],
axis
=-
1
)
f_actions
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
nn
.
leaky_relu
(
f_actions
,
negative_slope
=
0.1
))
f_actions
=
fc_layer
(
c
)(
nn
.
leaky_relu
(
f_actions
,
negative_slope
=
0.1
))
f_actions
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_actions
)
f_actions
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_actions
)
a_mask
=
x_actions
[:,
:,
2
]
==
0
a_mask
=
x_actions
[:,
:,
2
]
==
0
...
@@ -408,16 +411,16 @@ class Encoder(nn.Module):
...
@@ -408,16 +411,16 @@ class Encoder(nn.Module):
x_a_id
=
id_embed
(
x_a_id
)
x_a_id
=
id_embed
(
x_a_id
)
if
self
.
freeze_id
:
if
self
.
freeze_id
:
x_a_id
=
jax
.
lax
.
stop_gradient
(
x_a_id
)
x_a_id
=
jax
.
lax
.
stop_gradient
(
x_a_id
)
x_a_id
=
fc_layer
(
c
,
dtype
=
jnp
.
float32
)(
x_a_id
)
x_a_id
=
fc_layer
(
c
)(
x_a_id
)
x_a_feats
=
action_encoder
(
x_actions
[
...
,
3
:])
x_a_feats
=
action_encoder
(
x_actions
[
...
,
3
:])
x_a_feats
.
append
(
x_a_id
)
x_a_feats
.
append
(
x_a_id
)
x_a_feats
=
jnp
.
concatenate
(
x_a_feats
,
axis
=-
1
)
x_a_feats
=
jnp
.
concatenate
(
x_a_feats
,
axis
=-
1
)
x_a_feats
=
layer_norm
()(
x_a_feats
)
x_a_feats
=
layer_norm
()(
x_a_feats
)
x_a_feats
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
x_a_feats
)
x_a_feats
=
fc_layer
(
c
)(
x_a_feats
)
f_a_cards
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_a_cards
)
f_a_cards
=
fc_layer
(
c
)(
f_a_cards
)
f_actions
=
jax
.
nn
.
silu
(
f_a_cards
)
*
x_a_feats
f_actions
=
jax
.
nn
.
silu
(
f_a_cards
)
*
x_a_feats
f_actions
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_actions
)
f_actions
=
fc_layer
(
c
)(
f_actions
)
f_actions
=
x_a_feats
+
f_actions
f_actions
=
x_a_feats
+
f_actions
a_mask
=
x_actions
[:,
:,
3
]
==
0
a_mask
=
x_actions
[:,
:,
3
]
==
0
...
@@ -428,11 +431,12 @@ class Encoder(nn.Module):
...
@@ -428,11 +431,12 @@ class Encoder(nn.Module):
g_feats
.
append
(
f_g_h_actions
)
g_feats
.
append
(
f_g_h_actions
)
if
self
.
action_feats
:
if
self
.
action_feats
:
f_actions_g
=
fc_layer
(
c
,
dtype
=
self
.
dtype
)(
f_actions
)
f_actions_g
=
fc_layer
(
c
)(
f_actions
)
a_mask_
=
(
1
-
a_mask
.
astype
(
f_actions
.
dtype
))
a_mask_
=
(
1
-
a_mask
.
astype
(
f_actions
.
dtype
))
f_g_actions
=
(
f_actions_g
*
a_mask_
[:,
:,
None
])
.
sum
(
axis
=
1
)
f_g_actions
=
(
f_actions_g
*
a_mask_
[:,
:,
None
])
.
sum
(
axis
=
1
)
f_g_actions
=
f_g_actions
/
a_mask_
.
sum
(
axis
=
1
,
keepdims
=
True
)
f_g_actions
=
f_g_actions
/
a_mask_
.
sum
(
axis
=
1
,
keepdims
=
True
)
g_feats
.
append
(
f_g_actions
)
g_feats
.
append
(
f_g_actions
)
print
(
"f_g_actions"
,
f_g_actions
.
dtype
)
f_state
=
jnp
.
concatenate
(
g_feats
,
axis
=-
1
)
f_state
=
jnp
.
concatenate
(
g_feats
,
axis
=-
1
)
oc
=
self
.
out_channels
or
c
oc
=
self
.
out_channels
or
c
...
@@ -442,7 +446,8 @@ class Encoder(nn.Module):
...
@@ -442,7 +446,8 @@ class Encoder(nn.Module):
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
else
:
else
:
f_state
=
MLP
((
c
*
2
,
oc
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
f_state
=
MLP
((
c
*
2
,
oc
),
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)(
f_state
)
f_state
=
layer_norm
(
dtype
=
self
.
dtype
)(
f_state
)
print
(
"f_state"
,
f_state
.
dtype
)
f_state
=
layer_norm
()(
f_state
)
return
f_actions
,
f_state
,
f_g_g_card
,
a_mask
,
valid
return
f_actions
,
f_state
,
f_g_g_card
,
a_mask
,
valid
...
@@ -732,7 +737,7 @@ class RNNAgent(nn.Module):
...
@@ -732,7 +737,7 @@ class RNNAgent(nn.Module):
CriticCls
=
CrossCritic
if
self
.
batch_norm
else
Critic
CriticCls
=
CrossCritic
if
self
.
batch_norm
else
Critic
cs
=
[
self
.
critic_width
]
*
self
.
critic_depth
cs
=
[
self
.
critic_width
]
*
self
.
critic_depth
critic
=
CriticCls
(
critic
=
CriticCls
(
channels
=
cs
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
channels
=
cs
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
if
self
.
oppo_info
:
if
self
.
oppo_info
:
if
not
multi_step
:
if
not
multi_step
:
if
isinstance
(
rstate
[
0
],
tuple
):
if
isinstance
(
rstate
[
0
],
tuple
):
...
@@ -754,7 +759,7 @@ class RNNAgent(nn.Module):
...
@@ -754,7 +759,7 @@ class RNNAgent(nn.Module):
if
self
.
int_head
:
if
self
.
int_head
:
cs
=
[
self
.
critic_width
]
*
self
.
critic_depth
cs
=
[
self
.
critic_width
]
*
self
.
critic_depth
critic_int
=
Critic
(
critic_int
=
Critic
(
channels
=
cs
,
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
channels
=
cs
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
value_int
=
critic_int
(
f_state_r
)
value_int
=
critic_int
(
f_state_r
)
value
=
(
value
,
value_int
)
value
=
(
value
,
value_int
)
return
rstate
,
logits
,
value
,
valid
return
rstate
,
logits
,
value
,
valid
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment