Commit 14bceecd authored by sbl1996@126.com's avatar sbl1996@126.com

Add statistics of average step time

parent c2417798
......@@ -43,6 +43,8 @@ class Args:
"""seed of the experiment"""
log_frequency: int = 10
"""the logging frequency of the model performance (in terms of `updates`)"""
time_log_freq: int = 1000
"""the logging frequency of the deck time statistics"""
save_interval: int = 400
"""the frequency of saving the model (in terms of `updates`)"""
checkpoint: Optional[str] = None
......@@ -181,6 +183,7 @@ class Args:
learner_devices: Optional[List[str]] = None
num_embeddings: Optional[int] = None
freeze_id: Optional[bool] = None
deck_names: Optional[List[str]] = None
def make_env(args, seed, num_envs, num_threads, mode='self', thread_affinity_offset=-1, eval=False):
......@@ -317,6 +320,11 @@ def rollout(
action, key = categorical_sample(logits, key)
return next_obs, done, main, rstate1, rstate2, action, logits, key
deck_names = args.deck_names
deck_avg_times = {name: 0 for name in deck_names}
deck_max_times = {name: 0 for name in deck_names}
deck_time_count = {name: 0 for name in deck_names}
# put data in the last index
params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10)
......@@ -410,6 +418,21 @@ def rollout(
t.next_dones[idx] = True
t.rewards[idx] = -next_reward[idx]
break
for i in range(2):
deck_time = info['step_time'][idx][i]
deck_name = deck_names[info['deck'][idx][i]]
time_count = deck_time_count[deck_name]
avg_time = deck_avg_times[deck_name]
avg_time = avg_time * (time_count / (time_count + 1)) + deck_time / (time_count + 1)
max_time = max(deck_time, deck_max_times[deck_name])
deck_avg_times[deck_name] = avg_time
deck_max_times[deck_name] = max_time
deck_time_count[deck_name] += 1
if deck_time_count[deck_name] % args.time_log_freq == 0:
print(f"Deck {deck_name}, avg: {avg_time * 1000:.2f}, max: {max_time * 1000:.2f}")
episode_reward = info['r'][idx] * (1 if cur_main else -1)
win = 1 if episode_reward > 0 else 0
avg_ep_returns.append(episode_reward)
......@@ -584,7 +607,8 @@ def main():
learner_keys = jax.device_put_sharded(learner_keys, devices=learner_devices)
actor_keys = jax.random.split(key, len(actor_devices) * args.num_actor_threads)
deck = init_ygopro(args.env_id, "english", args.deck, args.code_list_file)
deck, deck_names = init_ygopro(args.env_id, "english", args.deck, args.code_list_file, return_deck_names=True)
args.deck_names = sorted(deck_names)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
......@@ -911,11 +935,14 @@ def main():
learn_opponent,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
params_queue_put_time = 0
for d_idx, d_id in enumerate(args.actor_device_ids):
device_params = jax.device_put(unreplicated_params, local_devices[d_id])
device_params["params"]["Encoder_0"]['Embed_0']["embedding"].block_until_ready()
params_queue_put_start = time.time()
for thread_id in range(args.num_actor_threads):
params_queues[d_idx * args.num_actor_threads + thread_id].put(device_params)
params_queue_put_time += time.time() - params_queue_put_start
loss = loss[-1].item()
if np.isnan(loss) or np.isinf(loss):
......@@ -935,7 +962,8 @@ def main():
print(
f"{tb_global_step} actor_update={update}, "
f"train_time={time.time() - training_time_start:.2f}, "
f"data_time={rollout_queue_get_time[-1]:.2f}"
f"data_time={rollout_queue_get_time[-1]:.2f}, "
f"put_time={params_queue_put_time:.2f}"
)
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[3][2][1].hyperparams["learning_rate"][-1].item(), tb_global_step
......
......@@ -23,7 +23,7 @@ _languages = {
"chinese": "zh",
}
def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False, return_deck_names=False):
short = _languages[lang]
db_path = Path(get_root_directory(), 'assets', 'locale', short, 'cards.cdb')
deck_fp = Path(deck)
......@@ -50,6 +50,10 @@ def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False):
elif 'EDOPro' in env_id:
from ygoenv.edopro import init_module
init_module(str(db_path), code_list_file, decks)
if return_deck_names:
if "_tokens" in decks:
del decks["_tokens"]
return deck_name, list(decks.keys())
return deck_name
......
......@@ -1245,6 +1245,7 @@ static ankerl::unordered_dense::map<std::string, std::vector<CardCode>>
static ankerl::unordered_dense::map<std::string, std::vector<CardCode>>
extra_decks_;
static std::vector<std::string> deck_names_;
static ankerl::unordered_dense::map<std::string, int> deck_names_ids_;
inline const Card &c_get_card(CardCode code) {
auto it = cards_.find(code);
......@@ -1388,6 +1389,7 @@ static void init_module(const std::string &db_path,
extra_decks_[name] = extra_deck;
if (name[0] != '_') {
deck_names_.push_back(name);
deck_names_ids_[name] = deck_names_.size() - 1;
}
preload_deck(db, main_deck);
......@@ -1532,7 +1534,10 @@ public:
"info:num_options"_.Bind(Spec<int>({}, {0, conf["max_options"_] - 1})),
"info:to_play"_.Bind(Spec<int>({}, {0, 1})),
"info:is_selfplay"_.Bind(Spec<int>({}, {0, 1})),
"info:win_reason"_.Bind(Spec<int>({}, {-1, 1})));
"info:win_reason"_.Bind(Spec<int>({}, {-1, 1})),
"info:step_time"_.Bind(Spec<double>({2})),
"info:deck"_.Bind(Spec<int>({2}))
);
}
template <typename Config>
static decltype(auto) ActionSpec(const Config &conf) {
......@@ -1655,6 +1660,10 @@ protected:
double reset_time_3_ = 0;
uint64_t reset_time_count_ = 0;
// average time for decks
ankerl::unordered_dense::map<std::string, double> deck_time_;
ankerl::unordered_dense::map<std::string, uint64_t> deck_time_count_;
const int n_history_actions_;
// circular buffer for history actions
......@@ -1756,6 +1765,14 @@ public:
(time_count + 1)) + seconds / (time_count + 1);
}
// void update_time_stat(const std::string& deck, double seconds) {
// uint64_t& time_count = deck_time_count_[deck];
// double& time_stat = deck_time_[deck];
// time_stat = time_stat * (static_cast<double>(time_count) /
// (time_count + 1)) + seconds / (time_count + 1);
// time_count++;
// }
MDuel new_duel(uint32_t seed) {
auto pduel = YGO_CreateDuel(seed);
MDuel mduel{pduel, seed};
......@@ -2177,7 +2194,7 @@ public:
}
void Step(const Action &action) override {
// clock_t start = clock();
clock_t start = clock();
int idx = action["action"_];
callback_(idx);
......@@ -2250,10 +2267,25 @@ public:
}
}
WriteState(reward, win_reason_);
// update_time_stat(start, step_time_count_, step_time_);
// step_time_count_++;
update_time_stat(start, step_time_count_, step_time_);
step_time_count_++;
double step_time = 0;
if (done_) {
step_time = step_time_;
step_time_ = 0;
step_time_count_ = 0;
}
WriteState(reward, win_reason_, step_time);
// if (done_) {
// update_time_stat(deck_name_[0], step_time_);
// update_time_stat(deck_name_[1], step_time_);
// step_time_ = 0;
// step_time_count_ = 0;
// }
// if (step_time_count_ % 3000 == 0) {
// fmt::println("Step time: {:.3f}", step_time_ * 1000);
// }
......@@ -2662,7 +2694,7 @@ private:
// ygopro-core API
void WriteState(float reward, int win_reason = 0) {
void WriteState(float reward, int win_reason = 0, double step_time = 0.0) {
State state = Allocate();
int n_options = legal_actions_.size();
......@@ -2670,6 +2702,12 @@ private:
state["info:to_play"_] = int(to_play_);
state["info:is_selfplay"_] = int(play_mode_ == kSelfPlay);
state["info:win_reason"_] = win_reason;
if (reward != 0.0) {
state["info:step_time"_][0] = step_time;
state["info:step_time"_][1] = step_time;
state["info:deck"_][0] = deck_names_ids_[deck_name_[0]];
state["info:deck"_][1] = deck_names_ids_[deck_name_[1]];
}
if (n_options == 0) {
state["info:num_options"_] = 1;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment