#!/usr/bin/env python3 # tcp_splitter.py import argparse import asyncio import json import logging import signal import socket import sys import os import ctypes from dataclasses import dataclass from pathlib import Path from typing import Callable, Awaitable, List, Optional BUF_SIZE = 64 * 1024 DEFAULT_MIRROR_BUF = 1 * 1024 * 1024 # 1 MiB def set_console_title(title: Optional[str]): if os.name == "nt" and title: try: ctypes.windll.kernel32.SetConsoleTitleW(str(title)) except Exception: pass def set_socket_opts(writer: asyncio.StreamWriter) -> None: sock = writer.get_extra_info("socket") if not sock: return for opt in ( (socket.IPPROTO_TCP, socket.TCP_NODELAY, 1), (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), ): try: sock.setsockopt(*opt) except Exception: pass @dataclass class Target: host: str port: int @dataclass class Config: listen_host: str listen_port: int primary: Target backups: List[Target] run_duration: int log_level: int mirror_buffer_bytes: int = DEFAULT_MIRROR_BUF idle_timeout: float = 0.0 def _level_from_name(name: str) -> int: try: lvl = getattr(logging, str(name).upper()) return lvl if isinstance(lvl, int) else logging.INFO except Exception: return logging.INFO def resolve_default_config_path() -> Path: base = Path(sys.executable).parent if getattr(sys, "frozen", False) else Path(__file__).resolve().parent cfg = base / "config.json" return cfg if cfg.exists() else Path.cwd() / "config.json" def load_config(path: Path) -> Config: if not path.exists(): raise FileNotFoundError(f"Config not found: {path}") raw = json.loads(path.read_text(encoding="utf-8")) targets = raw["target_hosts"] if not isinstance(targets, list) or not targets: raise ValueError("config.target_hosts must contain at least one entry") primary = Target(str(targets[0]["host"]), int(targets[0]["port"])) backups = [Target(str(t["host"]), int(t["port"])) for t in targets[1:]] return Config( listen_host=str(raw["listen_host"]), listen_port=int(raw["listen_port"]), primary=primary, backups=backups, run_duration=int(raw.get("run_duration", 0) or 0), log_level=_level_from_name(raw.get("log_level", "INFO")), mirror_buffer_bytes=int(raw.get("mirror_buffer_bytes", DEFAULT_MIRROR_BUF)), idle_timeout=float(raw.get("idle_timeout", 0.0)), ) class MirrorBuffer: def __init__(self, max_bytes: int): self._q: asyncio.Queue[bytes] = asyncio.Queue() self._bytes = 0 self._max = max_bytes self._closed = False self._lock = asyncio.Lock() @property def closed(self) -> bool: return self._closed async def close(self): async with self._lock: self._closed = True while not self._q.empty(): try: self._q.get_nowait(); self._q.task_done() except Exception: break async def put_nowait(self, data: bytes): if self._closed: return async with self._lock: new_total = self._bytes + len(data) if new_total > self._max: raise asyncio.QueueFull() self._bytes = new_total self._q.put_nowait(data) async def get(self) -> bytes: data = await self._q.get() async with self._lock: self._bytes -= len(data) self._q.task_done() return data def empty(self) -> bool: return self._q.empty() @dataclass class MirrorCtl: name: str buf: MirrorBuffer disable: Callable[[], Awaitable[None]] active: bool = True async def pump(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, *, label: str, idle_timeout: float = 0.0, mirrors: Optional[List[MirrorCtl]] = None): set_socket_opts(writer) while True: try: data = await (asyncio.wait_for(reader.read(BUF_SIZE), timeout=idle_timeout) if idle_timeout > 0 else reader.read(BUF_SIZE)) except asyncio.TimeoutError: logging.info("%s: idle timeout; closing direction", label); break if not data: break writer.write(data) try: await writer.drain() except ConnectionError: logging.info("%s: downstream closed; stopping", label); break if mirrors: for m in mirrors: if not m.active or m.buf.closed: continue try: await m.buf.put_nowait(data) except asyncio.QueueFull: m.active = False logging.warning("%s: mirror '%s' overflow -> disabling", label, m.name) try: await m.disable() except Exception: pass async def mirror_writer_task(name: str, writer: asyncio.StreamWriter, mbuf: MirrorBuffer): set_socket_opts(writer) try: while not mbuf.closed: if mbuf.empty() and mbuf.closed: break data = await mbuf.get() if not data: continue writer.write(data) try: await writer.drain() except ConnectionError: logging.info("backup '%s': writer closed", name); break except asyncio.CancelledError: pass finally: try: writer.close(); await writer.wait_closed() except Exception: pass async def discard_task(name: str, reader: asyncio.StreamReader): try: while True: if not await reader.read(BUF_SIZE): break except asyncio.CancelledError: pass except Exception as e: logging.debug("backup '%s': discard exception: %r", name, e) async def handle_client(client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter, cfg: Config): peer = client_writer.get_extra_info("peername") conn_id = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) else "unknown" logging.info("[%s] accepted", conn_id) set_socket_opts(client_writer) try: p_reader, p_writer = await asyncio.open_connection(cfg.primary.host, cfg.primary.port) set_socket_opts(p_writer) logging.debug("[%s] connected primary %s:%d", conn_id, cfg.primary.host, cfg.primary.port) except Exception as e: logging.error("[%s] failed to connect primary %s:%d: %s", conn_id, cfg.primary.host, cfg.primary.port, e) client_writer.close(); await client_writer.wait_closed(); return mirror_ctls: List[MirrorCtl] = [] mirror_tasks: List[asyncio.Task] = [] async def mk_disable(w: Optional[asyncio.StreamWriter], buf: MirrorBuffer): async def _disable(): try: await buf.close() except Exception: pass if w: try: w.close(); await w.wait_closed() except Exception: pass return _disable for tgt in cfg.backups: name = f"{tgt.host}:{tgt.port}" try: b_reader, b_writer = await asyncio.open_connection(tgt.host, tgt.port) set_socket_opts(b_writer) mbuf = MirrorBuffer(cfg.mirror_buffer_bytes) disable_cb = await mk_disable(b_writer, mbuf) ctl = MirrorCtl(name=name, buf=mbuf, disable=disable_cb, active=True) mirror_ctls.append(ctl) mirror_tasks += [ asyncio.create_task(mirror_writer_task(name, b_writer, mbuf)), asyncio.create_task(discard_task(name, b_reader)), ] logging.debug("[%s] connected backup %s", conn_id, name) except Exception as e: logging.warning("[%s] cannot connect backup %s: %s (disabled)", conn_id, name, e) c2p = asyncio.create_task(pump(client_reader, p_writer, label=f"[{conn_id}] client->primary", idle_timeout=cfg.idle_timeout, mirrors=mirror_ctls)) p2c = asyncio.create_task(pump(p_reader, client_writer, label=f"[{conn_id}] primary->client", idle_timeout=cfg.idle_timeout)) try: await asyncio.wait([c2p, p2c], return_when=asyncio.FIRST_COMPLETED) finally: for obj in (p_writer, client_writer): try: obj.close(); await obj.wait_closed() except Exception: pass for ctl in mirror_ctls: try: await ctl.disable() except Exception: pass for t in mirror_tasks + [c2p, p2c]: if not t.done(): t.cancel() for t in mirror_tasks + [c2p, p2c]: try: await t except Exception: pass logging.info("[%s] closed", conn_id) async def run_server(cfg: Config): server = await asyncio.start_server( lambda r, w: handle_client(r, w, cfg), host=cfg.listen_host, port=cfg.listen_port, start_serving=True, ) sockets = ", ".join(str(s.getsockname()) for s in (server.sockets or [])) logging.info("Listening on %s | primary=%s:%d | backups=%s | mirror_buffer=%d", sockets, cfg.primary.host, cfg.primary.port, ",".join(f"{b.host}:{b.port}" for b in cfg.backups) or "none", cfg.mirror_buffer_bytes) stop_event = asyncio.Event() def _signal_stop(): logging.info("Shutdown requested…") stop_event.set() loop = asyncio.get_running_loop() for sig in (signal.SIGINT, signal.SIGTERM): try: loop.add_signal_handler(sig, _signal_stop) except NotImplementedError: pass # Windows quirk # Make TASKS (not bare coroutines) stop_task = asyncio.create_task(stop_event.wait()) duration_task = (asyncio.create_task(asyncio.sleep(cfg.run_duration)) if cfg.run_duration and cfg.run_duration > 0 else None) async with server: tasks = [stop_task] + ([duration_task] if duration_task else []) done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) # Cancel whatever is left for t in pending: t.cancel() await asyncio.gather(*pending, return_exceptions=True) server.close() await server.wait_closed() logging.info("Server stopped.") def main(): ap = argparse.ArgumentParser(description="Asyncio TCP splitter with mirrored backups (config-driven).") ap.add_argument("--config", help="Path to JSON config file.") ap.add_argument("--title", help="Console window title (Windows).") args = ap.parse_args() cfg_path = Path(args.config) if args.config else resolve_default_config_path() cfg = load_config(cfg_path) logging.basicConfig(level=cfg.log_level, format="%(asctime)s %(levelname)s %(message)s") logging.info("Using config: %s", cfg_path) set_console_title(args.title) try: asyncio.run(run_server(cfg)) except KeyboardInterrupt: pass if __name__ == "__main__": main()