Handling chunks in an async generator

This commit is contained in:
Darren Burns
2024-08-07 11:44:27 +01:00
parent f9a4f53d15
commit c24eae2ba8
5 changed files with 63 additions and 7 deletions

View File

@@ -12,6 +12,7 @@ dependencies = [
"textual>=0.66.0",
"msgpack>=1.0.8",
"rich",
"msgpack-types>=0.3.0",
]
readme = "README.md"
requires-python = ">= 3.8"

View File

@@ -56,6 +56,8 @@ mdurl==0.1.2
# via markdown-it-py
msgpack==1.0.8
# via textual-serve
msgpack-types==0.3.0
# via textual-serve
multidict==6.0.5
# via aiohttp
# via yarl

View File

@@ -42,6 +42,8 @@ mdurl==0.1.2
# via markdown-it-py
msgpack==1.0.8
# via textual-serve
msgpack-types==0.3.0
# via textual-serve
multidict==6.0.5
# via aiohttp
# via yarl

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import msgpack
import asyncio
from contextlib import suppress
import io
import json
import os
@@ -54,7 +54,7 @@ class AppService:
self._task: asyncio.Task[None] | None = None
self._stdin: asyncio.StreamWriter | None = None
self._exit_event = asyncio.Event()
self.download_manager = download_manager
self._download_manager = download_manager
@property
def stdin(self) -> asyncio.StreamWriter:
@@ -309,7 +309,7 @@ class AppService:
try:
# Record this delivery key as available for download.
delivery_key = str(meta_data["key"])
await self.download_manager.start_download(delivery_key, self)
await self._download_manager.start_download(delivery_key, self)
except KeyError:
log.error("Missing key in `deliver_file_start` meta packet")
return
@@ -330,12 +330,12 @@ class AppService:
elif meta_type == "deliver_file_end":
try:
delivery_key = str(meta_data["key"])
await self.download_manager.finish_download(delivery_key)
await self._download_manager.finish_download(self, delivery_key)
except KeyError:
log.error("Missing key in `deliver_file_end` meta packet")
return
else:
await self.download_manager.finish_download(delivery_key)
await self._download_manager.finish_download(self, delivery_key)
else:
log.warning(
f"Unknown meta type: {meta_type!r}. You may need to update `textual-serve`."
@@ -347,3 +347,9 @@ class AppService:
Args:
payload: Encoded packed data.
"""
unpacked = msgpack.unpackb(payload)
if unpacked[0] == "deliver_file_chunk":
# If we receive a chunk, hand it to the download manager to
# handle distribution to the browser.
_, delivery_key, chunk_bytes = unpacked
await self.download_manager.chunk_received(self, delivery_key, chunk)

View File

@@ -1,7 +1,7 @@
import asyncio
from contextlib import suppress
import logging
from typing import Tuple
from typing import AsyncGenerator, Tuple
from textual_serve.app_service import AppService
@@ -68,7 +68,7 @@ class DownloadManager:
"""
async with self.running_app_sessions_lock:
if app_service not in self.running_app_sessions:
raise ValueError("App service not registered")
raise ValueError("App service not registered.")
# Create a queue to write the received chunks to.
self._active_downloads[(app_service.app_service_id, delivery_key)] = (
@@ -96,3 +96,48 @@ class DownloadManager:
await asyncio.wait_for(queue.join(), timeout=DOWNLOAD_TIMEOUT)
del self._active_downloads[download_key]
async def download(
self, app_service: AppService, delivery_key: str
) -> AsyncGenerator[bytes, None]:
"""Download a file from the given app service.
Args:
app_service: The app service to download from.
delivery_key: The delivery key to download.
"""
download_key: DownloadKey = (app_service.app_service_id, delivery_key)
download_queue = self._active_downloads[download_key]
while True:
# Request a chunk from the app service.
await app_service.send_meta(
{
"type": "deliver_chunk_request",
"key": delivery_key,
"size": 1024 * 64,
}
)
chunk = await download_queue.get()
if chunk is None:
# The app process has finished sending the file.
download_queue.task_done()
raise StopAsyncIteration
else:
download_queue.task_done()
yield chunk
async def chunk_received(
self, app_service: AppService, delivery_key: str, chunk: bytes
) -> None:
"""Handle a chunk received from the app service.
Args:
app_service: The app service that received the chunk.
delivery_key: The delivery key that the chunk was received for.
chunk: The chunk that was received.
"""
download_key = (app_service.app_service_id, delivery_key)
queue = self._active_downloads[download_key]
await queue.put(chunk)