Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
b7d52f29
Commit
b7d52f29
authored
May 12, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add option for no_rnn
parent
a1e6193c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
8 deletions
+15
-8
scripts/cleanba.py
scripts/cleanba.py
+5
-2
ygoai/rl/jax/agent2.py
ygoai/rl/jax/agent2.py
+10
-6
No files found.
scripts/cleanba.py
View file @
b7d52f29
...
@@ -80,6 +80,8 @@ class Args:
...
@@ -80,6 +80,8 @@ class Args:
"""whether to use history actions as input for agent"""
"""whether to use history actions as input for agent"""
eval_use_history
:
bool
=
True
eval_use_history
:
bool
=
True
"""whether to use history actions as input for eval agent"""
"""whether to use history actions as input for eval agent"""
use_rnn
:
bool
=
True
"""whether to use RNN for the agent"""
total_timesteps
:
int
=
50000000000
total_timesteps
:
int
=
50000000000
"""total timesteps of the experiments"""
"""total timesteps of the experiments"""
...
@@ -231,6 +233,7 @@ def create_agent(args, multi_step=False, eval=False):
...
@@ -231,6 +233,7 @@ def create_agent(args, multi_step=False, eval=False):
multi_step
=
multi_step
,
multi_step
=
multi_step
,
freeze_id
=
args
.
freeze_id
,
freeze_id
=
args
.
freeze_id
,
use_history
=
args
.
use_history
if
not
eval
else
args
.
eval_use_history
,
use_history
=
args
.
use_history
if
not
eval
else
args
.
eval_use_history
,
no_rnn
=
(
not
args
.
use_rnn
)
if
not
eval
else
False
)
)
...
@@ -318,8 +321,8 @@ def rollout(
...
@@ -318,8 +321,8 @@ def rollout(
rstate
=
jax
.
tree
.
map
(
rstate
=
jax
.
tree
.
map
(
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
lambda
x1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate1
,
rstate2
)
rstate
,
logits
=
get_logits
(
params
,
(
rstate
,
next_obs
))
rstate
,
logits
=
get_logits
(
params
,
(
rstate
,
next_obs
))
rstate1
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
x
,
y
),
rstate
,
rstate1
)
rstate1
=
jax
.
tree
.
map
(
lambda
x
1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x1
,
x2
),
rstate
,
rstate1
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
,
y
:
jnp
.
where
(
main
[:,
None
],
y
,
x
),
rstate
,
rstate2
)
rstate2
=
jax
.
tree
.
map
(
lambda
x
1
,
x2
:
jnp
.
where
(
main
[:,
None
],
x2
,
x1
),
rstate
,
rstate2
)
rstate1
,
rstate2
=
jax
.
tree
.
map
(
rstate1
,
rstate2
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
(
rstate1
,
rstate2
))
action
,
key
=
categorical_sample
(
logits
,
key
)
action
,
key
=
categorical_sample
(
logits
,
key
)
...
...
ygoai/rl/jax/agent2.py
View file @
b7d52f29
...
@@ -320,6 +320,7 @@ class LSTMAgent(nn.Module):
...
@@ -320,6 +320,7 @@ class LSTMAgent(nn.Module):
switch
:
bool
=
True
switch
:
bool
=
True
freeze_id
:
bool
=
False
freeze_id
:
bool
=
False
use_history
:
bool
=
True
use_history
:
bool
=
True
no_rnn
:
bool
=
False
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
inputs
):
def
__call__
(
self
,
inputs
):
...
@@ -366,18 +367,21 @@ class LSTMAgent(nn.Module):
...
@@ -366,18 +367,21 @@ class LSTMAgent(nn.Module):
scan
=
nn
.
scan
(
scan
=
nn
.
scan
(
body_fn
,
variable_broadcast
=
'params'
,
body_fn
,
variable_broadcast
=
'params'
,
split_rngs
=
{
'params'
:
False
})
split_rngs
=
{
'params'
:
False
})
f_state
,
done
,
switch_or_main
=
jax
.
tree
.
map
(
f_state
_r
,
done
,
switch_or_main
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
batch_size
)
+
x
.
shape
[
1
:]),
(
f_state
,
done
,
switch_or_main
))
lambda
x
:
jnp
.
reshape
(
x
,
(
num_steps
,
batch_size
)
+
x
.
shape
[
1
:]),
(
f_state
,
done
,
switch_or_main
))
rstate
,
f_state
=
scan
(
lstm_layer
,
(
rstate1
,
rstate2
),
f_state
,
done
,
switch_or_main
)
rstate
,
f_state
_r
=
scan
(
lstm_layer
,
(
rstate1
,
rstate2
),
f_state_r
,
done
,
switch_or_main
)
f_state
=
f_state
.
reshape
((
-
1
,
f_state
.
shape
[
-
1
]))
f_state
_r
=
f_state_r
.
reshape
((
-
1
,
f_state_r
.
shape
[
-
1
]))
else
:
else
:
rstate
,
f_state
=
lstm_layer
(
rstate
,
f_state
)
rstate
,
f_state
_r
=
lstm_layer
(
rstate
,
f_state
)
actor
=
Actor
(
actor
=
Actor
(
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
channels
=
c
,
dtype
=
jnp
.
float32
,
param_dtype
=
self
.
param_dtype
)
critic
=
Critic
(
critic
=
Critic
(
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
channels
=
[
c
,
c
,
c
],
dtype
=
self
.
dtype
,
param_dtype
=
self
.
param_dtype
)
logits
=
actor
(
f_state
,
f_actions
,
mask
)
if
self
.
no_rnn
:
value
=
critic
(
f_state
)
f_state_r
=
jnp
.
concatenate
([
f_state
for
i
in
range
(
self
.
lstm_channels
//
c
)],
axis
=-
1
)
logits
=
actor
(
f_state_r
,
f_actions
,
mask
)
value
=
critic
(
f_state_r
)
return
rstate
,
logits
,
value
,
valid
return
rstate
,
logits
,
value
,
valid
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment