Commit bf54d780 authored by Chunchi Che's avatar Chunchi Che

refactor ai predict

parent e0e5c7a6
Pipeline #28403 passed with stages
in 8 minutes and 46 seconds
export * from "./create";
export * from "./delete";
export * from "./predict";
export * from "./transaction";
This diff is collapsed.
This diff is collapsed.
......@@ -3,31 +3,38 @@ import PhaseType = ygopro.StocGameMessage.MsgNewPhase.PhaseType;
import { CardMeta } from "@/api";
//! 一些Neos中基础的数据结构
// Position
export const FACEUP_ATTACK = 0x1;
export const FACEDOWN_ATTACK = 0x2;
export const FACEUP_DEFENSE = 0x4;
export const FACEDOWN_DEFENSE = 0x8;
// 类型
const TYPE_MONSTER = 0x1; //
const TYPE_SPELL = 0x2; //
const TYPE_TRAP = 0x4; //
const TYPE_NORMAL = 0x10; //
const TYPE_EFFECT = 0x20; //
const TYPE_FUSION = 0x40; //
const TYPE_RITUAL = 0x80; //
const TYPE_TRAPMONSTER = 0x100; //
const TYPE_SPIRIT = 0x200; //
const TYPE_UNION = 0x400; //
const TYPE_DUAL = 0x800; //
const TYPE_TUNER = 0x1000; //
const TYPE_SYNCHRO = 0x2000; //
export const TYPE_MONSTER = 0x1; //
export const TYPE_SPELL = 0x2; //
export const TYPE_TRAP = 0x4; //
export const TYPE_NORMAL = 0x10; //
export const TYPE_EFFECT = 0x20; //
export const TYPE_FUSION = 0x40; //
export const TYPE_RITUAL = 0x80; //
export const TYPE_TRAPMONSTER = 0x100; //
export const TYPE_SPIRIT = 0x200; //
export const TYPE_UNION = 0x400; //
export const TYPE_DUAL = 0x800; //
export const TYPE_TUNER = 0x1000; //
export const TYPE_SYNCHRO = 0x2000; //
export const TYPE_TOKEN = 0x4000; //
const TYPE_QUICKPLAY = 0x10000; //
const TYPE_CONTINUOUS = 0x20000; //
const TYPE_EQUIP = 0x40000; //
const TYPE_FIELD = 0x80000; //
const TYPE_COUNTER = 0x100000; //
const TYPE_FLIP = 0x200000; //
const TYPE_TOON = 0x400000; //
const TYPE_XYZ = 0x800000; //
const TYPE_PENDULUM = 0x1000000; //
const TYPE_SPSUMMON = 0x2000000; //
export const TYPE_QUICKPLAY = 0x10000; //
export const TYPE_CONTINUOUS = 0x20000; //
export const TYPE_EQUIP = 0x40000; //
export const TYPE_FIELD = 0x80000; //
export const TYPE_COUNTER = 0x100000; //
export const TYPE_FLIP = 0x200000; //
export const TYPE_TOON = 0x400000; //
export const TYPE_XYZ = 0x800000; //
export const TYPE_PENDULUM = 0x1000000; //
export const TYPE_SPSUMMON = 0x2000000; //
export const TYPE_LINK = 0x4000000; //
/*
......@@ -147,13 +154,13 @@ export function isPendulumMonster(typeCode: number): boolean {
// 属性
// const ATTRIBUTE_ALL = 0x7f; //
const ATTRIBUTE_EARTH = 0x01; //
const ATTRIBUTE_WATER = 0x02; //
const ATTRIBUTE_FIRE = 0x04; //
const ATTRIBUTE_WIND = 0x08; //
const ATTRIBUTE_LIGHT = 0x10; //
const ATTRIBUTE_DARK = 0x20; //
const ATTRIBUTE_DEVINE = 0x40; //
export const ATTRIBUTE_EARTH = 0x01; //
export const ATTRIBUTE_WATER = 0x02; //
export const ATTRIBUTE_FIRE = 0x04; //
export const ATTRIBUTE_WIND = 0x08; //
export const ATTRIBUTE_LIGHT = 0x10; //
export const ATTRIBUTE_DARK = 0x20; //
export const ATTRIBUTE_DEVINE = 0x40; //
export const Attribute2StringCodeMap: Map<number, number> = new Map([
[ATTRIBUTE_EARTH, 1010],
......@@ -166,31 +173,31 @@ export const Attribute2StringCodeMap: Map<number, number> = new Map([
]);
// 种族
const RACE_WARRIOR = 0x1; //
const RACE_SPELLCASTER = 0x2; //
const RACE_FAIRY = 0x4; //
const RACE_FIEND = 0x8; //
const RACE_ZOMBIE = 0x10; //
const RACE_MACHINE = 0x20; //
const RACE_AQUA = 0x40; //
const RACE_PYRO = 0x80; //
const RACE_ROCK = 0x100; //
const RACE_WINDBEAST = 0x200; //
const RACE_PLANT = 0x400; //
const RACE_INSECT = 0x800; //
const RACE_THUNDER = 0x1000; //
const RACE_DRAGON = 0x2000; //
const RACE_BEAST = 0x4000; //
const RACE_BEASTWARRIOR = 0x8000; //
const RACE_DINOSAUR = 0x10000; //
const RACE_FISH = 0x20000; //
const RACE_SEASERPENT = 0x40000; //
const RACE_REPTILE = 0x80000; //
const RACE_PSYCHO = 0x100000; //
const RACE_DEVINE = 0x200000; //
const RACE_CREATORGOD = 0x400000; //
const RACE_WYRM = 0x800000; //
const RACE_CYBERSE = 0x1000000; //
export const RACE_WARRIOR = 0x1; //
export const RACE_SPELLCASTER = 0x2; //
export const RACE_FAIRY = 0x4; //
export const RACE_FIEND = 0x8; //
export const RACE_ZOMBIE = 0x10; //
export const RACE_MACHINE = 0x20; //
export const RACE_AQUA = 0x40; //
export const RACE_PYRO = 0x80; //
export const RACE_ROCK = 0x100; //
export const RACE_WINDBEAST = 0x200; //
export const RACE_PLANT = 0x400; //
export const RACE_INSECT = 0x800; //
export const RACE_THUNDER = 0x1000; //
export const RACE_DRAGON = 0x2000; //
export const RACE_BEAST = 0x4000; //
export const RACE_BEASTWARRIOR = 0x8000; //
export const RACE_DINOSAUR = 0x10000; //
export const RACE_FISH = 0x20000; //
export const RACE_SEASERPENT = 0x40000; //
export const RACE_REPTILE = 0x80000; //
export const RACE_PSYCHO = 0x100000; //
export const RACE_DEVINE = 0x200000; //
export const RACE_CREATORGOD = 0x400000; //
export const RACE_WYRM = 0x800000; //
export const RACE_CYBERSE = 0x1000000; //
export const Race2StringCodeMap: Map<number, number> = new Map([
[RACE_WARRIOR, 1020],
......
......@@ -12,61 +12,29 @@ import {
} from "@/api";
import { predictDuel } from "@/api/ygoAgent/predict";
import {
convertActionMsg,
convertCard,
convertDeckCard,
convertPhase,
Global,
Input,
MsgSelectSum,
MultiSelectMsg,
parsePlayerFromMsg,
} from "@/api/ygoAgent/schema";
import {
convertActionMsg,
convertCard,
convertDeckCard,
convertPhase,
convertPositionResponse,
parsePlayerFromMsg,
} from "@/api/ygoAgent/transaction";
import { cardStore, matStore } from "@/stores";
function computeSetDifference(a1: number[], a2: number[]): number[] {
const freq1 = new Map<number, number>();
const freq2 = new Map<number, number>();
for (const num of a1) {
freq1.set(num, (freq1.get(num) || 0) + 1);
}
for (const num of a2) {
freq2.set(num, (freq2.get(num) || 0) + 1);
}
for (const [num, count] of freq2) {
if (freq1.has(num)) {
freq1.set(num, freq1.get(num)! - count);
}
}
const difference: number[] = [];
for (const [num, count] of freq1) {
if (count > 0) {
difference.push(...Array(count).fill(num));
}
}
import { argmax, computeSetDifference } from "./util";
return difference;
}
const { DECK, HAND, MZONE, SZONE, GRAVE, REMOVED, EXTRA } = ygopro.CardZone;
export function genInput(msg: ygopro.StocGameMessage): Input {
// 全局信息可以从 `matStore` 里面拿
export function genAgentInput(msg: ygopro.StocGameMessage): Input {
const mat = matStore;
// 卡片信息可以从 `cardStore` 里面拿
// TODO (ygo-agent): TZONE
const zones = [
ygopro.CardZone.DECK,
ygopro.CardZone.HAND,
ygopro.CardZone.MZONE,
ygopro.CardZone.SZONE,
ygopro.CardZone.GRAVE,
ygopro.CardZone.REMOVED,
ygopro.CardZone.EXTRA,
];
// select_xxx msg 从参数 `msg` 里获取
// 这里已经保证 `msg` 是众多 `select_xxx` msg 中的一个
const zones = [DECK, HAND, MZONE, SZONE, GRAVE, REMOVED, EXTRA];
const player = parsePlayerFromMsg(msg);
const opponent = 1 - player;
......@@ -74,10 +42,7 @@ export function genInput(msg: ygopro.StocGameMessage): Input {
.filter(
(card) =>
zones.includes(card.location.zone) &&
!(
card.location.zone === ygopro.CardZone.DECK &&
card.location.controller === player
),
!(card.location.zone === DECK && card.location.controller === player),
)
.map((card) => convertCard(card, player));
......@@ -108,23 +73,25 @@ export function genInput(msg: ygopro.StocGameMessage): Input {
const actionMsg = convertActionMsg(msg);
return {
global: global,
global,
cards: deckCardsMe.concat(cards),
action_msg: actionMsg,
};
}
async function sendRequest(req: PredictReq) {
console.log("Sending predict request:", req);
const duelId = matStore.duelId;
const resp = await predictDuel(duelId, req);
console.log("Got predict response:", resp);
if (resp !== undefined) {
matStore.agentIndex = resp.index;
} else {
throw new Error("Failed to get predict response");
}
// TODO: 下面的逻辑需要封装一下,因为:
// 1. 现在实现的功能是AI托管,UI上不需要感知AI的预测结果;
// 2. 后面如果需要实现AI辅助功能,UI上需要感知AI的预测结果,
// 所以需要单独提供接口能力。
const preds = resp.predict_results.action_preds;
const actionIdx = argmax(preds, (r) => r.prob);
matStore.prevActionIndex = actionIdx;
......@@ -132,8 +99,12 @@ async function sendRequest(req: PredictReq) {
return pred;
}
// TODO:
// 1. 逻辑需要拆分下
// 2. 这个函数在外面被各个 service 模块分散调用,
// 需要改成在`gameMsg.ts`调用,并通过`try..catch`正确处理错误。
export async function sendAIPredictAsResponse(msg: ygopro.StocGameMessage) {
const input = genInput(msg);
const input = genAgentInput(msg);
const msgName = input.action_msg.data.msg_type;
const multiSelectMsgs = ["select_card", "select_tribute", "select_sum"];
......@@ -240,37 +211,3 @@ export async function sendAIPredictAsResponse(msg: ygopro.StocGameMessage) {
}
}
}
function argmax<T>(arr: T[], getValue: (item: T) => number): number {
if (arr.length === 0) {
throw new Error("Array is empty");
}
let maxIndex = 0;
let maxValue = getValue(arr[0]);
for (let i = 1; i < arr.length; i++) {
const currentValue = getValue(arr[i]);
if (currentValue > maxValue) {
maxValue = currentValue;
maxIndex = i;
}
}
return maxIndex;
}
function convertPositionResponse(response: number): ygopro.CardPosition {
switch (response) {
case 0x1:
return ygopro.CardPosition.FACEUP_ATTACK;
case 0x2:
return ygopro.CardPosition.FACEDOWN_ATTACK;
case 0x4:
return ygopro.CardPosition.FACEUP_DEFENSE;
case 0x8:
return ygopro.CardPosition.FACEDOWN_DEFENSE;
default:
throw new Error(`Invalid position response: ${response}`);
}
}
// Ygo Agent with AI-Assisted function on Yu-Gi-Oh! Game
export class YgoAgent {
// TODO
}
......@@ -7,6 +7,7 @@ import { displayAnnounceModal } from "@/ui/Duel/Message/AnnounceModal";
export default async (announce: MsgAnnounce) => {
if (matStore.autoSelect) {
// TODO: 如果是开启 AI 模式,不应该调用这个函数
console.log("intercept announce");
await sendAIPredictAsResponse(
announce as unknown as ygopro.StocGameMessage,
......
......@@ -68,6 +68,7 @@ export default async (selectChain: MsgSelectChain) => {
case 2: // 处理多张
case 3: {
if (matStore.autoSelect) {
// TODO: 确认AI模型是否可以处理其他case的情况
console.log("intercept selectChain");
await sendAIPredictAsResponse(
selectChain as unknown as ygopro.StocGameMessage,
......
......@@ -8,3 +8,48 @@ export function isAllOnField(locations: ygopro.CardLocation[]): boolean {
return locations.find((location) => !isOnField(location)) === undefined;
}
export function computeSetDifference(set1: number[], set2: number[]): number[] {
const freq1 = new Map<number, number>();
const freq2 = new Map<number, number>();
for (const num of set1) {
freq1.set(num, (freq1.get(num) || 0) + 1);
}
for (const num of set2) {
freq2.set(num, (freq2.get(num) || 0) + 1);
}
for (const [num, count] of freq2) {
if (freq1.has(num)) {
freq1.set(num, freq1.get(num)! - count);
}
}
const difference: number[] = [];
for (const [num, count] of freq1) {
if (count > 0) {
difference.push(...Array(count).fill(num));
}
}
return difference;
}
export function argmax<T>(arr: T[], getValue: (item: T) => number): number {
if (arr.length === 0) {
throw new Error("Array is empty");
}
let maxIndex = 0;
let maxValue = getValue(arr[0]);
for (let i = 1; i < arr.length; i++) {
const currentValue = getValue(arr[i]);
if (currentValue > maxValue) {
maxValue = currentValue;
maxIndex = i;
}
}
return maxIndex;
}
......@@ -50,6 +50,7 @@ export interface MatState {
/** 根据自己的先后手判断是否是自己 */
isMe: (player: number) => boolean;
// 下面其中一些貌似可以封装成为`AgentInfo`
turnCount: number;
duelId: string;
agentIndex: number;
......
......@@ -35,7 +35,7 @@ export const NeosModal: React.FC<ModalProps> = (props) => {
maskClosable={true}
onCancel={() => setMini(!mini)}
closeIcon={mini ? <UpOutlined /> : <MinusOutlined />}
bodyStyle={{ padding: "10px 0" }}
style={{ padding: "10px 0" }}
mask={!mini}
wrapClassName={classNames({ [styles.wrap]: mini })}
closable={true}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment