Files
textual-web/src/textual_web/ganglion_client.py
Will McGugan 0576678d63 first commit
2023-08-15 14:21:06 +01:00

340 lines
12 KiB
Python

from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING, cast
from functools import partial
import aiohttp
import logging
import msgpack
import signal
from pathlib import Path
from .packets import (
NotifyTerminalSize,
SessionClose,
SessionData,
RoutePing,
RoutePong,
)
from .environment import Environment
from .session import SessionConnector
from .session_manager import SessionManager
from .types import RouteKey, SessionID, Meta
from . import packets
from . import constants
from .packets import Packet, Handlers, PACKET_MAP
from .retry import Retry
from .terminal_session import Poller
if TYPE_CHECKING:
from .config import Config
log = logging.getLogger("textual-web")
PacketDataType = int | bytes | str | None
class PacketError(Exception):
"""A packet error."""
class _ClientConnector(SessionConnector):
def __init__(
self, client: GanglionClient, session_id: SessionID, route_key: RouteKey
) -> None:
self.client = client
self.session_id = session_id
self.route_key = route_key
async def on_data(self, data: bytes) -> None:
"""Data received from the process."""
await self.client.send(packets.SessionData(self.route_key, data))
async def on_meta(self, meta: Meta) -> None:
pass
async def on_close(self) -> None:
await self.client.send(packets.SessionClose(self.session_id, self.route_key))
self.client.session_manager.on_session_end(self.session_id)
class GanglionClient(Handlers):
"""Manages a connection to a ganglion server."""
def __init__(
self,
config_path: str,
config: Config,
environment: Environment,
api_key: str | None,
) -> None:
self.environment = environment
self.websocket_url = environment.url
path = Path(config_path).absolute().parent
self.config = config
self.api_key = api_key
self._websocket: aiohttp.ClientWebSocketResponse | None = None
self._poll_reader = Poller()
self.session_manager = SessionManager(self._poll_reader, path, config.apps)
self.exit_event = asyncio.Event()
self._task: asyncio.Task | None = None
def add_app(self, name: str, command: str, slug: str = "") -> None:
"""Add a new app
Args:
name: Name of the app.
command: Command to run the app.
slug: Slug used in URL, or blank to auto-generate on server.
"""
self.session_manager.add_app(name, command, slug=slug)
def add_terminal(self, name: str, command: str, slug: str = "") -> None:
"""Add a new terminal.
Args:
name: Name of the app.
command: Command to run the app.
slug: Slug used in URL, or blank to auto-generate on server.
"""
self.session_manager.add_app(name, command, slug=slug, terminal=True)
@classmethod
def decode_envelope(
cls, packet_envelope: tuple[PacketDataType, ...]
) -> Packet | None:
"""Decode a packet envelope.
Packet envelopes are a list where the first value is an integer denoting the type.
The type is used to look up the appropriate Packet class which is instantiated with
the rest of the data.
If the envelope contains *more* data than required, then that data is silently dropped.
This is to provide an extension mechanism.
Raises:
PacketError: If the packet_envelope is empty.
PacketError: If the packet type is not an int.
Returns:
One of the Packet classes defined in packets.py or None if the packet was of an unknown type.
"""
if not packet_envelope:
raise PacketError("Packet data is empty")
packet_data: list[PacketDataType]
packet_type, *packet_data = packet_envelope
if not isinstance(packet_type, int):
raise PacketError(f"Packet id expected int, found {packet_type!r}")
packet_class = PACKET_MAP.get(packet_type, None)
if packet_class is None:
return None
try:
packet = packet_class.build(*packet_data[: len(packet_class._attributes)])
except TypeError as error:
raise PacketError(f"Packet failed to validate; {error}")
return packet
async def run(self) -> None:
"""Run the connection loop."""
try:
await self._run()
finally:
# Shut down the poller thread
self._poll_reader.exit()
def on_keyboard_interrupt(self) -> None:
"""Signal handler to respond to keyboard interrupt."""
print(
"\r\033[F"
) # Move to start of line, to overwrite "^C" written by the shell (?)
log.info("Exit requested")
self.exit_event.set()
if self._task is not None:
self._task.cancel()
async def _run(self) -> None:
loop = asyncio.get_event_loop()
loop.add_signal_handler(signal.SIGINT, self.on_keyboard_interrupt)
self._poll_reader.set_loop(loop)
self._poll_reader.start()
self._task = asyncio.create_task(self.connect())
await self._task
async def connect(self) -> None:
"""Connect to the Ganglion server."""
try:
await self._connect()
except asyncio.CancelledError:
pass
async def _connect(self) -> None:
"""Internal connect."""
api_key = self.config.account.api_key or self.api_key or None
if api_key:
headers = {"GANGLIONAPIKEY": api_key}
else:
headers = {}
retry = Retry()
async for retry_count in retry:
if self.exit_event.is_set():
break
try:
if retry_count == 1:
log.info("connecting to Ganglion")
async with aiohttp.ClientSession() as session:
async with session.ws_connect(
self.websocket_url, headers=headers
) as websocket:
self._websocket = websocket
retry.success()
await self.post_connect()
try:
await self.run_websocket(websocket, retry)
finally:
self._websocket = None
log.info("Disconnected from Ganglion")
if self.exit_event.is_set():
break
except asyncio.CancelledError:
raise
except Exception as error:
if retry_count == 1:
log.warning(
"Unable to connect to Ganglion server. Will reattempt connection soon."
)
if constants.DEBUG:
log.error("Unable to connect; %s", error)
async def run_websocket(
self, websocket: aiohttp.ClientWebSocketResponse, retry: Retry
) -> None:
"""Run the websocket loop.
Args:
websocket: Websocket.
"""
unpackb = partial(msgpack.unpackb, use_list=True, raw=False)
BINARY = aiohttp.WSMsgType.BINARY
async def run_messages() -> None:
"""Read, decode, and dispatch websocket messages."""
async for message in websocket:
if message.type == BINARY:
try:
envelope = unpackb(message.data)
except Exception:
log.error(f"Unable to decode {message.data!r}")
else:
packet = self.decode_envelope(envelope)
log.debug("<RECV> %r", packet)
if packet is not None:
try:
await self.dispatch_packet(packet)
except Exception:
log.exception("error processing %r", packet)
elif message.type == aiohttp.WSMsgType.ERROR:
break
try:
await run_messages()
except asyncio.CancelledError:
retry.done()
await self.session_manager.close_all()
await websocket.close(message=b"Close requested")
try:
await run_messages()
except asyncio.CancelledError:
pass
except ConnectionResetError:
log.info("connection reset")
except Exception as error:
log.exception(str(error))
async def post_connect(self) -> None:
"""Called immediately after connecting to server."""
# Inform the server about our apps
apps = [
app.model_dump(include={"name", "slug", "color", "terminal"})
for app in self.config.apps
]
await self.send(packets.DeclareApps(apps))
async def send(self, packet: Packet) -> bool:
"""Send a packet.
Args:
packet: Packet to send.
Returns:
bool: `True` if the packet was sent, otherwise `False`.
"""
if self._websocket is None:
log.warning("Failed to send %r", packet)
return False
packet_bytes = msgpack.packb(packet, use_bin_type=True)
try:
await self._websocket.send_bytes(packet_bytes)
except Exception as error:
log.warning("Failed to send %r; %s", packet, error)
return False
else:
log.debug("<SEND> %r", packet)
return True
async def on_ping(self, packet: packets.Ping) -> None:
"""Sent by the server."""
# Reply to a Ping with an immediate Pong.
await self.send(packets.Pong(packet.data))
async def on_log(self, packet: packets.Log) -> None:
"""A log message sent by the server."""
log.debug(f"<ganglion> {packet.message}")
async def on_info(self, packet: packets.Info) -> None:
"""An info message (higher priority log) sent by the server."""
log.info(f"<ganglion> {packet.message}")
async def on_session_open(self, packet: packets.SessionOpen) -> None:
route_key = packet.route_key
session_process = await self.session_manager.new_session(
packet.application_slug,
SessionID(packet.session_id),
RouteKey(packet.route_key),
)
assert session_process is not None # TODO: handle session open failed
connector = _ClientConnector(
self, cast(SessionID, packet.session_id), cast(RouteKey, route_key)
)
await session_process.start(connector)
async def on_session_close(self, packet: SessionClose) -> None:
session_id = SessionID(packet.session_id)
session_process = self.session_manager.get_session(session_id)
await self.session_manager.close_session(session_id)
async def on_session_data(self, packet: SessionData) -> None:
session_process = self.session_manager.get_session_by_route_key(
RouteKey(packet.route_key)
)
if session_process is not None:
await session_process.send_bytes(packet.data)
async def on_notify_terminal_size(self, packet: NotifyTerminalSize) -> None:
session_process = self.session_manager.get_session(SessionID(packet.session_id))
if session_process is not None:
await session_process.set_terminal_size(packet.width, packet.height)
async def on_route_ping(self, packet: RoutePing) -> None:
await self.send(RoutePong(packet.route_key, packet.data))