Commit f6ad8563 authored by sbl1996@126.com's avatar sbl1996@126.com

Add expire for duel states in server

parent 7888000e
...@@ -3,7 +3,7 @@ from setuptools import setup, find_packages ...@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
__version__ = "0.0.1" __version__ = "0.0.1"
INSTALL_REQUIRES = [ INSTALL_REQUIRES = [
"numpy", "numpy==1.26.4",
"optree", "optree",
"fastapi", "fastapi",
"uvicorn[standard]", "uvicorn[standard]",
......
from enum import Enum from enum import Enum
from itertools import combinations from itertools import combinations
import time
from typing import List from typing import List
...@@ -1061,6 +1062,8 @@ class PredictState: ...@@ -1061,6 +1062,8 @@ class PredictState:
self.reset() self.reset()
self._timestamp = time.time()
def reset(self): def reset(self):
self._probs = None self._probs = None
self._actions = None self._actions = None
...@@ -1080,6 +1083,7 @@ class PredictState: ...@@ -1080,6 +1083,7 @@ class PredictState:
self._action_msg = input.action_msg self._action_msg = input.action_msg
self._turn = input.global_.turn self._turn = input.global_.turn
self._phase = input.global_.phase self._phase = input.global_.phase
self._timestamp = time.time()
def revert_pad_truncate(probs, n_actions): def revert_pad_truncate(probs, n_actions):
if len(probs) < n_actions: if len(probs) < n_actions:
......
...@@ -3,6 +3,7 @@ os.environ.setdefault("JAX_PLATFORMS", "cpu") ...@@ -3,6 +3,7 @@ os.environ.setdefault("JAX_PLATFORMS", "cpu")
from typing import Union, Dict from typing import Union, Dict
import time import time
import threading
import uuid import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
...@@ -25,13 +26,28 @@ class Settings(BaseSettings): ...@@ -25,13 +26,28 @@ class Settings(BaseSettings):
code_list: str = "code_list.txt" code_list: str = "code_list.txt"
checkpoint: str = "latest.flax_model" checkpoint: str = "latest.flax_model"
enable_cors: bool = Field(default=True, description="Enable CORS") enable_cors: bool = Field(default=True, description="Enable CORS")
state_expire: int = Field(default=3600, description="Duel state expire time in seconds")
test_duel_id: str = Field(default="9654823a-23fd-4850-bb-6fec241740b0", description="Test duel id")
settings = Settings() settings = Settings()
all_models = {} all_models = {}
duel_states: Dict[str, PredictState] = {} duel_states: Dict[str, PredictState] = {}
def delete_outdated_states():
while True:
current_time = time.time()
for k, v in list(duel_states.items()):
if k == settings.test_duel_id:
continue
if current_time - v._timestamp > settings.state_expire:
del duel_states[k]
time.sleep(600)
# Start the thread to delete outdated states
thread = threading.Thread(target=delete_outdated_states)
thread.daemon = True
thread.start()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
...@@ -43,7 +59,7 @@ async def lifespan(app: FastAPI): ...@@ -43,7 +59,7 @@ async def lifespan(app: FastAPI):
print(f"loaded checkpoint from {checkpoint}") print(f"loaded checkpoint from {checkpoint}")
state = new_state() state = new_state()
test_duel_id = "9654823a-23fd-4850-bb-6fec241740b0" test_duel_id = settings.test_duel_id
duel_states[test_duel_id] = state duel_states[test_duel_id] = state
yield yield
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment