Commit 83777a2d authored by sbl1996@126.com's avatar sbl1996@126.com

Add num_threads

parent 878f6a9e
...@@ -1144,13 +1144,13 @@ class Predictor: ...@@ -1144,13 +1144,13 @@ class Predictor:
return self.predict_fn(self.loaded, rstate, sample_obs) return self.predict_fn(self.loaded, rstate, sample_obs)
@staticmethod @staticmethod
def load(checkpoint): def load(checkpoint, num_threads):
sample_obs = sample_input() sample_obs = sample_input()
rstate = init_rstate() rstate = init_rstate()
if checkpoint.endswith(".flax_model"): if checkpoint.endswith(".flax_model"):
from .jax_inf import load_model, predict_fn from .jax_inf import load_model, predict_fn
elif checkpoint.endswith(".tflite"): elif checkpoint.endswith(".tflite"):
from .tflite_inf import load_model, predict_fn from .tflite_inf import load_model, predict_fn
predictor = Predictor(load_model(checkpoint, rstate, sample_obs), predict_fn) predictor = Predictor(load_model(checkpoint, rstate, sample_obs, num_threads=num_threads), predict_fn)
predictor.predict(rstate, sample_obs) predictor.predict(rstate, sample_obs)
return predictor return predictor
...@@ -35,7 +35,7 @@ def predict_fn(params, rstate, obs): ...@@ -35,7 +35,7 @@ def predict_fn(params, rstate, obs):
rstate, probs, value = get_probs_and_value(params, rstate, obs) rstate, probs, value = get_probs_and_value(params, rstate, obs)
return rstate, np.array(probs)[0].tolist(), float(np.array(value)[0]) return rstate, np.array(probs)[0].tolist(), float(np.array(value)[0])
def load_model(checkpoint, rstate, sample_obs): def load_model(checkpoint, rstate, sample_obs, **kwargs):
agent = create_agent() agent = create_agent()
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
key, agent_key = jax.random.split(key, 2) key, agent_key = jax.random.split(key, 2)
......
...@@ -29,6 +29,7 @@ class Settings(BaseSettings): ...@@ -29,6 +29,7 @@ class Settings(BaseSettings):
enable_cors: bool = Field(default=True, description="Enable CORS") enable_cors: bool = Field(default=True, description="Enable CORS")
state_expire: int = Field(default=3600, description="Duel state expire time in seconds") state_expire: int = Field(default=3600, description="Duel state expire time in seconds")
test_duel_id: str = Field(default="9654823a-23fd-4850-bb-6fec241740b0", description="Test duel id") test_duel_id: str = Field(default="9654823a-23fd-4850-bb-6fec241740b0", description="Test duel id")
ygo_num_threads: int = Field(default=1, description="Number of threads to use for YGO prediction")
settings = Settings() settings = Settings()
...@@ -55,7 +56,7 @@ async def lifespan(app: FastAPI): ...@@ -55,7 +56,7 @@ async def lifespan(app: FastAPI):
init_code_list(settings.code_list) init_code_list(settings.code_list)
checkpoint = settings.checkpoint checkpoint = settings.checkpoint
predictor = Predictor.load(checkpoint) predictor = Predictor.load(checkpoint, settings.ygo_num_threads)
all_models["default"] = predictor all_models["default"] = predictor
print(f"loaded checkpoint from {checkpoint}") print(f"loaded checkpoint from {checkpoint}")
......
...@@ -23,9 +23,10 @@ def predict_fn(interpreter, rstate, obs): ...@@ -23,9 +23,10 @@ def predict_fn(interpreter, rstate, obs):
value = float(value[0]) value = float(value[0])
return rstate, prob, value return rstate, prob, value
def load_model(checkpoint, *args): def load_model(checkpoint, *args, **kwargs):
with open(checkpoint, "rb") as f: with open(checkpoint, "rb") as f:
tflite_model = f.read() tflite_model = f.read()
interpreter = tf_lite.Interpreter(model_content=tflite_model) interpreter = tf_lite.Interpreter(
model_content=tflite_model, num_threads=kwargs.get("num_threads", 1))
interpreter.allocate_tensors() interpreter.allocate_tensors()
return interpreter return interpreter
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment