From 24a24b34f200238797962abf924ab15a7a89f142 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Wed, 7 Aug 2024 14:31:23 +0100 Subject: [PATCH] Streaming response, cleaning up unused code --- src/textual_serve/app_service.py | 14 ++- src/textual_serve/download_manager.py | 141 ++++++++++++++------------ src/textual_serve/server.py | 23 ++++- 3 files changed, 106 insertions(+), 72 deletions(-) diff --git a/src/textual_serve/app_service.py b/src/textual_serve/app_service.py index 0387d92..f730939 100644 --- a/src/textual_serve/app_service.py +++ b/src/textual_serve/app_service.py @@ -1,4 +1,5 @@ from __future__ import annotations +from pathlib import Path import msgpack import asyncio @@ -309,7 +310,12 @@ class AppService: try: # Record this delivery key as available for download. delivery_key = str(meta_data["key"]) - await self._download_manager.start_download(delivery_key, self) + await self._download_manager.create_download( + app_service=self, + delivery_key=delivery_key, + file_name=Path(str(meta_data["path"])).name, + open_method=str(meta_data["open_method"]), + ) except KeyError: log.error("Missing key in `deliver_file_start` meta packet") return @@ -330,12 +336,12 @@ class AppService: elif meta_type == "deliver_file_end": try: delivery_key = str(meta_data["key"]) - await self._download_manager.finish_download(self, delivery_key) + await self._download_manager.finish_download(delivery_key) except KeyError: log.error("Missing key in `deliver_file_end` meta packet") return else: - await self._download_manager.finish_download(self, delivery_key) + await self._download_manager.finish_download(delivery_key) else: log.warning( f"Unknown meta type: {meta_type!r}. You may need to update `textual-serve`." @@ -352,4 +358,4 @@ class AppService: # If we receive a chunk, hand it to the download manager to # handle distribution to the browser. _, delivery_key, chunk_bytes = unpacked - await self.download_manager.chunk_received(self, delivery_key, chunk_bytes) + await self._download_manager.chunk_received(self, delivery_key, chunk_bytes) diff --git a/src/textual_serve/download_manager.py b/src/textual_serve/download_manager.py index ffa0b86..9ed7538 100644 --- a/src/textual_serve/download_manager.py +++ b/src/textual_serve/download_manager.py @@ -1,7 +1,8 @@ import asyncio from contextlib import suppress +from dataclasses import dataclass, field import logging -from typing import AsyncGenerator, Tuple +from typing import AsyncGenerator from textual_serve.app_service import AppService @@ -9,8 +10,14 @@ log = logging.getLogger("textual-serve") DOWNLOAD_TIMEOUT = 4 -DownloadKey = Tuple[str, str] -"""A tuple of (app_service_id, delivery_key).""" + +@dataclass +class Download: + app_service: AppService + delivery_key: str + file_name: str + open_method: str + incoming_chunks: asyncio.Queue[bytes | None] = field(default_factory=asyncio.Queue) class DownloadManager: @@ -23,15 +30,10 @@ class DownloadManager: """ def __init__(self): - self.running_app_sessions_lock = asyncio.Lock() - self.running_app_sessions: list[AppService] = [] - """A list of running app sessions. An `AppService` will be added here when a browser - client connects and removed when it disconnects.""" - self._active_downloads_lock = asyncio.Lock() - self._active_downloads: dict[DownloadKey, asyncio.Queue[bytes | None]] = {} - """Set of active deliveries (string 'delivery keys' -> queue of bytes objects). - + self._active_downloads: dict[str, Download] = {} + """A dictionary of active downloads. + When a delivery key is received in a meta packet, it is added to this set. When the user hits the "/download/{key}" endpoint, we ensure the key is in this set and start the download by requesting chunks from the app process. @@ -40,74 +42,63 @@ class DownloadManager: meta packet, and we remove the key from this set. """ - async def register_app_service(self, app_service: AppService) -> None: - """Register an app service with the download manager. - - Args: - app_service: The app service to register. - """ - async with self.running_app_sessions_lock: - self.running_app_sessions.append(app_service) - - async def unregister_app_service(self, app_service: AppService) -> None: - """Unregister an app service from the download manager. - - Args: - app_service: The app service to unregister. - """ - # TODO - remove any downloads for this app service. - async with self.running_app_sessions_lock: - self.running_app_sessions.remove(app_service) - - async def start_download(self, app_service: AppService, delivery_key: str) -> None: - """Start a download for the given delivery key on the given app service. + async def create_download( + self, + *, + app_service: AppService, + delivery_key: str, + file_name: str, + open_method: str, + ) -> None: + """Prepare for a new download. Args: app_service: The app service to start the download for. delivery_key: The delivery key to start the download for. + file_name: The name of the file to download. + open_method: The method to open the file with. """ - async with self.running_app_sessions_lock: - if app_service not in self.running_app_sessions: - raise ValueError("App service not registered.") + async with self._active_downloads_lock: + self._active_downloads[delivery_key] = Download( + app_service, + delivery_key, + file_name, + open_method, + ) - # Create a queue to write the received chunks to. - self._active_downloads[(app_service.app_service_id, delivery_key)] = ( - asyncio.Queue[bytes | None]() - ) - - async def finish_download(self, app_service: AppService, delivery_key: str) -> None: + async def finish_download(self, delivery_key: str) -> None: """Finish a download for the given delivery key. Args: - app_service: The app service to finish the download for. delivery_key: The delivery key to finish the download for. """ - download_key = (app_service.app_service_id, delivery_key) try: - queue = self._active_downloads[download_key] + download = self._active_downloads[delivery_key] except KeyError: - log.error(f"Download {download_key!r} not found") + log.error(f"Download {delivery_key!r} not found") return # Shut down the download queue. Attempt graceful shutdown, but # timeout after DOWNLOAD_TIMEOUT seconds if the queue doesn't clear. - await queue.put(None) + await download.incoming_chunks.put(None) with suppress(asyncio.TimeoutError): - await asyncio.wait_for(queue.join(), timeout=DOWNLOAD_TIMEOUT) + await asyncio.wait_for( + download.incoming_chunks.join(), timeout=DOWNLOAD_TIMEOUT + ) - del self._active_downloads[download_key] + async with self._active_downloads_lock: + del self._active_downloads[delivery_key] - async def download( - self, app_service: AppService, delivery_key: str - ) -> AsyncGenerator[bytes, None]: + async def download(self, delivery_key: str) -> AsyncGenerator[bytes, None]: """Download a file from the given app service. Args: - app_service: The app service to download from. delivery_key: The delivery key to download. """ - download_key: DownloadKey = (app_service.app_service_id, delivery_key) - download_queue = self._active_downloads[download_key] + + app_service = await self._get_app_service(delivery_key) + download = self._active_downloads[delivery_key] + incoming_chunks = download.incoming_chunks while True: # Request a chunk from the app service. @@ -119,25 +110,45 @@ class DownloadManager: } ) - chunk = await download_queue.get() + chunk = await incoming_chunks.get() if chunk is None: # The app process has finished sending the file. - download_queue.task_done() + incoming_chunks.task_done() raise StopAsyncIteration else: - download_queue.task_done() + incoming_chunks.task_done() yield chunk - async def chunk_received( - self, app_service: AppService, delivery_key: str, chunk: bytes - ) -> None: - """Handle a chunk received from the app service. + async def chunk_received(self, delivery_key: str, chunk: bytes) -> None: + """Handle a chunk received from the app service for a download. Args: - app_service: The app service that received the chunk. delivery_key: The delivery key that the chunk was received for. chunk: The chunk that was received. """ - download_key = (app_service.app_service_id, delivery_key) - queue = self._active_downloads[download_key] - await queue.put(chunk) + download = self._active_downloads[delivery_key] + await download.incoming_chunks.put(chunk) + + async def _get_app_service(self, delivery_key: str) -> AppService: + """Get the app service that the given delivery key is linked to. + + Args: + delivery_key: The delivery key to get the app service for. + """ + async with self._active_downloads_lock: + for key in self._active_downloads.keys(): + if key == delivery_key: + return self._active_downloads[key].app_service + else: + raise ValueError( + f"No active download for delivery key {delivery_key!r}" + ) + + async def get_download_metadata(self, delivery_key: str) -> Download: + """Get the metadata for a download. + + Args: + delivery_key: The delivery key to get the metadata for. + """ + async with self._active_downloads_lock: + return self._active_downloads[delivery_key] diff --git a/src/textual_serve/server.py b/src/textual_serve/server.py index 48f66af..2f41804 100644 --- a/src/textual_serve/server.py +++ b/src/textual_serve/server.py @@ -152,11 +152,28 @@ class Server: app.on_shutdown.append(self.on_shutdown) return app - async def handle_download(self, request: web.Request) -> web.Response: + async def handle_download(self, request: web.Request) -> web.StreamResponse: """Handle a download request.""" key = request.match_info["key"] - # TODO - return web.Response() + + download_meta = await self.download_manager.get_download_metadata(key) + download_stream = self.download_manager.download(key) + + response = web.StreamResponse() + response.headers["Content-Type"] = "application/octet-stream" + disposition = ( + "attachment" if download_meta.open_method == "download" else "inline" + ) + response.headers["Content-Disposition"] = ( + f"{disposition}; filename={download_meta.file_name}" + ) + await response.prepare(request) + + async for chunk in download_stream: + await response.write(chunk) + + await response.write_eof() + return response async def on_shutdown(self, app: web.Application) -> None: """Called on shutdown.