Commit 8ef0f9c9 authored by nanahira's avatar nanahira

fix STOC_ERROR_MSG handling

parent 36e5a990
Pipeline #37878 passed with stages
in 2 minutes and 35 seconds
import asyncio import asyncio
import struct import struct
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
from typing import List, Tuple from typing import List, Tuple, Union
import yaml import yaml
import ipaddress import ipaddress
...@@ -80,6 +80,8 @@ STOC_HS_PLAYER_ENTER = 0x20 ...@@ -80,6 +80,8 @@ STOC_HS_PLAYER_ENTER = 0x20
async def handle_client(reader: StreamReader, writer: StreamWriter): async def handle_client(reader: StreamReader, writer: StreamWriter):
remote_reader: StreamReader = None
remote_writer: StreamWriter = None
async def stoc_send(proto: int, data: bytes): async def stoc_send(proto: int, data: bytes):
"""Send a STOC packet to the client.""" """Send a STOC packet to the client."""
packet = prepare_message(proto, data) packet = prepare_message(proto, data)
...@@ -116,15 +118,27 @@ async def handle_client(reader: StreamReader, writer: StreamWriter): ...@@ -116,15 +118,27 @@ async def handle_client(reader: StreamReader, writer: StreamWriter):
return packet_id, payload return packet_id, payload
logger.debug(f"Received unexpected packet ID {packet_id}, expected one of {accepted_ids}. Ignoring.") logger.debug(f"Received unexpected packet ID {packet_id}, expected one of {accepted_ids}. Ignoring.")
class State:
established = False
state = State()
async def close_connection(): async def close_connection():
try: try:
await stoc_send(STOC_ERROR_MSG, struct.pack('<BBBBI', 1, 0, 0, 0, 9)) if not state.established:
await stoc_send(STOC_ERROR_MSG, struct.pack('<BBBBI', 1, 0, 0, 0, 9))
except Exception as e: except Exception as e:
logger.warning(f"Failed to send error message before closing: {e}") logger.warning(f"Failed to send error message before closing: {e}")
writer.close() writer.close()
await writer.wait_closed() if remote_writer:
remote_writer.close()
try:
await asyncio.gather(writer.wait_closed(), remote_writer.wait_closed() if remote_writer else asyncio.sleep(0))
except Exception as e:
logger.warning(f"Error closing connections for {player_descriptor}: {e}")
peer_ip = writer.get_extra_info("peername")[0] peer_ip = writer.get_extra_info("peername")[0]
player_descriptor = f"Unknown Player ({peer_ip})"
is_proxy = is_trusted(peer_ip) is_proxy = is_trusted(peer_ip)
try: try:
...@@ -138,10 +152,15 @@ async def handle_client(reader: StreamReader, writer: StreamWriter): ...@@ -138,10 +152,15 @@ async def handle_client(reader: StreamReader, writer: StreamWriter):
external_address_payload = first_payload external_address_payload = first_payload
_, player_info_payload = await asyncio.wait_for(ctos_read_filter([CTOS_PLAYER_INFO], 40, 40), 5.0) _, player_info_payload = await asyncio.wait_for(ctos_read_filter([CTOS_PLAYER_INFO], 40, 40), 5.0)
pre_packets.append((CTOS_PLAYER_INFO, player_info_payload)) pre_packets.append((CTOS_PLAYER_INFO, player_info_payload))
else: elif first_packet_id == CTOS_PLAYER_INFO:
player_info_payload = first_payload
pre_packets.append((CTOS_PLAYER_INFO, first_payload)) pre_packets.append((CTOS_PLAYER_INFO, first_payload))
# make a dummy external_address_payload # make a dummy external_address_payload
external_address_payload = struct.pack("<I", 0) external_address_payload = struct.pack(">I", 0)
else:
raise ValueError(f"Unexpected first packet ID: {first_packet_id}")
player_name = parse_utf16_string(player_info_payload)
_, join_game_payload = await asyncio.wait_for(ctos_read_filter([CTOS_JOIN_GAME], 48, 48), 5.0) _, join_game_payload = await asyncio.wait_for(ctos_read_filter([CTOS_JOIN_GAME], 48, 48), 5.0)
pre_packets.append((CTOS_JOIN_GAME, join_game_payload)) pre_packets.append((CTOS_JOIN_GAME, join_game_payload))
...@@ -155,19 +174,22 @@ async def handle_client(reader: StreamReader, writer: StreamWriter): ...@@ -155,19 +174,22 @@ async def handle_client(reader: StreamReader, writer: StreamWriter):
client_ip = real_ip_str client_ip = real_ip_str
else: else:
if not is_proxy and real_ip_int != 0: if not is_proxy and real_ip_int != 0:
logger.warning(f"Untrusted IP {peer_ip} tried to spoof real_ip={real_ip_str}") logger.warning(f"Untrusted client {player_name} ({peer_ip}) tried to spoof real_ip={real_ip_str}")
client_ip = peer_ip client_ip = peer_ip
entry = match_route(hostname) entry = match_route(hostname)
player_descriptor = f"{player_name} ({client_ip})"
if entry is None: if entry is None:
logger.warning(f"No route found for hostname: {hostname} from {client_ip}") logger.warning(f"No route found for hostname: {hostname} from {player_descriptor}")
await send_chat(f"404 Not Found: Host [{hostname}] not found", player_type=11) await send_chat(f"404 Not Found: Host [{hostname}] not found", player_type=11)
await close_connection() await close_connection()
return return
async def open_menu(menu, menu_chain=[]): async def open_menu(menu, menu_chain=[]):
logger.info(f"{client_ip} requested {hostname} → Opening menu: {menu.get('welcome', 'Unknown')}") state.established = True
logger.info(f"{player_descriptor} requested {hostname} → Opening menu: {menu.get('welcome', 'Unknown')}")
use_tag = len(menu['options']) > 2 use_tag = len(menu['options']) > 2
async def send_info(): async def send_info():
...@@ -220,14 +242,14 @@ async def handle_client(reader: StreamReader, writer: StreamWriter): ...@@ -220,14 +242,14 @@ async def handle_client(reader: StreamReader, writer: StreamWriter):
target_host, target_port = await parse_entry(entry) target_host, target_port = await parse_entry(entry)
logger.info(f"{client_ip} requested {hostname} → forwarding to {target_host}:{target_port}") logger.info(f"{player_descriptor} requested {hostname} → forwarding to {target_host}:{target_port}")
# Connect to target server # Connect to target server
try: try:
remote_reader, remote_writer = await asyncio.wait_for( remote_reader, remote_writer = await asyncio.wait_for(
asyncio.open_connection(target_host, target_port), timeout=5.0) asyncio.open_connection(target_host, target_port), timeout=5.0)
except Exception as e: except Exception as e:
logger.warning(f"Failed to connect to {target_host}:{target_port} for client {client_ip}: {e}") logger.warning(f"Failed to connect to {target_host}:{target_port} for client {player_descriptor}: {e}")
await send_chat(f"502 Bad Gateway: Host [{hostname}] cannot be connected", player_type=11) await send_chat(f"502 Bad Gateway: Host [{hostname}] cannot be connected", player_type=11)
await close_connection() await close_connection()
return return
...@@ -242,26 +264,43 @@ async def handle_client(reader: StreamReader, writer: StreamWriter): ...@@ -242,26 +264,43 @@ async def handle_client(reader: StreamReader, writer: StreamWriter):
remote_writer.write(message) remote_writer.write(message)
await remote_writer.drain() await remote_writer.drain()
async def pipe(src, dst): async def pipe(src: StreamReader, dst: StreamWriter, direction: Union['STOC', 'CTOS']):
try: try:
while not src.at_eof(): while not src.at_eof():
data = await src.read(4096) header = await src.readexactly(3)
if not data: length = struct.unpack('<H', header[:2])[0]
break packet_id = header[2]
dst.write(data) logger.info(f"Received {direction} packet ID {packet_id} with length {length} for {player_descriptor}")
payload = await src.readexactly(length - 1)
if direction == 'STOC':
if packet_id == STOC_JOIN_GAME:
state.established = True
logger.info(f"Connection established for {player_descriptor} to {target_host}:{target_port}")
if packet_id == STOC_ERROR_MSG and state.established:
error_payload = struct.unpack('<BBBBI', payload)
msg = error_payload[0]
code = error_payload[4]
if msg == 1 or msg == 4: # ERRMSG_JOINERR or ERRMSG_VERERROR
logger.warning(f"Received error message for {player_descriptor}: {msg} {code}")
await close_connection()
return
dst.write(header)
dst.write(payload)
await dst.drain() await dst.drain()
except Exception: except Exception as e:
logger.warning(f"Error in pipe {direction} for {player_descriptor}: {e}")
pass pass
finally: finally:
logger.info(f"Closing pipe {direction} for {player_descriptor}")
dst.close() dst.close()
await asyncio.gather( await asyncio.gather(
pipe(reader, remote_writer), pipe(reader, remote_writer, 'CTOS'),
pipe(remote_reader, writer) pipe(remote_reader, writer, 'STOC')
) )
except Exception as e: except Exception as e:
logger.error(f"Error handling client {peer_ip}: {e}") logger.error(f"Error handling client {player_descriptor}: {e}")
try: try:
await stoc_send(STOC_ERROR_MSG, struct.pack('<BBBBI', 1, 0, 0, 0, 9)) await stoc_send(STOC_ERROR_MSG, struct.pack('<BBBBI', 1, 0, 0, 0, 9))
except Exception as send_error: except Exception as send_error:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment