Extracting logic into DownloadManager

This commit is contained in:
Darren Burns
2024-08-07 11:07:04 +01:00
parent 76d6f45a09
commit 87b9f0cfbc
3 changed files with 103 additions and 20 deletions

View File

@@ -10,12 +10,13 @@ from asyncio.subprocess import Process
import logging
from importlib.metadata import version
import uuid
import rich.repr
log = logging.getLogger("textual-serve")
from textual_serve.download_manager import DownloadManager
DOWNLOAD_TIMEOUT = 4
log = logging.getLogger("textual-serve")
@rich.repr.auto
@@ -33,8 +34,11 @@ class AppService:
write_bytes: Callable[[bytes], Awaitable[None]],
write_str: Callable[[str], Awaitable[None]],
close: Callable[[], Awaitable[None]],
download_manager: DownloadManager,
debug: bool = False,
) -> None:
self.app_service_id: str = uuid.uuid4().hex
"""The unique ID of this running app service."""
self.command = command
"""The command to launch the Textual app subprocess."""
self.remote_write_bytes = write_bytes
@@ -50,17 +54,7 @@ class AppService:
self._task: asyncio.Task[None] | None = None
self._stdin: asyncio.StreamWriter | None = None
self._exit_event = asyncio.Event()
self._active_downloads: dict[str, asyncio.Queue[bytes | None]] = {}
"""Set of active deliveries (string 'delivery keys' -> queue of bytes objects).
When a delivery key is received in a meta packet, it is added to this set.
When the user hits the "/download/{key}" endpoint, we ensure the key is in
this set and start the download by requesting chunks from the app process.
When the download is complete, the app process sends a "deliver_file_end"
meta packet, and we remove the key from this set.
"""
self.download_manager = download_manager
@property
def stdin(self) -> asyncio.StreamWriter:
@@ -315,7 +309,7 @@ class AppService:
try:
# Record this delivery key as available for download.
delivery_key = str(meta_data["key"])
self._active_downloads[delivery_key] = asyncio.Queue[bytes | None]()
await self.download_manager.start_download(delivery_key, self)
except KeyError:
log.error("Missing key in `deliver_file_start` meta packet")
return
@@ -335,16 +329,13 @@ class AppService:
# )
elif meta_type == "deliver_file_end":
try:
key = str(meta_data["key"])
delivery_key = str(meta_data["key"])
await self.download_manager.finish_download(delivery_key)
except KeyError:
log.error("Missing key in `deliver_file_end` meta packet")
return
else:
queue = self._active_downloads[key]
await queue.put(None)
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(queue.join(), timeout=DOWNLOAD_TIMEOUT)
del self._active_downloads[key]
await self.download_manager.finish_download(delivery_key)
else:
log.warning(
f"Unknown meta type: {meta_type!r}. You may need to update `textual-serve`."

View File

@@ -0,0 +1,87 @@
import asyncio
from contextlib import suppress
from typing import Tuple
from textual_serve.app_service import AppService
DOWNLOAD_TIMEOUT = 4
DownloadKey = Tuple[str, str]
"""A tuple of (app_service_id, delivery_key)."""
class DownloadManager:
"""Class which manages downloads for the server.
Serves as the link between the web server and app processes during downloads.
A single server has a single download manager, which manages all downloads for all
running app processes.
"""
def __init__(self):
self.running_app_sessions_lock = asyncio.Lock()
self.running_app_sessions: list[AppService] = []
"""A list of running app sessions. An `AppService` will be added here when a browser
client connects and removed when it disconnects."""
self._active_downloads_lock = asyncio.Lock()
self._active_downloads: dict[DownloadKey, asyncio.Queue[bytes | None]] = {}
"""Set of active deliveries (string 'delivery keys' -> queue of bytes objects).
When a delivery key is received in a meta packet, it is added to this set.
When the user hits the "/download/{key}" endpoint, we ensure the key is in
this set and start the download by requesting chunks from the app process.
When the download is complete, the app process sends a "deliver_file_end"
meta packet, and we remove the key from this set.
"""
async def register_app_service(self, app_service: AppService) -> None:
"""Register an app service with the download manager.
Args:
app_service: The app service to register.
"""
async with self.running_app_sessions_lock:
self.running_app_sessions.append(app_service)
async def unregister_app_service(self, app_service: AppService) -> None:
"""Unregister an app service from the download manager.
Args:
app_service: The app service to unregister.
"""
# TODO - remove any downloads for this app service.
async with self.running_app_sessions_lock:
self.running_app_sessions.remove(app_service)
async def start_download(self, app_service: AppService, delivery_key: str) -> None:
"""Start a download for the given delivery key on the given app service.
Args:
app_service: The app service to start the download for.
delivery_key: The delivery key to start the download for.
"""
async with self.running_app_sessions_lock:
if app_service not in self.running_app_sessions:
raise ValueError("App service not registered")
# Create a queue to write the received chunks to.
self._active_downloads[(app_service.app_service_id, delivery_key)] = (
asyncio.Queue[bytes | None]()
)
async def finish_download(self, app_service: AppService, delivery_key: str) -> None:
"""Finish a download for the given delivery key.
Args:
app_service: The app service to finish the download for.
delivery_key: The delivery key to finish the download for.
"""
download_key = (app_service.app_service_id, delivery_key)
queue = self._active_downloads[download_key]
await queue.put(None)
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(queue.join(), timeout=DOWNLOAD_TIMEOUT)
del self._active_downloads[download_key]

View File

@@ -22,6 +22,8 @@ from rich.console import Console
from rich.logging import RichHandler
from rich.highlighter import RegexHighlighter
from textual_serve.download_manager import DownloadManager
from .app_service import AppService
log = logging.getLogger("textual-serve")
@@ -100,6 +102,7 @@ class Server:
self.statics_path = base_path / statics_path
self.templates_path = base_path / templates_path
self.console = Console()
self.download_manager = DownloadManager()
def initialize_logging(self) -> None:
"""Initialize logging.
@@ -152,6 +155,7 @@ class Server:
async def handle_download(self, request: web.Request) -> web.Response:
"""Handle a download request."""
key = request.match_info["key"]
# TODO
return web.Response()
async def on_shutdown(self, app: web.Application) -> None:
@@ -292,6 +296,7 @@ class Server:
write_bytes=websocket.send_bytes,
write_str=websocket.send_str,
close=websocket.close,
download_manager=self.download_manager,
debug=self.debug,
)
await app_service.start(width, height)