diff --git a/src/textual/_cache.py b/src/textual/_cache.py index bbbd79bc1..f025a08b4 100644 --- a/src/textual/_cache.py +++ b/src/textual/_cache.py @@ -15,7 +15,7 @@ where the overhead of the cache is a small fraction of the total processing time from __future__ import annotations from threading import Lock -from typing import Dict, Generic, TypeVar, overload +from typing import Dict, Generic, KeysView, TypeVar, overload CacheKey = TypeVar("CacheKey") CacheValue = TypeVar("CacheValue") @@ -38,22 +38,30 @@ class LRUCache(Generic[CacheKey, CacheValue]): """ def __init__(self, maxsize: int) -> None: - self.maxsize = maxsize - self.cache: Dict[CacheKey, list[object]] = {} - self.full = False - self.head: list[object] = [] + self._maxsize = maxsize + self._cache: Dict[CacheKey, list[object]] = {} + self._full = False + self._head: list[object] = [] self._lock = Lock() super().__init__() + def __bool__(self) -> bool: + return bool(self._cache) + def __len__(self) -> int: - return len(self.cache) + return len(self._cache) def clear(self) -> None: """Clear the cache.""" with self._lock: - self.cache.clear() - self.full = False - self.head = [] + self._cache.clear() + self._full = False + self._head = [] + + def keys(self) -> KeysView[CacheKey]: + """Get cache keys.""" + # Mostly for tests + return self._cache.keys() def set(self, key: CacheKey, value: CacheValue) -> None: """Set a value. @@ -63,28 +71,28 @@ class LRUCache(Generic[CacheKey, CacheValue]): value (CacheValue): Value. """ with self._lock: - link = self.cache.get(key) + link = self._cache.get(key) if link is None: - head = self.head + head = self._head if not head: # First link references itself - self.head[:] = [head, head, key, value] + self._head[:] = [head, head, key, value] else: # Add a new root to the beginning - self.head = [head[0], head, key, value] + self._head = [head[0], head, key, value] # Updated references on previous root - head[0][1] = self.head # type: ignore[index] - head[0] = self.head - self.cache[key] = self.head + head[0][1] = self._head # type: ignore[index] + head[0] = self._head + self._cache[key] = self._head - if self.full or len(self.cache) > self.maxsize: + if self._full or len(self._cache) > self._maxsize: # Cache is full, we need to evict the oldest one - self.full = True - head = self.head + self._full = True + head = self._head last = head[0] last[0][1] = head # type: ignore[index] head[0] = last[0] # type: ignore[index] - del self.cache[last[2]] # type: ignore[index] + del self._cache[last[2]] # type: ignore[index] __setitem__ = set @@ -108,33 +116,33 @@ class LRUCache(Generic[CacheKey, CacheValue]): Returns: Union[CacheValue, Optional[DefaultValue]]: Either the value or a default. """ - link = self.cache.get(key) + link = self._cache.get(key) if link is None: return default with self._lock: - if link is not self.head: + if link is not self._head: # Remove link from list link[0][1] = link[1] # type: ignore[index] link[1][0] = link[0] # type: ignore[index] - head = self.head + head = self._head # Move link to head of list link[0] = head[0] link[1] = head - self.head = head[0][1] = head[0] = link # type: ignore[index] + self._head = head[0][1] = head[0] = link # type: ignore[index] return link[3] # type: ignore[return-value] def __getitem__(self, key: CacheKey) -> CacheValue: - link = self.cache[key] + link = self._cache[key] with self._lock: - if link is not self.head: + if link is not self._head: link[0][1] = link[1] # type: ignore[index] link[1][0] = link[0] # type: ignore[index] - head = self.head + head = self._head link[0] = head[0] link[1] = head - self.head = head[0][1] = head[0] = link # type: ignore[index] + self._head = head[0][1] = head[0] = link # type: ignore[index] return link[3] # type: ignore[return-value] def __contains__(self, key: CacheKey) -> bool: - return key in self.cache + return key in self._cache diff --git a/tests/test_cache.py b/tests/test_cache.py index 83edcfbe1..305060ae4 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,5 +1,7 @@ from __future__ import unicode_literals +import pytest + from textual._cache import LRUCache @@ -57,3 +59,70 @@ def test_lru_cache_get(): # Check it kicked out the 'oldest' key assert "egg" not in cache assert "eggegg" in cache + + +def test_lru_cache_mapping(): + """Test cache values can be set and read back.""" + cache = LRUCache(3) + cache["foo"] = 1 + cache.set("bar", 2) + cache.set("baz", 3) + assert cache["foo"] == 1 + assert cache["bar"] == 2 + assert cache.get("baz") == 3 + + +def test_lru_cache_clear(): + cache = LRUCache(3) + assert len(cache) == 0 + cache["foo"] = 1 + assert "foo" in cache + assert len(cache) == 1 + cache.clear() + assert "foo" not in cache + assert len(cache) == 0 + + +def test_lru_cache_bool(): + cache = LRUCache(3) + assert not cache + cache["foo"] = "bar" + assert cache + + +@pytest.mark.parametrize( + "keys,expected", + [ + ((), ()), + (("foo",), ("foo",)), + (("foo", "bar"), ("foo", "bar")), + (("foo", "bar", "baz"), ("foo", "bar", "baz")), + (("foo", "bar", "baz", "egg"), ("bar", "baz", "egg")), + (("foo", "bar", "baz", "egg", "bob"), ("baz", "egg", "bob")), + ], +) +def test_lru_cache_evicts(keys: list[str], expected: list[str]): + """Test adding adding additional values evicts oldest key""" + cache = LRUCache(3) + for value, key in enumerate(keys): + cache[key] = value + assert tuple(cache.keys()) == expected + + +@pytest.mark.parametrize( + "keys,expected_len", + [ + ((), 0), + (("foo",), 1), + (("foo", "bar"), 2), + (("foo", "bar", "baz"), 3), + (("foo", "bar", "baz", "egg"), 3), + (("foo", "bar", "baz", "egg", "bob"), 3), + ], +) +def test_lru_cache_len(keys: list[str], expected_len: int): + """Test adding adding additional values evicts oldest key""" + cache = LRUCache(3) + for value, key in enumerate(keys): + cache[key] = value + assert len(cache) == expected_len