mirror of
https://github.com/Textualize/textual.git
synced 2025-10-17 02:38:12 +03:00
added some more tests
This commit is contained in:
@@ -15,7 +15,7 @@ where the overhead of the cache is a small fraction of the total processing time
|
||||
from __future__ import annotations
|
||||
|
||||
from threading import Lock
|
||||
from typing import Dict, Generic, TypeVar, overload
|
||||
from typing import Dict, Generic, KeysView, TypeVar, overload
|
||||
|
||||
CacheKey = TypeVar("CacheKey")
|
||||
CacheValue = TypeVar("CacheValue")
|
||||
@@ -38,22 +38,30 @@ class LRUCache(Generic[CacheKey, CacheValue]):
|
||||
"""
|
||||
|
||||
def __init__(self, maxsize: int) -> None:
|
||||
self.maxsize = maxsize
|
||||
self.cache: Dict[CacheKey, list[object]] = {}
|
||||
self.full = False
|
||||
self.head: list[object] = []
|
||||
self._maxsize = maxsize
|
||||
self._cache: Dict[CacheKey, list[object]] = {}
|
||||
self._full = False
|
||||
self._head: list[object] = []
|
||||
self._lock = Lock()
|
||||
super().__init__()
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._cache)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.cache)
|
||||
return len(self._cache)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the cache."""
|
||||
with self._lock:
|
||||
self.cache.clear()
|
||||
self.full = False
|
||||
self.head = []
|
||||
self._cache.clear()
|
||||
self._full = False
|
||||
self._head = []
|
||||
|
||||
def keys(self) -> KeysView[CacheKey]:
|
||||
"""Get cache keys."""
|
||||
# Mostly for tests
|
||||
return self._cache.keys()
|
||||
|
||||
def set(self, key: CacheKey, value: CacheValue) -> None:
|
||||
"""Set a value.
|
||||
@@ -63,28 +71,28 @@ class LRUCache(Generic[CacheKey, CacheValue]):
|
||||
value (CacheValue): Value.
|
||||
"""
|
||||
with self._lock:
|
||||
link = self.cache.get(key)
|
||||
link = self._cache.get(key)
|
||||
if link is None:
|
||||
head = self.head
|
||||
head = self._head
|
||||
if not head:
|
||||
# First link references itself
|
||||
self.head[:] = [head, head, key, value]
|
||||
self._head[:] = [head, head, key, value]
|
||||
else:
|
||||
# Add a new root to the beginning
|
||||
self.head = [head[0], head, key, value]
|
||||
self._head = [head[0], head, key, value]
|
||||
# Updated references on previous root
|
||||
head[0][1] = self.head # type: ignore[index]
|
||||
head[0] = self.head
|
||||
self.cache[key] = self.head
|
||||
head[0][1] = self._head # type: ignore[index]
|
||||
head[0] = self._head
|
||||
self._cache[key] = self._head
|
||||
|
||||
if self.full or len(self.cache) > self.maxsize:
|
||||
if self._full or len(self._cache) > self._maxsize:
|
||||
# Cache is full, we need to evict the oldest one
|
||||
self.full = True
|
||||
head = self.head
|
||||
self._full = True
|
||||
head = self._head
|
||||
last = head[0]
|
||||
last[0][1] = head # type: ignore[index]
|
||||
head[0] = last[0] # type: ignore[index]
|
||||
del self.cache[last[2]] # type: ignore[index]
|
||||
del self._cache[last[2]] # type: ignore[index]
|
||||
|
||||
__setitem__ = set
|
||||
|
||||
@@ -108,33 +116,33 @@ class LRUCache(Generic[CacheKey, CacheValue]):
|
||||
Returns:
|
||||
Union[CacheValue, Optional[DefaultValue]]: Either the value or a default.
|
||||
"""
|
||||
link = self.cache.get(key)
|
||||
link = self._cache.get(key)
|
||||
if link is None:
|
||||
return default
|
||||
with self._lock:
|
||||
if link is not self.head:
|
||||
if link is not self._head:
|
||||
# Remove link from list
|
||||
link[0][1] = link[1] # type: ignore[index]
|
||||
link[1][0] = link[0] # type: ignore[index]
|
||||
head = self.head
|
||||
head = self._head
|
||||
# Move link to head of list
|
||||
link[0] = head[0]
|
||||
link[1] = head
|
||||
self.head = head[0][1] = head[0] = link # type: ignore[index]
|
||||
self._head = head[0][1] = head[0] = link # type: ignore[index]
|
||||
|
||||
return link[3] # type: ignore[return-value]
|
||||
|
||||
def __getitem__(self, key: CacheKey) -> CacheValue:
|
||||
link = self.cache[key]
|
||||
link = self._cache[key]
|
||||
with self._lock:
|
||||
if link is not self.head:
|
||||
if link is not self._head:
|
||||
link[0][1] = link[1] # type: ignore[index]
|
||||
link[1][0] = link[0] # type: ignore[index]
|
||||
head = self.head
|
||||
head = self._head
|
||||
link[0] = head[0]
|
||||
link[1] = head
|
||||
self.head = head[0][1] = head[0] = link # type: ignore[index]
|
||||
self._head = head[0][1] = head[0] = link # type: ignore[index]
|
||||
return link[3] # type: ignore[return-value]
|
||||
|
||||
def __contains__(self, key: CacheKey) -> bool:
|
||||
return key in self.cache
|
||||
return key in self._cache
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
|
||||
from textual._cache import LRUCache
|
||||
|
||||
|
||||
@@ -57,3 +59,70 @@ def test_lru_cache_get():
|
||||
# Check it kicked out the 'oldest' key
|
||||
assert "egg" not in cache
|
||||
assert "eggegg" in cache
|
||||
|
||||
|
||||
def test_lru_cache_mapping():
|
||||
"""Test cache values can be set and read back."""
|
||||
cache = LRUCache(3)
|
||||
cache["foo"] = 1
|
||||
cache.set("bar", 2)
|
||||
cache.set("baz", 3)
|
||||
assert cache["foo"] == 1
|
||||
assert cache["bar"] == 2
|
||||
assert cache.get("baz") == 3
|
||||
|
||||
|
||||
def test_lru_cache_clear():
|
||||
cache = LRUCache(3)
|
||||
assert len(cache) == 0
|
||||
cache["foo"] = 1
|
||||
assert "foo" in cache
|
||||
assert len(cache) == 1
|
||||
cache.clear()
|
||||
assert "foo" not in cache
|
||||
assert len(cache) == 0
|
||||
|
||||
|
||||
def test_lru_cache_bool():
|
||||
cache = LRUCache(3)
|
||||
assert not cache
|
||||
cache["foo"] = "bar"
|
||||
assert cache
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"keys,expected",
|
||||
[
|
||||
((), ()),
|
||||
(("foo",), ("foo",)),
|
||||
(("foo", "bar"), ("foo", "bar")),
|
||||
(("foo", "bar", "baz"), ("foo", "bar", "baz")),
|
||||
(("foo", "bar", "baz", "egg"), ("bar", "baz", "egg")),
|
||||
(("foo", "bar", "baz", "egg", "bob"), ("baz", "egg", "bob")),
|
||||
],
|
||||
)
|
||||
def test_lru_cache_evicts(keys: list[str], expected: list[str]):
|
||||
"""Test adding adding additional values evicts oldest key"""
|
||||
cache = LRUCache(3)
|
||||
for value, key in enumerate(keys):
|
||||
cache[key] = value
|
||||
assert tuple(cache.keys()) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"keys,expected_len",
|
||||
[
|
||||
((), 0),
|
||||
(("foo",), 1),
|
||||
(("foo", "bar"), 2),
|
||||
(("foo", "bar", "baz"), 3),
|
||||
(("foo", "bar", "baz", "egg"), 3),
|
||||
(("foo", "bar", "baz", "egg", "bob"), 3),
|
||||
],
|
||||
)
|
||||
def test_lru_cache_len(keys: list[str], expected_len: int):
|
||||
"""Test adding adding additional values evicts oldest key"""
|
||||
cache = LRUCache(3)
|
||||
for value, key in enumerate(keys):
|
||||
cache[key] = value
|
||||
assert len(cache) == expected_len
|
||||
|
||||
Reference in New Issue
Block a user