Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Y
ygo-agent
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Biluo Shen
ygo-agent
Commits
83777a2d
Commit
83777a2d
authored
Jul 19, 2024
by
sbl1996@126.com
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add num_threads
parent
878f6a9e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
8 additions
and
6 deletions
+8
-6
ygoinf/ygoinf/features.py
ygoinf/ygoinf/features.py
+2
-2
ygoinf/ygoinf/jax_inf.py
ygoinf/ygoinf/jax_inf.py
+1
-1
ygoinf/ygoinf/server.py
ygoinf/ygoinf/server.py
+2
-1
ygoinf/ygoinf/tflite_inf.py
ygoinf/ygoinf/tflite_inf.py
+3
-2
No files found.
ygoinf/ygoinf/features.py
View file @
83777a2d
...
...
@@ -1144,13 +1144,13 @@ class Predictor:
return
self
.
predict_fn
(
self
.
loaded
,
rstate
,
sample_obs
)
@
staticmethod
def
load
(
checkpoint
):
def
load
(
checkpoint
,
num_threads
):
sample_obs
=
sample_input
()
rstate
=
init_rstate
()
if
checkpoint
.
endswith
(
".flax_model"
):
from
.jax_inf
import
load_model
,
predict_fn
elif
checkpoint
.
endswith
(
".tflite"
):
from
.tflite_inf
import
load_model
,
predict_fn
predictor
=
Predictor
(
load_model
(
checkpoint
,
rstate
,
sample_obs
),
predict_fn
)
predictor
=
Predictor
(
load_model
(
checkpoint
,
rstate
,
sample_obs
,
num_threads
=
num_threads
),
predict_fn
)
predictor
.
predict
(
rstate
,
sample_obs
)
return
predictor
ygoinf/ygoinf/jax_inf.py
View file @
83777a2d
...
...
@@ -35,7 +35,7 @@ def predict_fn(params, rstate, obs):
rstate
,
probs
,
value
=
get_probs_and_value
(
params
,
rstate
,
obs
)
return
rstate
,
np
.
array
(
probs
)[
0
]
.
tolist
(),
float
(
np
.
array
(
value
)[
0
])
def
load_model
(
checkpoint
,
rstate
,
sample_obs
):
def
load_model
(
checkpoint
,
rstate
,
sample_obs
,
**
kwargs
):
agent
=
create_agent
()
key
=
jax
.
random
.
PRNGKey
(
0
)
key
,
agent_key
=
jax
.
random
.
split
(
key
,
2
)
...
...
ygoinf/ygoinf/server.py
View file @
83777a2d
...
...
@@ -29,6 +29,7 @@ class Settings(BaseSettings):
enable_cors
:
bool
=
Field
(
default
=
True
,
description
=
"Enable CORS"
)
state_expire
:
int
=
Field
(
default
=
3600
,
description
=
"Duel state expire time in seconds"
)
test_duel_id
:
str
=
Field
(
default
=
"9654823a-23fd-4850-bb-6fec241740b0"
,
description
=
"Test duel id"
)
ygo_num_threads
:
int
=
Field
(
default
=
1
,
description
=
"Number of threads to use for YGO prediction"
)
settings
=
Settings
()
...
...
@@ -55,7 +56,7 @@ async def lifespan(app: FastAPI):
init_code_list
(
settings
.
code_list
)
checkpoint
=
settings
.
checkpoint
predictor
=
Predictor
.
load
(
checkpoint
)
predictor
=
Predictor
.
load
(
checkpoint
,
settings
.
ygo_num_threads
)
all_models
[
"default"
]
=
predictor
print
(
f
"loaded checkpoint from {checkpoint}"
)
...
...
ygoinf/ygoinf/tflite_inf.py
View file @
83777a2d
...
...
@@ -23,9 +23,10 @@ def predict_fn(interpreter, rstate, obs):
value
=
float
(
value
[
0
])
return
rstate
,
prob
,
value
def
load_model
(
checkpoint
,
*
args
):
def
load_model
(
checkpoint
,
*
args
,
**
kwargs
):
with
open
(
checkpoint
,
"rb"
)
as
f
:
tflite_model
=
f
.
read
()
interpreter
=
tf_lite
.
Interpreter
(
model_content
=
tflite_model
)
interpreter
=
tf_lite
.
Interpreter
(
model_content
=
tflite_model
,
num_threads
=
kwargs
.
get
(
"num_threads"
,
1
))
interpreter
.
allocate_tensors
()
return
interpreter
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment