authentik/passbook/app_gw/websocket/consumer.py

84 lines
3.0 KiB
Python
Raw Normal View History

2019-04-10 18:03:42 +01:00
"""websocket proxy consumer"""
import threading
from ssl import CERT_NONE
import websocket
from channels.generic.websocket import WebsocketConsumer
2019-10-01 09:24:10 +01:00
from structlog import get_logger
2019-04-10 18:03:42 +01:00
from passbook.app_gw.models import ApplicationGatewayProvider
2019-10-01 09:24:10 +01:00
LOGGER = get_logger(__name__)
2019-04-10 18:03:42 +01:00
class ProxyConsumer(WebsocketConsumer):
"""Proxy websocket connection to upstream"""
_headers_dict = {}
_app_gw = None
_client = None
_thread = None
def _fix_headers(self, input_dict):
"""Fix headers from bytestrings to normal strings"""
return {
key.decode('utf-8'): value.decode('utf-8')
for key, value in dict(input_dict).items()
}
def connect(self):
"""Extract host header, lookup in database and proxy connection"""
self._headers_dict = self._fix_headers(dict(self.scope.get('headers')))
host = self._headers_dict.pop('host')
query_string = self.scope.get('query_string').decode('utf-8')
matches = ApplicationGatewayProvider.objects.filter(
server_name__contains=[host],
enabled=True)
if matches.exists():
self._app_gw = matches.first()
# TODO: Get upstream that starts with wss or
upstream = self._app_gw.upstream[0].replace('http', 'ws') + self.scope.get('path')
if query_string:
upstream += '?' + query_string
sslopt = {}
if not self._app_gw.upstream_ssl_verification:
sslopt = {"cert_reqs": CERT_NONE}
self._client = websocket.WebSocketApp(
url=upstream,
subprotocols=self.scope.get('subprotocols'),
header=self._headers_dict,
on_message=self._client_on_message_handler(),
on_error=self._client_on_error_handler(),
on_close=self._client_on_close_handler(),
on_open=self._client_on_open_handler())
LOGGER.debug("Accepting connection for %s", host)
self._thread = threading.Thread(target=lambda: self._client.run_forever(sslopt=sslopt))
self._thread.start()
def _client_on_open_handler(self):
return lambda ws: self.accept(self._client.sock.handshake_response.subprotocol)
def _client_on_message_handler(self):
# pylint: disable=unused-argument,invalid-name
def message_handler(ws, message):
if isinstance(message, str):
self.send(text_data=message)
else:
self.send(bytes_data=message)
return message_handler
def _client_on_error_handler(self):
return lambda ws, error: print(error)
def _client_on_close_handler(self):
return lambda ws: self.disconnect(0)
def disconnect(self, code):
self._client.close()
def receive(self, text_data=None, bytes_data=None):
if text_data:
opcode = websocket.ABNF.OPCODE_TEXT
if bytes_data:
opcode = websocket.ABNF.OPCODE_BINARY
self._client.send(text_data or bytes_data, opcode)