301 lines
11 KiB
Python
301 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
# tcp_splitter.py
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import signal
|
|
import socket
|
|
import sys
|
|
import os
|
|
import ctypes
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Callable, Awaitable, List, Optional
|
|
|
|
BUF_SIZE = 64 * 1024
|
|
DEFAULT_MIRROR_BUF = 1 * 1024 * 1024 # 1 MiB
|
|
|
|
def set_console_title(title: Optional[str]):
|
|
if os.name == "nt" and title:
|
|
try:
|
|
ctypes.windll.kernel32.SetConsoleTitleW(str(title))
|
|
except Exception:
|
|
pass
|
|
|
|
def set_socket_opts(writer: asyncio.StreamWriter) -> None:
|
|
sock = writer.get_extra_info("socket")
|
|
if not sock:
|
|
return
|
|
for opt in (
|
|
(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),
|
|
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
|
|
):
|
|
try: sock.setsockopt(*opt)
|
|
except Exception: pass
|
|
|
|
@dataclass
|
|
class Target:
|
|
host: str
|
|
port: int
|
|
|
|
@dataclass
|
|
class Config:
|
|
listen_host: str
|
|
listen_port: int
|
|
primary: Target
|
|
backups: List[Target]
|
|
run_duration: int
|
|
log_level: int
|
|
mirror_buffer_bytes: int = DEFAULT_MIRROR_BUF
|
|
idle_timeout: float = 0.0
|
|
|
|
def _level_from_name(name: str) -> int:
|
|
try:
|
|
lvl = getattr(logging, str(name).upper())
|
|
return lvl if isinstance(lvl, int) else logging.INFO
|
|
except Exception:
|
|
return logging.INFO
|
|
|
|
def resolve_default_config_path() -> Path:
|
|
base = Path(sys.executable).parent if getattr(sys, "frozen", False) else Path(__file__).resolve().parent
|
|
cfg = base / "config.json"
|
|
return cfg if cfg.exists() else Path.cwd() / "config.json"
|
|
|
|
def load_config(path: Path) -> Config:
|
|
if not path.exists():
|
|
raise FileNotFoundError(f"Config not found: {path}")
|
|
raw = json.loads(path.read_text(encoding="utf-8"))
|
|
targets = raw["target_hosts"]
|
|
if not isinstance(targets, list) or not targets:
|
|
raise ValueError("config.target_hosts must contain at least one entry")
|
|
primary = Target(str(targets[0]["host"]), int(targets[0]["port"]))
|
|
backups = [Target(str(t["host"]), int(t["port"])) for t in targets[1:]]
|
|
return Config(
|
|
listen_host=str(raw["listen_host"]),
|
|
listen_port=int(raw["listen_port"]),
|
|
primary=primary,
|
|
backups=backups,
|
|
run_duration=int(raw.get("run_duration", 0) or 0),
|
|
log_level=_level_from_name(raw.get("log_level", "INFO")),
|
|
mirror_buffer_bytes=int(raw.get("mirror_buffer_bytes", DEFAULT_MIRROR_BUF)),
|
|
idle_timeout=float(raw.get("idle_timeout", 0.0)),
|
|
)
|
|
|
|
class MirrorBuffer:
|
|
def __init__(self, max_bytes: int):
|
|
self._q: asyncio.Queue[bytes] = asyncio.Queue()
|
|
self._bytes = 0
|
|
self._max = max_bytes
|
|
self._closed = False
|
|
self._lock = asyncio.Lock()
|
|
@property
|
|
def closed(self) -> bool: return self._closed
|
|
async def close(self):
|
|
async with self._lock:
|
|
self._closed = True
|
|
while not self._q.empty():
|
|
try:
|
|
self._q.get_nowait(); self._q.task_done()
|
|
except Exception: break
|
|
async def put_nowait(self, data: bytes):
|
|
if self._closed: return
|
|
async with self._lock:
|
|
new_total = self._bytes + len(data)
|
|
if new_total > self._max:
|
|
raise asyncio.QueueFull()
|
|
self._bytes = new_total
|
|
self._q.put_nowait(data)
|
|
async def get(self) -> bytes:
|
|
data = await self._q.get()
|
|
async with self._lock:
|
|
self._bytes -= len(data)
|
|
self._q.task_done()
|
|
return data
|
|
def empty(self) -> bool: return self._q.empty()
|
|
|
|
@dataclass
|
|
class MirrorCtl:
|
|
name: str
|
|
buf: MirrorBuffer
|
|
disable: Callable[[], Awaitable[None]]
|
|
active: bool = True
|
|
|
|
async def pump(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, *,
|
|
label: str, idle_timeout: float = 0.0,
|
|
mirrors: Optional[List[MirrorCtl]] = None):
|
|
set_socket_opts(writer)
|
|
while True:
|
|
try:
|
|
data = await (asyncio.wait_for(reader.read(BUF_SIZE), timeout=idle_timeout) if idle_timeout > 0 else reader.read(BUF_SIZE))
|
|
except asyncio.TimeoutError:
|
|
logging.info("%s: idle timeout; closing direction", label); break
|
|
if not data: break
|
|
writer.write(data)
|
|
try: await writer.drain()
|
|
except ConnectionError:
|
|
logging.info("%s: downstream closed; stopping", label); break
|
|
if mirrors:
|
|
for m in mirrors:
|
|
if not m.active or m.buf.closed: continue
|
|
try: await m.buf.put_nowait(data)
|
|
except asyncio.QueueFull:
|
|
m.active = False
|
|
logging.warning("%s: mirror '%s' overflow -> disabling", label, m.name)
|
|
try: await m.disable()
|
|
except Exception: pass
|
|
|
|
async def mirror_writer_task(name: str, writer: asyncio.StreamWriter, mbuf: MirrorBuffer):
|
|
set_socket_opts(writer)
|
|
try:
|
|
while not mbuf.closed:
|
|
if mbuf.empty() and mbuf.closed: break
|
|
data = await mbuf.get()
|
|
if not data: continue
|
|
writer.write(data)
|
|
try: await writer.drain()
|
|
except ConnectionError:
|
|
logging.info("backup '%s': writer closed", name); break
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
try: writer.close(); await writer.wait_closed()
|
|
except Exception: pass
|
|
|
|
async def discard_task(name: str, reader: asyncio.StreamReader):
|
|
try:
|
|
while True:
|
|
if not await reader.read(BUF_SIZE): break
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as e:
|
|
logging.debug("backup '%s': discard exception: %r", name, e)
|
|
|
|
async def handle_client(client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter, cfg: Config):
|
|
peer = client_writer.get_extra_info("peername")
|
|
conn_id = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) else "unknown"
|
|
logging.info("[%s] accepted", conn_id)
|
|
set_socket_opts(client_writer)
|
|
try:
|
|
p_reader, p_writer = await asyncio.open_connection(cfg.primary.host, cfg.primary.port)
|
|
set_socket_opts(p_writer)
|
|
logging.debug("[%s] connected primary %s:%d", conn_id, cfg.primary.host, cfg.primary.port)
|
|
except Exception as e:
|
|
logging.error("[%s] failed to connect primary %s:%d: %s", conn_id, cfg.primary.host, cfg.primary.port, e)
|
|
client_writer.close(); await client_writer.wait_closed(); return
|
|
|
|
mirror_ctls: List[MirrorCtl] = []
|
|
mirror_tasks: List[asyncio.Task] = []
|
|
|
|
async def mk_disable(w: Optional[asyncio.StreamWriter], buf: MirrorBuffer):
|
|
async def _disable():
|
|
try: await buf.close()
|
|
except Exception: pass
|
|
if w:
|
|
try: w.close(); await w.wait_closed()
|
|
except Exception: pass
|
|
return _disable
|
|
|
|
for tgt in cfg.backups:
|
|
name = f"{tgt.host}:{tgt.port}"
|
|
try:
|
|
b_reader, b_writer = await asyncio.open_connection(tgt.host, tgt.port)
|
|
set_socket_opts(b_writer)
|
|
mbuf = MirrorBuffer(cfg.mirror_buffer_bytes)
|
|
disable_cb = await mk_disable(b_writer, mbuf)
|
|
ctl = MirrorCtl(name=name, buf=mbuf, disable=disable_cb, active=True)
|
|
mirror_ctls.append(ctl)
|
|
mirror_tasks += [
|
|
asyncio.create_task(mirror_writer_task(name, b_writer, mbuf)),
|
|
asyncio.create_task(discard_task(name, b_reader)),
|
|
]
|
|
logging.debug("[%s] connected backup %s", conn_id, name)
|
|
except Exception as e:
|
|
logging.warning("[%s] cannot connect backup %s: %s (disabled)", conn_id, name, e)
|
|
|
|
c2p = asyncio.create_task(pump(client_reader, p_writer, label=f"[{conn_id}] client->primary",
|
|
idle_timeout=cfg.idle_timeout, mirrors=mirror_ctls))
|
|
p2c = asyncio.create_task(pump(p_reader, client_writer, label=f"[{conn_id}] primary->client",
|
|
idle_timeout=cfg.idle_timeout))
|
|
|
|
try:
|
|
await asyncio.wait([c2p, p2c], return_when=asyncio.FIRST_COMPLETED)
|
|
finally:
|
|
for obj in (p_writer, client_writer):
|
|
try: obj.close(); await obj.wait_closed()
|
|
except Exception: pass
|
|
for ctl in mirror_ctls:
|
|
try: await ctl.disable()
|
|
except Exception: pass
|
|
for t in mirror_tasks + [c2p, p2c]:
|
|
if not t.done(): t.cancel()
|
|
for t in mirror_tasks + [c2p, p2c]:
|
|
try: await t
|
|
except Exception: pass
|
|
logging.info("[%s] closed", conn_id)
|
|
|
|
async def run_server(cfg: Config):
|
|
server = await asyncio.start_server(
|
|
lambda r, w: handle_client(r, w, cfg),
|
|
host=cfg.listen_host,
|
|
port=cfg.listen_port,
|
|
start_serving=True,
|
|
)
|
|
sockets = ", ".join(str(s.getsockname()) for s in (server.sockets or []))
|
|
logging.info("Listening on %s | primary=%s:%d | backups=%s | mirror_buffer=%d",
|
|
sockets, cfg.primary.host, cfg.primary.port,
|
|
",".join(f"{b.host}:{b.port}" for b in cfg.backups) or "none",
|
|
cfg.mirror_buffer_bytes)
|
|
|
|
stop_event = asyncio.Event()
|
|
|
|
def _signal_stop():
|
|
logging.info("Shutdown requested…")
|
|
stop_event.set()
|
|
|
|
loop = asyncio.get_running_loop()
|
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
try:
|
|
loop.add_signal_handler(sig, _signal_stop)
|
|
except NotImplementedError:
|
|
pass # Windows quirk
|
|
|
|
# Make TASKS (not bare coroutines)
|
|
stop_task = asyncio.create_task(stop_event.wait())
|
|
duration_task = (asyncio.create_task(asyncio.sleep(cfg.run_duration))
|
|
if cfg.run_duration and cfg.run_duration > 0 else None)
|
|
|
|
async with server:
|
|
tasks = [stop_task] + ([duration_task] if duration_task else [])
|
|
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
|
|
# Cancel whatever is left
|
|
for t in pending:
|
|
t.cancel()
|
|
await asyncio.gather(*pending, return_exceptions=True)
|
|
|
|
server.close()
|
|
await server.wait_closed()
|
|
|
|
logging.info("Server stopped.")
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser(description="Asyncio TCP splitter with mirrored backups (config-driven).")
|
|
ap.add_argument("--config", help="Path to JSON config file.")
|
|
ap.add_argument("--title", help="Console window title (Windows).")
|
|
args = ap.parse_args()
|
|
|
|
cfg_path = Path(args.config) if args.config else resolve_default_config_path()
|
|
cfg = load_config(cfg_path)
|
|
logging.basicConfig(level=cfg.log_level, format="%(asctime)s %(levelname)s %(message)s")
|
|
logging.info("Using config: %s", cfg_path)
|
|
|
|
set_console_title(args.title)
|
|
try:
|
|
asyncio.run(run_server(cfg))
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
if __name__ == "__main__":
|
|
main()
|