Commit ad9c3c34 authored by biluo.shen's avatar biluo.shen

Add turn feature

parent 0c7cfc92
......@@ -17,7 +17,7 @@
## Global
- lp: 2, max 65535 to 2 bytes
- oppo_lp: 2, max 65535 to 2 bytes
<!-- - turn: 8, int, trunc to 8 -->
- turn: 1, int, trunc to 8
- phase: 1, int, one-hot (10)
- is_first: 1, int, 0: False, 1: True
- is_my_turn: 1, int, 0: False, 1: True
......
......@@ -15,6 +15,7 @@ import tyro
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import Agent
from ygoai.rl.buffer import create_obs
@dataclass
......@@ -143,7 +144,7 @@ if __name__ == "__main__":
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
agent.load_state_dict(state_dict)
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), envs.reset()[0])
obs = create_obs(envs.observation_space, (num_envs,), device=device)
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
agent = torch.jit.optimize_for_inference(traced_model)
......
......@@ -78,7 +78,8 @@ class Agent(nn.Module):
self.lp_fc_emb = linear(c_num, c // 4)
self.oppo_lp_fc_emb = linear(c_num, c // 4)
self.phase_embed = nn.Embedding(10, c // 4)
self.turn_embed = nn.Embedding(20, c // 8)
self.phase_embed = nn.Embedding(10, c // 8)
self.if_first_embed = nn.Embedding(2, c // 8)
self.is_my_turn_embed = nn.Embedding(2, c // 8)
......@@ -231,11 +232,12 @@ class Agent(nn.Module):
x_g_oppo_lp = self.oppo_lp_fc_emb(self.num_transform(x_global_1[:, 2:4]))
x_global_2 = x[:, 4:-1].long()
x_g_phase = self.phase_embed(x_global_2[:, 0])
x_g_if_first = self.if_first_embed(x_global_2[:, 1])
x_g_is_my_turn = self.is_my_turn_embed(x_global_2[:, 2])
x_g_turn = self.turn_embed(x_global_2[:, 0])
x_g_phase = self.phase_embed(x_global_2[:, 1])
x_g_if_first = self.if_first_embed(x_global_2[:, 2])
x_g_is_my_turn = self.is_my_turn_embed(x_global_2[:, 3])
x_global = torch.cat([x_g_lp, x_g_oppo_lp, x_g_phase, x_g_if_first, x_g_is_my_turn], dim=-1)
x_global = torch.cat([x_g_lp, x_g_oppo_lp, x_g_turn, x_g_phase, x_g_if_first, x_g_is_my_turn], dim=-1)
return x_global
def forward(self, x):
......@@ -308,6 +310,6 @@ class Agent(nn.Module):
f_actions = self.action_norm(f_actions)
values = self.value_head(f_actions)[..., 0]
values = torch.tanh(values)
values = torch.where(mask, torch.full_like(values, -1.01), values)
# values = torch.tanh(values)
values = torch.where(mask, torch.full_like(values, -10), values)
return values, valid
\ No newline at end of file
......@@ -510,6 +510,16 @@ class DMCBuffer:
return data
def create_obs(observation_space: spaces.Dict, shape: Tuple[int, ...], device: Union[th.device, str] = "cpu"):
obs_shape = get_obs_shape(observation_space)
obs = {
key: th.zeros(
(*shape, *_obs_shape),
dtype=dtype_dict[observation_space[key].dtype.type], device=device)
for key, _obs_shape in obs_shape.items()
}
return obs
class DMCDictBuffer:
observation_space: spaces.Dict
obs_shape: Dict[str, Tuple[int, ...]]
......
......@@ -1188,7 +1188,7 @@ public:
int n_action_feats = 9 + conf["max_multi_select"_] * 2;
return MakeDict(
"obs:cards_"_.Bind(Spec<uint8_t>({conf["max_cards"_] * 2, 39})),
"obs:global_"_.Bind(Spec<uint8_t>({8})),
"obs:global_"_.Bind(Spec<uint8_t>({9})),
"obs:actions_"_.Bind(
Spec<uint8_t>({conf["max_options"_], n_action_feats})),
"obs:h_actions_"_.Bind(
......@@ -1659,9 +1659,10 @@ private:
feat(2) = op_lp_1;
feat(3) = op_lp_2;
feat(4) = phase2id.at(current_phase_);
feat(5) = (me == 0) ? 1 : 0;
feat(6) = (me == tp_) ? 1 : 0;
feat(4) = std::min(turn_count_, 8);
feat(5) = phase2id.at(current_phase_);
feat(6) = (me == 0) ? 1 : 0;
feat(7) = (me == tp_) ? 1 : 0;
}
void _set_obs_action_spec(TArray<uint8_t> &feat, int i, int j,
......@@ -1883,7 +1884,7 @@ private:
if (n_options == 0) {
state["info:num_options"_] = 1;
state["obs:global_"_][7] = uint8_t(1);
state["obs:global_"_][8] = uint8_t(1);
return;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment