Commit 0dc3ae38 authored by nanahira's avatar nanahira

add menu

parent a4288acb
import asyncio
import struct
from asyncio import StreamReader, StreamWriter
from typing import List, Tuple
import yaml
import ipaddress
import fnmatch
import logging
CTOS_EXTERNAL_ADDRESS = 0x17
# Setup logging to stdout
logger = logging.getLogger("proxy")
logger.setLevel(logging.INFO)
......@@ -21,86 +22,133 @@ with open("config.yaml") as f:
trusted_proxies = [ipaddress.ip_network(net) for net in CONFIG.get("trusted_proxies", [])]
def is_trusted(ip):
addr = ipaddress.ip_address(ip)
return any(addr in net for net in trusted_proxies)
def parse_host(host: str) -> Tuple[str, int]:
"""Parse a host string into (hostname, port)."""
if ":" in host:
hostname, port = host.rsplit(":", 1)
return hostname, int(port)
return host, 7911
def match_route(hostname: str):
for route in CONFIG["routes"]:
if fnmatch.fnmatch(hostname, route["match"]):
to = route["to"]
if ":" in to:
host, port = to.rsplit(":", 1)
return host, int(port)
return None, None
return route
return None
def parse_utf16_hostname(data: bytes):
def parse_utf16_string(data: bytes):
hostname = ""
for i in range(0, len(data), 2):
wchar = struct.unpack("<H", data[i:i+2])[0]
wchar = struct.unpack("<H", data[i:i + 2])[0]
if wchar == 0:
break
hostname += chr(wchar)
return hostname
async def handle_client(reader, writer):
def prepare_utf16_string(s: str) -> bytes:
"""Convert a string to UTF-16-LE encoded bytes with a null terminator."""
encoded = s.encode('utf-16le')
return encoded + struct.pack('<H', 0) # Add null terminator
def prepare_message(proto: int, data: bytes):
"""Prepare a STOC packet with the given data."""
length = len(data) + 1 # +1 for the packet type byte
return struct.pack('<H', length) + bytes([proto]) + data
CTOS_PLAYER_INFO = 0x10
CTOS_JOIN_GAME = 0x12
CTOS_EXTERNAL_ADDRESS = 0x17
CTOS_HS_TODUELIST = 0x20
CTOS_HS_KICK = 0x24
STOC_ERROR_MSG = 0x02
STOC_CHAT = 0x19
STOC_JOIN_GAME = 0x12
STOC_TYPE_CHANGE = 0x13
STOC_HS_PLAYER_ENTER = 0x20
async def handle_client(reader: StreamReader, writer: StreamWriter):
async def stoc_send(proto: int, data: bytes):
"""Send a STOC packet to the client."""
packet = prepare_message(proto, data)
writer.write(packet)
await writer.drain()
async def send_chat(msg: str, player_type: int):
# Truncate and encode msg to UTF-16-LE, with null terminator
encoded = msg.encode('utf-16le')[:510] # max 255 UTF-16 chars
encoded += struct.pack('<H', 0) # null terminator
payload = struct.pack('<H', player_type) + encoded
packet = struct.pack('<H', len(payload) + 1) + bytes([0x19]) + payload
writer.write(packet)
await writer.drain()
await stoc_send(STOC_CHAT, payload)
async def ctos_read(min_length=0, max_length=0xffffffff):
logger.debug("Reading CTOS packet header")
proto_header = await reader.readexactly(3)
length = struct.unpack('<H', proto_header[:2])[0]
packet_id = proto_header[2]
payload_length = length - 1 # Subtract 1 for the packet ID byte
if payload_length < min_length or payload_length > max_length:
raise ValueError(f"Invalid packet length {payload_length} for packet ID {packet_id}")
if payload_length == 0:
return packet_id, b''
logger.debug(f"Reading CTOS packet payload of length {payload_length} for packet ID {packet_id}")
payload = await asyncio.wait_for(reader.readexactly(payload_length), timeout=5.0)
return packet_id, payload
async def ctos_read_filter(accepted_ids: List[int], min_length=1, max_length=0xffffffff):
while True:
logger.debug(f"Waiting for CTOS packet with IDs {accepted_ids}")
packet_id, payload = await ctos_read(min_length, max_length)
if packet_id in accepted_ids:
return packet_id, payload
logger.debug(f"Received unexpected packet ID {packet_id}, expected one of {accepted_ids}. Ignoring.")
async def close_connection():
try:
payload = struct.pack('<BBBBI', 1, 0, 0, 0, 9) # msg=1, code=9
packet = struct.pack('<H', len(payload) + 1) + bytes([0x02]) + payload # STOC_ERROR_MSG = 0x02
writer.write(packet)
await writer.drain()
await stoc_send(STOC_ERROR_MSG, struct.pack('<BBBBI', 1, 0, 0, 0, 9))
except Exception as e:
logger.warning(f"Failed to send error message before closing: {e}")
writer.close()
await writer.wait_closed()
peer_ip = writer.get_extra_info("peername")[0]
is_proxy = is_trusted(peer_ip)
try:
try:
# Read packet header
header = await asyncio.wait_for(reader.readexactly(2), timeout=5.0)
length = struct.unpack("<H", header)[0]
packet_name = await asyncio.wait_for(reader.readexactly(1), timeout=5.0)
packet_id = packet_name[0]
if packet_id != CTOS_EXTERNAL_ADDRESS:
logger.warning(f"First packet is not CTOS_EXTERNAL_ADDRESS from {peer_ip}, closing.")
await send_chat( "400 Bad Request: CTOS_EXTERNAL_ADDRESS not found", player_type=11)
await close_connection()
return
if length < 6 or length > 516:
logger.warning(f"Invalid packet length {length} from {peer_ip}, closing.")
await close_connection()
return
payload = await asyncio.wait_for(reader.readexactly(length - 1), timeout=5.0)
except asyncio.TimeoutError:
logger.warning(f"Timeout while waiting for payload from {peer_ip}, closing.")
await close_connection()
return
external_address_payload = b''
pre_packets: List[Tuple[int, bytes]] = []
first_packet_id, first_payload = await asyncio.wait_for(
ctos_read_filter([CTOS_EXTERNAL_ADDRESS, CTOS_PLAYER_INFO],
min_length=4, max_length=1024), 5.0)
if first_packet_id == CTOS_EXTERNAL_ADDRESS:
external_address_payload = first_payload
_, player_info_payload = await asyncio.wait_for(ctos_read_filter([CTOS_PLAYER_INFO], 40, 40), 5.0)
pre_packets.append((CTOS_PLAYER_INFO, player_info_payload))
else:
pre_packets.append((CTOS_PLAYER_INFO, first_payload))
# make a dummy external_address_payload
external_address_payload = struct.pack("<I", 0)
_, join_game_payload = await asyncio.wait_for(ctos_read_filter([CTOS_JOIN_GAME], 48, 48), 5.0)
pre_packets.append((CTOS_JOIN_GAME, join_game_payload))
real_ip = ipaddress.IPv4Address(payload[0:4])
real_ip = ipaddress.IPv4Address(external_address_payload[0:4])
real_ip_str = str(real_ip)
real_ip_int = int(real_ip)
hostname = parse_utf16_hostname(payload[4:])
hostname = parse_utf16_string(external_address_payload[4:])
if is_proxy and real_ip_int != 0:
client_ip = real_ip_str
......@@ -109,14 +157,68 @@ async def handle_client(reader, writer):
logger.warning(f"Untrusted IP {peer_ip} tried to spoof real_ip={real_ip_str}")
client_ip = peer_ip
target_host, target_port = match_route(hostname)
entry = match_route(hostname)
if not target_host:
if entry is None:
logger.warning(f"No route found for hostname: {hostname} from {client_ip}")
await send_chat(f"404 Not Found: Host [{hostname}] not found", player_type=11)
await close_connection()
return
async def open_menu(menu, menu_chain=[]):
logger.info(f"{client_ip} requested {hostname} → Opening menu: {menu.get('welcome', 'Unknown')}")
use_tag = len(menu['options']) > 2
async def send_info():
await stoc_send(STOC_JOIN_GAME, struct.pack("<IBBBBBBBBIBBH",
0, # lflist
5, # rule
2 if use_tag else 0, # mode
5, # duel_rule
0, # no_check_deck
0, # no_shuffle_deck
0, # paddings
0,
0,
16000 if use_tag else 8000, # start_lp
5, # start_hand
1, # draw_count
240, # time_limit
))
await stoc_send(STOC_TYPE_CHANGE, struct.pack("<B", 0x17)) # is_host | is_spectator
for i, option in enumerate(menu['options'][:4]):
string_payload = option['name'][:20].encode('utf-16le').ljust(40, b'\x00')
player_enter_payload = string_payload + struct.pack("<BB", i, 0)
await stoc_send(STOC_HS_PLAYER_ENTER, player_enter_payload)
await send_info()
if "welcome" in menu:
await send_chat(menu["welcome"], player_type=12)
while True:
ret_id, ret_payload = await asyncio.wait_for(ctos_read_filter([CTOS_HS_TODUELIST, CTOS_HS_KICK], 0, 1),
300)
if ret_id == CTOS_HS_KICK:
break
await send_info()
select_pos = ret_payload[0]
if select_pos >= len(menu['options']):
raise ValueError(f"Invalid selection {select_pos} for menu {menu['name']}")
selected_option = menu['options'][select_pos]
return await parse_entry(selected_option, menu_chain=menu_chain + [menu])
async def parse_entry(entry, menu_chain=[]):
if "return" in entry:
if len(menu_chain) < 2:
raise ValueError("No parent menu to return to")
return await open_menu(menu_chain[-2], menu_chain=menu_chain[:-1])
if "to" in entry:
return parse_host(entry["to"])
if "menu" in entry:
return await open_menu(entry["menu"], menu_chain)
raise ValueError(f"Invalid route entry: {entry}")
target_host, target_port = await parse_entry(entry)
logger.info(f"{client_ip} requested {hostname} → forwarding to {target_host}:{target_port}")
# Connect to target server
......@@ -130,13 +232,13 @@ async def handle_client(reader, writer):
return
# Overwrite real_ip in payload with resolved client_ip
try:
payload = ipaddress.IPv4Address(client_ip).packed + payload[4:]
except Exception as e:
logger.warning(f"Failed to write real_ip for {client_ip}: {e}")
external_address_payload = ipaddress.IPv4Address(client_ip).packed + external_address_payload[4:]
pre_packets.insert(0, (CTOS_EXTERNAL_ADDRESS, external_address_payload))
# Forward first packet
remote_writer.write(header + packet_name + payload)
# Forward pre-packets to the remote server
for packet_id, payload in pre_packets:
message = prepare_message(packet_id, payload)
remote_writer.write(message)
await remote_writer.drain()
async def pipe(src, dst):
......@@ -159,8 +261,14 @@ async def handle_client(reader, writer):
except Exception as e:
logger.error(f"Error handling client {peer_ip}: {e}")
try:
await stoc_send(STOC_ERROR_MSG, struct.pack('<BBBBI', 1, 0, 0, 0, 9))
except Exception as send_error:
pass
finally:
writer.close()
await writer.wait_closed()
async def main():
server = await asyncio.start_server(handle_client, CONFIG["host"], CONFIG["port"])
......@@ -169,5 +277,6 @@ async def main():
async with server:
await server.serve_forever()
if __name__ == "__main__":
asyncio.run(main())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment