Merge pull request #808 from Textualize/get-widget

Get widget
This commit is contained in:
Will McGugan
2022-09-30 12:33:57 +01:00
committed by GitHub
6 changed files with 43 additions and 15 deletions

View File

@@ -13,25 +13,23 @@ without having to render the entire screen.
from __future__ import annotations from __future__ import annotations
import sys
from itertools import chain from itertools import chain
from operator import itemgetter from operator import itemgetter
import sys from typing import TYPE_CHECKING, Callable, Iterable, Iterator, NamedTuple, cast
from typing import Callable, cast, Iterator, Iterable, NamedTuple, TYPE_CHECKING
import rich.repr import rich.repr
from rich.console import Console, ConsoleOptions, RenderableType, RenderResult
from rich.console import Console, ConsoleOptions, RenderResult, RenderableType
from rich.control import Control from rich.control import Control
from rich.segment import Segment from rich.segment import Segment
from rich.style import Style from rich.style import Style
from . import errors from . import errors
from .geometry import Region, Offset, Size
from ._cells import cell_len from ._cells import cell_len
from ._profile import timer
from ._loop import loop_last from ._loop import loop_last
from ._profile import timer
from ._types import Lines from ._types import Lines
from .geometry import Offset, Region, Size
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
from typing import TypeAlias from typing import TypeAlias
@@ -203,6 +201,8 @@ class Compositor:
# Regions that require an update # Regions that require an update
self._dirty_regions: set[Region] = set() self._dirty_regions: set[Region] = set()
self._layers_visible: dict[int, list[tuple[Widget, Region]]] | None = None
@classmethod @classmethod
def _regions_to_spans( def _regions_to_spans(
cls, regions: Iterable[Region] cls, regions: Iterable[Region]
@@ -257,6 +257,7 @@ class Compositor:
""" """
self._cuts = None self._cuts = None
self._layers = None self._layers = None
self._layers_visible = None
self.root = parent self.root = parent
self.size = size self.size = size
@@ -475,6 +476,26 @@ class Compositor:
) )
return self._layers return self._layers
@property
def layers_visible(self) -> dict[int, list[tuple[Widget, Region]]]:
"""Visible widgets and regions in layers order."""
if self._layers_visible is None:
layers_visible: dict[int, list[tuple[Widget, Region]]]
screen_region = self.size.region
_, screen_height = self.size
layers_visible = {y: [] for y in screen_region.line_range}
visible_intersection = screen_region.intersection
for widget, region, *_ in self:
(_x, y, _width, height) = region
if y + height > 0 and y < screen_height:
for y in visible_intersection(region).line_range:
layers_visible[y].append((widget, region))
self._layers_visible = layers_visible
return self._layers_visible
def __iter__(self) -> Iterator[tuple[Widget, Region, Region, Size, Size]]: def __iter__(self) -> Iterator[tuple[Widget, Region, Region, Size, Size]]:
"""Iterate map with information regarding each widget and is position """Iterate map with information regarding each widget and is position
@@ -514,9 +535,10 @@ class Compositor:
tuple[Widget, Region]: A tuple of the widget and its region. tuple[Widget, Region]: A tuple of the widget and its region.
""" """
# TODO: Optimize with some line based lookup # TODO: Optimize with some line based lookup
contains = Region.contains contains = Region.contains
for widget, cropped_region, region, *_ in self: for widget, region in self.layers_visible.get(y, []):
if contains(cropped_region, x, y) and widget.visible: if contains(region, x, y) and widget.visible:
return widget, region return widget, region
raise errors.NoWidget(f"No widget under screen coordinate ({x}, {y})") raise errors.NoWidget(f"No widget under screen coordinate ({x}, {y})")
@@ -531,8 +553,8 @@ class Compositor:
Iterable[tuple[Widget, Region]]: Sequence of (WIDGET, REGION) tuples. Iterable[tuple[Widget, Region]]: Sequence of (WIDGET, REGION) tuples.
""" """
contains = Region.contains contains = Region.contains
for widget, cropped_region, region, *_ in self: for widget, region in self.layers_visible.get(y, []):
if contains(cropped_region, x, y) and widget.visible: if contains(region, x, y) and widget.visible:
yield widget, region yield widget, region
def get_style_at(self, x: int, y: int) -> Style: def get_style_at(self, x: int, y: int) -> Style:

View File

@@ -369,9 +369,12 @@ class Stylesheet:
else: else:
rules = reversed(self.rules) rules = reversed(self.rules)
# Collect the rules defined in the stylesheet # Collect the rules defined in the stylesheet
node._has_hover_style = False
for rule in rules: for rule in rules:
is_default_rules = rule.is_default_rules is_default_rules = rule.is_default_rules
tie_breaker = rule.tie_breaker tie_breaker = rule.tie_breaker
if ":hover" in rule.selector_names:
node._has_hover_style = True
for base_specificity in _check_rule(rule, css_path_nodes): for base_specificity in _check_rule(rule, css_path_nodes):
for key, rule_specificity, value in rule.styles.extract_rules( for key, rule_specificity, value in rule.styles.extract_rules(
base_specificity, is_default_rules, tie_breaker base_specificity, is_default_rules, tie_breaker

View File

@@ -128,6 +128,7 @@ class DOMNode(MessagePump):
self._auto_refresh_timer: Timer | None = None self._auto_refresh_timer: Timer | None = None
self._css_types = {cls.__name__ for cls in self._css_bases(self.__class__)} self._css_types = {cls.__name__ for cls in self._css_bases(self.__class__)}
self._bindings = Bindings(self.BINDINGS) self._bindings = Bindings(self.BINDINGS)
self._has_hover_style: bool = False
super().__init__() super().__init__()

View File

@@ -1531,7 +1531,8 @@ class Widget(DOMNode):
def watch_mouse_over(self, value: bool) -> None: def watch_mouse_over(self, value: bool) -> None:
"""Update from CSS if mouse over state changes.""" """Update from CSS if mouse over state changes."""
self.app.update_styles(self) if self._has_hover_style:
self.app.update_styles(self)
def watch_has_focus(self, value: bool) -> None: def watch_has_focus(self, value: bool) -> None:
"""Update from CSS if has focus state changes.""" """Update from CSS if has focus state changes."""

View File

@@ -204,7 +204,7 @@ class Button(Static, can_focus=True):
def watch_mouse_over(self, value: bool) -> None: def watch_mouse_over(self, value: bool) -> None:
"""Update from CSS if mouse over state changes.""" """Update from CSS if mouse over state changes."""
if not self.disabled: if self._has_hover_style and not self.disabled:
self.app.update_styles(self) self.app.update_styles(self)
def validate_variant(self, variant: str) -> str: def validate_variant(self, variant: str) -> str:

View File

@@ -86,12 +86,13 @@ class Static(Widget):
""" """
return self._renderable return self._renderable
def update(self, renderable: RenderableType = "") -> None: def update(self, renderable: RenderableType = "", *, layout: bool = True) -> None:
"""Update the widget's content area with new text or Rich renderable. """Update the widget's content area with new text or Rich renderable.
Args: Args:
renderable (RenderableType, optional): A new rich renderable. Defaults to empty renderable; renderable (RenderableType, optional): A new rich renderable. Defaults to empty renderable;
layout (bool, optional): Perform a layout. Defaults to True.
""" """
_check_renderable(renderable) _check_renderable(renderable)
self.renderable = renderable self.renderable = renderable
self.refresh(layout=True) self.refresh(layout=layout)