Merge pull request #705 from Textualize/optimize-styles-apply

Optimize applying CSS styles
This commit is contained in:
Will McGugan
2022-08-26 15:05:03 +01:00
committed by GitHub
9 changed files with 186 additions and 41 deletions

View File

@@ -51,4 +51,3 @@ Button {
.started #reset { .started #reset {
visibility: hidden visibility: hidden
} }

View File

@@ -183,6 +183,9 @@ class Compositor:
# Note this may be a superset of self.map.keys() as some widgets may be invisible for various reasons # Note this may be a superset of self.map.keys() as some widgets may be invisible for various reasons
self.widgets: set[Widget] = set() self.widgets: set[Widget] = set()
# A lazy cache of visible (on screen) widgets
self._visible_widgets: set[Widget] | None = set()
# The top level widget # The top level widget
self.root: Widget | None = None self.root: Widget | None = None
@@ -269,6 +272,7 @@ class Compositor:
# Replace map and widgets # Replace map and widgets
self.map = map self.map = map
self.widgets = widgets self.widgets = widgets
self._visible_widgets = None
# Get a map of regions # Get a map of regions
self.regions = { self.regions = {
@@ -305,6 +309,22 @@ class Compositor:
resized=resized_widgets, resized=resized_widgets,
) )
@property
def visible_widgets(self) -> set[Widget]:
"""Get a set of visible widgets.
Returns:
set[Widget]: Widgets in the screen.
"""
if self._visible_widgets is None:
in_screen = self.size.region.__contains__
self._visible_widgets = {
widget
for widget, (region, clip) in self.regions.items()
if in_screen(region)
}
return self._visible_widgets
def _arrange_root( def _arrange_root(
self, root: Widget, size: Size self, root: Widget, size: Size
) -> tuple[CompositorMap, set[Widget]]: ) -> tuple[CompositorMap, set[Widget]]:

View File

@@ -224,7 +224,7 @@ class App(Generic[ReturnType], DOMNode):
self.design = DEFAULT_COLORS self.design = DEFAULT_COLORS
self.stylesheet = Stylesheet(variables=self.get_css_variables()) self.stylesheet = Stylesheet(variables=self.get_css_variables())
self._require_stylesheet_update = False self._require_stylesheet_update: set[DOMNode] = set()
self.css_path = css_path or self.CSS_PATH self.css_path = css_path or self.CSS_PATH
self._registry: WeakSet[DOMNode] = WeakSet() self._registry: WeakSet[DOMNode] = WeakSet()
@@ -714,13 +714,18 @@ class App(Generic[ReturnType], DOMNode):
""" """
return self.screen.get_child(id) return self.screen.get_child(id)
def update_styles(self) -> None: def update_styles(self, node: DOMNode | None = None) -> None:
"""Request update of styles. """Request update of styles.
Should be called whenever CSS classes / pseudo classes change. Should be called whenever CSS classes / pseudo classes change.
""" """
self._require_stylesheet_update = True self._require_stylesheet_update.add(self.screen if node is None else node)
self.check_idle()
def update_visible_styles(self) -> None:
"""Update visible styles only."""
self._require_stylesheet_update.update(self.screen.visible_widgets)
self.check_idle() self.check_idle()
def mount(self, *anon_widgets: Widget, **widgets: Widget) -> None: def mount(self, *anon_widgets: Widget, **widgets: Widget) -> None:
@@ -1137,16 +1142,21 @@ class App(Generic[ReturnType], DOMNode):
self.set_timer(screenshot_timer, on_screenshot, name="screenshot timer") self.set_timer(screenshot_timer, on_screenshot, name="screenshot timer")
def on_mount(self) -> None: def _on_mount(self) -> None:
widgets = self.compose() widgets = self.compose()
if widgets: if widgets:
self.mount_all(widgets) self.mount_all(widgets)
async def on_idle(self) -> None: def _on_idle(self) -> None:
"""Perform actions when there are no messages in the queue.""" """Perform actions when there are no messages in the queue."""
if self._require_stylesheet_update: if self._require_stylesheet_update:
self._require_stylesheet_update = False nodes: set[DOMNode] = {
self.stylesheet.update(self, animate=True) child
for node in self._require_stylesheet_update
for child in node.walk_children()
}
self._require_stylesheet_update.clear()
self.stylesheet.update_nodes(nodes, animate=True)
def _register_child(self, parent: DOMNode, child: Widget) -> bool: def _register_child(self, parent: DOMNode, child: Widget) -> bool:
if child not in self._registry: if child not in self._registry:

View File

@@ -159,9 +159,13 @@ class RuleSet:
selector_set: list[SelectorSet] = field(default_factory=list) selector_set: list[SelectorSet] = field(default_factory=list)
styles: Styles = field(default_factory=Styles) styles: Styles = field(default_factory=Styles)
errors: list[tuple[Token, str]] = field(default_factory=list) errors: list[tuple[Token, str]] = field(default_factory=list)
classes: set[str] = field(default_factory=set)
is_default_rules: bool = False is_default_rules: bool = False
tie_breaker: int = 0 tie_breaker: int = 0
selector_names: set[str] = field(default_factory=set)
def __hash__(self):
return id(self)
@classmethod @classmethod
def _selector_to_css(cls, selectors: list[Selector]) -> str: def _selector_to_css(cls, selectors: list[Selector]) -> str:
@@ -195,11 +199,37 @@ class RuleSet:
def _post_parse(self) -> None: def _post_parse(self) -> None:
"""Called after the RuleSet is parsed.""" """Called after the RuleSet is parsed."""
# Build a set of the class names that have been updated # Build a set of the class names that have been updated
update = self.classes.update
class_type = SelectorType.CLASS class_type = SelectorType.CLASS
id_type = SelectorType.ID
type_type = SelectorType.TYPE
universal_type = SelectorType.UNIVERSAL
update_selectors = self.selector_names.update
for selector_set in self.selector_set: for selector_set in self.selector_set:
update( update_selectors(
"*"
for selector in selector_set.selectors
if selector.type == universal_type
)
update_selectors(
selector.name selector.name
for selector in selector_set.selectors for selector in selector_set.selectors
if selector.type == type_type
)
update_selectors(
f".{selector.name}"
for selector in selector_set.selectors
if selector.type == class_type if selector.type == class_type
) )
update_selectors(
f"#{selector.name}"
for selector in selector_set.selectors
if selector.type == id_type
)
update_selectors(
f":{pseudo_class}"
for selector in selector_set.selectors
for pseudo_class in selector.pseudo_classes
)

View File

@@ -4,10 +4,10 @@ import os
from collections import defaultdict from collections import defaultdict
from operator import itemgetter from operator import itemgetter
from pathlib import Path, PurePath from pathlib import Path, PurePath
from typing import cast, Iterable, NamedTuple from typing import Iterable, NamedTuple, cast
import rich.repr import rich.repr
from rich.console import RenderableType, RenderResult, Console, ConsoleOptions from rich.console import Console, ConsoleOptions, RenderableType, RenderResult
from rich.markup import render from rich.markup import render
from rich.padding import Padding from rich.padding import Padding
from rich.panel import Panel from rich.panel import Panel
@@ -15,17 +15,18 @@ from rich.style import Style
from rich.syntax import Syntax from rich.syntax import Syntax
from rich.text import Text from rich.text import Text
from textual.widget import Widget from .. import messages
from .._profile import timer
from ..dom import DOMNode
from ..widget import Widget
from .errors import StylesheetError from .errors import StylesheetError
from .match import _check_selectors from .match import _check_selectors
from .model import RuleSet from .model import RuleSet
from .parse import parse from .parse import parse
from .styles import RulesMap, Styles from .styles import RulesMap, Styles
from .tokenize import tokenize_values, Token from .tokenize import Token, tokenize_values
from .tokenizer import TokenError from .tokenizer import TokenError
from .types import Specificity3, Specificity6 from .types import Specificity3, Specificity6
from ..dom import DOMNode
from .. import messages
class StylesheetParseError(StylesheetError): class StylesheetParseError(StylesheetError):
@@ -135,6 +136,7 @@ class CssSource(NamedTuple):
class Stylesheet: class Stylesheet:
def __init__(self, *, variables: dict[str, str] | None = None) -> None: def __init__(self, *, variables: dict[str, str] | None = None) -> None:
self._rules: list[RuleSet] = [] self._rules: list[RuleSet] = []
self._rules_map: dict[str, list[RuleSet]] | None = None
self.variables = variables or {} self.variables = variables or {}
self.source: dict[str, CssSource] = {} self.source: dict[str, CssSource] = {}
self._require_parse = False self._require_parse = False
@@ -144,12 +146,32 @@ class Stylesheet:
@property @property
def rules(self) -> list[RuleSet]: def rules(self) -> list[RuleSet]:
"""List of rule sets.
Returns:
list[RuleSet]: List of rules sets for this stylesheet.
"""
if self._require_parse: if self._require_parse:
self.parse() self.parse()
self._require_parse = False self._require_parse = False
assert self._rules is not None assert self._rules is not None
return self._rules return self._rules
@property
def rules_map(self) -> dict[str, list[RuleSet]]:
"""Structure that maps a selector on to a list of rules.
Returns:
dict[str, list[RuleSet]]: Mapping of selector to rule sets.
"""
if self._rules_map is None:
rules_map: dict[str, list[RuleSet]] = defaultdict(list)
for rule in self.rules:
for name in rule.selector_names:
rules_map[name].append(rule)
self._rules_map = dict(rules_map)
return self._rules_map
@property @property
def css(self) -> str: def css(self) -> str:
return "\n\n".join(rule_set.css for rule_set in self.rules) return "\n\n".join(rule_set.css for rule_set in self.rules)
@@ -283,6 +305,7 @@ class Stylesheet:
add_rules(css_rules) add_rules(css_rules)
self._rules = rules self._rules = rules
self._require_parse = False self._require_parse = False
self._rules_map = None
def reparse(self) -> None: def reparse(self) -> None:
"""Re-parse source, applying new variables. """Re-parse source, applying new variables.
@@ -300,15 +323,24 @@ class Stylesheet:
) )
stylesheet.parse() stylesheet.parse()
self._rules = stylesheet.rules self._rules = stylesheet.rules
self._rules_map = None
self.source = stylesheet.source self.source = stylesheet.source
@classmethod @classmethod
def _check_rule(cls, rule: RuleSet, node: DOMNode) -> Iterable[Specificity3]: def _check_rule(
cls, rule: RuleSet, css_path_nodes: list[DOMNode]
) -> Iterable[Specificity3]:
for selector_set in rule.selector_set: for selector_set in rule.selector_set:
if _check_selectors(selector_set.selectors, node.css_path_nodes): if _check_selectors(selector_set.selectors, css_path_nodes):
yield selector_set.specificity yield selector_set.specificity
def apply(self, node: DOMNode, animate: bool = False) -> None: def apply(
self,
node: DOMNode,
*,
limit_rules: set[RuleSet] | None = None,
animate: bool = False,
) -> None:
"""Apply the stylesheet to a DOM node. """Apply the stylesheet to a DOM node.
Args: Args:
@@ -319,32 +351,34 @@ class Stylesheet:
rule will be applied. rule will be applied.
animate (bool, optional): Animate changed rules. Defaults to ``False``. animate (bool, optional): Animate changed rules. Defaults to ``False``.
""" """
# TODO: Need to optimize to make applying stylesheet more efficient
# I think we can pre-calculate which rules may be applicable to a given node
# Dictionary of rule attribute names e.g. "text_background" to list of tuples. # Dictionary of rule attribute names e.g. "text_background" to list of tuples.
# The tuples contain the rule specificity, and the value for that rule. # The tuples contain the rule specificity, and the value for that rule.
# We can use this to determine, for a given rule, whether we should apply it # We can use this to determine, for a given rule, whether we should apply it
# or not by examining the specificity. If we have two rules for the # or not by examining the specificity. If we have two rules for the
# same attribute, then we can choose the most specific rule and use that. # same attribute, then we can choose the most specific rule and use that.
rule_attributes: dict[str, list[tuple[Specificity6, object]]] rule_attributes: defaultdict[str, list[tuple[Specificity6, object]]]
rule_attributes = {} rule_attributes = defaultdict(list)
_check_rule = self._check_rule _check_rule = self._check_rule
css_path_nodes = node.css_path_nodes
rules: Iterable[RuleSet]
if limit_rules:
rules = [rule for rule in reversed(self.rules) if rule in limit_rules]
else:
rules = reversed(self.rules)
# Collect the rules defined in the stylesheet # Collect the rules defined in the stylesheet
for rule in reversed(self.rules): for rule in rules:
is_default_rules = rule.is_default_rules is_default_rules = rule.is_default_rules
tie_breaker = rule.tie_breaker tie_breaker = rule.tie_breaker
for base_specificity in _check_rule(rule, node): for base_specificity in _check_rule(rule, css_path_nodes):
for key, rule_specificity, value in rule.styles.extract_rules( for key, rule_specificity, value in rule.styles.extract_rules(
base_specificity, is_default_rules, tie_breaker base_specificity, is_default_rules, tie_breaker
): ):
rule_attributes.setdefault(key, []).append( rule_attributes[key].append((rule_specificity, value))
(rule_specificity, value)
)
if not rule_attributes:
return
# For each rule declared for this node, keep only the most specific one # For each rule declared for this node, keep only the most specific one
get_first_item = itemgetter(0) get_first_item = itemgetter(0)
node_rules: RulesMap = cast( node_rules: RulesMap = cast(
@@ -354,7 +388,6 @@ class Stylesheet:
for name, specificity_rules in rule_attributes.items() for name, specificity_rules in rule_attributes.items()
}, },
) )
self.replace_rules(node, node_rules, animate=animate) self.replace_rules(node, node_rules, animate=animate)
node._component_styles.clear() node._component_styles.clear()
@@ -381,7 +414,7 @@ class Stylesheet:
base_styles = styles.base base_styles = styles.base
# Styles currently used on new rules # Styles currently used on new rules
modified_rule_keys = {*base_styles.get_rules().keys(), *rules.keys()} modified_rule_keys = base_styles.get_rules().keys() | rules.keys()
# Current render rules (missing rules are filled with default) # Current render rules (missing rules are filled with default)
current_render_rules = styles.get_render_rules() current_render_rules = styles.get_render_rules()
@@ -434,10 +467,34 @@ class Stylesheet:
node.post_message_no_wait(messages.StylesUpdated(sender=node)) node.post_message_no_wait(messages.StylesUpdated(sender=node))
def update(self, root: DOMNode, animate: bool = False) -> None: def update(self, root: DOMNode, animate: bool = False) -> None:
"""Update a node and its children.""" """Update styles on node and its children.
Args:
root (DOMNode): Root note to update.
animate (bool, optional): Enable CSS animation. Defaults to False.
"""
self.update_nodes(root.walk_children(), animate=animate)
def update_nodes(self, nodes: Iterable[DOMNode], animate: bool = False) -> None:
"""Update styles for nodes.
Args:
nodes (DOMNode): Nodes to update.
animate (bool, optional): Enable CSS animation. Defaults to False.
"""
rules_map = self.rules_map
apply = self.apply apply = self.apply
for node in root.walk_children():
apply(node, animate=animate) for node in nodes:
rules = {
rule
for name in node._selector_names
if name in rules_map
for rule in rules_map[name]
}
apply(node, limit_rules=rules, animate=animate)
if isinstance(node, Widget) and node.is_scrollable: if isinstance(node, Widget) and node.is_scrollable:
if node.show_vertical_scrollbar: if node.show_vertical_scrollbar:
apply(node.vertical_scrollbar) apply(node.vertical_scrollbar)

View File

@@ -163,6 +163,7 @@ class DevtoolsClient:
if isinstance(log, str): if isinstance(log, str):
await websocket.send_str(log) await websocket.send_str(log)
else: else:
assert isinstance(log, bytes)
await websocket.send_bytes(log) await websocket.send_bytes(log)
log_queue.task_done() log_queue.task_done()

View File

@@ -87,6 +87,7 @@ class DOMNode(MessagePump):
self._auto_refresh: float | None = None self._auto_refresh: float | None = None
self._auto_refresh_timer: Timer | None = None self._auto_refresh_timer: Timer | None = None
self._css_types = {cls.__name__ for cls in self._css_bases(self.__class__)}
super().__init__() super().__init__()
@@ -311,6 +312,23 @@ class DOMNode(MessagePump):
append(node) append(node)
return result[::-1] return result[::-1]
@property
def _selector_names(self) -> list[str]:
"""Get a set of selectors applicable to this widget.
Returns:
set[str]: Set of selector names.
"""
selectors: list[str] = [
"*",
*(f".{class_name}" for class_name in self._classes),
*(f":{class_name}" for class_name in self.get_pseudo_classes()),
*self._css_types,
]
if self._id is not None:
selectors.append(f"#{self._id}")
return selectors
@property @property
def display(self) -> bool: def display(self) -> bool:
""" """
@@ -699,7 +717,7 @@ class DOMNode(MessagePump):
if old_classes == self._classes: if old_classes == self._classes:
return return
try: try:
self.app.stylesheet.update(self.app, animate=True) self.app.update_styles(self)
except NoActiveAppError: except NoActiveAppError:
pass pass
@@ -715,7 +733,7 @@ class DOMNode(MessagePump):
if old_classes == self._classes: if old_classes == self._classes:
return return
try: try:
self.app.stylesheet.update(self.app, animate=True) self.app.update_styles(self)
except NoActiveAppError: except NoActiveAppError:
pass pass
@@ -731,7 +749,7 @@ class DOMNode(MessagePump):
if old_classes == self._classes: if old_classes == self._classes:
return return
try: try:
self.app.stylesheet.update(self.app, animate=True) self.app.update_styles(self)
except NoActiveAppError: except NoActiveAppError:
pass pass

View File

@@ -63,6 +63,16 @@ class Screen(Widget):
) )
return self._update_timer return self._update_timer
@property
def widgets(self) -> list[Widget]:
"""Get all widgets."""
return list(self._compositor.map.keys())
@property
def visible_widgets(self) -> list[Widget]:
"""Get a list of visible widgets."""
return list(self._compositor.visible_widgets)
def watch_dark(self, dark: bool) -> None: def watch_dark(self, dark: bool) -> None:
pass pass

View File

@@ -1227,11 +1227,11 @@ class Widget(DOMNode):
def watch_mouse_over(self, value: bool) -> None: def watch_mouse_over(self, value: bool) -> None:
"""Update from CSS if mouse over state changes.""" """Update from CSS if mouse over state changes."""
self.app.update_styles() self.app.update_styles(self)
def watch_has_focus(self, value: bool) -> None: def watch_has_focus(self, value: bool) -> None:
"""Update from CSS if has focus state changes.""" """Update from CSS if has focus state changes."""
self.app.update_styles() self.app.update_styles(self)
def size_updated( def size_updated(
self, size: Size, virtual_size: Size, container_size: Size self, size: Size, virtual_size: Size, container_size: Size