diff --git a/readme.txt b/readme.txt index eb05055..b6a1ba7 100644 --- a/readme.txt +++ b/readme.txt @@ -2,8 +2,8 @@ Et Python-basert verktøy som splitter innkommende TCP-trafikk til flere mål samtidig. ## Start -Kjør `tcp_splitter.py` for å starte programmet. -Scriptet `worker.py` blir startet automatisk av hovedscriptet, og restartes med jevne mellomrom basert på verdien satt i `config.json`. +Kjør `worker.exe` for å starte programmet. +Scriptet `tcp_splitter.exe` blir startet automatisk av hovedscriptet, og restartes med jevne mellomrom basert på verdien satt i `config.json`. --- diff --git a/tcp_splitter.exe b/tcp_splitter.exe new file mode 100644 index 0000000..b6d6955 Binary files /dev/null and b/tcp_splitter.exe differ diff --git a/tcp_splitter.py b/tcp_splitter.py index 91d5f2a..14b41fc 100644 --- a/tcp_splitter.py +++ b/tcp_splitter.py @@ -1,26 +1,300 @@ -import subprocess -import time +#!/usr/bin/env python3 +# tcp_splitter.py +import argparse +import asyncio +import json +import logging +import signal +import socket +import sys +import os +import ctypes +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Awaitable, List, Optional -while True: - # Get the current time and print it when starting the process - restart_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - print(f"Starting worker script at {restart_time}") +BUF_SIZE = 64 * 1024 +DEFAULT_MIRROR_BUF = 1 * 1024 * 1024 # 1 MiB - # Start the worker script in a new Command Prompt window with a custom title - process = subprocess.Popen(['start', 'cmd', '/c', 'title TCP_SPLITTER && worker.exe'], shell=True) +def set_console_title(title: Optional[str]): + if os.name == "nt" and title: + try: + ctypes.windll.kernel32.SetConsoleTitleW(str(title)) + except Exception: + pass - # Wait a moment to allow the process to start - time.sleep(2) +def set_socket_opts(writer: asyncio.StreamWriter) -> None: + sock = writer.get_extra_info("socket") + if not sock: + return + for opt in ( + (socket.IPPROTO_TCP, socket.TCP_NODELAY, 1), + (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), + ): + try: sock.setsockopt(*opt) + except Exception: pass +@dataclass +class Target: + host: str + port: int + +@dataclass +class Config: + listen_host: str + listen_port: int + primary: Target + backups: List[Target] + run_duration: int + log_level: int + mirror_buffer_bytes: int = DEFAULT_MIRROR_BUF + idle_timeout: float = 0.0 + +def _level_from_name(name: str) -> int: + try: + lvl = getattr(logging, str(name).upper()) + return lvl if isinstance(lvl, int) else logging.INFO + except Exception: + return logging.INFO + +def resolve_default_config_path() -> Path: + base = Path(sys.executable).parent if getattr(sys, "frozen", False) else Path(__file__).resolve().parent + cfg = base / "config.json" + return cfg if cfg.exists() else Path.cwd() / "config.json" + +def load_config(path: Path) -> Config: + if not path.exists(): + raise FileNotFoundError(f"Config not found: {path}") + raw = json.loads(path.read_text(encoding="utf-8")) + targets = raw["target_hosts"] + if not isinstance(targets, list) or not targets: + raise ValueError("config.target_hosts must contain at least one entry") + primary = Target(str(targets[0]["host"]), int(targets[0]["port"])) + backups = [Target(str(t["host"]), int(t["port"])) for t in targets[1:]] + return Config( + listen_host=str(raw["listen_host"]), + listen_port=int(raw["listen_port"]), + primary=primary, + backups=backups, + run_duration=int(raw.get("run_duration", 0) or 0), + log_level=_level_from_name(raw.get("log_level", "INFO")), + mirror_buffer_bytes=int(raw.get("mirror_buffer_bytes", DEFAULT_MIRROR_BUF)), + idle_timeout=float(raw.get("idle_timeout", 0.0)), + ) + +class MirrorBuffer: + def __init__(self, max_bytes: int): + self._q: asyncio.Queue[bytes] = asyncio.Queue() + self._bytes = 0 + self._max = max_bytes + self._closed = False + self._lock = asyncio.Lock() + @property + def closed(self) -> bool: return self._closed + async def close(self): + async with self._lock: + self._closed = True + while not self._q.empty(): + try: + self._q.get_nowait(); self._q.task_done() + except Exception: break + async def put_nowait(self, data: bytes): + if self._closed: return + async with self._lock: + new_total = self._bytes + len(data) + if new_total > self._max: + raise asyncio.QueueFull() + self._bytes = new_total + self._q.put_nowait(data) + async def get(self) -> bytes: + data = await self._q.get() + async with self._lock: + self._bytes -= len(data) + self._q.task_done() + return data + def empty(self) -> bool: return self._q.empty() + +@dataclass +class MirrorCtl: + name: str + buf: MirrorBuffer + disable: Callable[[], Awaitable[None]] + active: bool = True + +async def pump(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, *, + label: str, idle_timeout: float = 0.0, + mirrors: Optional[List[MirrorCtl]] = None): + set_socket_opts(writer) while True: - # Check if the specific titled process is still running - result = subprocess.run('tasklist /v | findstr "TCP_SPLITTER"', shell=True, stdout=subprocess.PIPE, text=True) + try: + data = await (asyncio.wait_for(reader.read(BUF_SIZE), timeout=idle_timeout) if idle_timeout > 0 else reader.read(BUF_SIZE)) + except asyncio.TimeoutError: + logging.info("%s: idle timeout; closing direction", label); break + if not data: break + writer.write(data) + try: await writer.drain() + except ConnectionError: + logging.info("%s: downstream closed; stopping", label); break + if mirrors: + for m in mirrors: + if not m.active or m.buf.closed: continue + try: await m.buf.put_nowait(data) + except asyncio.QueueFull: + m.active = False + logging.warning("%s: mirror '%s' overflow -> disabling", label, m.name) + try: await m.disable() + except Exception: pass - # If "TCP_SPLITTER" title is not found, the worker script has stopped - if "TCP_SPLITTER" not in result.stdout: - exit_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - print(f"Worker script exited at {exit_time}. Restarting...") - break # Exit the inner loop to restart the outer loop and launch the script again +async def mirror_writer_task(name: str, writer: asyncio.StreamWriter, mbuf: MirrorBuffer): + set_socket_opts(writer) + try: + while not mbuf.closed: + if mbuf.empty() and mbuf.closed: break + data = await mbuf.get() + if not data: continue + writer.write(data) + try: await writer.drain() + except ConnectionError: + logging.info("backup '%s': writer closed", name); break + except asyncio.CancelledError: + pass + finally: + try: writer.close(); await writer.wait_closed() + except Exception: pass - # Check every second if the script is still running - time.sleep(1) +async def discard_task(name: str, reader: asyncio.StreamReader): + try: + while True: + if not await reader.read(BUF_SIZE): break + except asyncio.CancelledError: + pass + except Exception as e: + logging.debug("backup '%s': discard exception: %r", name, e) + +async def handle_client(client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter, cfg: Config): + peer = client_writer.get_extra_info("peername") + conn_id = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) else "unknown" + logging.info("[%s] accepted", conn_id) + set_socket_opts(client_writer) + try: + p_reader, p_writer = await asyncio.open_connection(cfg.primary.host, cfg.primary.port) + set_socket_opts(p_writer) + logging.debug("[%s] connected primary %s:%d", conn_id, cfg.primary.host, cfg.primary.port) + except Exception as e: + logging.error("[%s] failed to connect primary %s:%d: %s", conn_id, cfg.primary.host, cfg.primary.port, e) + client_writer.close(); await client_writer.wait_closed(); return + + mirror_ctls: List[MirrorCtl] = [] + mirror_tasks: List[asyncio.Task] = [] + + async def mk_disable(w: Optional[asyncio.StreamWriter], buf: MirrorBuffer): + async def _disable(): + try: await buf.close() + except Exception: pass + if w: + try: w.close(); await w.wait_closed() + except Exception: pass + return _disable + + for tgt in cfg.backups: + name = f"{tgt.host}:{tgt.port}" + try: + b_reader, b_writer = await asyncio.open_connection(tgt.host, tgt.port) + set_socket_opts(b_writer) + mbuf = MirrorBuffer(cfg.mirror_buffer_bytes) + disable_cb = await mk_disable(b_writer, mbuf) + ctl = MirrorCtl(name=name, buf=mbuf, disable=disable_cb, active=True) + mirror_ctls.append(ctl) + mirror_tasks += [ + asyncio.create_task(mirror_writer_task(name, b_writer, mbuf)), + asyncio.create_task(discard_task(name, b_reader)), + ] + logging.debug("[%s] connected backup %s", conn_id, name) + except Exception as e: + logging.warning("[%s] cannot connect backup %s: %s (disabled)", conn_id, name, e) + + c2p = asyncio.create_task(pump(client_reader, p_writer, label=f"[{conn_id}] client->primary", + idle_timeout=cfg.idle_timeout, mirrors=mirror_ctls)) + p2c = asyncio.create_task(pump(p_reader, client_writer, label=f"[{conn_id}] primary->client", + idle_timeout=cfg.idle_timeout)) + + try: + await asyncio.wait([c2p, p2c], return_when=asyncio.FIRST_COMPLETED) + finally: + for obj in (p_writer, client_writer): + try: obj.close(); await obj.wait_closed() + except Exception: pass + for ctl in mirror_ctls: + try: await ctl.disable() + except Exception: pass + for t in mirror_tasks + [c2p, p2c]: + if not t.done(): t.cancel() + for t in mirror_tasks + [c2p, p2c]: + try: await t + except Exception: pass + logging.info("[%s] closed", conn_id) + +async def run_server(cfg: Config): + server = await asyncio.start_server( + lambda r, w: handle_client(r, w, cfg), + host=cfg.listen_host, + port=cfg.listen_port, + start_serving=True, + ) + sockets = ", ".join(str(s.getsockname()) for s in (server.sockets or [])) + logging.info("Listening on %s | primary=%s:%d | backups=%s | mirror_buffer=%d", + sockets, cfg.primary.host, cfg.primary.port, + ",".join(f"{b.host}:{b.port}" for b in cfg.backups) or "none", + cfg.mirror_buffer_bytes) + + stop_event = asyncio.Event() + + def _signal_stop(): + logging.info("Shutdown requested…") + stop_event.set() + + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, _signal_stop) + except NotImplementedError: + pass # Windows quirk + + # Make TASKS (not bare coroutines) + stop_task = asyncio.create_task(stop_event.wait()) + duration_task = (asyncio.create_task(asyncio.sleep(cfg.run_duration)) + if cfg.run_duration and cfg.run_duration > 0 else None) + + async with server: + tasks = [stop_task] + ([duration_task] if duration_task else []) + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + # Cancel whatever is left + for t in pending: + t.cancel() + await asyncio.gather(*pending, return_exceptions=True) + + server.close() + await server.wait_closed() + + logging.info("Server stopped.") + +def main(): + ap = argparse.ArgumentParser(description="Asyncio TCP splitter with mirrored backups (config-driven).") + ap.add_argument("--config", help="Path to JSON config file.") + ap.add_argument("--title", help="Console window title (Windows).") + args = ap.parse_args() + + cfg_path = Path(args.config) if args.config else resolve_default_config_path() + cfg = load_config(cfg_path) + logging.basicConfig(level=cfg.log_level, format="%(asctime)s %(levelname)s %(message)s") + logging.info("Using config: %s", cfg_path) + + set_console_title(args.title) + try: + asyncio.run(run_server(cfg)) + except KeyboardInterrupt: + pass + +if __name__ == "__main__": + main() diff --git a/worker.exe b/worker.exe new file mode 100644 index 0000000..f36079f Binary files /dev/null and b/worker.exe differ diff --git a/worker.py b/worker.py index 37e69a9..a35dbf6 100644 --- a/worker.py +++ b/worker.py @@ -1,162 +1,56 @@ -import socket -import threading -import json +#!/usr/bin/env python3 +# worker.py +import subprocess import time -import logging -from datetime import datetime -from crccheck.crc import CrcArc +import sys +from pathlib import Path -# Load configuration from config.json -with open('config.json', 'r') as config_file: - config = json.load(config_file) +TITLE = "TCP_SPLITTER" +SPLITTER_EXE_NAME = "tcp_splitter.exe" +CONFIG_NAME = "config.json" -# Configuration from config.json -LISTEN_HOST = config['listen_host'] -LISTEN_PORT = config['listen_port'] -TARGET_HOSTS = config['target_hosts'] -RUN_DURATION = config['run_duration'] -LOG_LEVEL = config['log_level'] +# Windows creation flag: create child in a new console window +CREATE_NEW_CONSOLE = 0x00000010 -# Map log level strings to logging module constants -log_level_mapping = { - "DEBUG": logging.DEBUG, - "INFO": logging.INFO, - "WARNING": logging.WARNING, - "ERROR": logging.ERROR, - "CRITICAL": logging.CRITICAL -} +def base_dir() -> Path: + # When frozen by PyInstaller, sys.executable is the EXE path + return Path(sys.executable).parent if getattr(sys, "frozen", False) else Path(__file__).resolve().parent -# Validate and set the log level -if LOG_LEVEL not in log_level_mapping: - raise ValueError(f"Invalid log level: {LOG_LEVEL}. Valid options are: {list(log_level_mapping.keys())}") +def launch() -> subprocess.Popen: + base = base_dir() + splitter = base / SPLITTER_EXE_NAME + cfg = base / CONFIG_NAME -# Logging configuration -logging.basicConfig( - level=log_level_mapping[LOG_LEVEL], - format='%(asctime)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler('relay_server.log', encoding='utf-8'), # Log to a file - logging.StreamHandler() # Log to the console - ] -) -logger = logging.getLogger(__name__) + if not splitter.exists(): + raise FileNotFoundError(f"Cannot find {splitter}") + if not cfg.exists(): + raise FileNotFoundError(f"Cannot find {cfg}") -def relay_to_server(server_host, server_port, data): - """Relay data to a server and return the response.""" - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: - server_socket.connect((server_host, server_port)) - server_socket.sendall(data) - response = server_socket.recv(4096) - return response - except Exception as e: - logger.error(f"Failed to relay data to {server_host}:{server_port}: {e}") - return None + # Launch tcp_splitter.exe with its own console and a readable title + args = [str(splitter), "--config", str(cfg), "--title", TITLE] + proc = subprocess.Popen(args, creationflags=CREATE_NEW_CONSOLE) + return proc -def calculate_crc(data): - """Calculate CRC using CrcArc.""" - crc = CrcArc.calc(data.encode()) # Encode data to bytes - crc_hex = hex(crc).split('x')[1].upper().zfill(4) # Format as uppercase hex (4 chars) - return crc_hex +def main(): + while True: + start_ts = time.strftime("%Y-%m-%d %H:%M:%S") + print(f"[worker] Starting splitter at {start_ts}") + try: + proc = launch() + except Exception as e: + print(f"[worker] Launch error: {e}") + time.sleep(3) + continue - -def calculate_length(data): - """Calculate the length of the data in hexadecimal.""" - length = len(data.encode()) # Length in bytes - length_hex = hex(length)[2:].upper().zfill(4) # Convert to hex, ensure 4 digits - return length_hex - -def handle_client(client_socket, client_address): - """Handle a client connection.""" - try: - # Receive data from the client - data = client_socket.recv(4096) - if not data: - logger.warning(f"No data received from {client_address}") - return - - logger.debug(f"Received data from {client_address}: {data}") - - # Decode the data and check for NULL message - plain_data = data.decode("utf-8", errors='ignore') - # Validate message integrity - if not "SIA-DCS" in plain_data or not "ADM-CID" in plain_data: - return - if '"NULL"' in plain_data: - logger.debug("NULL Message detected, sending ACK without relaying signal!") - orig_text = plain_data[15:].strip() - ack_text = '"ACK"' + orig_text - ack_text_bytes = ack_text.encode() - logger.debug(f"ACK Message: {ack_text}") - - # Calculate CRC and format it - crc = CrcArc.calc(ack_text_bytes) - crcstr = str(hex(crc)).split('x') - crcstr = str(crcstr[1].upper()) - if len(crcstr) == 2: - crcstr = '00' + crcstr - if len(crcstr) == 3: - crcstr = '0' + crcstr - - # Calculate length and format it - length = str(hex(len(ack_text_bytes))).split('x') - logger.debug(f"CRC & Length: {crcstr}, 00{length[1]}") - # Construct the ACK message - ack_msg = '\n' + crcstr + '00' + length[1].upper() + ack_text + '\r' - - # Send the ACK response - ack_bytes = bytes(ack_msg, 'ASCII') - client_socket.sendall(ack_bytes) - logger.debug(f"Response sent to client: {ack_msg.strip()}") - return - - # Relay the data to all target hosts - responses = [] - for target in TARGET_HOSTS: - logger.info(f"Relaying data to {target['host']}:{target['port']}") - response = relay_to_server(target['host'], target['port'], data) - if response: - responses.append(response) - logger.debug(f"Received response from {target['host']}:{target['port']}: {response}") - - # Send only the first server's response back to the client - if responses: - client_socket.sendall(responses[0]) - logger.info(f"Sent response to {client_address}: {responses[0]}") - else: - logger.warning(f"No responses received from target hosts for {client_address}") - - except Exception as e: - logger.error(f"Error handling client {client_address}: {e}") - finally: - client_socket.close() - logger.debug(f"Closed connection with {client_address}") - -def start_relay(): - """Start the TCP relay server.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: - server_socket.bind((LISTEN_HOST, LISTEN_PORT)) - server_socket.listen(5) - logger.info(f"Relay server started listening on {LISTEN_HOST}:{LISTEN_PORT}") - - # Set a timer to stop the server after RUN_DURATION seconds - stop_time = time.time() + RUN_DURATION - logger.debug(f"Server will run for {RUN_DURATION} seconds.") - - while time.time() < stop_time: - try: - # Set a timeout to periodically check if the run duration has elapsed - server_socket.settimeout(1) - client_socket, client_address = server_socket.accept() - logger.debug(f"Accepted connection from {client_address}") - client_thread = threading.Thread(target=handle_client, args=(client_socket, client_address)) - client_thread.start() - except socket.timeout: - continue - except Exception as e: - logger.error(f"Error accepting connection: {e}") - - logger.info("Server run duration elapsed. Shutting down.") + # Poll until it exits; restart on exit + while True: + rc = proc.poll() + if rc is not None: + end_ts = time.strftime("%Y-%m-%d %H:%M:%S") + print(f"[worker] Splitter exited (code {rc}) at {end_ts}. Restarting…") + time.sleep(1) + break + time.sleep(1) if __name__ == "__main__": - start_relay() \ No newline at end of file + main()