Commit e5e5402a authored by biluo.shen's avatar biluo.shen

Add more global features

parent 4b934828
...@@ -19,19 +19,21 @@ ...@@ -19,19 +19,21 @@
- lp: 2, max 65535 to 2 bytes - lp: 2, max 65535 to 2 bytes
- oppo_lp: 2, max 65535 to 2 bytes - oppo_lp: 2, max 65535 to 2 bytes
- n_my_decks: 1, int - n_my_decks: 1, int
- n_my_extras:
- n_my_hands: - n_my_hands:
- n_my_graves:
- n_my_removes:
- n_my_monsters: - n_my_monsters:
- n_my_spell_traps: - n_my_spell_traps:
- n_my_graves:
- n_my_removes:
- n_my_extras:
- n_op_decks: - n_op_decks:
- n_op_extras:
- n_op_hands: - n_op_hands:
- n_op_graves:
- n_op_removes:
- n_op_monsters: - n_op_monsters:
- n_op_spell_traps: - n_op_spell_traps:
- n_op_graves:
- n_op_removes:
- n_op_extras:
- n_my_hands: (another embed, to enhance)
- n_op_hands: (another embed, to enhance)
- turn: 1, int, trunc to 8 - turn: 1, int, trunc to 8
- phase: 1, int, one-hot (10) - phase: 1, int, one-hot (10)
- is_first: 1, int, 0: False, 1: True - is_first: 1, int, 0: False, 1: True
......
import sys
import time
import os
import random
from typing import Optional, Literal
from dataclasses import dataclass
import ygoenv
import numpy as np
import optree
import tyro
from ygoai.utils import init_ygopro
from ygoai.rl.utils import RecordEpisodeStatistics
from ygoai.rl.agent import PPOAgent as Agent
from ygoai.rl.buffer import create_obs
@dataclass
class Args:
seed: int = 1
"""the random seed"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
env_id: str = "YGOPro-v0"
"""the id of the environment"""
deck: str = "../assets/deck/OldSchool.ydk"
"""the deck file to use"""
deck1: Optional[str] = None
"""the deck file for the first player"""
deck2: Optional[str] = None
"""the deck file for the second player"""
code_list_file: str = "code_list.txt"
"""the code list file for card embeddings"""
lang: str = "english"
"""the language to use"""
max_options: int = 24
"""the maximum number of options"""
n_history_actions: int = 16
"""the number of history actions to use"""
num_embeddings: Optional[int] = None
"""the number of embeddings of the agent"""
record: bool = False
"""whether to record the game as YGOPro replays"""
num_episodes: int = 1024
"""the number of episodes to run"""
num_envs: int = 64
"""the number of parallel game environments"""
verbose: bool = False
"""whether to print debug information"""
num_layers: int = 2
"""the number of layers for the agent"""
num_channels: int = 128
"""the number of channels for the agent"""
checkpoint1: Optional[str] = "checkpoints/agent.pt"
"""the checkpoint to load for the first agent"""
checkpoint2: Optional[str] = "checkpoints/agent.pt"
"""the checkpoint to load for the second agent"""
compile: bool = True
"""if toggled, the model will be compiled"""
optimize: bool = False
"""if toggled, the model will be optimized"""
torch_threads: Optional[int] = None
"""the number of threads to use for torch, defaults to ($OMP_NUM_THREADS or 2) * world_size"""
env_threads: Optional[int] = 16
"""the number of threads to use for envpool, defaults to `num_envs`"""
def predict_step(agent, obs):
with torch.no_grad():
logits, values, _valid = agent(obs)
probs = torch.softmax(logits, dim=-1)
return probs
if __name__ == "__main__":
args = tyro.cli(Args)
if args.record:
assert args.num_envs == 1, "Recording only works with a single environment"
assert args.verbose, "Recording only works with verbose mode"
args.env_threads = min(args.env_threads or args.num_envs, args.num_envs)
args.torch_threads = args.torch_threads or int(os.getenv("OMP_NUM_THREADS", "4"))
deck = init_ygopro(args.env_id, args.lang, args.deck, args.code_list_file)
args.deck1 = args.deck1 or deck
args.deck2 = args.deck2 or deck
seed = args.seed
random.seed(seed)
np.random.seed(seed)
import torch
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
torch.set_num_threads(args.torch_threads)
torch.set_float32_matmul_precision('high')
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
num_envs = args.num_envs
envs = ygoenv.make(
task_id=args.env_id,
env_type="gymnasium",
num_envs=num_envs,
num_threads=args.env_threads,
seed=seed,
deck1=args.deck1,
deck2=args.deck2,
player=-1,
max_options=args.max_options,
n_history_actions=args.n_history_actions,
play_mode='self',
verbose=args.verbose,
record=args.record,
)
envs.num_envs = num_envs
envs = RecordEpisodeStatistics(envs)
embedding_shape = args.num_embeddings
if embedding_shape is None:
with open(args.code_list_file, "r") as f:
code_list = f.readlines()
embedding_shape = len(code_list)
L = args.num_layers
agent1 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
agent2 = Agent(args.num_channels, L, L, 1, embedding_shape).to(device)
for agent, ckpt in zip([agent1, agent2], [args.checkpoint1, args.checkpoint2]):
state_dict = torch.load(ckpt, map_location=device)
if not args.compile:
prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
if args.compile:
predict_step = torch.compile(predict_step, mode='reduce-overhead')
else:
if args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device)
def optimize_for_inference(agent):
with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
return torch.jit.optimize_for_inference(traced_model)
agent1 = optimize_for_inference(agent1)
agent2 = optimize_for_inference(agent2)
obs, infos = envs.reset()
next_to_play_ = infos['to_play']
episode_rewards = []
episode_lengths = []
win_rates = []
win_reasons = []
step = 0
start = time.time()
start_step = step
player1_ = np.concatenate([
np.zeros(num_envs // 2, dtype=np.int64),
np.ones(num_envs // 2, dtype=np.int64)
])
player1 = torch.from_numpy(player1_).to(device=device)
model_time = env_time = 0
while True:
if start_step == 0 and len(episode_lengths) > int(args.num_episodes * 0.1):
start = time.time()
start_step = step
model_time = env_time = 0
_start = time.time()
next_to_play = torch.from_numpy(next_to_play_).to(device=device)
obs = optree.tree_map(lambda x: torch.from_numpy(x).to(device=device), obs)
probs1 = predict_step(agent1, obs).clone()
probs2 = predict_step(agent2, obs).clone()
probs = torch.where((next_to_play == player1)[:, None], probs1, probs2)
probs = probs.cpu().numpy()
actions = probs.argmax(axis=1)
model_time += time.time() - _start
to_play = next_to_play_
_start = time.time()
obs, rewards, dones, infos = envs.step(actions)
next_to_play_ = infos['to_play']
env_time += time.time() - _start
step += 1
for idx, d in enumerate(dones):
if d:
win_reason = infos['win_reason'][idx]
pl = 1 if to_play[idx] == player1_[idx] else -1
episode_length = infos['l'][idx]
episode_reward = infos['r'][idx] * pl
win = 1 if episode_reward > 0 else 0
episode_lengths.append(episode_length)
episode_rewards.append(episode_reward)
win_rates.append(win)
win_reasons.append(1 if win_reason == 1 else 0)
sys.stderr.write(f"Episode {len(episode_lengths)}: length={episode_length}, reward={episode_reward}, win={win}, win_reason={win_reason}\n")
if len(episode_lengths) >= args.num_episodes:
break
print(f"len={np.mean(episode_lengths)}, reward={np.mean(episode_rewards)}, win_rate={np.mean(win_rates)}, win_reason={np.mean(win_reasons)}")
total_time = time.time() - start
total_steps = (step - start_step) * num_envs
print(f"SPS: {total_steps / total_time:.0f}, total_steps: {total_steps}")
print(f"total: {total_time:.4f}, model: {model_time:.4f}, env: {env_time:.4f}")
\ No newline at end of file
...@@ -150,24 +150,20 @@ if __name__ == "__main__": ...@@ -150,24 +150,20 @@ if __name__ == "__main__":
# agent = agent.eval() # agent = agent.eval()
if args.checkpoint: if args.checkpoint:
state_dict = torch.load(args.checkpoint, map_location=device) state_dict = torch.load(args.checkpoint, map_location=device)
else: if not args.compile:
state_dict = None prefix = "_orig_mod."
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
if args.compile: if args.compile:
if state_dict:
print(agent.load_state_dict(state_dict))
agent = torch.compile(agent, mode='reduce-overhead') agent = torch.compile(agent, mode='reduce-overhead')
else: elif args.optimize:
prefix = "_orig_mod." obs = create_obs(envs.observation_space, (num_envs,), device=device)
if state_dict: def optimize_for_inference(agent):
state_dict = {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
print(agent.load_state_dict(state_dict))
if args.optimize:
obs = create_obs(envs.observation_space, (num_envs,), device=device)
with torch.no_grad(): with torch.no_grad():
traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False) traced_model = torch.jit.trace(agent, (obs,), check_tolerance=False, check_trace=False)
agent = torch.jit.optimize_for_inference(traced_model) return torch.jit.optimize_for_inference(traced_model)
agent = optimize_for_inference(agent)
obs, infos = envs.reset() obs, infos = envs.reset()
next_to_play = infos['to_play'] next_to_play = infos['to_play']
......
...@@ -626,7 +626,8 @@ def run(local_rank, world_size): ...@@ -626,7 +626,8 @@ def run(local_rank, world_size):
eval_stats = torch.tensor([eval_return, eval_ep_len, eval_win_rate], dtype=torch.float32, device=device) eval_stats = torch.tensor([eval_return, eval_ep_len, eval_win_rate], dtype=torch.float32, device=device)
# sync the statistics # sync the statistics
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG) if args.world_size > 1:
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
if local_rank == 0: if local_rank == 0:
eval_return, eval_ep_len, eval_win_rate = eval_stats.cpu().numpy() eval_return, eval_ep_len, eval_win_rate = eval_stats.cpu().numpy()
writer.add_scalar("charts/eval_return", eval_return, global_step) writer.add_scalar("charts/eval_return", eval_return, global_step)
......
...@@ -633,7 +633,8 @@ def main(): ...@@ -633,7 +633,8 @@ def main():
eval_stats = torch.tensor([eval_return, eval_ep_len, eval_win_rate], dtype=torch.float32, device=device) eval_stats = torch.tensor([eval_return, eval_ep_len, eval_win_rate], dtype=torch.float32, device=device)
# sync the statistics # sync the statistics
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG) if args.world_size > 1:
dist.all_reduce(eval_stats, op=dist.ReduceOp.AVG)
eval_return, eval_ep_len, eval_win_rate = eval_stats.cpu().numpy() eval_return, eval_ep_len, eval_win_rate = eval_stats.cpu().numpy()
if rank == 0: if rank == 0:
writer.add_scalar("charts/eval_return", eval_return, global_step) writer.add_scalar("charts/eval_return", eval_return, global_step)
......
...@@ -45,6 +45,9 @@ class Encoder(nn.Module): ...@@ -45,6 +45,9 @@ class Encoder(nn.Module):
self.bin_points = nn.Parameter(bin_points, requires_grad=False) self.bin_points = nn.Parameter(bin_points, requires_grad=False)
self.bin_intervals = nn.Parameter(bin_intervals, requires_grad=False) self.bin_intervals = nn.Parameter(bin_intervals, requires_grad=False)
self.count_embed = nn.Embedding(100, c // 16)
self.hand_count_embed = nn.Embedding(100, c // 16)
if embedding_shape is None: if embedding_shape is None:
n_embed, embed_dim = 999, 1024 n_embed, embed_dim = 999, 1024
elif isinstance(embedding_shape, int): elif isinstance(embedding_shape, int):
...@@ -88,12 +91,15 @@ class Encoder(nn.Module): ...@@ -88,12 +91,15 @@ class Encoder(nn.Module):
self.if_first_embed = nn.Embedding(2, c // 8) self.if_first_embed = nn.Embedding(2, c // 8)
self.is_my_turn_embed = nn.Embedding(2, c // 8) self.is_my_turn_embed = nn.Embedding(2, c // 8)
self.global_norm_pre = nn.LayerNorm(c, elementwise_affine=affine) self.my_deck_fc_emb = linear(1024, c // 4)
self.global_norm_pre = nn.LayerNorm(c * 2, elementwise_affine=affine)
self.global_net = nn.Sequential( self.global_net = nn.Sequential(
nn.Linear(c, c), nn.Linear(c * 2, c * 2),
nn.ReLU(), nn.ReLU(),
nn.Linear(c, c), nn.Linear(c * 2, c * 2),
) )
self.global_proj = nn.Linear(c * 2, c)
self.global_norm = nn.LayerNorm(c, elementwise_affine=False) self.global_norm = nn.LayerNorm(c, elementwise_affine=False)
divisor = 8 divisor = 8
...@@ -235,13 +241,20 @@ class Encoder(nn.Module): ...@@ -235,13 +241,20 @@ class Encoder(nn.Module):
x_g_lp = self.lp_fc_emb(self.num_transform(x_global_1[:, 0:2])) x_g_lp = self.lp_fc_emb(self.num_transform(x_global_1[:, 0:2]))
x_g_oppo_lp = self.oppo_lp_fc_emb(self.num_transform(x_global_1[:, 2:4])) x_g_oppo_lp = self.oppo_lp_fc_emb(self.num_transform(x_global_1[:, 2:4]))
x_global_2 = x[:, 4:-1].long() x_global_2 = x[:, 4:8].long()
x_g_turn = self.turn_embed(x_global_2[:, 0]) x_g_turn = self.turn_embed(x_global_2[:, 0])
x_g_phase = self.phase_embed(x_global_2[:, 1]) x_g_phase = self.phase_embed(x_global_2[:, 1])
x_g_if_first = self.if_first_embed(x_global_2[:, 2]) x_g_if_first = self.if_first_embed(x_global_2[:, 2])
x_g_is_my_turn = self.is_my_turn_embed(x_global_2[:, 3]) x_g_is_my_turn = self.is_my_turn_embed(x_global_2[:, 3])
x_global = torch.cat([x_g_lp, x_g_oppo_lp, x_g_turn, x_g_phase, x_g_if_first, x_g_is_my_turn], dim=-1) x_global_3 = x[:, 8:22].long()
x_g_cs = self.count_embed(x_global_3).flatten(1)
x_g_my_hand_c = self.hand_count_embed(x_global_3[:, 1])
x_g_op_hand_c = self.hand_count_embed(x_global_3[:, 8])
x_global = torch.cat([
x_g_lp, x_g_oppo_lp, x_g_turn, x_g_phase, x_g_if_first, x_g_is_my_turn,
x_g_cs, x_g_my_hand_c, x_g_op_hand_c], dim=-1)
return x_global return x_global
def forward(self, x): def forward(self, x):
...@@ -278,6 +291,7 @@ class Encoder(nn.Module): ...@@ -278,6 +291,7 @@ class Encoder(nn.Module):
x_global = self.encode_global(x_global) x_global = self.encode_global(x_global)
x_global = self.global_norm_pre(x_global) x_global = self.global_norm_pre(x_global)
f_global = x_global + self.global_net(x_global) f_global = x_global + self.global_net(x_global)
f_global = self.global_proj(f_global)
f_global = self.global_norm(f_global) f_global = self.global_norm(f_global)
f_cards = f_cards + f_global.unsqueeze(1) f_cards = f_cards + f_global.unsqueeze(1)
...@@ -320,53 +334,6 @@ class Encoder(nn.Module): ...@@ -320,53 +334,6 @@ class Encoder(nn.Module):
f_state = torch.cat([f_s_cards_global, f_s_actions_ha], dim=-1) f_state = torch.cat([f_s_cards_global, f_s_actions_ha], dim=-1)
return f_actions, f_state, mask, valid return f_actions, f_state, mask, valid
# class PPOCritic(nn.Module):
# def __init__(self, channels):
# super(PPOCritic, self).__init__()
# c = channels
# self.net = nn.Sequential(
# nn.Linear(c * 2, c // 2),
# nn.ReLU(),
# nn.Linear(c // 2, 1),
# )
# def forward(self, f_state):
# return self.net(f_state)
# class PPOActor(nn.Module):
# def __init__(self, channels):
# super(PPOActor, self).__init__()
# c = channels
# self.trans = nn.TransformerEncoderLayer(
# c, 4, c * 4, dropout=0.0, batch_first=True, norm_first=True, bias=False)
# self.head = nn.Sequential(
# nn.Linear(c, c // 4),
# nn.ReLU(),
# nn.Linear(c // 4, 1),
# )
# def forward(self, f_actions, mask, action):
# f_actions = self.trans(f_actions, src_key_padding_mask=mask)
# logits = self.head(f_actions)[..., 0]
# logits = logits.float()
# logits = logits.masked_fill(mask, float("-inf"))
# probs = Categorical(logits=logits)
# return probs.log_prob(action), probs.entropy()
# def predict(self, f_actions, mask):
# f_actions = self.trans(f_actions, src_key_padding_mask=mask)
# logits = self.head(f_actions)[..., 0]
# logits = logits.float()
# logits = logits.masked_fill(mask, float("-inf"))
# return logits
class Actor(nn.Module): class Actor(nn.Module):
def __init__(self, channels, use_transformer=False): def __init__(self, channels, use_transformer=False):
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <SQLiteCpp/SQLiteCpp.h> #include <SQLiteCpp/SQLiteCpp.h>
#include <SQLiteCpp/VariadicBind.h> #include <SQLiteCpp/VariadicBind.h>
#include <ankerl/unordered_dense.h> #include <ankerl/unordered_dense.h>
#include <unordered_map>
#include "ygoenv/core/async_envpool.h" #include "ygoenv/core/async_envpool.h"
#include "ygoenv/core/env.h" #include "ygoenv/core/env.h"
...@@ -1236,7 +1237,7 @@ public: ...@@ -1236,7 +1237,7 @@ public:
int n_action_feats = 10 + conf["max_multi_select"_] * 2; int n_action_feats = 10 + conf["max_multi_select"_] * 2;
return MakeDict( return MakeDict(
"obs:cards_"_.Bind(Spec<uint8_t>({conf["max_cards"_] * 2, 40})), "obs:cards_"_.Bind(Spec<uint8_t>({conf["max_cards"_] * 2, 40})),
"obs:global_"_.Bind(Spec<uint8_t>({9})), "obs:global_"_.Bind(Spec<uint8_t>({23})),
"obs:actions_"_.Bind( "obs:actions_"_.Bind(
Spec<uint8_t>({conf["max_options"_], n_action_feats})), Spec<uint8_t>({conf["max_options"_], n_action_feats})),
"obs:h_actions_"_.Bind( "obs:h_actions_"_.Bind(
...@@ -1650,17 +1651,32 @@ public: ...@@ -1650,17 +1651,32 @@ public:
float reward = 0; float reward = 0;
int reason = 0; int reason = 0;
if (done_) { if (done_) {
float base_reward = 1.0; float base_reward;
int win_turn = turn_count_ - winner_; if (winner_ == 0) {
if (win_turn <= 1) { if (turn_count_ <= 1) {
base_reward = 8.0; // FTK
} else if (win_turn <= 3) { base_reward = 16.0;
base_reward = 4.0; } else if (turn_count_ <= 3) {
} else if (win_turn <= 5) { base_reward = 8.0;
base_reward = 2.0; } else if (turn_count_ <= 5) {
base_reward = 4.0;
} else if (turn_count_ <= 7) {
base_reward = 2.0;
} else {
base_reward = 0.5 + 1.0 / (turn_count_ - 7);
}
} else { } else {
base_reward = 0.5 + 1.0 / (win_turn - 5); if (turn_count_ <= 1) {
base_reward = 8.0;
} else if (turn_count_ <= 3) {
base_reward = 4.0;
} else if (turn_count_ <= 5) {
base_reward = 2.0;
} else {
base_reward = 0.5 + 1.0 / (turn_count_ - 5);
}
} }
if (play_mode_ == kSelfPlay) { if (play_mode_ == kSelfPlay) {
// to_play_ is the previous player // to_play_ is the previous player
reward = winner_ == to_play_ ? base_reward : -base_reward; reward = winner_ == to_play_ ? base_reward : -base_reward;
...@@ -1698,8 +1714,9 @@ public: ...@@ -1698,8 +1714,9 @@ public:
private: private:
using SpecIndex = ankerl::unordered_dense::map<std::string, uint16_t>; using SpecIndex = ankerl::unordered_dense::map<std::string, uint16_t>;
void _set_obs_cards(TArray<uint8_t> &f_cards, SpecIndex &spec2index, std::tuple<SpecIndex, std::vector<int>> _set_obs_cards(TArray<uint8_t> &f_cards, PlayerId to_play) {
PlayerId to_play) { SpecIndex spec2index;
std::vector<int> loc_n_cards;
for (auto pi = 0; pi < 2; pi++) { for (auto pi = 0; pi < 2; pi++) {
const PlayerId player = (to_play + pi) % 2; const PlayerId player = (to_play + pi) % 2;
const bool opponent = pi == 1; const bool opponent = pi == 1;
...@@ -1718,6 +1735,7 @@ private: ...@@ -1718,6 +1735,7 @@ private:
} }
if (opponent && hidden_for_opponent) { if (opponent && hidden_for_opponent) {
auto n_cards = YGO_QueryFieldCount(pduel_, player, location); auto n_cards = YGO_QueryFieldCount(pduel_, player, location);
loc_n_cards.push_back(n_cards);
for (auto i = 0; i < n_cards; i++) { for (auto i = 0; i < n_cards; i++) {
f_cards(offset, 2) = location2id.at(location); f_cards(offset, 2) = location2id.at(location);
f_cards(offset, 4) = 1; f_cards(offset, 4) = 1;
...@@ -1725,7 +1743,9 @@ private: ...@@ -1725,7 +1743,9 @@ private:
} }
} else { } else {
std::vector<Card> cards = get_cards_in_location(player, location); std::vector<Card> cards = get_cards_in_location(player, location);
for (int i = 0; i < cards.size(); ++i) { int n_cards = cards.size();
loc_n_cards.push_back(n_cards);
for (int i = 0; i < n_cards; ++i) {
const auto &c = cards[i]; const auto &c = cards[i];
auto spec = c.get_spec(opponent); auto spec = c.get_spec(opponent);
bool hide = false; bool hide = false;
...@@ -1744,6 +1764,7 @@ private: ...@@ -1744,6 +1764,7 @@ private:
} }
} }
} }
return {spec2index, loc_n_cards};
} }
void _set_obs_card_(TArray<uint8_t> &f_cards, int offset, const Card &c, void _set_obs_card_(TArray<uint8_t> &f_cards, int offset, const Card &c,
...@@ -1797,7 +1818,7 @@ private: ...@@ -1797,7 +1818,7 @@ private:
} }
} }
void _set_obs_global(TArray<uint8_t> &feat, PlayerId player) { void _set_obs_global(TArray<uint8_t> &feat, PlayerId player, const std::vector<int> &loc_n_cards) {
uint8_t me = player; uint8_t me = player;
uint8_t op = 1 - player; uint8_t op = 1 - player;
...@@ -1813,6 +1834,10 @@ private: ...@@ -1813,6 +1834,10 @@ private:
feat(5) = phase2id.at(current_phase_); feat(5) = phase2id.at(current_phase_);
feat(6) = (me == 0) ? 1 : 0; feat(6) = (me == 0) ? 1 : 0;
feat(7) = (me == tp_) ? 1 : 0; feat(7) = (me == tp_) ? 1 : 0;
for (int i = 0; i < loc_n_cards.size(); i++) {
feat(8 + i) = static_cast<uint8_t>(loc_n_cards[i]);
}
} }
void _set_obs_action_spec(TArray<uint8_t> &feat, int i, int j, void _set_obs_action_spec(TArray<uint8_t> &feat, int i, int j,
...@@ -2148,14 +2173,13 @@ private: ...@@ -2148,14 +2173,13 @@ private:
if (n_options == 0) { if (n_options == 0) {
state["info:num_options"_] = 1; state["info:num_options"_] = 1;
state["obs:global_"_][8] = uint8_t(1); state["obs:global_"_][22] = uint8_t(1);
return; return;
} }
SpecIndex spec2index; auto [spec2index, loc_n_cards] = _set_obs_cards(state["obs:cards_"_], to_play_);
_set_obs_cards(state["obs:cards_"_], spec2index, to_play_);
_set_obs_global(state["obs:global_"_], to_play_); _set_obs_global(state["obs:global_"_], to_play_, loc_n_cards);
// we can't shuffle because idx must be stable in callback // we can't shuffle because idx must be stable in callback
if (n_options > max_options()) { if (n_options > max_options()) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment