relay.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. import argparse
  2. import asyncio
  3. import os
  4. import logging
  5. import aiohttp
  6. from aiohttp import web
  7. # -----------------------------------------------------------------------------
  8. # 4MB is the default inside aiohttp
  9. # -----------------------------------------------------------------------------
  10. MAX_MSG_SIZE = int(os.environ.get("WSLINK_MAX_MSG_SIZE", 4194304))
  11. HEART_BEAT = int(os.environ.get("WSLINK_HEART_BEAT", 30)) # 30 seconds
  12. logger = logging.getLogger(__name__)
  13. logger.setLevel(logging.INFO)
  14. # -----------------------------------------------------------------------------
  15. # Helper classes
  16. # -----------------------------------------------------------------------------
  17. class WsClientConnection:
  18. def __init__(self, propagate_disconnect=True):
  19. self._url = None
  20. self._session = None
  21. self._ws = None
  22. self._connected = 0
  23. self._destination = None
  24. self._ready = asyncio.get_running_loop().create_future()
  25. self.propagate_disconnect = propagate_disconnect
  26. def bind(self, value):
  27. self._destination = value
  28. @property
  29. def ready(self):
  30. return self._ready
  31. async def connect(self, url):
  32. logger.debug("client::connect::%s", url)
  33. self._url = url
  34. if self._session is None:
  35. async with aiohttp.ClientSession() as session:
  36. logger.debug("client::connect::session")
  37. self._session = session
  38. try:
  39. async with session.ws_connect(self._url) as ws:
  40. logger.debug("client::connect::ws")
  41. self._ws = ws
  42. self._connected += 1
  43. self._ready.set_result(True)
  44. async for msg in ws:
  45. logger.debug("client::connect::ws::msg")
  46. if self._connected < 1:
  47. logger.debug("client::connect::ws::msg::disconnect")
  48. break
  49. if self._destination:
  50. logger.debug("client::connect::ws::msg::send")
  51. await self._destination.send(msg)
  52. else:
  53. logger.error("ws-client: No destination for message")
  54. # Disconnect
  55. self.disconnect()
  56. # Cleanup connection
  57. if not self._ws.closed:
  58. await self._ws.close()
  59. self._ws = None
  60. self._session = None
  61. finally:
  62. self._ready.set_result(False)
  63. logger.debug("client::connect::exit")
  64. async def send(self, msg):
  65. if self._connected > 0 and not self._ws.closed:
  66. logger.debug("client::send")
  67. if msg.type == aiohttp.WSMsgType.TEXT:
  68. await self._ws.send_str(msg.data)
  69. elif msg.type == aiohttp.WSMsgType.BINARY:
  70. await self._ws.send_bytes(msg.data)
  71. elif msg.type == aiohttp.WSMsgType.PING:
  72. await self._ws.ping(msg.data)
  73. elif msg.type == aiohttp.WSMsgType.PONG:
  74. await self._ws.pong(msg.data)
  75. elif msg.type == aiohttp.WSMsgType.CLOSE:
  76. await self._ws.close()
  77. else:
  78. logger.error("Invalid message to forward")
  79. else:
  80. logger.error("client::send - NO SEND")
  81. logger.error("%s - %s", self._connected, self._ws.closed)
  82. logger.error("-" * 60)
  83. def disconnect(self):
  84. logger.debug("client::disconnect %s", self._connected)
  85. if self._connected > 0:
  86. self._connected = -1
  87. if self._destination and self.propagate_disconnect:
  88. self._destination.disconnect()
  89. async def close(self):
  90. if self._ws is not None:
  91. await self._ws.close()
  92. # -----------------------------------------------------------------------------
  93. class WsServerConnection:
  94. def __init__(self, propagate_disconnect=True):
  95. self._ws = None
  96. self._destination = None
  97. self._connected = 0
  98. self.propagate_disconnect = propagate_disconnect
  99. def bind(self, value):
  100. self._destination = value
  101. async def connect(self, request):
  102. logger.debug("server::connect")
  103. self._ws = web.WebSocketResponse(
  104. max_msg_size=MAX_MSG_SIZE, heartbeat=HEART_BEAT
  105. )
  106. await self._ws.prepare(request)
  107. logger.debug("server::connect::prepare")
  108. self._connected += 1
  109. if self._connected > 0:
  110. async for msg in self._ws:
  111. logger.debug("server::connect::ws::msg")
  112. if self._connected < 1:
  113. break
  114. if self._destination:
  115. logger.debug("server::connect::ws::msg::send-begin")
  116. await self._destination.send(msg)
  117. logger.debug("server::connect::ws::msg::send-end")
  118. else:
  119. logger.error("ws-server: No destination for message")
  120. # Disconnect
  121. self.disconnect()
  122. # Cleanup connection
  123. if not self._ws.closed:
  124. await self._ws.close()
  125. self._ws = None
  126. logger.debug("server::connect::exit")
  127. async def send(self, msg):
  128. if self._connected > 0 and not self._ws.closed:
  129. logger.debug("server::send")
  130. if msg.type == aiohttp.WSMsgType.TEXT:
  131. await self._ws.send_str(msg.data)
  132. elif msg.type == aiohttp.WSMsgType.BINARY:
  133. await self._ws.send_bytes(msg.data)
  134. elif msg.type == aiohttp.WSMsgType.PING:
  135. await self._ws.ping(msg.data)
  136. elif msg.type == aiohttp.WSMsgType.PONG:
  137. await self._ws.pong(msg.data)
  138. elif msg.type == aiohttp.WSMsgType.CLOSE:
  139. await self._ws.close()
  140. else:
  141. logger.error("Invalid message to forward")
  142. else:
  143. logger.error("server::send - NO SEND")
  144. logger.error("%s - %s", self._connected, self._ws.closed)
  145. logger.error("-" * 60)
  146. def disconnect(self):
  147. logger.debug("server::disconnect %s", self._connected)
  148. if self._connected > 0:
  149. self._connected = -1
  150. if self._destination and self.propagate_disconnect:
  151. self._destination.disconnect()
  152. async def close(self):
  153. if self._ws is not None:
  154. await self._ws.close()
  155. # -----------------------------------------------------------------------------
  156. class ForwardConnection:
  157. def __init__(self, request, url):
  158. self._req = request
  159. self._url = url
  160. self._ws_client = WsClientConnection()
  161. self._ws_server = WsServerConnection()
  162. self._ws_server.bind(self._ws_client)
  163. self._ws_client.bind(self._ws_server)
  164. async def connect(self):
  165. task = asyncio.create_task(self._ws_client.connect(self._url))
  166. task.add_done_callback(lambda *args, **kwargs: self._ws_server.disconnect())
  167. await self._ws_client.ready
  168. await self._ws_server.connect(self._req)
  169. def disconnect(self):
  170. self._ws_client.disconnect()
  171. self._ws_server.disconnect()
  172. class SinkConnection:
  173. def __init__(self, request):
  174. self._process_req = request
  175. self._client_req = None
  176. self._process_ws = None
  177. self._client_ws = None
  178. def can_handle(self, request):
  179. if self._process_req == request:
  180. return True
  181. if self._client_req == request:
  182. return True
  183. if self._client_ws is None:
  184. return True
  185. return False
  186. async def connect(self, request):
  187. if self._process_req == request:
  188. # First connection is the actual server. Cannot reconnect.
  189. self._process_ws = WsServerConnection()
  190. await self._process_ws.connect(request)
  191. if self._client_ws is not None:
  192. await self._client_ws.close()
  193. return True
  194. elif self._client_req is None:
  195. # Second connection is the browser. Can reconnect.
  196. self._client_req = request
  197. self._client_ws = WsServerConnection(propagate_disconnect=False)
  198. self._client_ws.bind(self._process_ws)
  199. self._process_ws.bind(self._client_ws)
  200. await self._client_ws.connect(request)
  201. self._client_ws.bind(None)
  202. self._process_ws.bind(None)
  203. self._client_req = None
  204. self._client_ws = None
  205. return False
  206. # -----------------------------------------------------------------------------
  207. # Handlers
  208. # -----------------------------------------------------------------------------
  209. async def _root_handler(request):
  210. if request.query_string:
  211. return web.HTTPFound(f"index.html?{request.query_string}")
  212. return web.HTTPFound("index.html")
  213. # -----------------------------------------------------------------------------
  214. class WsHandler:
  215. def __init__(self):
  216. self._forward_map = {}
  217. self._relay_map = {}
  218. def get_handler(self, mode="forward"):
  219. logger.info("get_handler %s", mode)
  220. if mode == "forward":
  221. return self.forward_connect
  222. if mode == "relay":
  223. return self.relay_connect
  224. logger.error("No handler !!!")
  225. # -----------------------------
  226. # forward infrastructure
  227. # -----------------------------
  228. async def forward_connect(self, request):
  229. host = request.match_info.get("host", "localhost")
  230. port = int(request.match_info.get("port", "1234"))
  231. path = request.match_info.get("path", "ws")
  232. target_url = f"ws://{host}:{port}/{path}"
  233. logger.info("=> %s", target_url)
  234. if target_url in self._forward_map:
  235. raise web.HTTPForbidden()
  236. forwarder = ForwardConnection(request, target_url)
  237. self._forward_map[target_url] = forwarder
  238. await forwarder.connect()
  239. self._forward_map.pop(target_url)
  240. # -----------------------------
  241. # relay server infrastructure
  242. # -----------------------------
  243. async def relay_connect(self, request):
  244. id = request.path
  245. if id not in self._relay_map:
  246. handler = SinkConnection(request)
  247. self._relay_map[id] = handler
  248. handler = self._relay_map[id]
  249. if not handler.can_handle(request):
  250. raise web.HTTPForbidden()
  251. if await handler.connect(request):
  252. # Only pop when the server dies
  253. self._relay_map.pop(id)
  254. # -----------------------------------------------------------------------------
  255. # Executable
  256. # -----------------------------------------------------------------------------
  257. def main(host=None, port=None, www_path=None, proxy_route=None, mode=None):
  258. wsRelay = WsHandler()
  259. # Handle CLI
  260. parser = argparse.ArgumentParser(
  261. description="Start ws relay with static content delivery",
  262. formatter_class=argparse.RawDescriptionHelpFormatter,
  263. )
  264. parser.add_argument(
  265. "--host",
  266. type=str,
  267. default="localhost",
  268. help="the interface for the web-server to listen on (default: 0.0.0.0)",
  269. dest="host",
  270. )
  271. parser.add_argument(
  272. "-p",
  273. "--port",
  274. type=int,
  275. default=8080,
  276. help="port number for the web-server to listen on (default: 8080)",
  277. dest="port",
  278. )
  279. parser.add_argument(
  280. "--mode",
  281. type=str,
  282. default="forward",
  283. help="Working mode [forward, relay] (default: forward)",
  284. )
  285. parser.add_argument("--www", type=str, help="Directory to serve", dest="www_path")
  286. parser.add_argument(
  287. "--proxy-route",
  288. type=str,
  289. help="Proxy URL pattern (default: /proxy/{port}) mode::forward(ws://{host=localhost}:{port=1234}/{path=ws})",
  290. default="/proxy/{port}",
  291. dest="proxy_route",
  292. )
  293. args, _ = parser.parse_known_args()
  294. if host is None:
  295. host = args.host
  296. if port is None:
  297. port = args.port
  298. if mode is None:
  299. mode = args.mode
  300. if www_path is None:
  301. www_path = args.www_path
  302. if proxy_route is None:
  303. proxy_route = args.proxy_route
  304. logging.basicConfig()
  305. # Manage routes
  306. routes = []
  307. # Need to be first: static delivery should be a fallback
  308. if proxy_route is not None:
  309. logger.info("Proxy route: %s", proxy_route)
  310. routes.append(web.get(proxy_route, wsRelay.get_handler(mode)))
  311. # Serve static content
  312. if www_path is not None:
  313. logger.info("WWW: %s", www_path)
  314. routes.append(web.get("/", _root_handler))
  315. routes.append(web.static("/", www_path))
  316. # Setup web app
  317. logger.info("Starting relay server: %s %s", host, port)
  318. web_app = web.Application()
  319. web_app.add_routes(routes)
  320. web.run_app(web_app, host=host, port=port)
  321. # -----------------------------------------------------------------------------
  322. # Main
  323. # -----------------------------------------------------------------------------
  324. if __name__ == "__main__":
  325. main()