diff --git a/src/textual/dom.py b/src/textual/dom.py index e68939b39..56b9e8770 100644 --- a/src/textual/dom.py +++ b/src/textual/dom.py @@ -1,12 +1,12 @@ from __future__ import annotations import re +from functools import lru_cache from inspect import getfile from typing import ( TYPE_CHECKING, ClassVar, Iterable, - Iterator, Sequence, Type, TypeVar, @@ -224,13 +224,14 @@ class DOMNode(MessagePump): Reactive._initialize_object(self) @property - def _node_bases(self) -> Iterator[Type[DOMNode]]: + def _node_bases(self) -> Sequence[Type[DOMNode]]: """The DOMNode bases classes (including self.__class__)""" # Node bases are in reversed order so that the base class is lower priority return self._css_bases(self.__class__) @classmethod - def _css_bases(cls, base: Type[DOMNode]) -> Iterator[Type[DOMNode]]: + @lru_cache(maxsize=None) + def _css_bases(cls, base: Type[DOMNode]) -> Sequence[Type[DOMNode]]: """Get the DOMNode base classes, which inherit CSS. Args: @@ -239,9 +240,10 @@ class DOMNode(MessagePump): Returns: An iterable of DOMNode classes. """ + classes: list[type[DOMNode]] = [] _class = base while True: - yield _class + classes.append(_class) if not _class._inherit_css: break for _base in _class.__bases__: @@ -250,6 +252,7 @@ class DOMNode(MessagePump): break else: break + return classes @classmethod def _merge_bindings(cls) -> Bindings: @@ -314,7 +317,9 @@ class DOMNode(MessagePump): return css_stack - def _get_component_classes(self) -> set[str]: + @classmethod + @lru_cache(maxsize=None) + def _get_component_classes(cls) -> frozenset[str]: """Gets the component classes for this class and inherited from bases. Component classes are inherited from base classes, unless @@ -325,12 +330,12 @@ class DOMNode(MessagePump): """ component_classes: set[str] = set() - for base in self._node_bases: + for base in cls._css_bases(cls): component_classes.update(base.__dict__.get("COMPONENT_CLASSES", set())) if not base.__dict__.get("_inherit_component_classes", True): break - return component_classes + return frozenset(component_classes) @property def parent(self) -> DOMNode | None: diff --git a/tests/test_dom.py b/tests/test_dom.py index f366d0706..b78a958c9 100644 --- a/tests/test_dom.py +++ b/tests/test_dom.py @@ -157,7 +157,7 @@ def test_component_classes_inheritance(): f = F() f_cc = f._get_component_classes() - assert node_cc == set() + assert node_cc == frozenset() assert a_cc == {"a-1", "a-2"} assert b_cc == {"b-1"} assert c_cc == {"b-1", "c-1", "c-2"}