123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408 |
- import argparse
- import asyncio
- import os
- import logging
- import aiohttp
- from aiohttp import web
- # -----------------------------------------------------------------------------
- # 4MB is the default inside aiohttp
- # -----------------------------------------------------------------------------
- MAX_MSG_SIZE = int(os.environ.get("WSLINK_MAX_MSG_SIZE", 4194304))
- HEART_BEAT = int(os.environ.get("WSLINK_HEART_BEAT", 30)) # 30 seconds
- logger = logging.getLogger(__name__)
- logger.setLevel(logging.INFO)
- # -----------------------------------------------------------------------------
- # Helper classes
- # -----------------------------------------------------------------------------
- class WsClientConnection:
- def __init__(self, propagate_disconnect=True):
- self._url = None
- self._session = None
- self._ws = None
- self._connected = 0
- self._destination = None
- self._ready = asyncio.get_running_loop().create_future()
- self.propagate_disconnect = propagate_disconnect
- def bind(self, value):
- self._destination = value
- @property
- def ready(self):
- return self._ready
- async def connect(self, url):
- logger.debug("client::connect::%s", url)
- self._url = url
- if self._session is None:
- async with aiohttp.ClientSession() as session:
- logger.debug("client::connect::session")
- self._session = session
- try:
- async with session.ws_connect(self._url) as ws:
- logger.debug("client::connect::ws")
- self._ws = ws
- self._connected += 1
- self._ready.set_result(True)
- async for msg in ws:
- logger.debug("client::connect::ws::msg")
- if self._connected < 1:
- logger.debug("client::connect::ws::msg::disconnect")
- break
- if self._destination:
- logger.debug("client::connect::ws::msg::send")
- await self._destination.send(msg)
- else:
- logger.error("ws-client: No destination for message")
- # Disconnect
- self.disconnect()
- # Cleanup connection
- if not self._ws.closed:
- await self._ws.close()
- self._ws = None
- self._session = None
- finally:
- self._ready.set_result(False)
- logger.debug("client::connect::exit")
- async def send(self, msg):
- if self._connected > 0 and not self._ws.closed:
- logger.debug("client::send")
- if msg.type == aiohttp.WSMsgType.TEXT:
- await self._ws.send_str(msg.data)
- elif msg.type == aiohttp.WSMsgType.BINARY:
- await self._ws.send_bytes(msg.data)
- elif msg.type == aiohttp.WSMsgType.PING:
- await self._ws.ping(msg.data)
- elif msg.type == aiohttp.WSMsgType.PONG:
- await self._ws.pong(msg.data)
- elif msg.type == aiohttp.WSMsgType.CLOSE:
- await self._ws.close()
- else:
- logger.error("Invalid message to forward")
- else:
- logger.error("client::send - NO SEND")
- logger.error("%s - %s", self._connected, self._ws.closed)
- logger.error("-" * 60)
- def disconnect(self):
- logger.debug("client::disconnect %s", self._connected)
- if self._connected > 0:
- self._connected = -1
- if self._destination and self.propagate_disconnect:
- self._destination.disconnect()
- async def close(self):
- if self._ws is not None:
- await self._ws.close()
- # -----------------------------------------------------------------------------
- class WsServerConnection:
- def __init__(self, propagate_disconnect=True):
- self._ws = None
- self._destination = None
- self._connected = 0
- self.propagate_disconnect = propagate_disconnect
- def bind(self, value):
- self._destination = value
- async def connect(self, request):
- logger.debug("server::connect")
- self._ws = web.WebSocketResponse(
- max_msg_size=MAX_MSG_SIZE, heartbeat=HEART_BEAT
- )
- await self._ws.prepare(request)
- logger.debug("server::connect::prepare")
- self._connected += 1
- if self._connected > 0:
- async for msg in self._ws:
- logger.debug("server::connect::ws::msg")
- if self._connected < 1:
- break
- if self._destination:
- logger.debug("server::connect::ws::msg::send-begin")
- await self._destination.send(msg)
- logger.debug("server::connect::ws::msg::send-end")
- else:
- logger.error("ws-server: No destination for message")
- # Disconnect
- self.disconnect()
- # Cleanup connection
- if not self._ws.closed:
- await self._ws.close()
- self._ws = None
- logger.debug("server::connect::exit")
- async def send(self, msg):
- if self._connected > 0 and not self._ws.closed:
- logger.debug("server::send")
- if msg.type == aiohttp.WSMsgType.TEXT:
- await self._ws.send_str(msg.data)
- elif msg.type == aiohttp.WSMsgType.BINARY:
- await self._ws.send_bytes(msg.data)
- elif msg.type == aiohttp.WSMsgType.PING:
- await self._ws.ping(msg.data)
- elif msg.type == aiohttp.WSMsgType.PONG:
- await self._ws.pong(msg.data)
- elif msg.type == aiohttp.WSMsgType.CLOSE:
- await self._ws.close()
- else:
- logger.error("Invalid message to forward")
- else:
- logger.error("server::send - NO SEND")
- logger.error("%s - %s", self._connected, self._ws.closed)
- logger.error("-" * 60)
- def disconnect(self):
- logger.debug("server::disconnect %s", self._connected)
- if self._connected > 0:
- self._connected = -1
- if self._destination and self.propagate_disconnect:
- self._destination.disconnect()
- async def close(self):
- if self._ws is not None:
- await self._ws.close()
- # -----------------------------------------------------------------------------
- class ForwardConnection:
- def __init__(self, request, url):
- self._req = request
- self._url = url
- self._ws_client = WsClientConnection()
- self._ws_server = WsServerConnection()
- self._ws_server.bind(self._ws_client)
- self._ws_client.bind(self._ws_server)
- async def connect(self):
- task = asyncio.create_task(self._ws_client.connect(self._url))
- task.add_done_callback(lambda *args, **kwargs: self._ws_server.disconnect())
- await self._ws_client.ready
- await self._ws_server.connect(self._req)
- def disconnect(self):
- self._ws_client.disconnect()
- self._ws_server.disconnect()
- class SinkConnection:
- def __init__(self, request):
- self._process_req = request
- self._client_req = None
- self._process_ws = None
- self._client_ws = None
- def can_handle(self, request):
- if self._process_req == request:
- return True
- if self._client_req == request:
- return True
- if self._client_ws is None:
- return True
- return False
- async def connect(self, request):
- if self._process_req == request:
- # First connection is the actual server. Cannot reconnect.
- self._process_ws = WsServerConnection()
- await self._process_ws.connect(request)
- if self._client_ws is not None:
- await self._client_ws.close()
- return True
- elif self._client_req is None:
- # Second connection is the browser. Can reconnect.
- self._client_req = request
- self._client_ws = WsServerConnection(propagate_disconnect=False)
- self._client_ws.bind(self._process_ws)
- self._process_ws.bind(self._client_ws)
- await self._client_ws.connect(request)
- self._client_ws.bind(None)
- self._process_ws.bind(None)
- self._client_req = None
- self._client_ws = None
- return False
- # -----------------------------------------------------------------------------
- # Handlers
- # -----------------------------------------------------------------------------
- async def _root_handler(request):
- if request.query_string:
- return web.HTTPFound(f"index.html?{request.query_string}")
- return web.HTTPFound("index.html")
- # -----------------------------------------------------------------------------
- class WsHandler:
- def __init__(self):
- self._forward_map = {}
- self._relay_map = {}
- def get_handler(self, mode="forward"):
- logger.info("get_handler %s", mode)
- if mode == "forward":
- return self.forward_connect
- if mode == "relay":
- return self.relay_connect
- logger.error("No handler !!!")
- # -----------------------------
- # forward infrastructure
- # -----------------------------
- async def forward_connect(self, request):
- host = request.match_info.get("host", "localhost")
- port = int(request.match_info.get("port", "1234"))
- path = request.match_info.get("path", "ws")
- target_url = f"ws://{host}:{port}/{path}"
- logger.info("=> %s", target_url)
- if target_url in self._forward_map:
- raise web.HTTPForbidden()
- forwarder = ForwardConnection(request, target_url)
- self._forward_map[target_url] = forwarder
- await forwarder.connect()
- self._forward_map.pop(target_url)
- # -----------------------------
- # relay server infrastructure
- # -----------------------------
- async def relay_connect(self, request):
- id = request.path
- if id not in self._relay_map:
- handler = SinkConnection(request)
- self._relay_map[id] = handler
- handler = self._relay_map[id]
- if not handler.can_handle(request):
- raise web.HTTPForbidden()
- if await handler.connect(request):
- # Only pop when the server dies
- self._relay_map.pop(id)
- # -----------------------------------------------------------------------------
- # Executable
- # -----------------------------------------------------------------------------
- def main(host=None, port=None, www_path=None, proxy_route=None, mode=None):
- wsRelay = WsHandler()
- # Handle CLI
- parser = argparse.ArgumentParser(
- description="Start ws relay with static content delivery",
- formatter_class=argparse.RawDescriptionHelpFormatter,
- )
- parser.add_argument(
- "--host",
- type=str,
- default="localhost",
- help="the interface for the web-server to listen on (default: 0.0.0.0)",
- dest="host",
- )
- parser.add_argument(
- "-p",
- "--port",
- type=int,
- default=8080,
- help="port number for the web-server to listen on (default: 8080)",
- dest="port",
- )
- parser.add_argument(
- "--mode",
- type=str,
- default="forward",
- help="Working mode [forward, relay] (default: forward)",
- )
- parser.add_argument("--www", type=str, help="Directory to serve", dest="www_path")
- parser.add_argument(
- "--proxy-route",
- type=str,
- help="Proxy URL pattern (default: /proxy/{port}) mode::forward(ws://{host=localhost}:{port=1234}/{path=ws})",
- default="/proxy/{port}",
- dest="proxy_route",
- )
- args, _ = parser.parse_known_args()
- if host is None:
- host = args.host
- if port is None:
- port = args.port
- if mode is None:
- mode = args.mode
- if www_path is None:
- www_path = args.www_path
- if proxy_route is None:
- proxy_route = args.proxy_route
- logging.basicConfig()
- # Manage routes
- routes = []
- # Need to be first: static delivery should be a fallback
- if proxy_route is not None:
- logger.info("Proxy route: %s", proxy_route)
- routes.append(web.get(proxy_route, wsRelay.get_handler(mode)))
- # Serve static content
- if www_path is not None:
- logger.info("WWW: %s", www_path)
- routes.append(web.get("/", _root_handler))
- routes.append(web.static("/", www_path))
- # Setup web app
- logger.info("Starting relay server: %s %s", host, port)
- web_app = web.Application()
- web_app.add_routes(routes)
- web.run_app(web_app, host=host, port=port)
- # -----------------------------------------------------------------------------
- # Main
- # -----------------------------------------------------------------------------
- if __name__ == "__main__":
- main()
|