From 9569a98c7788bd9f433d0590ee621230c50d906c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rodrigo=20Gir=C3=A3o=20Serr=C3=A3o?= <5621605+rodrigogiraoserrao@users.noreply.github.com> Date: Mon, 29 May 2023 19:49:58 +0100 Subject: [PATCH] Increase worker coverage and fix bug. --- CHANGELOG.md | 1 + src/textual/worker.py | 2 +- tests/test_worker.py | 62 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5cc55decb..a98498540 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Fixed zero division error https://github.com/Textualize/textual/issues/2673 - Fix `scroll_to_center` when there were nested layers out of view (Compositor full_map not populated fully) https://github.com/Textualize/textual/pull/2684 +- Issue with computing progress in workers https://github.com/Textualize/textual/pull/2686 ### Added diff --git a/src/textual/worker.py b/src/textual/worker.py index 68b11bc95..de4295a8a 100644 --- a/src/textual/worker.py +++ b/src/textual/worker.py @@ -256,7 +256,7 @@ class Worker(Generic[ResultType]): if completed_steps is not None: self._completed_steps += completed_steps if total_steps != -1: - self._total_steps = None if total_steps is None else min(0, total_steps) + self._total_steps = None if total_steps is None else max(0, total_steps) def advance(self, steps: int = 1) -> None: """Advance the number of completed steps. diff --git a/tests/test_worker.py b/tests/test_worker.py index 3ef802f5f..f3c6dc239 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -1,11 +1,15 @@ import asyncio +from re import A import pytest from textual.app import App from textual.worker import ( + DeadlockError, + NoActiveWorker, Worker, WorkerCancelled, + WorkerError, WorkerFailed, WorkerState, get_current_worker, @@ -153,3 +157,61 @@ async def test_get_worker() -> None: worker._start(app) assert await worker.wait() is worker + + +def test_no_active_worker() -> None: + """No active worker raises a specific exception.""" + with pytest.raises(NoActiveWorker): + get_current_worker() + + +async def test_progress_update(): + async def long_work(): + pass + + app = App() + async with app.run_test(): + worker = Worker(app, long_work) + worker._start(app) + worker.update(total_steps=100) + assert worker.progress == 0 + worker.advance(50) + assert worker.progress == 50 + worker.update(completed_steps=23) + assert worker.progress == 73 + + +async def test_double_start(): + async def long_work(): + return 0 + + app = App() + async with app.run_test(): + worker = Worker(app, long_work) + worker._start(app) + worker._start(app) + assert await worker.wait() == 0 + + +async def test_self_referential_deadlock(): + async def self_referential_work(): + await get_current_worker().wait() + + app = App() + async with app.run_test(): + worker = Worker(app, self_referential_work) + worker._start(app) + with pytest.raises(WorkerFailed) as exc: + await worker.wait() + assert exc.type is DeadlockError + + +async def test_wait_without_start(): + async def work(): + return + + app = App() + async with app.run_test(): + worker = Worker(app, work) + with pytest.raises(WorkerError): + await worker.wait()