Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
cd95bf43
Commit
cd95bf43
authored
Apr 20, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Print value in eval
parent
14d32bc7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
7 deletions
+8
-7
scripts/eval.py
scripts/eval.py
+8
-7
No files found.
scripts/eval.py
View file @
cd95bf43
...
...
@@ -157,17 +157,17 @@ if __name__ == "__main__":
params
=
jax
.
device_put
(
params
)
@
jax
.
jit
def
get_probs
(
params
,
rstate
,
obs
,
done
):
def
get_probs
_and_value
(
params
,
rstate
,
obs
,
done
):
agent
=
create_agent
(
args
)
next_rstate
,
logits
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
2
]
next_rstate
,
logits
,
value
=
agent
.
apply
(
params
,
(
rstate
,
obs
))[:
3
]
probs
=
jax
.
nn
.
softmax
(
logits
,
axis
=-
1
)
next_rstate
=
jax
.
tree
.
map
(
lambda
x
:
jnp
.
where
(
done
[:,
None
],
0
,
x
),
next_rstate
)
return
next_rstate
,
probs
return
next_rstate
,
probs
,
value
def
predict_fn
(
rstate
,
obs
,
done
):
rstate
,
probs
=
get_probs
(
params
,
rstate
,
obs
,
done
)
return
rstate
,
np
.
array
(
probs
)
rstate
,
probs
,
value
=
get_probs_and_value
(
params
,
rstate
,
obs
,
done
)
return
rstate
,
np
.
array
(
probs
)
,
np
.
array
(
value
)
print
(
f
"loaded checkpoint from {args.checkpoint}"
)
...
...
@@ -194,9 +194,10 @@ if __name__ == "__main__":
if
args
.
checkpoint
:
_start
=
time
.
time
()
rstate
,
probs
=
predict_fn
(
rstate
,
obs
,
dones
)
rstate
,
probs
,
value
=
predict_fn
(
rstate
,
obs
,
dones
)
if
args
.
verbose
:
print
([
f
"{p:.4f}"
for
p
in
probs
[
probs
!=
0
]
.
tolist
()])
print
(
f
"probs: {[f'{p:.4f}' for p in probs[probs != 0].tolist()]}"
)
print
(
f
"value: {value[0][0]}"
)
actions
=
probs
.
argmax
(
axis
=
1
)
model_time
+=
time
.
time
()
-
_start
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment