Oppdatert til mer stabil versjon.

This commit is contained in:
Anders Knutsen 2025-09-02 15:25:48 +02:00
parent f51134e2ff
commit a2ddc7cc1c
5 changed files with 340 additions and 172 deletions

View File

@ -2,8 +2,8 @@
Et Python-basert verktøy som splitter innkommende TCP-trafikk til flere mål samtidig.
## Start
Kjør `tcp_splitter.py` for å starte programmet.
Scriptet `worker.py` blir startet automatisk av hovedscriptet, og restartes med jevne mellomrom basert på verdien satt i `config.json`.
Kjør `worker.exe` for å starte programmet.
Scriptet `tcp_splitter.exe` blir startet automatisk av hovedscriptet, og restartes med jevne mellomrom basert på verdien satt i `config.json`.
---

BIN
tcp_splitter.exe Normal file

Binary file not shown.

View File

@ -1,26 +1,300 @@
import subprocess
import time
#!/usr/bin/env python3
# tcp_splitter.py
import argparse
import asyncio
import json
import logging
import signal
import socket
import sys
import os
import ctypes
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Awaitable, List, Optional
while True:
# Get the current time and print it when starting the process
restart_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
print(f"Starting worker script at {restart_time}")
BUF_SIZE = 64 * 1024
DEFAULT_MIRROR_BUF = 1 * 1024 * 1024 # 1 MiB
# Start the worker script in a new Command Prompt window with a custom title
process = subprocess.Popen(['start', 'cmd', '/c', 'title TCP_SPLITTER && worker.exe'], shell=True)
def set_console_title(title: Optional[str]):
if os.name == "nt" and title:
try:
ctypes.windll.kernel32.SetConsoleTitleW(str(title))
except Exception:
pass
# Wait a moment to allow the process to start
time.sleep(2)
def set_socket_opts(writer: asyncio.StreamWriter) -> None:
sock = writer.get_extra_info("socket")
if not sock:
return
for opt in (
(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
):
try: sock.setsockopt(*opt)
except Exception: pass
@dataclass
class Target:
host: str
port: int
@dataclass
class Config:
listen_host: str
listen_port: int
primary: Target
backups: List[Target]
run_duration: int
log_level: int
mirror_buffer_bytes: int = DEFAULT_MIRROR_BUF
idle_timeout: float = 0.0
def _level_from_name(name: str) -> int:
try:
lvl = getattr(logging, str(name).upper())
return lvl if isinstance(lvl, int) else logging.INFO
except Exception:
return logging.INFO
def resolve_default_config_path() -> Path:
base = Path(sys.executable).parent if getattr(sys, "frozen", False) else Path(__file__).resolve().parent
cfg = base / "config.json"
return cfg if cfg.exists() else Path.cwd() / "config.json"
def load_config(path: Path) -> Config:
if not path.exists():
raise FileNotFoundError(f"Config not found: {path}")
raw = json.loads(path.read_text(encoding="utf-8"))
targets = raw["target_hosts"]
if not isinstance(targets, list) or not targets:
raise ValueError("config.target_hosts must contain at least one entry")
primary = Target(str(targets[0]["host"]), int(targets[0]["port"]))
backups = [Target(str(t["host"]), int(t["port"])) for t in targets[1:]]
return Config(
listen_host=str(raw["listen_host"]),
listen_port=int(raw["listen_port"]),
primary=primary,
backups=backups,
run_duration=int(raw.get("run_duration", 0) or 0),
log_level=_level_from_name(raw.get("log_level", "INFO")),
mirror_buffer_bytes=int(raw.get("mirror_buffer_bytes", DEFAULT_MIRROR_BUF)),
idle_timeout=float(raw.get("idle_timeout", 0.0)),
)
class MirrorBuffer:
def __init__(self, max_bytes: int):
self._q: asyncio.Queue[bytes] = asyncio.Queue()
self._bytes = 0
self._max = max_bytes
self._closed = False
self._lock = asyncio.Lock()
@property
def closed(self) -> bool: return self._closed
async def close(self):
async with self._lock:
self._closed = True
while not self._q.empty():
try:
self._q.get_nowait(); self._q.task_done()
except Exception: break
async def put_nowait(self, data: bytes):
if self._closed: return
async with self._lock:
new_total = self._bytes + len(data)
if new_total > self._max:
raise asyncio.QueueFull()
self._bytes = new_total
self._q.put_nowait(data)
async def get(self) -> bytes:
data = await self._q.get()
async with self._lock:
self._bytes -= len(data)
self._q.task_done()
return data
def empty(self) -> bool: return self._q.empty()
@dataclass
class MirrorCtl:
name: str
buf: MirrorBuffer
disable: Callable[[], Awaitable[None]]
active: bool = True
async def pump(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, *,
label: str, idle_timeout: float = 0.0,
mirrors: Optional[List[MirrorCtl]] = None):
set_socket_opts(writer)
while True:
# Check if the specific titled process is still running
result = subprocess.run('tasklist /v | findstr "TCP_SPLITTER"', shell=True, stdout=subprocess.PIPE, text=True)
try:
data = await (asyncio.wait_for(reader.read(BUF_SIZE), timeout=idle_timeout) if idle_timeout > 0 else reader.read(BUF_SIZE))
except asyncio.TimeoutError:
logging.info("%s: idle timeout; closing direction", label); break
if not data: break
writer.write(data)
try: await writer.drain()
except ConnectionError:
logging.info("%s: downstream closed; stopping", label); break
if mirrors:
for m in mirrors:
if not m.active or m.buf.closed: continue
try: await m.buf.put_nowait(data)
except asyncio.QueueFull:
m.active = False
logging.warning("%s: mirror '%s' overflow -> disabling", label, m.name)
try: await m.disable()
except Exception: pass
# If "TCP_SPLITTER" title is not found, the worker script has stopped
if "TCP_SPLITTER" not in result.stdout:
exit_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
print(f"Worker script exited at {exit_time}. Restarting...")
break # Exit the inner loop to restart the outer loop and launch the script again
async def mirror_writer_task(name: str, writer: asyncio.StreamWriter, mbuf: MirrorBuffer):
set_socket_opts(writer)
try:
while not mbuf.closed:
if mbuf.empty() and mbuf.closed: break
data = await mbuf.get()
if not data: continue
writer.write(data)
try: await writer.drain()
except ConnectionError:
logging.info("backup '%s': writer closed", name); break
except asyncio.CancelledError:
pass
finally:
try: writer.close(); await writer.wait_closed()
except Exception: pass
# Check every second if the script is still running
time.sleep(1)
async def discard_task(name: str, reader: asyncio.StreamReader):
try:
while True:
if not await reader.read(BUF_SIZE): break
except asyncio.CancelledError:
pass
except Exception as e:
logging.debug("backup '%s': discard exception: %r", name, e)
async def handle_client(client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter, cfg: Config):
peer = client_writer.get_extra_info("peername")
conn_id = f"{peer[0]}:{peer[1]}" if isinstance(peer, tuple) else "unknown"
logging.info("[%s] accepted", conn_id)
set_socket_opts(client_writer)
try:
p_reader, p_writer = await asyncio.open_connection(cfg.primary.host, cfg.primary.port)
set_socket_opts(p_writer)
logging.debug("[%s] connected primary %s:%d", conn_id, cfg.primary.host, cfg.primary.port)
except Exception as e:
logging.error("[%s] failed to connect primary %s:%d: %s", conn_id, cfg.primary.host, cfg.primary.port, e)
client_writer.close(); await client_writer.wait_closed(); return
mirror_ctls: List[MirrorCtl] = []
mirror_tasks: List[asyncio.Task] = []
async def mk_disable(w: Optional[asyncio.StreamWriter], buf: MirrorBuffer):
async def _disable():
try: await buf.close()
except Exception: pass
if w:
try: w.close(); await w.wait_closed()
except Exception: pass
return _disable
for tgt in cfg.backups:
name = f"{tgt.host}:{tgt.port}"
try:
b_reader, b_writer = await asyncio.open_connection(tgt.host, tgt.port)
set_socket_opts(b_writer)
mbuf = MirrorBuffer(cfg.mirror_buffer_bytes)
disable_cb = await mk_disable(b_writer, mbuf)
ctl = MirrorCtl(name=name, buf=mbuf, disable=disable_cb, active=True)
mirror_ctls.append(ctl)
mirror_tasks += [
asyncio.create_task(mirror_writer_task(name, b_writer, mbuf)),
asyncio.create_task(discard_task(name, b_reader)),
]
logging.debug("[%s] connected backup %s", conn_id, name)
except Exception as e:
logging.warning("[%s] cannot connect backup %s: %s (disabled)", conn_id, name, e)
c2p = asyncio.create_task(pump(client_reader, p_writer, label=f"[{conn_id}] client->primary",
idle_timeout=cfg.idle_timeout, mirrors=mirror_ctls))
p2c = asyncio.create_task(pump(p_reader, client_writer, label=f"[{conn_id}] primary->client",
idle_timeout=cfg.idle_timeout))
try:
await asyncio.wait([c2p, p2c], return_when=asyncio.FIRST_COMPLETED)
finally:
for obj in (p_writer, client_writer):
try: obj.close(); await obj.wait_closed()
except Exception: pass
for ctl in mirror_ctls:
try: await ctl.disable()
except Exception: pass
for t in mirror_tasks + [c2p, p2c]:
if not t.done(): t.cancel()
for t in mirror_tasks + [c2p, p2c]:
try: await t
except Exception: pass
logging.info("[%s] closed", conn_id)
async def run_server(cfg: Config):
server = await asyncio.start_server(
lambda r, w: handle_client(r, w, cfg),
host=cfg.listen_host,
port=cfg.listen_port,
start_serving=True,
)
sockets = ", ".join(str(s.getsockname()) for s in (server.sockets or []))
logging.info("Listening on %s | primary=%s:%d | backups=%s | mirror_buffer=%d",
sockets, cfg.primary.host, cfg.primary.port,
",".join(f"{b.host}:{b.port}" for b in cfg.backups) or "none",
cfg.mirror_buffer_bytes)
stop_event = asyncio.Event()
def _signal_stop():
logging.info("Shutdown requested…")
stop_event.set()
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
try:
loop.add_signal_handler(sig, _signal_stop)
except NotImplementedError:
pass # Windows quirk
# Make TASKS (not bare coroutines)
stop_task = asyncio.create_task(stop_event.wait())
duration_task = (asyncio.create_task(asyncio.sleep(cfg.run_duration))
if cfg.run_duration and cfg.run_duration > 0 else None)
async with server:
tasks = [stop_task] + ([duration_task] if duration_task else [])
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
# Cancel whatever is left
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
server.close()
await server.wait_closed()
logging.info("Server stopped.")
def main():
ap = argparse.ArgumentParser(description="Asyncio TCP splitter with mirrored backups (config-driven).")
ap.add_argument("--config", help="Path to JSON config file.")
ap.add_argument("--title", help="Console window title (Windows).")
args = ap.parse_args()
cfg_path = Path(args.config) if args.config else resolve_default_config_path()
cfg = load_config(cfg_path)
logging.basicConfig(level=cfg.log_level, format="%(asctime)s %(levelname)s %(message)s")
logging.info("Using config: %s", cfg_path)
set_console_title(args.title)
try:
asyncio.run(run_server(cfg))
except KeyboardInterrupt:
pass
if __name__ == "__main__":
main()

BIN
worker.exe Normal file

Binary file not shown.

196
worker.py
View File

@ -1,162 +1,56 @@
import socket
import threading
import json
#!/usr/bin/env python3
# worker.py
import subprocess
import time
import logging
from datetime import datetime
from crccheck.crc import CrcArc
import sys
from pathlib import Path
# Load configuration from config.json
with open('config.json', 'r') as config_file:
config = json.load(config_file)
TITLE = "TCP_SPLITTER"
SPLITTER_EXE_NAME = "tcp_splitter.exe"
CONFIG_NAME = "config.json"
# Configuration from config.json
LISTEN_HOST = config['listen_host']
LISTEN_PORT = config['listen_port']
TARGET_HOSTS = config['target_hosts']
RUN_DURATION = config['run_duration']
LOG_LEVEL = config['log_level']
# Windows creation flag: create child in a new console window
CREATE_NEW_CONSOLE = 0x00000010
# Map log level strings to logging module constants
log_level_mapping = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL
}
def base_dir() -> Path:
# When frozen by PyInstaller, sys.executable is the EXE path
return Path(sys.executable).parent if getattr(sys, "frozen", False) else Path(__file__).resolve().parent
# Validate and set the log level
if LOG_LEVEL not in log_level_mapping:
raise ValueError(f"Invalid log level: {LOG_LEVEL}. Valid options are: {list(log_level_mapping.keys())}")
def launch() -> subprocess.Popen:
base = base_dir()
splitter = base / SPLITTER_EXE_NAME
cfg = base / CONFIG_NAME
# Logging configuration
logging.basicConfig(
level=log_level_mapping[LOG_LEVEL],
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('relay_server.log', encoding='utf-8'), # Log to a file
logging.StreamHandler() # Log to the console
]
)
logger = logging.getLogger(__name__)
if not splitter.exists():
raise FileNotFoundError(f"Cannot find {splitter}")
if not cfg.exists():
raise FileNotFoundError(f"Cannot find {cfg}")
def relay_to_server(server_host, server_port, data):
"""Relay data to a server and return the response."""
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
server_socket.connect((server_host, server_port))
server_socket.sendall(data)
response = server_socket.recv(4096)
return response
except Exception as e:
logger.error(f"Failed to relay data to {server_host}:{server_port}: {e}")
return None
# Launch tcp_splitter.exe with its own console and a readable title
args = [str(splitter), "--config", str(cfg), "--title", TITLE]
proc = subprocess.Popen(args, creationflags=CREATE_NEW_CONSOLE)
return proc
def calculate_crc(data):
"""Calculate CRC using CrcArc."""
crc = CrcArc.calc(data.encode()) # Encode data to bytes
crc_hex = hex(crc).split('x')[1].upper().zfill(4) # Format as uppercase hex (4 chars)
return crc_hex
def main():
while True:
start_ts = time.strftime("%Y-%m-%d %H:%M:%S")
print(f"[worker] Starting splitter at {start_ts}")
try:
proc = launch()
except Exception as e:
print(f"[worker] Launch error: {e}")
time.sleep(3)
continue
def calculate_length(data):
"""Calculate the length of the data in hexadecimal."""
length = len(data.encode()) # Length in bytes
length_hex = hex(length)[2:].upper().zfill(4) # Convert to hex, ensure 4 digits
return length_hex
def handle_client(client_socket, client_address):
"""Handle a client connection."""
try:
# Receive data from the client
data = client_socket.recv(4096)
if not data:
logger.warning(f"No data received from {client_address}")
return
logger.debug(f"Received data from {client_address}: {data}")
# Decode the data and check for NULL message
plain_data = data.decode("utf-8", errors='ignore')
# Validate message integrity
if not "SIA-DCS" in plain_data or not "ADM-CID" in plain_data:
return
if '"NULL"' in plain_data:
logger.debug("NULL Message detected, sending ACK without relaying signal!")
orig_text = plain_data[15:].strip()
ack_text = '"ACK"' + orig_text
ack_text_bytes = ack_text.encode()
logger.debug(f"ACK Message: {ack_text}")
# Calculate CRC and format it
crc = CrcArc.calc(ack_text_bytes)
crcstr = str(hex(crc)).split('x')
crcstr = str(crcstr[1].upper())
if len(crcstr) == 2:
crcstr = '00' + crcstr
if len(crcstr) == 3:
crcstr = '0' + crcstr
# Calculate length and format it
length = str(hex(len(ack_text_bytes))).split('x')
logger.debug(f"CRC & Length: {crcstr}, 00{length[1]}")
# Construct the ACK message
ack_msg = '\n' + crcstr + '00' + length[1].upper() + ack_text + '\r'
# Send the ACK response
ack_bytes = bytes(ack_msg, 'ASCII')
client_socket.sendall(ack_bytes)
logger.debug(f"Response sent to client: {ack_msg.strip()}")
return
# Relay the data to all target hosts
responses = []
for target in TARGET_HOSTS:
logger.info(f"Relaying data to {target['host']}:{target['port']}")
response = relay_to_server(target['host'], target['port'], data)
if response:
responses.append(response)
logger.debug(f"Received response from {target['host']}:{target['port']}: {response}")
# Send only the first server's response back to the client
if responses:
client_socket.sendall(responses[0])
logger.info(f"Sent response to {client_address}: {responses[0]}")
else:
logger.warning(f"No responses received from target hosts for {client_address}")
except Exception as e:
logger.error(f"Error handling client {client_address}: {e}")
finally:
client_socket.close()
logger.debug(f"Closed connection with {client_address}")
def start_relay():
"""Start the TCP relay server."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
server_socket.bind((LISTEN_HOST, LISTEN_PORT))
server_socket.listen(5)
logger.info(f"Relay server started listening on {LISTEN_HOST}:{LISTEN_PORT}")
# Set a timer to stop the server after RUN_DURATION seconds
stop_time = time.time() + RUN_DURATION
logger.debug(f"Server will run for {RUN_DURATION} seconds.")
while time.time() < stop_time:
try:
# Set a timeout to periodically check if the run duration has elapsed
server_socket.settimeout(1)
client_socket, client_address = server_socket.accept()
logger.debug(f"Accepted connection from {client_address}")
client_thread = threading.Thread(target=handle_client, args=(client_socket, client_address))
client_thread.start()
except socket.timeout:
continue
except Exception as e:
logger.error(f"Error accepting connection: {e}")
logger.info("Server run duration elapsed. Shutting down.")
# Poll until it exits; restart on exit
while True:
rc = proc.poll()
if rc is not None:
end_ts = time.strftime("%Y-%m-%d %H:%M:%S")
print(f"[worker] Splitter exited (code {rc}) at {end_ts}. Restarting…")
time.sleep(1)
break
time.sleep(1)
if __name__ == "__main__":
start_relay()
main()