Files
textual/tests/suggester/test_suggester.py
Rodrigo Girão Serrão fc86682dfa Add tests for suggester.
2023-05-25 17:32:26 +01:00

112 lines
3.1 KiB
Python

from __future__ import annotations
import pytest
from textual.dom import DOMNode
from textual.suggester import Suggester, SuggestionReady
class FillSuggester(Suggester):
async def get_suggestion(self, value: str):
if len(value) <= 10:
return f"{value:x<10}"
class LogListNode(DOMNode):
def __init__(self, log_list: list[tuple[str, str]]) -> None:
self.log_list = log_list
def post_message(self, message: SuggestionReady):
# We hijack post_message so we can intercept messages without creating a full app.
self.log_list.append((message.suggestion, message.value))
async def test_cache_on():
log = []
class MySuggester(Suggester):
async def get_suggestion(self, value: str):
log.append(value)
return value
suggester = MySuggester(use_cache=True)
await suggester._get_suggestion(DOMNode(), "hello")
assert log == ["hello"]
await suggester._get_suggestion(DOMNode(), "hello")
assert log == ["hello"]
async def test_cache_off():
log = []
class MySuggester(Suggester):
async def get_suggestion(self, value: str):
log.append(value)
return value
suggester = MySuggester(use_cache=False)
await suggester._get_suggestion(DOMNode(), "hello")
assert log == ["hello"]
await suggester._get_suggestion(DOMNode(), "hello")
assert log == ["hello", "hello"]
async def test_suggestion_ready_message():
log = []
suggester = FillSuggester()
await suggester._get_suggestion(LogListNode(log), "hello")
assert log == [("helloxxxxx", "hello")]
await suggester._get_suggestion(LogListNode(log), "world")
assert log == [("helloxxxxx", "hello"), ("worldxxxxx", "world")]
async def test_no_message_if_no_suggestion():
log = []
suggester = FillSuggester()
await suggester._get_suggestion(LogListNode(log), "this is a longer string")
assert log == []
async def test_suggestion_ready_message_on_cache_hit():
log = []
suggester = FillSuggester(use_cache=True)
await suggester._get_suggestion(LogListNode(log), "hello")
assert log == [("helloxxxxx", "hello")]
await suggester._get_suggestion(LogListNode(log), "hello")
assert log == [("helloxxxxx", "hello"), ("helloxxxxx", "hello")]
@pytest.mark.parametrize(
"value",
[
"hello",
"HELLO",
"HeLlO",
"Hello",
"hELLO",
],
)
async def test_case_insensitive_suggestions(value):
class MySuggester(Suggester):
async def get_suggestion(self, value: str):
assert "hello" == value
suggester = MySuggester(use_cache=False, case_sensitive=False)
await suggester._get_suggestion(DOMNode(), value)
async def test_case_insensitive_cache_hits():
count = 0
class MySuggester(Suggester):
async def get_suggestion(self, value: str):
nonlocal count
count += 1
return value + "abc"
suggester = MySuggester(use_cache=True, case_sensitive=False)
hellos = ["hello", "HELLO", "HeLlO", "Hello", "hELLO"]
for hello in hellos:
await suggester._get_suggestion(DOMNode(), hello)
assert count == 1