tcp_splitter/tcp_splitter.py

301 lines
11 KiB
Python

#!/usr/bin/env python3
# tcp_splitter.py
import argparse
import asyncio
import json
import logging
import signal
import socket
import sys
import os
import ctypes
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Awaitable, List, Optional
BUF_SIZE = 64 * 1024
DEFAULT_MIRROR_BUF = 1 * 1024 * 1024 # 1 MiB
def set_console_title(title: Optional[str]):
if os.name == "nt" and title:
try:
ctypes.windll.kernel32.SetConsoleTitleW(str(title))
except Exception:
pass
def set_socket_opts(writer: asyncio.StreamWriter) -> None:
sock = writer.get_extra_info("socket")
if not sock:
return
for opt in (
(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
):
try: sock.setsockopt(*opt)
except Exception: pass
@dataclass
class Target:
host: str
port: int
@dataclass
class Config:
listen_host: str
listen_port: int
primary: Target
backups: List[Target]
run_duration: int
log_level: int
mirror_buffer_bytes: int = DEFAULT_MIRROR_BUF
idle_timeout: float = 0.0
def _level_from_name(name: str) -> int:
try:
lvl = getattr(logging, str(name).upper())
return lvl if isinstance(lvl, int) else logging.INFO
except Exception:
return logging.INFO
def resolve_default_config_path() -> Path:
base = Path(sys.executable).parent if getattr(sys, "frozen", False) else Path(__file__).resolve().parent
cfg = base / "config.json"
return cfg if cfg.exists() else Path.cwd() / "config.json"
def load_config(path: Path) -> Config:
if not path.exists():
raise FileNotFoundError(f"Config not found: {path}")
raw = json.loads(path.read_text(encoding="utf-8"))
targets = raw["target_hosts"]
if not isinstance(targets, list) or not targets:
raise ValueError("config.target_hosts must contain at least one entry")
primary = Target(str(targets[0]["host"]), int(targets[0]["port"]))
backups = [Target(str(t["host"]), int(t["port"])) for t in targets[1:]]
return Config(
listen_host=str(raw["listen_host"]),
listen_port=int(raw["listen_port"]),
primary=primary,
backups=backups,
run_duration=int(raw.get("run_duration", 0) or 0),
log_level=_level_from_name(raw.get("log_level", "INFO")),
mirror_buffer_bytes=int(raw.get("mirror_buffer_bytes", DEFAULT_MIRROR_BUF)),
idle_timeout=float(raw.get("idle_timeout", 0.0)),
)
class MirrorBuffer:
def __init__(self, max_bytes: int):
self._q: asyncio.Queue[bytes] = asyncio.Queue()
self._bytes = 0
self._max = max_bytes
self._closed = False
self._lock = asyncio.Lock()
@property
def closed(self) -> bool: return self._closed
async def close(self):
async with self._lock:
self._closed = True
while not self._q.empty():
try:
self._q.get_nowait(); self._q.task_done()
except Exception: break
async def put_nowait(self, data: bytes):
if self._closed: return
async with self._lock:
new_total = self._bytes + len(data)
if new_total > self._max:
raise asyncio.QueueFull()
self._bytes = new_total
self._q.put_nowait(data)
async def get(self) -> bytes:
data = await self._q.get()
async with self._lock:
self._bytes -= len(data)
self._q.task_done()
return data
def empty(self) -> bool: return self._q.empty()
@dataclass
class MirrorCtl:
name: str
buf: MirrorBuffer
disable: Callable[[], Awaitable[None]]
active: bool = True
async def pump(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, *,
label: str, idle_timeout: float = 0.0,
mirrors: Optional[List[MirrorCtl]] = None):
set_socket_opts(writer)
while True:
try:
data = await (asyncio.wait_for(reader.read(BUF_SIZE), timeout=idle_timeout) if idle_timeout > 0 else reader.read(BUF_SIZE))
except asyncio.TimeoutError:
logging.info("%s: idle timeout; closing direction", label); break
if not data: break
writer.write(data)
try: await writer.drain()
except ConnectionError:
logging.info("%s: downstream closed; stopping", label); break
if mirrors:
for m in mirrors:
if not m.active or m.buf.closed: continue
try: await m.buf.put_nowait(data)
except asyncio.QueueFull:
m.active = False
logging.warning("%s: mirror '%s' overflow -> disabling", label, m.name)
try: await m.disable()
except Exception: pass
async def mirror_writer_task(name: str, writer: asyncio.StreamWriter, mbuf: MirrorBuffer):
set_socket_opts(writer)
try:
while not mbuf.closed:
if mbuf.empty() and mbuf.closed: break
data = await mbuf.get()
if not data: continue
writer.write(data)
try: await writer.drain()
except ConnectionError:
logging.info("backup '%s': writer closed", name); break
except asyncio.CancelledError:
pass
finally:
try: writer.close(); await writer.wait_closed()
except Exception: pass
async def discard_task(name: str, reader: asyncio.StreamReader):
try:
while True:
if not await reader.read(BUF_SIZE): break
except asyncio.CancelledError:
pass
except Exception as e:
logging.debug("backup '%s': discard exception: %r", name, e)
async def handle_client(client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter, cfg: Config):
peer = client_writer.get_extra_info("peername")
conn_id = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) else "unknown"
logging.info("[%s] accepted", conn_id)
set_socket_opts(client_writer)
try:
p_reader, p_writer = await asyncio.open_connection(cfg.primary.host, cfg.primary.port)
set_socket_opts(p_writer)
logging.debug("[%s] connected primary %s:%d", conn_id, cfg.primary.host, cfg.primary.port)
except Exception as e:
logging.error("[%s] failed to connect primary %s:%d: %s", conn_id, cfg.primary.host, cfg.primary.port, e)
client_writer.close(); await client_writer.wait_closed(); return
mirror_ctls: List[MirrorCtl] = []
mirror_tasks: List[asyncio.Task] = []
async def mk_disable(w: Optional[asyncio.StreamWriter], buf: MirrorBuffer):
async def _disable():
try: await buf.close()
except Exception: pass
if w:
try: w.close(); await w.wait_closed()
except Exception: pass
return _disable
for tgt in cfg.backups:
name = f"{tgt.host}:{tgt.port}"
try:
b_reader, b_writer = await asyncio.open_connection(tgt.host, tgt.port)
set_socket_opts(b_writer)
mbuf = MirrorBuffer(cfg.mirror_buffer_bytes)
disable_cb = await mk_disable(b_writer, mbuf)
ctl = MirrorCtl(name=name, buf=mbuf, disable=disable_cb, active=True)
mirror_ctls.append(ctl)
mirror_tasks += [
asyncio.create_task(mirror_writer_task(name, b_writer, mbuf)),
asyncio.create_task(discard_task(name, b_reader)),
]
logging.debug("[%s] connected backup %s", conn_id, name)
except Exception as e:
logging.warning("[%s] cannot connect backup %s: %s (disabled)", conn_id, name, e)
c2p = asyncio.create_task(pump(client_reader, p_writer, label=f"[{conn_id}] client->primary",
idle_timeout=cfg.idle_timeout, mirrors=mirror_ctls))
p2c = asyncio.create_task(pump(p_reader, client_writer, label=f"[{conn_id}] primary->client",
idle_timeout=cfg.idle_timeout))
try:
await asyncio.wait([c2p, p2c], return_when=asyncio.FIRST_COMPLETED)
finally:
for obj in (p_writer, client_writer):
try: obj.close(); await obj.wait_closed()
except Exception: pass
for ctl in mirror_ctls:
try: await ctl.disable()
except Exception: pass
for t in mirror_tasks + [c2p, p2c]:
if not t.done(): t.cancel()
for t in mirror_tasks + [c2p, p2c]:
try: await t
except Exception: pass
logging.info("[%s] closed", conn_id)
async def run_server(cfg: Config):
server = await asyncio.start_server(
lambda r, w: handle_client(r, w, cfg),
host=cfg.listen_host,
port=cfg.listen_port,
start_serving=True,
)
sockets = ", ".join(str(s.getsockname()) for s in (server.sockets or []))
logging.info("Listening on %s | primary=%s:%d | backups=%s | mirror_buffer=%d",
sockets, cfg.primary.host, cfg.primary.port,
",".join(f"{b.host}:{b.port}" for b in cfg.backups) or "none",
cfg.mirror_buffer_bytes)
stop_event = asyncio.Event()
def _signal_stop():
logging.info("Shutdown requested…")
stop_event.set()
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
try:
loop.add_signal_handler(sig, _signal_stop)
except NotImplementedError:
pass # Windows quirk
# Make TASKS (not bare coroutines)
stop_task = asyncio.create_task(stop_event.wait())
duration_task = (asyncio.create_task(asyncio.sleep(cfg.run_duration))
if cfg.run_duration and cfg.run_duration > 0 else None)
async with server:
tasks = [stop_task] + ([duration_task] if duration_task else [])
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
# Cancel whatever is left
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
server.close()
await server.wait_closed()
logging.info("Server stopped.")
def main():
ap = argparse.ArgumentParser(description="Asyncio TCP splitter with mirrored backups (config-driven).")
ap.add_argument("--config", help="Path to JSON config file.")
ap.add_argument("--title", help="Console window title (Windows).")
args = ap.parse_args()
cfg_path = Path(args.config) if args.config else resolve_default_config_path()
cfg = load_config(cfg_path)
logging.basicConfig(level=cfg.log_level, format="%(asctime)s %(levelname)s %(message)s")
logging.info("Using config: %s", cfg_path)
set_console_title(args.title)
try:
asyncio.run(run_server(cfg))
except KeyboardInterrupt:
pass
if __name__ == "__main__":
main()