Oppdatert til mer stabil versjon.

This commit is contained in:
Anders Knutsen 2025-09-02 15:25:48 +02:00
parent f51134e2ff
commit a2ddc7cc1c
5 changed files with 340 additions and 172 deletions

View File

@ -2,8 +2,8 @@
Et Python-basert verktøy som splitter innkommende TCP-trafikk til flere mål samtidig. Et Python-basert verktøy som splitter innkommende TCP-trafikk til flere mål samtidig.
## Start ## Start
Kjør `tcp_splitter.py` for å starte programmet. Kjør `worker.exe` for å starte programmet.
Scriptet `worker.py` blir startet automatisk av hovedscriptet, og restartes med jevne mellomrom basert på verdien satt i `config.json`. Scriptet `tcp_splitter.exe` blir startet automatisk av hovedscriptet, og restartes med jevne mellomrom basert på verdien satt i `config.json`.
--- ---

BIN
tcp_splitter.exe Normal file

Binary file not shown.

View File

@ -1,26 +1,300 @@
import subprocess #!/usr/bin/env python3
import time # tcp_splitter.py
import argparse
import asyncio
import json
import logging
import signal
import socket
import sys
import os
import ctypes
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Awaitable, List, Optional
while True: BUF_SIZE = 64 * 1024
# Get the current time and print it when starting the process DEFAULT_MIRROR_BUF = 1 * 1024 * 1024 # 1 MiB
restart_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
print(f"Starting worker script at {restart_time}")
# Start the worker script in a new Command Prompt window with a custom title def set_console_title(title: Optional[str]):
process = subprocess.Popen(['start', 'cmd', '/c', 'title TCP_SPLITTER && worker.exe'], shell=True) if os.name == "nt" and title:
try:
ctypes.windll.kernel32.SetConsoleTitleW(str(title))
except Exception:
pass
# Wait a moment to allow the process to start def set_socket_opts(writer: asyncio.StreamWriter) -> None:
time.sleep(2) sock = writer.get_extra_info("socket")
if not sock:
return
for opt in (
(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
):
try: sock.setsockopt(*opt)
except Exception: pass
@dataclass
class Target:
host: str
port: int
@dataclass
class Config:
listen_host: str
listen_port: int
primary: Target
backups: List[Target]
run_duration: int
log_level: int
mirror_buffer_bytes: int = DEFAULT_MIRROR_BUF
idle_timeout: float = 0.0
def _level_from_name(name: str) -> int:
try:
lvl = getattr(logging, str(name).upper())
return lvl if isinstance(lvl, int) else logging.INFO
except Exception:
return logging.INFO
def resolve_default_config_path() -> Path:
base = Path(sys.executable).parent if getattr(sys, "frozen", False) else Path(__file__).resolve().parent
cfg = base / "config.json"
return cfg if cfg.exists() else Path.cwd() / "config.json"
def load_config(path: Path) -> Config:
if not path.exists():
raise FileNotFoundError(f"Config not found: {path}")
raw = json.loads(path.read_text(encoding="utf-8"))
targets = raw["target_hosts"]
if not isinstance(targets, list) or not targets:
raise ValueError("config.target_hosts must contain at least one entry")
primary = Target(str(targets[0]["host"]), int(targets[0]["port"]))
backups = [Target(str(t["host"]), int(t["port"])) for t in targets[1:]]
return Config(
listen_host=str(raw["listen_host"]),
listen_port=int(raw["listen_port"]),
primary=primary,
backups=backups,
run_duration=int(raw.get("run_duration", 0) or 0),
log_level=_level_from_name(raw.get("log_level", "INFO")),
mirror_buffer_bytes=int(raw.get("mirror_buffer_bytes", DEFAULT_MIRROR_BUF)),
idle_timeout=float(raw.get("idle_timeout", 0.0)),
)
class MirrorBuffer:
def __init__(self, max_bytes: int):
self._q: asyncio.Queue[bytes] = asyncio.Queue()
self._bytes = 0
self._max = max_bytes
self._closed = False
self._lock = asyncio.Lock()
@property
def closed(self) -> bool: return self._closed
async def close(self):
async with self._lock:
self._closed = True
while not self._q.empty():
try:
self._q.get_nowait(); self._q.task_done()
except Exception: break
async def put_nowait(self, data: bytes):
if self._closed: return
async with self._lock:
new_total = self._bytes + len(data)
if new_total > self._max:
raise asyncio.QueueFull()
self._bytes = new_total
self._q.put_nowait(data)
async def get(self) -> bytes:
data = await self._q.get()
async with self._lock:
self._bytes -= len(data)
self._q.task_done()
return data
def empty(self) -> bool: return self._q.empty()
@dataclass
class MirrorCtl:
name: str
buf: MirrorBuffer
disable: Callable[[], Awaitable[None]]
active: bool = True
async def pump(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, *,
label: str, idle_timeout: float = 0.0,
mirrors: Optional[List[MirrorCtl]] = None):
set_socket_opts(writer)
while True: while True:
# Check if the specific titled process is still running try:
result = subprocess.run('tasklist /v | findstr "TCP_SPLITTER"', shell=True, stdout=subprocess.PIPE, text=True) data = await (asyncio.wait_for(reader.read(BUF_SIZE), timeout=idle_timeout) if idle_timeout > 0 else reader.read(BUF_SIZE))
except asyncio.TimeoutError:
logging.info("%s: idle timeout; closing direction", label); break
if not data: break
writer.write(data)
try: await writer.drain()
except ConnectionError:
logging.info("%s: downstream closed; stopping", label); break
if mirrors:
for m in mirrors:
if not m.active or m.buf.closed: continue
try: await m.buf.put_nowait(data)
except asyncio.QueueFull:
m.active = False
logging.warning("%s: mirror '%s' overflow -> disabling", label, m.name)
try: await m.disable()
except Exception: pass
# If "TCP_SPLITTER" title is not found, the worker script has stopped async def mirror_writer_task(name: str, writer: asyncio.StreamWriter, mbuf: MirrorBuffer):
if "TCP_SPLITTER" not in result.stdout: set_socket_opts(writer)
exit_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) try:
print(f"Worker script exited at {exit_time}. Restarting...") while not mbuf.closed:
break # Exit the inner loop to restart the outer loop and launch the script again if mbuf.empty() and mbuf.closed: break
data = await mbuf.get()
if not data: continue
writer.write(data)
try: await writer.drain()
except ConnectionError:
logging.info("backup '%s': writer closed", name); break
except asyncio.CancelledError:
pass
finally:
try: writer.close(); await writer.wait_closed()
except Exception: pass
# Check every second if the script is still running async def discard_task(name: str, reader: asyncio.StreamReader):
time.sleep(1) try:
while True:
if not await reader.read(BUF_SIZE): break
except asyncio.CancelledError:
pass
except Exception as e:
logging.debug("backup '%s': discard exception: %r", name, e)
async def handle_client(client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter, cfg: Config):
peer = client_writer.get_extra_info("peername")
conn_id = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) else "unknown"
logging.info("[%s] accepted", conn_id)
set_socket_opts(client_writer)
try:
p_reader, p_writer = await asyncio.open_connection(cfg.primary.host, cfg.primary.port)
set_socket_opts(p_writer)
logging.debug("[%s] connected primary %s:%d", conn_id, cfg.primary.host, cfg.primary.port)
except Exception as e:
logging.error("[%s] failed to connect primary %s:%d: %s", conn_id, cfg.primary.host, cfg.primary.port, e)
client_writer.close(); await client_writer.wait_closed(); return
mirror_ctls: List[MirrorCtl] = []
mirror_tasks: List[asyncio.Task] = []
async def mk_disable(w: Optional[asyncio.StreamWriter], buf: MirrorBuffer):
async def _disable():
try: await buf.close()
except Exception: pass
if w:
try: w.close(); await w.wait_closed()
except Exception: pass
return _disable
for tgt in cfg.backups:
name = f"{tgt.host}:{tgt.port}"
try:
b_reader, b_writer = await asyncio.open_connection(tgt.host, tgt.port)
set_socket_opts(b_writer)
mbuf = MirrorBuffer(cfg.mirror_buffer_bytes)
disable_cb = await mk_disable(b_writer, mbuf)
ctl = MirrorCtl(name=name, buf=mbuf, disable=disable_cb, active=True)
mirror_ctls.append(ctl)
mirror_tasks += [
asyncio.create_task(mirror_writer_task(name, b_writer, mbuf)),
asyncio.create_task(discard_task(name, b_reader)),
]
logging.debug("[%s] connected backup %s", conn_id, name)
except Exception as e:
logging.warning("[%s] cannot connect backup %s: %s (disabled)", conn_id, name, e)
c2p = asyncio.create_task(pump(client_reader, p_writer, label=f"[{conn_id}] client->primary",
idle_timeout=cfg.idle_timeout, mirrors=mirror_ctls))
p2c = asyncio.create_task(pump(p_reader, client_writer, label=f"[{conn_id}] primary->client",
idle_timeout=cfg.idle_timeout))
try:
await asyncio.wait([c2p, p2c], return_when=asyncio.FIRST_COMPLETED)
finally:
for obj in (p_writer, client_writer):
try: obj.close(); await obj.wait_closed()
except Exception: pass
for ctl in mirror_ctls:
try: await ctl.disable()
except Exception: pass
for t in mirror_tasks + [c2p, p2c]:
if not t.done(): t.cancel()
for t in mirror_tasks + [c2p, p2c]:
try: await t
except Exception: pass
logging.info("[%s] closed", conn_id)
async def run_server(cfg: Config):
server = await asyncio.start_server(
lambda r, w: handle_client(r, w, cfg),
host=cfg.listen_host,
port=cfg.listen_port,
start_serving=True,
)
sockets = ", ".join(str(s.getsockname()) for s in (server.sockets or []))
logging.info("Listening on %s | primary=%s:%d | backups=%s | mirror_buffer=%d",
sockets, cfg.primary.host, cfg.primary.port,
",".join(f"{b.host}:{b.port}" for b in cfg.backups) or "none",
cfg.mirror_buffer_bytes)
stop_event = asyncio.Event()
def _signal_stop():
logging.info("Shutdown requested…")
stop_event.set()
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
try:
loop.add_signal_handler(sig, _signal_stop)
except NotImplementedError:
pass # Windows quirk
# Make TASKS (not bare coroutines)
stop_task = asyncio.create_task(stop_event.wait())
duration_task = (asyncio.create_task(asyncio.sleep(cfg.run_duration))
if cfg.run_duration and cfg.run_duration > 0 else None)
async with server:
tasks = [stop_task] + ([duration_task] if duration_task else [])
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
# Cancel whatever is left
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
server.close()
await server.wait_closed()
logging.info("Server stopped.")
def main():
ap = argparse.ArgumentParser(description="Asyncio TCP splitter with mirrored backups (config-driven).")
ap.add_argument("--config", help="Path to JSON config file.")
ap.add_argument("--title", help="Console window title (Windows).")
args = ap.parse_args()
cfg_path = Path(args.config) if args.config else resolve_default_config_path()
cfg = load_config(cfg_path)
logging.basicConfig(level=cfg.log_level, format="%(asctime)s %(levelname)s %(message)s")
logging.info("Using config: %s", cfg_path)
set_console_title(args.title)
try:
asyncio.run(run_server(cfg))
except KeyboardInterrupt:
pass
if __name__ == "__main__":
main()

