Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
4a0590bd
Commit
4a0590bd
authored
May 21, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Rename entropy_loss to ent_loss
parent
04e61b91
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
14 deletions
+17
-14
scripts/cleanba.py
scripts/cleanba.py
+17
-14
No files found.
scripts/cleanba.py
View file @
4a0590bd
...
...
@@ -25,7 +25,7 @@ from tensorboardX import SummaryWriter
from
ygoai.utils
import
init_ygopro
,
load_embeddings
from
ygoai.rl.ckpt
import
ModelCheckpoint
,
sync_to_gcs
,
zip_files
from
ygoai.rl.jax.agent
2
import
RNNAgent
,
ModelArgs
from
ygoai.rl.jax.agent
import
RNNAgent
,
ModelArgs
from
ygoai.rl.jax.utils
import
RecordEpisodeStatistics
,
masked_normalize
,
categorical_sample
from
ygoai.rl.jax.eval
import
evaluate
,
battle
from
ygoai.rl.jax
import
clipped_surrogate_pg_loss
,
vtrace_2p0s
,
mse_loss
,
entropy_loss
,
simple_policy_loss
,
ach_loss
,
policy_gradient_loss
...
...
@@ -285,7 +285,7 @@ def rollout(
avg_win_rates
=
deque
(
maxlen
=
1000
)
agent
=
create_agent
(
args
)
eval_agent
=
create_agent
(
args
,
eval
=
True
)
eval_agent
=
create_agent
(
args
,
eval
=
eval_mode
!=
'bot'
)
@
jax
.
jit
def
get_action
(
params
,
obs
,
rstate
):
...
...
@@ -492,7 +492,7 @@ def rollout(
writer
.
add_scalar
(
"charts/SPS_update"
,
SPS_update
,
tb_global_step
)
if
__name__
==
"__main__"
:
def
main
()
:
args
=
tyro
.
cli
(
Args
)
args
.
local_batch_size
=
int
(
args
.
local_num_envs
*
args
.
num_steps
*
args
.
num_actor_threads
*
len
(
args
.
actor_device_ids
))
args
.
local_minibatch_size
=
int
(
args
.
local_batch_size
//
args
.
num_minibatches
)
...
...
@@ -796,13 +796,13 @@ if __name__ == "__main__":
shuffled_mask
=
jnp
.
ones_like
(
shuffled_storage
.
mains
)
def
update_minibatch
(
agent_state
,
minibatch
):
(
loss
,
(
pg_loss
,
v_loss
,
ent
ropy
_loss
,
approx_kl
)),
grads
=
loss_grad_fn
(
(
loss
,
(
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)),
grads
=
loss_grad_fn
(
agent_state
.
params
,
*
minibatch
)
grads
=
jax
.
lax
.
pmean
(
grads
,
axis_name
=
"local_devices"
)
agent_state
=
agent_state
.
apply_gradients
(
grads
=
grads
)
return
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
ent
ropy
_loss
,
approx_kl
)
return
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
ent
ropy
_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
agent_state
,
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_minibatch
,
agent_state
,
(
...
...
@@ -819,17 +819,17 @@ if __name__ == "__main__":
shuffled_next_value
,
),
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
ent
ropy
_loss
,
approx_kl
)
return
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
ent
ropy
_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
(
agent_state
,
key
),
(
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
)
=
jax
.
lax
.
scan
(
update_epoch
,
(
agent_state
,
key
),
(),
length
=
args
.
update_epochs
)
loss
=
jax
.
lax
.
pmean
(
loss
,
axis_name
=
"local_devices"
)
.
mean
()
pg_loss
=
jax
.
lax
.
pmean
(
pg_loss
,
axis_name
=
"local_devices"
)
.
mean
()
v_loss
=
jax
.
lax
.
pmean
(
v_loss
,
axis_name
=
"local_devices"
)
.
mean
()
ent
ropy_loss
=
jax
.
lax
.
pmean
(
entropy
_loss
,
axis_name
=
"local_devices"
)
.
mean
()
ent
_loss
=
jax
.
lax
.
pmean
(
ent
_loss
,
axis_name
=
"local_devices"
)
.
mean
()
approx_kl
=
jax
.
lax
.
pmean
(
approx_kl
,
axis_name
=
"local_devices"
)
.
mean
()
return
agent_state
,
loss
,
pg_loss
,
v_loss
,
ent
ropy
_loss
,
approx_kl
,
key
return
agent_state
,
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
,
key
all_reduce_value
=
jax
.
pmap
(
lambda
x
:
jax
.
lax
.
pmean
(
x
,
axis_name
=
"main_devices"
),
...
...
@@ -872,7 +872,6 @@ if __name__ == "__main__":
params_queues
[
-
1
]
.
put
(
device_params
)
rollout_queue_get_time
=
deque
(
maxlen
=
10
)
data_transfer_time
=
deque
(
maxlen
=
10
)
learner_policy_version
=
0
while
True
:
learner_policy_version
+=
1
...
...
@@ -905,7 +904,7 @@ if __name__ == "__main__":
rollout_queue_get_time
.
append
(
time
.
time
()
-
rollout_queue_get_time_start
)
training_time_start
=
time
.
time
()
(
agent_state
,
loss
,
pg_loss
,
v_loss
,
ent
ropy
_loss
,
approx_kl
,
learner_keys
)
=
multi_device_update
(
(
agent_state
,
loss
,
pg_loss
,
v_loss
,
ent_loss
,
approx_kl
,
learner_keys
)
=
multi_device_update
(
agent_state
,
*
list
(
zip
(
*
sharded_data_list
)),
learner_keys
,
...
...
@@ -943,7 +942,7 @@ if __name__ == "__main__":
)
writer
.
add_scalar
(
"losses/value_loss"
,
v_loss
[
-
1
]
.
item
(),
tb_global_step
)
writer
.
add_scalar
(
"losses/policy_loss"
,
pg_loss
[
-
1
]
.
item
(),
tb_global_step
)
writer
.
add_scalar
(
"losses/entropy"
,
ent
ropy
_loss
[
-
1
]
.
item
(),
tb_global_step
)
writer
.
add_scalar
(
"losses/entropy"
,
ent_loss
[
-
1
]
.
item
(),
tb_global_step
)
writer
.
add_scalar
(
"losses/approx_kl"
,
approx_kl
[
-
1
]
.
item
(),
tb_global_step
)
writer
.
add_scalar
(
"losses/loss"
,
loss
,
tb_global_step
)
...
...
@@ -965,4 +964,8 @@ if __name__ == "__main__":
if
args
.
distributed
:
jax
.
distributed
.
shutdown
()
writer
.
close
()
\ No newline at end of file
writer
.
close
()
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment