mirror of
https://github.com/Textualize/textual.git
synced 2025-10-17 02:38:12 +03:00
rlock tests
This commit is contained in:
@@ -103,6 +103,7 @@ from .messages import CallbackType
|
||||
from .notifications import Notification, Notifications, Notify, SeverityLevel
|
||||
from .reactive import Reactive
|
||||
from .renderables.blank import Blank
|
||||
from .rlock import RLock
|
||||
from .screen import (
|
||||
ActiveBinding,
|
||||
Screen,
|
||||
@@ -579,7 +580,7 @@ class App(Generic[ReturnType], DOMNode):
|
||||
else None
|
||||
)
|
||||
self._screenshot: str | None = None
|
||||
self._dom_lock = asyncio.Lock()
|
||||
self._dom_lock = RLock()
|
||||
self._dom_ready = False
|
||||
self._batch_count = 0
|
||||
self._notifications = Notifications()
|
||||
@@ -3555,7 +3556,7 @@ class App(Generic[ReturnType], DOMNode):
|
||||
# or one will turn up. Things will work out later.
|
||||
return
|
||||
# Update the toast rack.
|
||||
toast_rack.show(self._notifications)
|
||||
self.call_later(toast_rack.show, self._notifications)
|
||||
|
||||
def notify(
|
||||
self,
|
||||
|
||||
61
src/textual/rlock.py
Normal file
61
src/textual/rlock.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from asyncio import Lock, Task, current_task
|
||||
|
||||
|
||||
class RLock:
|
||||
"""A re-entrant asyncio lock."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._owner: Task | None = None
|
||||
self._count = 0
|
||||
self._lock = Lock()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
"""Wait until the lock can be acquired."""
|
||||
task = current_task()
|
||||
assert task is not None
|
||||
if self._owner is None or self._owner is not task:
|
||||
await self._lock.acquire()
|
||||
self._owner = task
|
||||
self._count += 1
|
||||
|
||||
def release(self) -> None:
|
||||
"""Release a previously acquired lock."""
|
||||
task = current_task()
|
||||
assert task is not None
|
||||
self._count -= 1
|
||||
if self._count < 0:
|
||||
# Should not occur if every acquire as a release
|
||||
raise RuntimeError("RLock.release called too many times")
|
||||
if self._owner is task:
|
||||
if not self._count:
|
||||
self._owner = None
|
||||
self._lock.release()
|
||||
|
||||
@property
|
||||
def is_locked(self):
|
||||
"""Return True if lock is acquired."""
|
||||
return self._lock.locked()
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
"""Asynchronous context manager to acquire and release lock."""
|
||||
await self.acquire()
|
||||
|
||||
async def __aexit__(self, _type, _value, _traceback) -> None:
|
||||
"""Exit the context manager."""
|
||||
self.release()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from asyncio import Lock
|
||||
|
||||
async def locks():
|
||||
lock = RLock()
|
||||
async with lock:
|
||||
async with lock:
|
||||
print("Hello")
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(locks())
|
||||
@@ -183,7 +183,6 @@ class ToastRack(Container, inherit_css=False):
|
||||
Args:
|
||||
notifications: The notifications to show.
|
||||
"""
|
||||
|
||||
# Look for any stale toasts and remove them.
|
||||
for toast in self.query(Toast):
|
||||
if toast._notification not in notifications:
|
||||
|
||||
56
tests/test_rlock.py
Normal file
56
tests/test_rlock.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from textual.rlock import RLock
|
||||
|
||||
|
||||
async def test_simple_lock():
|
||||
lock = RLock()
|
||||
# Starts not locked
|
||||
assert not lock.is_locked
|
||||
# Acquire the lock
|
||||
await lock.acquire()
|
||||
assert lock.is_locked
|
||||
# Acquire a second time (should not block)
|
||||
await lock.acquire()
|
||||
assert lock.is_locked
|
||||
|
||||
# Release the lock
|
||||
lock.release()
|
||||
# Should still be locked
|
||||
assert lock.is_locked
|
||||
# Release the lock
|
||||
lock.release()
|
||||
# Should be released
|
||||
assert not lock.is_locked
|
||||
|
||||
# Another release is a runtime error
|
||||
with pytest.raises(RuntimeError):
|
||||
lock.release()
|
||||
|
||||
|
||||
async def test_multiple_tasks() -> None:
|
||||
"""Check RLock prevents other tasks from acquiring lock."""
|
||||
lock = RLock()
|
||||
|
||||
started: list[int] = []
|
||||
done: list[int] = []
|
||||
|
||||
async def test_task(n: int) -> None:
|
||||
started.append(n)
|
||||
async with lock:
|
||||
done.append(n)
|
||||
|
||||
async with lock:
|
||||
assert done == []
|
||||
task1 = asyncio.create_task(test_task(1))
|
||||
assert sorted(started) == []
|
||||
task2 = asyncio.create_task(test_task(2))
|
||||
await asyncio.sleep(0)
|
||||
assert sorted(started) == [1, 2]
|
||||
|
||||
await task1
|
||||
assert 1 in done
|
||||
await task2
|
||||
assert 2 in done
|
||||
Reference in New Issue
Block a user