BIN
worker.exe Normal file

Binary file not shown.

192
worker.py
View File

@ -1,162 +1,56 @@
import socket #!/usr/bin/env python3
import threading # worker.py
import json import subprocess
import time import time
import logging import sys
from datetime import datetime from pathlib import Path
from crccheck.crc import CrcArc
# Load configuration from config.json TITLE = "TCP_SPLITTER"
with open('config.json', 'r') as config_file: SPLITTER_EXE_NAME = "tcp_splitter.exe"
config = json.load(config_file) CONFIG_NAME = "config.json"
# Configuration from config.json # Windows creation flag: create child in a new console window
LISTEN_HOST = config['listen_host'] CREATE_NEW_CONSOLE = 0x00000010
LISTEN_PORT = config['listen_port']
TARGET_HOSTS = config['target_hosts']
RUN_DURATION = config['run_duration']
LOG_LEVEL = config['log_level']
# Map log level strings to logging module constants def base_dir() -> Path:
log_level_mapping = { # When frozen by PyInstaller, sys.executable is the EXE path
"DEBUG": logging.DEBUG, return Path(sys.executable).parent if getattr(sys, "frozen", False) else Path(__file__).resolve().parent
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL
}
# Validate and set the log level def launch() -> subprocess.Popen:
if LOG_LEVEL not in log_level_mapping: base = base_dir()
raise ValueError(f"Invalid log level: {LOG_LEVEL}. Valid options are: {list(log_level_mapping.keys())}") splitter = base / SPLITTER_EXE_NAME
cfg = base / CONFIG_NAME
# Logging configuration if not splitter.exists():
logging.basicConfig( raise FileNotFoundError(f"Cannot find {splitter}")
level=log_level_mapping[LOG_LEVEL], if not cfg.exists():
format='%(asctime)s - %(levelname)s - %(message)s', raise FileNotFoundError(f"Cannot find {cfg}")
handlers=[
logging.FileHandler('relay_server.log', encoding='utf-8'), # Log to a file
logging.StreamHandler() # Log to the console
]
)
logger = logging.getLogger(__name__)
def relay_to_server(server_host, server_port, data): # Launch tcp_splitter.exe with its own console and a readable title
"""Relay data to a server and return the response.""" args = [str(splitter), "--config", str(cfg), "--title", TITLE]
proc = subprocess.Popen(args, creationflags=CREATE_NEW_CONSOLE)
return proc
def main():
while True:
start_ts = time.strftime("%Y-%m-%d %H:%M:%S")
print(f"[worker] Starting splitter at {start_ts}")
try: try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: proc = launch()
server_socket.connect((server_host, server_port))
server_socket.sendall(data)
response = server_socket.recv(4096)
return response
except Exception as e: except Exception as e:
logger.error(f"Failed to relay data to {server_host}:{server_port}: {e}") print(f"[worker] Launch error: {e}")
return None time.sleep(3)
def calculate_crc(data):
"""Calculate CRC using CrcArc."""
crc = CrcArc.calc(data.encode()) # Encode data to bytes
crc_hex = hex(crc).split('x')[1].upper().zfill(4) # Format as uppercase hex (4 chars)
return crc_hex
def calculate_length(data):
"""Calculate the length of the data in hexadecimal."""
length = len(data.encode()) # Length in bytes
length_hex = hex(length)[2:].upper().zfill(4) # Convert to hex, ensure 4 digits
return length_hex
def handle_client(client_socket, client_address):
"""Handle a client connection."""
try:
# Receive data from the client
data = client_socket.recv(4096)
if not data:
logger.warning(f"No data received from {client_address}")
return
logger.debug(f"Received data from {client_address}: {data}")
# Decode the data and check for NULL message
plain_data = data.decode("utf-8", errors='ignore')
# Validate message integrity
if not "SIA-DCS" in plain_data or not "ADM-CID" in plain_data:
return
if '"NULL"' in plain_data:
logger.debug("NULL Message detected, sending ACK without relaying signal!")
orig_text = plain_data[15:].strip()
ack_text = '"ACK"' + orig_text
ack_text_bytes = ack_text.encode()
logger.debug(f"ACK Message: {ack_text}")
# Calculate CRC and format it
crc = CrcArc.calc(ack_text_bytes)
crcstr = str(hex(crc)).split('x')
crcstr = str(crcstr[1].upper())
if len(crcstr) == 2:
crcstr = '00' + crcstr
if len(crcstr) == 3:
crcstr = '0' + crcstr
# Calculate length and format it
length = str(hex(len(ack_text_bytes))).split('x')
logger.debug(f"CRC & Length: {crcstr}, 00{length[1]}")
# Construct the ACK message
ack_msg = '\n' + crcstr + '00' + length[1].upper() + ack_text + '\r'
# Send the ACK response
ack_bytes = bytes(ack_msg, 'ASCII')
client_socket.sendall(ack_bytes)
logger.debug(f"Response sent to client: {ack_msg.strip()}")
return
# Relay the data to all target hosts
responses = []
for target in TARGET_HOSTS:
logger.info(f"Relaying data to {target['host']}:{target['port']}")
response = relay_to_server(target['host'], target['port'], data)
if response:
responses.append(response)
logger.debug(f"Received response from {target['host']}:{target['port']}: {response}")
# Send only the first server's response back to the client
if responses:
client_socket.sendall(responses[0])
logger.info(f"Sent response to {client_address}: {responses[0]}")
else:
logger.warning(f"No responses received from target hosts for {client_address}")
except Exception as e:
logger.error(f"Error handling client {client_address}: {e}")
finally:
client_socket.close()
logger.debug(f"Closed connection with {client_address}")
def start_relay():
"""Start the TCP relay server."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
server_socket.bind((LISTEN_HOST, LISTEN_PORT))
server_socket.listen(5)
logger.info(f"Relay server started listening on {LISTEN_HOST}:{LISTEN_PORT}")
# Set a timer to stop the server after RUN_DURATION seconds
stop_time = time.time() + RUN_DURATION
logger.debug(f"Server will run for {RUN_DURATION} seconds.")
while time.time() < stop_time:
try:
# Set a timeout to periodically check if the run duration has elapsed
server_socket.settimeout(1)
client_socket, client_address = server_socket.accept()
logger.debug(f"Accepted connection from {client_address}")
client_thread = threading.Thread(target=handle_client, args=(client_socket, client_address))
client_thread.start()
except socket.timeout:
continue continue
except Exception as e:
logger.error(f"Error accepting connection: {e}")
logger.info("Server run duration elapsed. Shutting down.") # Poll until it exits; restart on exit
while True:
rc = proc.poll()
if rc is not None:
end_ts = time.strftime("%Y-%m-%d %H:%M:%S")
print(f"[worker] Splitter exited (code {rc}) at {end_ts}. Restarting…")
time.sleep(1)
break
time.sleep(1)
if __name__ == "__main__": if __name__ == "__main__":
start_relay() main()