added some more tests

This commit is contained in:
Will McGugan
2022-06-21 14:07:27 +01:00
parent 637b916ce7
commit ad0507cb27
2 changed files with 106 additions and 29 deletions

View File

@@ -15,7 +15,7 @@ where the overhead of the cache is a small fraction of the total processing time
from __future__ import annotations
from threading import Lock
from typing import Dict, Generic, TypeVar, overload
from typing import Dict, Generic, KeysView, TypeVar, overload
CacheKey = TypeVar("CacheKey")
CacheValue = TypeVar("CacheValue")
@@ -38,22 +38,30 @@ class LRUCache(Generic[CacheKey, CacheValue]):
"""
def __init__(self, maxsize: int) -> None:
self.maxsize = maxsize
self.cache: Dict[CacheKey, list[object]] = {}
self.full = False
self.head: list[object] = []
self._maxsize = maxsize
self._cache: Dict[CacheKey, list[object]] = {}
self._full = False
self._head: list[object] = []
self._lock = Lock()
super().__init__()
def __bool__(self) -> bool:
return bool(self._cache)
def __len__(self) -> int:
return len(self.cache)
return len(self._cache)
def clear(self) -> None:
"""Clear the cache."""
with self._lock:
self.cache.clear()
self.full = False
self.head = []
self._cache.clear()
self._full = False
self._head = []
def keys(self) -> KeysView[CacheKey]:
"""Get cache keys."""
# Mostly for tests
return self._cache.keys()
def set(self, key: CacheKey, value: CacheValue) -> None:
"""Set a value.
@@ -63,28 +71,28 @@ class LRUCache(Generic[CacheKey, CacheValue]):
value (CacheValue): Value.
"""
with self._lock:
link = self.cache.get(key)
link = self._cache.get(key)
if link is None:
head = self.head
head = self._head
if not head:
# First link references itself
self.head[:] = [head, head, key, value]
self._head[:] = [head, head, key, value]
else:
# Add a new root to the beginning
self.head = [head[0], head, key, value]
self._head = [head[0], head, key, value]
# Updated references on previous root
head[0][1] = self.head # type: ignore[index]
head[0] = self.head
self.cache[key] = self.head
head[0][1] = self._head # type: ignore[index]
head[0] = self._head
self._cache[key] = self._head
if self.full or len(self.cache) > self.maxsize:
if self._full or len(self._cache) > self._maxsize:
# Cache is full, we need to evict the oldest one
self.full = True
head = self.head
self._full = True
head = self._head
last = head[0]
last[0][1] = head # type: ignore[index]
head[0] = last[0] # type: ignore[index]
del self.cache[last[2]] # type: ignore[index]
del self._cache[last[2]] # type: ignore[index]
__setitem__ = set
@@ -108,33 +116,33 @@ class LRUCache(Generic[CacheKey, CacheValue]):
Returns:
Union[CacheValue, Optional[DefaultValue]]: Either the value or a default.
"""
link = self.cache.get(key)
link = self._cache.get(key)
if link is None:
return default
with self._lock:
if link is not self.head:
if link is not self._head:
# Remove link from list
link[0][1] = link[1] # type: ignore[index]
link[1][0] = link[0] # type: ignore[index]
head = self.head
head = self._head
# Move link to head of list
link[0] = head[0]
link[1] = head
self.head = head[0][1] = head[0] = link # type: ignore[index]
self._head = head[0][1] = head[0] = link # type: ignore[index]
return link[3] # type: ignore[return-value]
def __getitem__(self, key: CacheKey) -> CacheValue:
link = self.cache[key]
link = self._cache[key]
with self._lock:
if link is not self.head:
if link is not self._head:
link[0][1] = link[1] # type: ignore[index]
link[1][0] = link[0] # type: ignore[index]
head = self.head
head = self._head
link[0] = head[0]
link[1] = head
self.head = head[0][1] = head[0] = link # type: ignore[index]
self._head = head[0][1] = head[0] = link # type: ignore[index]
return link[3] # type: ignore[return-value]
def __contains__(self, key: CacheKey) -> bool:
return key in self.cache
return key in self._cache

View File

@@ -1,5 +1,7 @@
from __future__ import unicode_literals
import pytest
from textual._cache import LRUCache
@@ -57,3 +59,70 @@ def test_lru_cache_get():
# Check it kicked out the 'oldest' key
assert "egg" not in cache
assert "eggegg" in cache
def test_lru_cache_mapping():
"""Test cache values can be set and read back."""
cache = LRUCache(3)
cache["foo"] = 1
cache.set("bar", 2)
cache.set("baz", 3)
assert cache["foo"] == 1
assert cache["bar"] == 2
assert cache.get("baz") == 3
def test_lru_cache_clear():
cache = LRUCache(3)
assert len(cache) == 0
cache["foo"] = 1
assert "foo" in cache
assert len(cache) == 1
cache.clear()
assert "foo" not in cache
assert len(cache) == 0
def test_lru_cache_bool():
cache = LRUCache(3)
assert not cache
cache["foo"] = "bar"
assert cache
@pytest.mark.parametrize(
"keys,expected",
[
((), ()),
(("foo",), ("foo",)),
(("foo", "bar"), ("foo", "bar")),
(("foo", "bar", "baz"), ("foo", "bar", "baz")),
(("foo", "bar", "baz", "egg"), ("bar", "baz", "egg")),
(("foo", "bar", "baz", "egg", "bob"), ("baz", "egg", "bob")),
],
)
def test_lru_cache_evicts(keys: list[str], expected: list[str]):
"""Test adding adding additional values evicts oldest key"""
cache = LRUCache(3)
for value, key in enumerate(keys):
cache[key] = value
assert tuple(cache.keys()) == expected
@pytest.mark.parametrize(
"keys,expected_len",
[
((), 0),
(("foo",), 1),
(("foo", "bar"), 2),
(("foo", "bar", "baz"), 3),
(("foo", "bar", "baz", "egg"), 3),
(("foo", "bar", "baz", "egg", "bob"), 3),
],
)
def test_lru_cache_len(keys: list[str], expected_len: int):
"""Test adding adding additional values evicts oldest key"""
cache = LRUCache(3)
for value, key in enumerate(keys):
cache[key] = value
assert len(cache) == expected_len