rlock tests

This commit is contained in:
Will McGugan
2024-06-06 14:59:16 +01:00
parent 792527334b
commit b10d24d361
4 changed files with 120 additions and 3 deletions

View File

@@ -103,6 +103,7 @@ from .messages import CallbackType
from .notifications import Notification, Notifications, Notify, SeverityLevel
from .reactive import Reactive
from .renderables.blank import Blank
from .rlock import RLock
from .screen import (
ActiveBinding,
Screen,
@@ -579,7 +580,7 @@ class App(Generic[ReturnType], DOMNode):
else None
)
self._screenshot: str | None = None
self._dom_lock = asyncio.Lock()
self._dom_lock = RLock()
self._dom_ready = False
self._batch_count = 0
self._notifications = Notifications()
@@ -3555,7 +3556,7 @@ class App(Generic[ReturnType], DOMNode):
# or one will turn up. Things will work out later.
return
# Update the toast rack.
toast_rack.show(self._notifications)
self.call_later(toast_rack.show, self._notifications)
def notify(
self,

61
src/textual/rlock.py Normal file
View File

@@ -0,0 +1,61 @@
from __future__ import annotations
from asyncio import Lock, Task, current_task
class RLock:
"""A re-entrant asyncio lock."""
def __init__(self) -> None:
self._owner: Task | None = None
self._count = 0
self._lock = Lock()
async def acquire(self) -> None:
"""Wait until the lock can be acquired."""
task = current_task()
assert task is not None
if self._owner is None or self._owner is not task:
await self._lock.acquire()
self._owner = task
self._count += 1
def release(self) -> None:
"""Release a previously acquired lock."""
task = current_task()
assert task is not None
self._count -= 1
if self._count < 0:
# Should not occur if every acquire as a release
raise RuntimeError("RLock.release called too many times")
if self._owner is task:
if not self._count:
self._owner = None
self._lock.release()
@property
def is_locked(self):
"""Return True if lock is acquired."""
return self._lock.locked()
async def __aenter__(self) -> None:
"""Asynchronous context manager to acquire and release lock."""
await self.acquire()
async def __aexit__(self, _type, _value, _traceback) -> None:
"""Exit the context manager."""
self.release()
if __name__ == "__main__":
from asyncio import Lock
async def locks():
lock = RLock()
async with lock:
async with lock:
print("Hello")
import asyncio
asyncio.run(locks())

View File

@@ -183,7 +183,6 @@ class ToastRack(Container, inherit_css=False):
Args:
notifications: The notifications to show.
"""
# Look for any stale toasts and remove them.
for toast in self.query(Toast):
if toast._notification not in notifications:

56
tests/test_rlock.py Normal file
View File

@@ -0,0 +1,56 @@
import asyncio
import pytest
from textual.rlock import RLock
async def test_simple_lock():
lock = RLock()
# Starts not locked
assert not lock.is_locked
# Acquire the lock
await lock.acquire()
assert lock.is_locked
# Acquire a second time (should not block)
await lock.acquire()
assert lock.is_locked
# Release the lock
lock.release()
# Should still be locked
assert lock.is_locked
# Release the lock
lock.release()
# Should be released
assert not lock.is_locked
# Another release is a runtime error
with pytest.raises(RuntimeError):
lock.release()
async def test_multiple_tasks() -> None:
"""Check RLock prevents other tasks from acquiring lock."""
lock = RLock()
started: list[int] = []
done: list[int] = []
async def test_task(n: int) -> None:
started.append(n)
async with lock:
done.append(n)
async with lock:
assert done == []
task1 = asyncio.create_task(test_task(1))
assert sorted(started) == []
task2 = asyncio.create_task(test_task(2))
await asyncio.sleep(0)
assert sorted(started) == [1, 2]
await task1
assert 1 in done
await task2
assert 2 in done