Streaming response, cleaning up unused code

This commit is contained in:
Darren Burns
2024-08-07 14:31:23 +01:00
parent 1d6ef49f0f
commit 24a24b34f2
3 changed files with 106 additions and 72 deletions

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
from pathlib import Path
import msgpack
import asyncio
@@ -309,7 +310,12 @@ class AppService:
try:
# Record this delivery key as available for download.
delivery_key = str(meta_data["key"])
await self._download_manager.start_download(delivery_key, self)
await self._download_manager.create_download(
app_service=self,
delivery_key=delivery_key,
file_name=Path(str(meta_data["path"])).name,
open_method=str(meta_data["open_method"]),
)
except KeyError:
log.error("Missing key in `deliver_file_start` meta packet")
return
@@ -330,12 +336,12 @@ class AppService:
elif meta_type == "deliver_file_end":
try:
delivery_key = str(meta_data["key"])
await self._download_manager.finish_download(self, delivery_key)
await self._download_manager.finish_download(delivery_key)
except KeyError:
log.error("Missing key in `deliver_file_end` meta packet")
return
else:
await self._download_manager.finish_download(self, delivery_key)
await self._download_manager.finish_download(delivery_key)
else:
log.warning(
f"Unknown meta type: {meta_type!r}. You may need to update `textual-serve`."
@@ -352,4 +358,4 @@ class AppService:
# If we receive a chunk, hand it to the download manager to
# handle distribution to the browser.
_, delivery_key, chunk_bytes = unpacked
await self.download_manager.chunk_received(self, delivery_key, chunk_bytes)
await self._download_manager.chunk_received(self, delivery_key, chunk_bytes)

View File

@@ -1,7 +1,8 @@
import asyncio
from contextlib import suppress
from dataclasses import dataclass, field
import logging
from typing import AsyncGenerator, Tuple
from typing import AsyncGenerator
from textual_serve.app_service import AppService
@@ -9,8 +10,14 @@ log = logging.getLogger("textual-serve")
DOWNLOAD_TIMEOUT = 4
DownloadKey = Tuple[str, str]
"""A tuple of (app_service_id, delivery_key)."""
@dataclass
class Download:
app_service: AppService
delivery_key: str
file_name: str
open_method: str
incoming_chunks: asyncio.Queue[bytes | None] = field(default_factory=asyncio.Queue)
class DownloadManager:
@@ -23,14 +30,9 @@ class DownloadManager:
"""
def __init__(self):
self.running_app_sessions_lock = asyncio.Lock()
self.running_app_sessions: list[AppService] = []
"""A list of running app sessions. An `AppService` will be added here when a browser
client connects and removed when it disconnects."""
self._active_downloads_lock = asyncio.Lock()
self._active_downloads: dict[DownloadKey, asyncio.Queue[bytes | None]] = {}
"""Set of active deliveries (string 'delivery keys' -> queue of bytes objects).
self._active_downloads: dict[str, Download] = {}
"""A dictionary of active downloads.
When a delivery key is received in a meta packet, it is added to this set.
When the user hits the "/download/{key}" endpoint, we ensure the key is in
@@ -40,74 +42,63 @@ class DownloadManager:
meta packet, and we remove the key from this set.
"""
async def register_app_service(self, app_service: AppService) -> None:
"""Register an app service with the download manager.
Args:
app_service: The app service to register.
"""
async with self.running_app_sessions_lock:
self.running_app_sessions.append(app_service)
async def unregister_app_service(self, app_service: AppService) -> None:
"""Unregister an app service from the download manager.
Args:
app_service: The app service to unregister.
"""
# TODO - remove any downloads for this app service.
async with self.running_app_sessions_lock:
self.running_app_sessions.remove(app_service)
async def start_download(self, app_service: AppService, delivery_key: str) -> None:
"""Start a download for the given delivery key on the given app service.
async def create_download(
self,
*,
app_service: AppService,
delivery_key: str,
file_name: str,
open_method: str,
) -> None:
"""Prepare for a new download.
Args:
app_service: The app service to start the download for.
delivery_key: The delivery key to start the download for.
file_name: The name of the file to download.
open_method: The method to open the file with.
"""
async with self.running_app_sessions_lock:
if app_service not in self.running_app_sessions:
raise ValueError("App service not registered.")
# Create a queue to write the received chunks to.
self._active_downloads[(app_service.app_service_id, delivery_key)] = (
asyncio.Queue[bytes | None]()
async with self._active_downloads_lock:
self._active_downloads[delivery_key] = Download(
app_service,
delivery_key,
file_name,
open_method,
)
async def finish_download(self, app_service: AppService, delivery_key: str) -> None:
async def finish_download(self, delivery_key: str) -> None:
"""Finish a download for the given delivery key.
Args:
app_service: The app service to finish the download for.
delivery_key: The delivery key to finish the download for.
"""
download_key = (app_service.app_service_id, delivery_key)
try:
queue = self._active_downloads[download_key]
download = self._active_downloads[delivery_key]
except KeyError:
log.error(f"Download {download_key!r} not found")
log.error(f"Download {delivery_key!r} not found")
return
# Shut down the download queue. Attempt graceful shutdown, but
# timeout after DOWNLOAD_TIMEOUT seconds if the queue doesn't clear.
await queue.put(None)
await download.incoming_chunks.put(None)
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(queue.join(), timeout=DOWNLOAD_TIMEOUT)
await asyncio.wait_for(
download.incoming_chunks.join(), timeout=DOWNLOAD_TIMEOUT
)
del self._active_downloads[download_key]
async with self._active_downloads_lock:
del self._active_downloads[delivery_key]
async def download(
self, app_service: AppService, delivery_key: str
) -> AsyncGenerator[bytes, None]:
async def download(self, delivery_key: str) -> AsyncGenerator[bytes, None]:
"""Download a file from the given app service.
Args:
app_service: The app service to download from.
delivery_key: The delivery key to download.
"""
download_key: DownloadKey = (app_service.app_service_id, delivery_key)
download_queue = self._active_downloads[download_key]
app_service = await self._get_app_service(delivery_key)
download = self._active_downloads[delivery_key]
incoming_chunks = download.incoming_chunks
while True:
# Request a chunk from the app service.
@@ -119,25 +110,45 @@ class DownloadManager:
}
)
chunk = await download_queue.get()
chunk = await incoming_chunks.get()
if chunk is None:
# The app process has finished sending the file.
download_queue.task_done()
incoming_chunks.task_done()
raise StopAsyncIteration
else:
download_queue.task_done()
incoming_chunks.task_done()
yield chunk
async def chunk_received(
self, app_service: AppService, delivery_key: str, chunk: bytes
) -> None:
"""Handle a chunk received from the app service.
async def chunk_received(self, delivery_key: str, chunk: bytes) -> None:
"""Handle a chunk received from the app service for a download.
Args:
app_service: The app service that received the chunk.
delivery_key: The delivery key that the chunk was received for.
chunk: The chunk that was received.
"""
download_key = (app_service.app_service_id, delivery_key)
queue = self._active_downloads[download_key]
await queue.put(chunk)
download = self._active_downloads[delivery_key]
await download.incoming_chunks.put(chunk)
async def _get_app_service(self, delivery_key: str) -> AppService:
"""Get the app service that the given delivery key is linked to.
Args:
delivery_key: The delivery key to get the app service for.
"""
async with self._active_downloads_lock:
for key in self._active_downloads.keys():
if key == delivery_key:
return self._active_downloads[key].app_service
else:
raise ValueError(
f"No active download for delivery key {delivery_key!r}"
)
async def get_download_metadata(self, delivery_key: str) -> Download:
"""Get the metadata for a download.
Args:
delivery_key: The delivery key to get the metadata for.
"""
async with self._active_downloads_lock:
return self._active_downloads[delivery_key]

View File

@@ -152,11 +152,28 @@ class Server:
app.on_shutdown.append(self.on_shutdown)
return app
async def handle_download(self, request: web.Request) -> web.Response:
async def handle_download(self, request: web.Request) -> web.StreamResponse:
"""Handle a download request."""
key = request.match_info["key"]
# TODO
return web.Response()
download_meta = await self.download_manager.get_download_metadata(key)
download_stream = self.download_manager.download(key)
response = web.StreamResponse()
response.headers["Content-Type"] = "application/octet-stream"
disposition = (
"attachment" if download_meta.open_method == "download" else "inline"
)
response.headers["Content-Disposition"] = (
f"{disposition}; filename={download_meta.file_name}"
)
await response.prepare(request)
async for chunk in download_stream:
await response.write(chunk)
await response.write_eof()
return response
async def on_shutdown(self, app: web.Application) -> None:
"""Called on shutdown.