From edd1c28b794e8b358eef5c96e1cf098c9f004682 Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Wed, 22 Feb 2023 13:41:05 +0000 Subject: [PATCH] restore css bases --- src/textual/css/stylesheet.py | 2 +- src/textual/dom.py | 27 ++++++++++++--------------- tests/test_dom.py | 26 +++++++++++++------------- 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/src/textual/css/stylesheet.py b/src/textual/css/stylesheet.py index d1c886923..492e231d2 100644 --- a/src/textual/css/stylesheet.py +++ b/src/textual/css/stylesheet.py @@ -406,7 +406,7 @@ class Stylesheet: ) self.replace_rules(node, node_rules, animate=animate) - component_classes = DOMNode._get_component_classes(type(node)) + component_classes = node._get_component_classes() if component_classes: old_component_styles = node._component_styles.copy() node._component_styles.clear() diff --git a/src/textual/dom.py b/src/textual/dom.py index c0b5596c4..e68939b39 100644 --- a/src/textual/dom.py +++ b/src/textual/dom.py @@ -1,12 +1,12 @@ from __future__ import annotations import re -from functools import lru_cache from inspect import getfile from typing import ( TYPE_CHECKING, ClassVar, Iterable, + Iterator, Sequence, Type, TypeVar, @@ -219,18 +219,18 @@ class DOMNode(MessagePump): styles = self._component_styles[name] return styles - def _post_mount(self) -> None: + def _post_mount(self): """Called after the object has been mounted.""" Reactive._initialize_object(self) @property - def _node_bases(self) -> Iterable[Type[DOMNode]]: + def _node_bases(self) -> Iterator[Type[DOMNode]]: """The DOMNode bases classes (including self.__class__)""" # Node bases are in reversed order so that the base class is lower priority return self._css_bases(self.__class__) - @staticmethod - def _css_bases(base: Type[DOMNode]) -> Iterable[Type[DOMNode]]: + @classmethod + def _css_bases(cls, base: Type[DOMNode]) -> Iterator[Type[DOMNode]]: """Get the DOMNode base classes, which inherit CSS. Args: @@ -240,10 +240,8 @@ class DOMNode(MessagePump): An iterable of DOMNode classes. """ _class = base - node_classes: list[Type[DOMNode]] = [] - add_class = node_classes.append while True: - add_class(_class) + yield _class if not _class._inherit_css: break for _base in _class.__bases__: @@ -252,7 +250,6 @@ class DOMNode(MessagePump): break else: break - return node_classes @classmethod def _merge_bindings(cls) -> Bindings: @@ -317,8 +314,7 @@ class DOMNode(MessagePump): return css_stack - @staticmethod - def _get_component_classes(node_type: type[DOMNode]) -> Iterable[str]: + def _get_component_classes(self) -> set[str]: """Gets the component classes for this class and inherited from bases. Component classes are inherited from base classes, unless @@ -327,13 +323,14 @@ class DOMNode(MessagePump): Returns: A set with all the component classes available. """ - component_classes: list[str] = [] - for base in DOMNode._css_bases(node_type): - component_classes.extend(base.__dict__.get("COMPONENT_CLASSES", ())) + + component_classes: set[str] = set() + for base in self._node_bases: + component_classes.update(base.__dict__.get("COMPONENT_CLASSES", set())) if not base.__dict__.get("_inherit_component_classes", True): break - return sorted(component_classes) + return component_classes @property def parent(self) -> DOMNode | None: diff --git a/tests/test_dom.py b/tests/test_dom.py index b88b4052c..f366d0706 100644 --- a/tests/test_dom.py +++ b/tests/test_dom.py @@ -143,27 +143,27 @@ def test_component_classes_inheritance(): COMPONENT_CLASSES = {"f-1"} node = DOMNode() - node_cc = DOMNode._get_component_classes(type(node)) + node_cc = node._get_component_classes() a = A() - a_cc = DOMNode._get_component_classes(type(a)) + a_cc = a._get_component_classes() b = B() - b_cc = DOMNode._get_component_classes(type(b)) + b_cc = b._get_component_classes() c = C() - c_cc = DOMNode._get_component_classes(type(c)) + c_cc = c._get_component_classes() d = D() - d_cc = DOMNode._get_component_classes(type(d)) + d_cc = d._get_component_classes() e = E() - e_cc = DOMNode._get_component_classes(type(e)) + e_cc = e._get_component_classes() f = F() - f_cc = DOMNode._get_component_classes(type(f)) + f_cc = f._get_component_classes() - assert node_cc == [] - assert a_cc == ["a-1", "a-2"] - assert b_cc == ["b-1"] - assert c_cc == ["b-1", "c-1", "c-2"] + assert node_cc == set() + assert a_cc == {"a-1", "a-2"} + assert b_cc == {"b-1"} + assert c_cc == {"b-1", "c-1", "c-2"} assert d_cc == c_cc - assert e_cc == ["b-1", "c-1", "c-2", "e-1"] - assert f_cc == ["f-1"] + assert e_cc == {"b-1", "c-1", "c-2", "e-1"} + assert f_cc == {"f-1"} @pytest.fixture