mirror of
https://github.com/Textualize/textual.git
synced 2025-10-17 02:38:12 +03:00
Additional testing for devtools client/server
This commit is contained in:
@@ -30,7 +30,7 @@ from ._event_broker import extract_handler_actions, NoHandler
|
|||||||
from ._profile import timer
|
from ._profile import timer
|
||||||
from .binding import Bindings, NoBinding
|
from .binding import Bindings, NoBinding
|
||||||
from .css.stylesheet import Stylesheet, StylesheetParseError, StylesheetError
|
from .css.stylesheet import Stylesheet, StylesheetParseError, StylesheetError
|
||||||
from .devtools_client import DevtoolsClient
|
from .devtools_client import DevtoolsClient, DevtoolsConnectionError
|
||||||
from .dom import DOMNode
|
from .dom import DOMNode
|
||||||
from .driver import Driver
|
from .driver import Driver
|
||||||
from .file_monitor import FileMonitor
|
from .file_monitor import FileMonitor
|
||||||
@@ -415,10 +415,15 @@ class App(DOMNode):
|
|||||||
log(f"driver={self.driver_class}")
|
log(f"driver={self.driver_class}")
|
||||||
|
|
||||||
if os.getenv("TEXTUAL_DEVTOOLS") == "1":
|
if os.getenv("TEXTUAL_DEVTOOLS") == "1":
|
||||||
await self.devtools.connect()
|
try:
|
||||||
self.log_file.write(f"Connected to devtools ({self.devtools.url})\n")
|
await self.devtools.connect()
|
||||||
self.log_file.flush()
|
self.log_file.write(f"Connected to devtools ({self.devtools.url})\n")
|
||||||
|
self.log_file.flush()
|
||||||
|
except DevtoolsConnectionError:
|
||||||
|
self.log_file.write(
|
||||||
|
f"Couldn't connect to devtools ({self.devtools.url})\n"
|
||||||
|
)
|
||||||
|
self.log_file.flush()
|
||||||
try:
|
try:
|
||||||
if self.css_file is not None:
|
if self.css_file is not None:
|
||||||
self.stylesheet.read(self.css_file)
|
self.stylesheet.read(self.css_file)
|
||||||
@@ -456,7 +461,8 @@ class App(DOMNode):
|
|||||||
await self.animator.start()
|
await self.animator.start()
|
||||||
await super().process_messages()
|
await super().process_messages()
|
||||||
log("PROCESS END")
|
log("PROCESS END")
|
||||||
await self._disconnect_devtools()
|
if self.devtools.is_connected:
|
||||||
|
await self._disconnect_devtools()
|
||||||
with timer("animator.stop()"):
|
with timer("animator.stop()"):
|
||||||
await self.animator.stop()
|
await self.animator.stop()
|
||||||
with timer("self.close_all()"):
|
with timer("self.close_all()"):
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
@@ -20,12 +22,13 @@ from aiohttp.web_ws import WebSocketResponse
|
|||||||
from dateutil.tz import tz
|
from dateutil.tz import tz
|
||||||
from rich.align import Align
|
from rich.align import Align
|
||||||
from rich.console import Console, ConsoleOptions, RenderResult
|
from rich.console import Console, ConsoleOptions, RenderResult
|
||||||
|
from rich.markup import escape
|
||||||
from rich.rule import Rule
|
from rich.rule import Rule
|
||||||
from rich.segment import Segments, Segment
|
from rich.segment import Segments, Segment
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
|
|
||||||
DEFAULT_PORT = 8081
|
DEFAULT_PORT = 8081
|
||||||
SIZE_CHANGE_POLL_DELAY_SECONDS = 2
|
DEFAULT_SIZE_CHANGE_POLL_DELAY_SECONDS = 2
|
||||||
QUEUEABLE_TYPES = {"client_log", "client_spillover"}
|
QUEUEABLE_TYPES = {"client_log", "client_spillover"}
|
||||||
|
|
||||||
|
|
||||||
@@ -82,7 +85,9 @@ class DevtoolsInternalMessage:
|
|||||||
yield Rule(self.message, style=level_to_style.get(self.level, "dim"))
|
yield Rule(self.message, style=level_to_style.get(self.level, "dim"))
|
||||||
|
|
||||||
|
|
||||||
async def enqueue_size_changes(console: Console, outgoing_queue: Queue):
|
async def _enqueue_size_changes(
|
||||||
|
console: Console, outgoing_queue: Queue, poll_delay: int
|
||||||
|
):
|
||||||
current_width = console.width
|
current_width = console.width
|
||||||
current_height = console.height
|
current_height = console.height
|
||||||
while True:
|
while True:
|
||||||
@@ -90,13 +95,13 @@ async def enqueue_size_changes(console: Console, outgoing_queue: Queue):
|
|||||||
height = console.height
|
height = console.height
|
||||||
dimensions_changed = width != current_width or height != current_height
|
dimensions_changed = width != current_width or height != current_height
|
||||||
if dimensions_changed:
|
if dimensions_changed:
|
||||||
await enqueue_server_info(outgoing_queue, width, height)
|
await _enqueue_server_info(outgoing_queue, width, height)
|
||||||
current_width = width
|
current_width = width
|
||||||
current_height = height
|
current_height = height
|
||||||
await asyncio.sleep(SIZE_CHANGE_POLL_DELAY_SECONDS)
|
await asyncio.sleep(poll_delay)
|
||||||
|
|
||||||
|
|
||||||
async def enqueue_server_info(outgoing_queue: Queue, width: int, height: int) -> None:
|
async def _enqueue_server_info(outgoing_queue: Queue, width: int, height: int) -> None:
|
||||||
await outgoing_queue.put(
|
await outgoing_queue.put(
|
||||||
{
|
{
|
||||||
"type": "server_info",
|
"type": "server_info",
|
||||||
@@ -108,7 +113,7 @@ async def enqueue_server_info(outgoing_queue: Queue, width: int, height: int) ->
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def consume_incoming(console: Console, incoming_queue: Queue[dict]) -> None:
|
async def _consume_incoming(console: Console, incoming_queue: Queue[dict]) -> None:
|
||||||
while True:
|
while True:
|
||||||
message_json = await incoming_queue.get()
|
message_json = await incoming_queue.get()
|
||||||
type = message_json["type"]
|
type = message_json["type"]
|
||||||
@@ -137,7 +142,7 @@ async def consume_incoming(console: Console, incoming_queue: Queue[dict]) -> Non
|
|||||||
incoming_queue.task_done()
|
incoming_queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
async def consume_outgoing(
|
async def _consume_outgoing(
|
||||||
outgoing_queue: Queue[dict], websocket: WebSocketResponse
|
outgoing_queue: Queue[dict], websocket: WebSocketResponse
|
||||||
) -> None:
|
) -> None:
|
||||||
while True:
|
while True:
|
||||||
@@ -153,21 +158,28 @@ async def websocket_handler(request: Request):
|
|||||||
request.app["websockets"].add(websocket)
|
request.app["websockets"].add(websocket)
|
||||||
|
|
||||||
console = request.app["console"]
|
console = request.app["console"]
|
||||||
|
size_change_poll_delay = request.app["size_change_poll_delay_secs"]
|
||||||
|
|
||||||
incoming_queue: Queue[dict] = Queue()
|
incoming_queue: Queue[dict] = Queue()
|
||||||
outgoing_queue: Queue[dict] = Queue()
|
outgoing_queue: Queue[dict] = Queue()
|
||||||
|
|
||||||
request.app["tasks"].extend(
|
request.app["tasks"].extend(
|
||||||
(
|
(
|
||||||
asyncio.create_task(consume_outgoing(outgoing_queue, websocket)),
|
asyncio.create_task(_consume_outgoing(outgoing_queue, websocket)),
|
||||||
asyncio.create_task(enqueue_size_changes(console, outgoing_queue)),
|
asyncio.create_task(
|
||||||
asyncio.create_task(consume_incoming(console, incoming_queue)),
|
_enqueue_size_changes(
|
||||||
|
console, outgoing_queue, poll_delay=size_change_poll_delay
|
||||||
|
)
|
||||||
|
),
|
||||||
|
asyncio.create_task(_consume_incoming(console, incoming_queue)),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
console.print(DevtoolsInternalMessage(f"Client '{request.remote}' connected"))
|
console.print(
|
||||||
|
DevtoolsInternalMessage(f"Client '{escape(request.remote)}' connected")
|
||||||
|
)
|
||||||
|
|
||||||
await enqueue_server_info(
|
await _enqueue_server_info(
|
||||||
outgoing_queue, width=console.width, height=console.height
|
outgoing_queue, width=console.width, height=console.height
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
@@ -177,7 +189,7 @@ async def websocket_handler(request: Request):
|
|||||||
try:
|
try:
|
||||||
message_json = json.loads(message.data)
|
message_json = json.loads(message.data)
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
console.print(f"{message.data}")
|
console.print(escape(str(message.data)))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
type = message_json.get("type")
|
type = message_json.get("type")
|
||||||
@@ -196,13 +208,13 @@ async def websocket_handler(request: Request):
|
|||||||
request.app["websockets"].discard(websocket)
|
request.app["websockets"].discard(websocket)
|
||||||
console.print()
|
console.print()
|
||||||
console.print(
|
console.print(
|
||||||
DevtoolsInternalMessage(f"Client '{request.remote}' disconnected")
|
DevtoolsInternalMessage(f"Client '{escape(request.remote)}' disconnected")
|
||||||
)
|
)
|
||||||
|
|
||||||
return websocket
|
return websocket
|
||||||
|
|
||||||
|
|
||||||
async def on_shutdown(app: Application) -> None:
|
async def _on_shutdown(app: Application) -> None:
|
||||||
for task in app["tasks"]:
|
for task in app["tasks"]:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
with suppress(CancelledError):
|
with suppress(CancelledError):
|
||||||
@@ -214,13 +226,16 @@ async def on_shutdown(app: Application) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_devtools(port: int) -> None:
|
def _run_devtools(port: int) -> None:
|
||||||
app = make_aiohttp_app()
|
app = _make_devtools_aiohttp_app()
|
||||||
run_app(app, port=port)
|
run_app(app, port=port)
|
||||||
|
|
||||||
|
|
||||||
def make_aiohttp_app():
|
def _make_devtools_aiohttp_app(
|
||||||
|
size_change_poll_delay_secs: float = DEFAULT_SIZE_CHANGE_POLL_DELAY_SECONDS,
|
||||||
|
):
|
||||||
app = Application()
|
app = Application()
|
||||||
|
app["size_change_poll_delay_secs"] = size_change_poll_delay_secs
|
||||||
app["console"] = Console()
|
app["console"] = Console()
|
||||||
app["websockets"] = weakref.WeakSet()
|
app["websockets"] = weakref.WeakSet()
|
||||||
app["tasks"] = []
|
app["tasks"] = []
|
||||||
@@ -229,7 +244,7 @@ def make_aiohttp_app():
|
|||||||
get("/textual-devtools-websocket", websocket_handler),
|
get("/textual-devtools-websocket", websocket_handler),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
app.on_shutdown.append(on_shutdown)
|
app.on_shutdown.append(_on_shutdown)
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
@@ -238,4 +253,4 @@ if __name__ == "__main__":
|
|||||||
port = int(sys.argv[1])
|
port = int(sys.argv[1])
|
||||||
else:
|
else:
|
||||||
port = DEFAULT_PORT
|
port = DEFAULT_PORT
|
||||||
run_devtools(port)
|
_run_devtools(port)
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ class DevtoolsConnectionError(Exception):
|
|||||||
class DevtoolsClient:
|
class DevtoolsClient:
|
||||||
def __init__(self, address: str = "127.0.0.1", port: int = DEFAULT_PORT):
|
def __init__(self, address: str = "127.0.0.1", port: int = DEFAULT_PORT):
|
||||||
self.url: str = f"ws://{address}:{port}"
|
self.url: str = f"ws://{address}:{port}"
|
||||||
|
self.session: aiohttp.ClientSession | None = None
|
||||||
self.log_queue_task: Task | None = None
|
self.log_queue_task: Task | None = None
|
||||||
self.update_console_task: Task | None = None
|
self.update_console_task: Task | None = None
|
||||||
self.console: DevtoolsConsole = DevtoolsConsole(file=StringIO())
|
self.console: DevtoolsConsole = DevtoolsConsole(file=StringIO())
|
||||||
@@ -52,7 +53,7 @@ class DevtoolsClient:
|
|||||||
self.spillover: int = 0
|
self.spillover: int = 0
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
self.session: aiohttp.ClientSession = aiohttp.ClientSession()
|
self.session = aiohttp.ClientSession()
|
||||||
self.log_queue: Queue[str | Type[DetachDevtools]] = Queue(
|
self.log_queue: Queue[str | Type[DetachDevtools]] = Queue(
|
||||||
maxsize=LOG_QUEUE_MAXSIZE
|
maxsize=LOG_QUEUE_MAXSIZE
|
||||||
)
|
)
|
||||||
|
|||||||
24
tests/conftest.py
Normal file
24
tests/conftest.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from textual.devtools import _make_devtools_aiohttp_app
|
||||||
|
from textual.devtools_client import DevtoolsClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def server(aiohttp_server, unused_tcp_port):
|
||||||
|
app = _make_devtools_aiohttp_app(
|
||||||
|
size_change_poll_delay_secs=0.001,
|
||||||
|
)
|
||||||
|
server = await aiohttp_server(app, port=unused_tcp_port)
|
||||||
|
yield server
|
||||||
|
await server.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def devtools(aiohttp_client, server):
|
||||||
|
client = await aiohttp_client(server)
|
||||||
|
devtools = DevtoolsClient(address=client.host, port=client.port)
|
||||||
|
await devtools.connect()
|
||||||
|
yield devtools
|
||||||
|
await devtools.disconnect()
|
||||||
|
await client.close()
|
||||||
88
tests/test_devtools.py
Normal file
88
tests/test_devtools.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import time_machine
|
||||||
|
from rich.align import Align
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.segment import Segment
|
||||||
|
|
||||||
|
from tests.utilities.render import wait_for_predicate
|
||||||
|
from textual.devtools import DevtoolsLogMessage, DevtoolsInternalMessage
|
||||||
|
|
||||||
|
TIMESTAMP = 1649166819
|
||||||
|
WIDTH = 40
|
||||||
|
# The string "Hello, world!" is encoded in the payload below
|
||||||
|
EXAMPLE_LOG = {
|
||||||
|
"type": "client_log",
|
||||||
|
"payload": {
|
||||||
|
"encoded_segments": "gASVQgAAAAAAAABdlCiMDHJpY2guc2VnbWVudJSMB1NlZ"
|
||||||
|
"21lbnSUk5SMDUhlbGxvLCB3b3JsZCGUTk6HlIGUaAOMAQqUTk6HlIGUZS4=",
|
||||||
|
"line_number": 123,
|
||||||
|
"path": "abc/hello.py",
|
||||||
|
"timestamp": TIMESTAMP,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def console():
|
||||||
|
return Console(width=WIDTH)
|
||||||
|
|
||||||
|
|
||||||
|
@time_machine.travel(TIMESTAMP)
|
||||||
|
def test_log_message_render(console):
|
||||||
|
message = DevtoolsLogMessage(
|
||||||
|
[Segment("content")],
|
||||||
|
path="abc/hello.py",
|
||||||
|
line_number=123,
|
||||||
|
unix_timestamp=TIMESTAMP,
|
||||||
|
)
|
||||||
|
table = next(iter(message.__rich_console__(console, console.options)))
|
||||||
|
|
||||||
|
assert len(table.rows) == 1
|
||||||
|
|
||||||
|
columns = list(table.columns)
|
||||||
|
left_cells = list(columns[0].cells)
|
||||||
|
left = left_cells[0]
|
||||||
|
right_cells = list(columns[1].cells)
|
||||||
|
right: Align = right_cells[0]
|
||||||
|
|
||||||
|
assert left == " [#888177]15:53:39 [dim]BST[/]"
|
||||||
|
assert right.align == "right"
|
||||||
|
assert "hello.py:123" in right.renderable
|
||||||
|
|
||||||
|
|
||||||
|
def test_internal_message_render(console):
|
||||||
|
message = DevtoolsInternalMessage("hello")
|
||||||
|
rule = next(iter(message.__rich_console__(console, console.options)))
|
||||||
|
assert rule.title == "hello"
|
||||||
|
assert rule.characters == "─"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_devtools_valid_client_log(devtools):
|
||||||
|
await devtools.websocket.send_json(EXAMPLE_LOG)
|
||||||
|
assert devtools.is_connected
|
||||||
|
|
||||||
|
|
||||||
|
async def test_devtools_string_not_json_message(devtools):
|
||||||
|
await devtools.websocket.send_str("ABCDEFG")
|
||||||
|
assert devtools.is_connected
|
||||||
|
|
||||||
|
|
||||||
|
async def test_devtools_invalid_json_message(devtools):
|
||||||
|
await devtools.websocket.send_json({"invalid": "json"})
|
||||||
|
assert devtools.is_connected
|
||||||
|
|
||||||
|
|
||||||
|
async def test_devtools_spillover_message(devtools):
|
||||||
|
await devtools.websocket.send_json(
|
||||||
|
{"type": "client_spillover", "payload": {"spillover": 123}}
|
||||||
|
)
|
||||||
|
assert devtools.is_connected
|
||||||
|
|
||||||
|
|
||||||
|
async def test_devtools_console_size_change(server, devtools):
|
||||||
|
# Update the width of the console on the server-side
|
||||||
|
server.app["console"].width = 124
|
||||||
|
# Wait for the client side to update the console on their end
|
||||||
|
await wait_for_predicate(lambda: devtools.console.width == 124)
|
||||||
@@ -9,29 +9,12 @@ from aiohttp.web_ws import WebSocketResponse
|
|||||||
from rich.console import ConsoleDimensions
|
from rich.console import ConsoleDimensions
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
from textual.devtools import make_aiohttp_app
|
from tests.utilities.render import wait_for_predicate
|
||||||
from textual.devtools_client import DevtoolsClient
|
from textual.devtools_client import DevtoolsClient
|
||||||
|
|
||||||
TIMESTAMP = 1649166819
|
TIMESTAMP = 1649166819
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def server(aiohttp_server, unused_tcp_port):
|
|
||||||
server = await aiohttp_server(make_aiohttp_app(), port=unused_tcp_port)
|
|
||||||
yield server
|
|
||||||
await server.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def devtools(aiohttp_client, server):
|
|
||||||
client = await aiohttp_client(server)
|
|
||||||
devtools = DevtoolsClient(address=client.host, port=client.port)
|
|
||||||
await devtools.connect()
|
|
||||||
yield devtools
|
|
||||||
await devtools.disconnect()
|
|
||||||
await client.close()
|
|
||||||
|
|
||||||
|
|
||||||
def test_devtools_client_initialize_defaults():
|
def test_devtools_client_initialize_defaults():
|
||||||
devtools = DevtoolsClient()
|
devtools = DevtoolsClient()
|
||||||
assert devtools.url == "ws://127.0.0.1:8081"
|
assert devtools.url == "ws://127.0.0.1:8081"
|
||||||
@@ -50,7 +33,7 @@ async def test_devtools_log_places_encodes_and_queues_message(devtools):
|
|||||||
assert queued_log_json == {
|
assert queued_log_json == {
|
||||||
"payload": {
|
"payload": {
|
||||||
"encoded_segments": "gASVQgAAAAAAAABdlCiMDHJpY2guc2VnbWVudJSMB1NlZ"
|
"encoded_segments": "gASVQgAAAAAAAABdlCiMDHJpY2guc2VnbWVudJSMB1NlZ"
|
||||||
"21lbnSUk5SMDUhlbGxvLCB3b3JsZCGUTk6HlIGUaAOMAQqUTk6HlIGUZS4=",
|
"21lbnSUk5SMDUhlbGxvLCB3b3JsZCGUTk6HlIGUaAOMAQqUTk6HlIGUZS4=",
|
||||||
"line_number": 0,
|
"line_number": 0,
|
||||||
"path": "",
|
"path": "",
|
||||||
"timestamp": TIMESTAMP,
|
"timestamp": TIMESTAMP,
|
||||||
@@ -99,12 +82,17 @@ async def test_devtools_log_spillover(devtools):
|
|||||||
|
|
||||||
# Ensure we're informing the server of spillover rate-limiting
|
# Ensure we're informing the server of spillover rate-limiting
|
||||||
spillover_message = await devtools.log_queue.get()
|
spillover_message = await devtools.log_queue.get()
|
||||||
assert json.loads(spillover_message) == {"type": "client_spillover", "payload": {"spillover": 2}}
|
assert json.loads(spillover_message) == {
|
||||||
|
"type": "client_spillover",
|
||||||
|
"payload": {"spillover": 2},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_devtools_client_update_console_dimensions(devtools, server):
|
async def test_devtools_client_update_console_dimensions(devtools, server):
|
||||||
server_websocket: WebSocketResponse = next(iter(server.app["websockets"]))
|
"""Sending new server info through websocket from server to client should (eventually)
|
||||||
# Send new server information from the server via the websocket
|
result in the dimensions of the devtools client console being updated to match.
|
||||||
|
"""
|
||||||
|
server_to_client: WebSocketResponse = next(iter(server.app["websockets"]))
|
||||||
server_info = {
|
server_info = {
|
||||||
"type": "server_info",
|
"type": "server_info",
|
||||||
"payload": {
|
"payload": {
|
||||||
@@ -112,13 +100,7 @@ async def test_devtools_client_update_console_dimensions(devtools, server):
|
|||||||
"height": 456,
|
"height": 456,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
await server_websocket.send_json(server_info)
|
await server_to_client.send_json(server_info)
|
||||||
timer = 0
|
await wait_for_predicate(
|
||||||
poll_period = .1
|
lambda: devtools.console.size == ConsoleDimensions(123, 456)
|
||||||
while True:
|
)
|
||||||
if timer > 3:
|
|
||||||
pytest.fail("The devtools client dimensions did not update")
|
|
||||||
if devtools.console.size == ConsoleDimensions(123, 456):
|
|
||||||
break
|
|
||||||
await asyncio.sleep(.1)
|
|
||||||
timer += poll_period
|
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
|
import asyncio
|
||||||
import io
|
import io
|
||||||
import re
|
import re
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import pytest
|
||||||
from rich.console import Console, RenderableType
|
from rich.console import Console, RenderableType
|
||||||
|
|
||||||
|
|
||||||
re_link_ids = re.compile(r"id=[\d\.\-]*?;.*?\x1b")
|
re_link_ids = re.compile(r"id=[\d\.\-]*?;.*?\x1b")
|
||||||
|
|
||||||
|
|
||||||
@@ -23,3 +25,32 @@ def render(renderable: RenderableType, no_wrap: bool = False) -> str:
|
|||||||
console.print(renderable, no_wrap=no_wrap, end="")
|
console.print(renderable, no_wrap=no_wrap, end="")
|
||||||
output = replace_link_ids(capture.get())
|
output = replace_link_ids(capture.get())
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_for_predicate(
|
||||||
|
predicate: Callable[[], bool],
|
||||||
|
timeout_secs: float = 2,
|
||||||
|
poll_delay_secs: float = 0.001,
|
||||||
|
) -> None:
|
||||||
|
"""Wait for the given predicate to become True by evaluating it every `poll_delay_secs`
|
||||||
|
seconds. Fail the pytest test if the predicate does not become True after `timeout_secs`
|
||||||
|
seconds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predicate (Callable[[], bool]): The predicate function which will be called repeatedly.
|
||||||
|
timeout_secs (float): If the predicate doesn't evaluate to True after this number of
|
||||||
|
seconds, the test will fail.
|
||||||
|
poll_delay_secs (float): The number of seconds to wait between each call to the
|
||||||
|
predicate function.
|
||||||
|
"""
|
||||||
|
time_taken = 0
|
||||||
|
while True:
|
||||||
|
result = predicate()
|
||||||
|
if result:
|
||||||
|
return
|
||||||
|
await asyncio.sleep(poll_delay_secs)
|
||||||
|
time_taken += poll_delay_secs
|
||||||
|
if time_taken > timeout_secs:
|
||||||
|
pytest.fail(
|
||||||
|
f"Predicate {predicate} did not return True after {timeout_secs} seconds."
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user