Merge pull request #670 from Textualize/sidebar-fix

fix for footer
This commit is contained in:
Will McGugan
2022-08-15 07:50:03 +01:00
committed by GitHub
27 changed files with 569 additions and 216 deletions

View File

@@ -14,9 +14,9 @@ class AddRemoveApp(App):
CSS = """
#buttons {
dock: top;
height: auto;
height: auto;
}
Button {
#buttons Button {
width: 1fr;
}
#items {
@@ -26,8 +26,8 @@ class AddRemoveApp(App):
Thing {
height: 5;
background: $panel;
border: wide $primary;
margin: 0 1;
border: tall $primary;
margin: 1 1;
content-align: center middle;
}
"""

View File

@@ -16,11 +16,13 @@ App > Screen {
background: $surface;
color: $text-surface;
layers: sidebar;
layers: base sidebar;
color: $text-background;
background: $background;
layout: vertical;
overflow: hidden;
}
@@ -47,7 +49,7 @@ DataTable {
/* opacity: 50%; */
padding: 1;
margin: 1 2;
height: 12;
height: 24;
}
#sidebar {
@@ -55,6 +57,7 @@ DataTable {
background: $panel;
dock: left;
width: 30;
margin-bottom: 1;
offset-x: -100%;
transition: offset 500ms in_out_cubic;
@@ -88,14 +91,7 @@ DataTable {
content-align: center middle;
}
#header {
color: $text-secondary-background;
background: $secondary-background;
height: 1;
content-align: center middle;
dock: top;
}
Tweet {
@@ -120,7 +116,7 @@ Tweet {
overflow-x: auto;
overflow-y: scroll;
margin: 1 2;
height: 20;
height: 24;
align-horizontal: center;
layout: vertical;
}

View File

@@ -6,8 +6,8 @@ from rich.text import Text
from textual.app import App, ComposeResult
from textual.reactive import Reactive
from textual.widget import Widget
from textual.widgets import Static, DataTable, DirectoryTree
from textual.layout import Vertical
from textual.widgets import Static, DataTable, DirectoryTree, Header, Footer
from textual.layout import Container
CODE = '''
from __future__ import annotations
@@ -109,19 +109,18 @@ class BasicApp(App, css_path="basic.css"):
def on_load(self):
"""Bind keys here."""
self.bind("s", "toggle_class('#sidebar', '-active')")
self.bind("s", "toggle_class('#sidebar', '-active')", description="Sidebar")
self.bind("d", "toggle_dark", description="Dark mode")
self.bind("q", "quit", description="Quit")
self.bind("f", "query_test", description="Query test")
def compose(self):
yield Header()
def compose(self) -> ComposeResult:
table = DataTable()
self.scroll_to_target = Tweet(TweetBody())
yield Static(
Text.from_markup(
"[b]This is a [u]Textual[/u] app, running in the terminal"
),
id="header",
)
yield from (
yield Container(
Tweet(TweetBody()),
Widget(
Static(
@@ -143,7 +142,6 @@ class BasicApp(App, css_path="basic.css"):
Tweet(TweetBody(), classes="scroll-horizontal"),
Tweet(TweetBody(), classes="scroll-horizontal"),
)
yield Widget(id="footer")
yield Widget(
Widget(classes="title"),
Widget(classes="user"),
@@ -153,6 +151,7 @@ class BasicApp(App, css_path="basic.css"):
Widget(classes="content"),
id="sidebar",
)
yield Footer()
table.add_column("Foo", width=20)
table.add_column("Bar", width=20)
@@ -164,12 +163,32 @@ class BasicApp(App, css_path="basic.css"):
for n in range(100):
table.add_row(*[f"Cell ([b]{n}[/b], {col})" for col in range(6)])
def on_mount(self):
self.sub_title = "Widget demo"
async def on_key(self, event) -> None:
await self.dispatch_key(event)
def key_d(self):
def action_toggle_dark(self):
self.dark = not self.dark
def action_query_test(self):
query = self.query("Tweet")
self.log(query)
self.log(query.nodes)
self.log(query)
self.log(query.nodes)
query.set_styles("outline: outer red;")
query = query.exclude(".scroll-horizontal")
self.log(query)
self.log(query.nodes)
# query = query.filter(".rubbish")
# self.log(query)
# self.log(query.first())
async def key_q(self):
await self.shutdown()

8
sandbox/will/fill.css Normal file
View File

@@ -0,0 +1,8 @@
App Static {
border: heavy white;
background: blue;
color: white;
height: 100%;
box-sizing: border-box;
}

10
sandbox/will/fill.py Normal file
View File

@@ -0,0 +1,10 @@
from textual.app import App, ComposeResult
from textual.widgets import Static
class FillApp(App):
def compose(self) -> ComposeResult:
yield Static("Hello")
app = FillApp(css_path="fill.css")

17
sandbox/will/footer.py Normal file
View File

@@ -0,0 +1,17 @@
from textual.app import App
from textual.widgets import Header, Footer
class FooterApp(App):
def on_mount(self):
self.sub_title = "Header and footer example"
self.bind("b", "app.bell", description="Play the Bell")
self.bind("d", "dark", description="Toggle dark")
self.bind("f1", "app.bell", description="Hello World")
def action_dark(self):
self.dark = not self.dark
def compose(self):
yield Header()
yield Footer()

View File

@@ -1082,7 +1082,6 @@ class App(Generic[ReturnType], DOMNode):
Returns:
bool: True if an action was processed.
"""
event.stop()
try:
style = getattr(event, "style")
except AttributeError:
@@ -1091,6 +1090,8 @@ class App(Generic[ReturnType], DOMNode):
modifiers, action = extract_handler_actions(event_name, style.meta)
except NoHandler:
return False
else:
event.stop()
if isinstance(action, str):
await self.action(
action, default_namespace=default_namespace, modifiers=modifiers

View File

@@ -61,10 +61,11 @@ def get_box_model(
)
else:
# An explicit width
content_width = styles.width.resolve_dimension(
styles_width = styles.width
content_width = styles_width.resolve_dimension(
sizing_container - styles.margin.totals, viewport, fraction_unit
)
if is_border_box:
if is_border_box and styles_width.excludes_border:
content_width -= gutter.width
if styles.min_width is not None:
@@ -92,11 +93,12 @@ def get_box_model(
get_content_height(content_container, viewport, int(content_width))
)
else:
styles_height = styles.height
# Explicit height set
content_height = styles.height.resolve_dimension(
content_height = styles_height.resolve_dimension(
sizing_container - styles.margin.totals, viewport, fraction_unit
)
if is_border_box:
if is_border_box and styles_height.excludes_border:
content_height -= gutter.height
if styles.min_height is not None:

View File

@@ -78,6 +78,9 @@ class HelpText:
self.summary = summary
self.bullets = bullets or []
def __str__(self) -> str:
return self.summary
def __rich_console__(
self, console: Console, options: ConsoleOptions
) -> RenderResult:

View File

@@ -9,11 +9,11 @@ from .tokenizer import TokenError
class DeclarationError(Exception):
def __init__(self, name: str, token: Token, message: str) -> None:
def __init__(self, name: str, token: Token, message: str | HelpText) -> None:
self.name = name
self.token = token
self.message = message
super().__init__(message)
super().__init__(str(message))
class StyleTypeError(TypeError):

View File

@@ -133,6 +133,10 @@ class SelectorSet:
for selector, next_selector in zip(self.selectors, self.selectors[1:]):
selector.advance = int(next_selector.combinator != SAME)
@property
def css(self) -> str:
return RuleSet._selector_to_css(self.selectors)
def __rich_repr__(self) -> rich.repr.Result:
selectors = RuleSet._selector_to_css(self.selectors)
yield selectors

View File

@@ -6,63 +6,108 @@ actions to the nodes in the query.
If this sounds like JQuery, a (once) popular JS library, it is no coincidence.
DOMQuery objects are typically created by Widget.filter method.
DOMQuery objects are typically created by Widget.query method.
Queries are *lazy*. Results will be calculated at the point you iterate over the query, or call
a method which evaluates the query, such as first() and last().
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Iterator, TypeVar, overload
import rich.repr
from typing import Iterator, overload, TYPE_CHECKING
from .errors import DeclarationError
from .match import match
from .parse import parse_selectors
from .model import SelectorSet
from .parse import parse_declarations, parse_selectors
if TYPE_CHECKING:
from ..dom import DOMNode
from ..widget import Widget
class NoMatchingNodesError(Exception):
class QueryError(Exception):
pass
class NoMatchingNodesError(QueryError):
pass
class WrongType(QueryError):
pass
@rich.repr.auto(angular=True)
class DOMQuery:
__slots__ = [
"_node",
"_nodes",
"_filters",
"_excludes",
]
def __init__(
self,
node: DOMNode | None = None,
selector: str | None = None,
nodes: list[Widget] | None = None,
node: DOMNode,
*,
filter: str | None = None,
exclude: str | None = None,
parent: DOMQuery | None = None,
) -> None:
self._node = node
self._nodes: list[Widget] | None = None
self._filters: list[tuple[SelectorSet, ...]] = (
parent._filters.copy() if parent else []
)
self._excludes: list[tuple[SelectorSet, ...]] = (
parent._excludes.copy() if parent else []
)
if filter is not None:
self._filters.append(parse_selectors(filter))
if exclude is not None:
self._excludes.append(parse_selectors(exclude))
@property
def node(self) -> DOMNode:
return self._node
@property
def nodes(self) -> list[Widget]:
"""Lazily evaluate nodes."""
from ..widget import Widget
self._selector = selector
self._nodes: list[Widget] = []
if nodes is not None:
if self._nodes is None:
nodes = [
node
for node in self._node.walk_children(Widget)
if all(match(selector_set, node) for selector_set in self._filters)
]
nodes = [
node
for node in nodes
if not any(match(selector_set, node) for selector_set in self._excludes)
]
self._nodes = nodes
elif node is not None:
self._nodes = [node for node in node.walk_children()]
else:
self._nodes = []
if selector is not None:
selector_set = parse_selectors(selector)
self._nodes = [_node for _node in self._nodes if match(selector_set, _node)]
return self._nodes
def __len__(self) -> int:
return len(self._nodes)
return len(self.nodes)
def __bool__(self) -> bool:
"""True if non-empty, otherwise False."""
return bool(self._nodes)
return bool(self.nodes)
def __iter__(self) -> Iterator[Widget]:
return iter(self._nodes)
return iter(self.nodes)
def __reversed__(self) -> Iterator[Widget]:
return reversed(self.nodes)
@overload
def __getitem__(self, index: int) -> Widget:
@@ -73,10 +118,20 @@ class DOMQuery:
...
def __getitem__(self, index: int | slice) -> Widget | list[Widget]:
return self._nodes[index]
return self.nodes[index]
def __rich_repr__(self) -> rich.repr.Result:
yield self._nodes
yield self.node
if self._filters:
yield "filter", " AND ".join(
",".join(selector.css for selector in selectors)
for selectors in self._filters
)
if self._excludes:
yield "exclude", " OR ".join(
",".join(selector.css for selector in selectors)
for selectors in self._excludes
)
def filter(self, selector: str) -> DOMQuery:
"""Filter this set by the given CSS selector.
@@ -88,11 +143,7 @@ class DOMQuery:
DOMQuery: New DOM Query.
"""
selector_set = parse_selectors(selector)
query = DOMQuery(
nodes=[_node for _node in self._nodes if match(selector_set, _node)]
)
return query
return DOMQuery(self.node, filter=selector, parent=self)
def exclude(self, selector: str) -> DOMQuery:
"""Exclude nodes that match a given selector.
@@ -103,70 +154,145 @@ class DOMQuery:
Returns:
DOMQuery: New DOM query.
"""
selector_set = parse_selectors(selector)
query = DOMQuery(
nodes=[_node for _node in self._nodes if not match(selector_set, _node)]
)
return query
return DOMQuery(self.node, exclude=selector, parent=self)
ExpectType = TypeVar("ExpectType")
@overload
def first(self) -> Widget:
"""Get the first matched node.
...
@overload
def first(self, expect_type: type[ExpectType]) -> ExpectType:
...
def first(self, expect_type: type[ExpectType] | None = None) -> Widget | ExpectType:
"""Get the *first* match node.
Args:
expect_type (type[ExpectType] | None, optional): Require matched node is of this type,
or None for any type. Defaults to None.
Raises:
WrongType: If the wrong type was found.
NoMatchingNodesError: If there are no matching nodes in the query.
Returns:
DOMNode: A DOM Node.
Widget | ExpectType: The matching Widget.
"""
if self._nodes:
return self._nodes[0]
if self.nodes:
first = self.nodes[0]
if expect_type is not None:
if not isinstance(first, expect_type):
raise WrongType(
f"Query value is wrong type; expected {expect_type}, got {type(first)}"
)
return first
else:
raise NoMatchingNodesError(
f"No nodes match the selector {self._selector!r}"
)
raise NoMatchingNodesError(f"No nodes match {self!r}")
@overload
def last(self) -> Widget:
"""Get the last matched node.
...
@overload
def last(self, expect_type: type[ExpectType]) -> ExpectType:
...
def last(self, expect_type: type[ExpectType] | None = None) -> Widget | ExpectType:
"""Get the *last* match node.
Args:
expect_type (type[ExpectType] | None, optional): Require matched node is of this type,
or None for any type. Defaults to None.
Raises:
WrongType: If the wrong type was found.
NoMatchingNodesError: If there are no matching nodes in the query.
Returns:
DOMNode: A DOM Node.
Widget | ExpectType: The matching Widget.
"""
if self._nodes:
return self._nodes[-1]
if self.nodes:
last = self.nodes[-1]
if expect_type is not None:
if not isinstance(last, expect_type):
raise WrongType(
f"Query value is wrong type; expected {expect_type}, got {type(last)}"
)
return last
else:
raise NoMatchingNodesError(
f"No nodes match the selector {self._selector!r}"
)
raise NoMatchingNodesError(f"No nodes match {self!r}")
@overload
def results(self) -> Iterator[Widget]:
...
@overload
def results(self, filter_type: type[ExpectType]) -> Iterator[ExpectType]:
...
def results(
self, filter_type: type[ExpectType] | None = None
) -> Iterator[Widget | ExpectType]:
"""Get query results, optionally filtered by a given type.
Args:
filter_type (type[ExpectType] | None): A Widget class to filter results,
or None for no filter. Defaults to None.
Yields:
Iterator[Widget | ExpectType]: An iterator of Widget instances.
"""
if filter_type is None:
yield from self
else:
for node in self:
if isinstance(node, filter_type):
yield node
def add_class(self, *class_names: str) -> DOMQuery:
"""Add the given class name(s) to nodes."""
for node in self._nodes:
for node in self:
node.add_class(*class_names)
return self
def remove_class(self, *class_names: str) -> DOMQuery:
"""Remove the given class names from the nodes."""
for node in self._nodes:
for node in self:
node.remove_class(*class_names)
return self
def toggle_class(self, *class_names: str) -> DOMQuery:
"""Toggle the given class names from matched nodes."""
for node in self._nodes:
for node in self:
node.toggle_class(*class_names)
return self
def remove(self) -> DOMQuery:
"""Remove matched nodes from the DOM"""
for node in self._nodes:
for node in self:
node.remove()
return self
def set_styles(self, css: str | None = None, **styles: str) -> DOMQuery:
def set_styles(self, css: str | None = None, **update_styles) -> DOMQuery:
"""Set styles on matched nodes.
Args:
css (str, optional): CSS declarations to parser, or None. Defaults to None.
"""
for node in self._nodes:
node.set_styles(css, **styles)
_rich_traceback_omit = True
for node in self:
node.set_styles(**update_styles)
if css is not None:
try:
new_styles = parse_declarations(css, path="set_styles")
except DeclarationError as error:
raise DeclarationError(error.name, error.token, error.message) from None
for node in self:
node._inline_styles.merge(new_styles)
node.refresh(layout=True)
return self
def refresh(self, *, repaint: bool = True, layout: bool = False) -> DOMQuery:
@@ -179,6 +305,6 @@ class DOMQuery:
Returns:
DOMQuery: Query for chaining.
"""
for node in self._nodes:
for node in self:
node.refresh(repaint=repaint, layout=layout)
return self

View File

@@ -37,6 +37,8 @@ class Unit(Enum):
AUTO = 8
UNIT_EXCLUDES_BORDER = {Unit.CELLS, Unit.FRACTION, Unit.VIEW_WIDTH, Unit.VIEW_HEIGHT}
UNIT_SYMBOL = {
Unit.CELLS: "",
Unit.FRACTION: "fr",
@@ -199,6 +201,10 @@ class Scalar(NamedTuple):
"""Check if the unit is a fraction."""
return self.unit == Unit.FRACTION
@property
def excludes_border(self) -> bool:
return self.unit in UNIT_EXCLUDES_BORDER
@property
def cells(self) -> int | None:
"""Check if the unit is explicit cells."""

View File

@@ -12,7 +12,7 @@ from rich.style import Style
from .._animator import Animation, EasingFunction
from ..color import Color
from ..geometry import Offset, Size, Spacing
from ..geometry import Offset, Spacing
from ._style_properties import (
AlignProperty,
BorderProperty,
@@ -223,8 +223,6 @@ class StylesBase(ABC):
layers = NameListProperty()
transitions = TransitionsProperty()
rich_style = StyleProperty()
tint = ColorProperty("transparent")
scrollbar_color = ColorProperty("ansi_bright_magenta")
scrollbar_color_hover = ColorProperty("ansi_yellow")
@@ -800,6 +798,12 @@ class RenderStyles(StylesBase):
"""Quick access to the inline styles."""
return self._inline_styles
@property
def rich_style(self) -> Style:
"""Get a Rich style for this Styles object."""
assert self.node is not None
return self.node.rich_style
def __rich_repr__(self) -> rich.repr.Result:
for rule_name in RULE_NAMES:
if self.has_rule(rule_name):

View File

@@ -23,7 +23,7 @@ from .parse import parse
from .styles import RulesMap, Styles
from .tokenize import tokenize_values, Token
from .tokenizer import TokenError
from .types import Specificity3, Specificity5
from .types import Specificity3, Specificity6
from ..dom import DOMNode
from .. import messages
@@ -325,7 +325,7 @@ class Stylesheet:
# We can use this to determine, for a given rule, whether we should apply it
# or not by examining the specificity. If we have two rules for the
# same attribute, then we can choose the most specific rule and use that.
rule_attributes: dict[str, list[tuple[Specificity5, object]]]
rule_attributes: dict[str, list[tuple[Specificity6, object]]]
rule_attributes = defaultdict(list)
_check_rule = self._check_rule
@@ -352,12 +352,12 @@ class Stylesheet:
self.replace_rules(node, node_rules, animate=animate)
node.component_styles.clear()
node._component_styles.clear()
for component in node.COMPONENT_CLASSES:
virtual_node = DOMNode(classes=component)
virtual_node.set_parent(node)
self.apply(virtual_node, animate=False)
node.component_styles[component] = virtual_node.styles
node._component_styles[component] = virtual_node.styles
@classmethod
def replace_rules(

View File

@@ -1,8 +1,7 @@
from __future__ import annotations
from inspect import getfile
from operator import attrgetter
from typing import ClassVar, Iterable, Iterator, Type, TYPE_CHECKING
from typing import ClassVar, Iterable, Iterator, Type, overload, TypeVar, TYPE_CHECKING
import rich.repr
from rich.highlighter import ReprHighlighter
@@ -16,7 +15,7 @@ from ._node_list import NodeList
from .color import Color, WHITE, BLACK
from .css._error_tools import friendly_list
from .css.constants import VALID_DISPLAY, VALID_VISIBILITY
from .css.errors import StyleValueError
from .css.errors import StyleValueError, DeclarationError
from .css.parse import parse_declarations
from .css.styles import Styles, RenderStyles
from .css.query import NoMatchingNodesError
@@ -24,7 +23,6 @@ from .message_pump import MessagePump
if TYPE_CHECKING:
from .app import App
from .css.styles import StylesBase
from .css.query import DOMQuery
from .screen import Screen
from .widget import Widget
@@ -68,7 +66,7 @@ class DOMNode(MessagePump):
self._inline_styles: Styles = Styles(self)
self.styles = RenderStyles(self, self._css_styles, self._inline_styles)
# A mapping of class names to Styles set in COMPONENT_CLASSES
self.component_styles: dict[str, StylesBase] = {}
self._component_styles: dict[str, RenderStyles] = {}
super().__init__()
@@ -80,6 +78,23 @@ class DOMNode(MessagePump):
css_type_names.add(base.__name__.lower())
cls._css_type_names = frozenset(css_type_names)
def get_component_styles(self, name: str) -> RenderStyles:
"""Get a "component" styles object (must be defined in COMPONENT_CLASSES classvar).
Args:
name (str): Name of the component.
Raises:
KeyError: If the component class doesn't exist.
Returns:
RenderStyles: A Styles object.
"""
if name not in self._component_styles:
raise KeyError(f"No {name!r} key in COMPONENT_CLASSES")
styles = self._component_styles[name]
return styles
@property
def _node_bases(self) -> Iterator[Type[DOMNode]]:
"""Get the DOMNode bases classes (including self.__class__)
@@ -445,7 +460,7 @@ class DOMNode(MessagePump):
node._set_dirty()
node._layout_required = True
def add_child(self, node: DOMNode) -> None:
def add_child(self, node: Widget) -> None:
"""Add a new child node.
Args:
@@ -454,7 +469,7 @@ class DOMNode(MessagePump):
self.children._append(node)
node.set_parent(self)
def add_children(self, *nodes: DOMNode, **named_nodes: DOMNode) -> None:
def add_children(self, *nodes: Widget, **named_nodes: Widget) -> None:
"""Add multiple children to this node.
Args:
@@ -470,19 +485,45 @@ class DOMNode(MessagePump):
_append(node)
node.id = node_id
def walk_children(self, with_self: bool = True) -> Iterable[DOMNode]:
"""Generate all descendants of this node.
WalkType = TypeVar("WalkType")
@overload
def walk_children(
self,
filter_type: type[WalkType],
*,
with_self: bool = True,
) -> Iterable[WalkType]:
...
@overload
def walk_children(self, *, with_self: bool = True) -> Iterable[DOMNode]:
...
def walk_children(
self,
filter_type: type[WalkType] | None = None,
*,
with_self: bool = True,
) -> Iterable[DOMNode | WalkType]:
"""Generate descendant nodes.
Args:
with_self (bool, optional): Also include self in the results. Defaults to True.
filter_type (type[WalkType] | None, optional): Filter only this type, or None for no filter.
Defaults to None.
with_self (bool, optional): Also yield self in addition to descendants. Defaults to True.
Returns:
Iterable[DOMNode | WalkType]: An iterable of nodes.
"""
stack: list[Iterator[DOMNode]] = [iter(self.children)]
pop = stack.pop
push = stack.append
check_type = filter_type or DOMNode
if with_self:
if with_self and isinstance(self, check_type):
yield self
while stack:
@@ -490,7 +531,8 @@ class DOMNode(MessagePump):
if node is None:
pop()
else:
yield node
if isinstance(node, check_type):
yield node
if node.children:
push(iter(node.children))
@@ -522,10 +564,28 @@ class DOMNode(MessagePump):
"""
from .css.query import DOMQuery
return DOMQuery(self, selector)
return DOMQuery(self, filter=selector)
ExpectType = TypeVar("ExpectType")
@overload
def query_one(self, selector: str) -> Widget:
"""Get the first Widget matching the given selector.
...
@overload
def query_one(self, selector: type[ExpectType]) -> ExpectType:
...
@overload
def query_one(self, selector: str, expect_type: type[ExpectType]) -> ExpectType:
...
def query_one(
self,
selector: str | type[ExpectType],
expect_type: type[ExpectType] | None = None,
) -> ExpectType | Widget:
"""Get the first Widget matching the given selector or selector type.
Args:
selector (str | None, optional): A selector.
@@ -535,19 +595,31 @@ class DOMNode(MessagePump):
"""
from .css.query import DOMQuery
query = DOMQuery(self.screen, selector)
return query.first()
if isinstance(selector, str):
query_selector = selector
else:
query_selector = selector.__name__
query = DOMQuery(self.screen, filter=query_selector)
def set_styles(self, css: str | None = None, **styles) -> None:
if expect_type is None:
return query.first()
else:
return query.first(expect_type)
def set_styles(self, css: str | None = None, **update_styles) -> None:
"""Set custom styles on this object."""
# TODO: This can be done more efficiently
kwarg_css = "\n".join(
f"{key.replace('_', '-')}: {value}" for key, value in styles.items()
)
apply_css = f"{css or ''}\n{kwarg_css}\n"
new_styles = parse_declarations(apply_css, f"<custom styles for ${self!r}>")
self.styles.merge(new_styles)
self.refresh()
if css is not None:
try:
new_styles = parse_declarations(css, path="set_styles")
except DeclarationError as error:
raise DeclarationError(error.name, error.token, error.message) from None
self._inline_styles.merge(new_styles)
self.refresh(layout=True)
styles = self.styles
for key, value in update_styles.items():
setattr(styles, key, value)
def has_class(self, *class_names: str) -> bool:
"""Check if the Node has all the given class names.

View File

@@ -342,10 +342,13 @@ class MessagePump(metaclass=MessagePumpMeta):
message (Message): Message object.
"""
private_method = f"_{method_name}"
for cls in self.__class__.__mro__:
if message._no_default_action:
break
method = cls.__dict__.get(method_name, None)
method = cls.__dict__.get(private_method, None) or cls.__dict__.get(
method_name, None
)
if method is not None:
yield cls, method.__get__(self, cls)

View File

@@ -137,7 +137,7 @@ class Reactive(Generic[ReactiveType]):
def watch(
obj: Reactable, attribute_name: str, callback: Callable[[Any], Awaitable[None]]
obj: Reactable, attribute_name: str, callback: Callable[[Any], object]
) -> None:
watcher_name = f"__{attribute_name}_watchers"
current_value = getattr(obj, attribute_name, None)

View File

@@ -57,7 +57,7 @@ class Screen(Widget):
"""Timer used to perform updates."""
if self._update_timer is None:
self._update_timer = self.set_interval(
UPDATE_PERIOD, self._on_update, name="screen_update", pause=True
UPDATE_PERIOD, self._on_timer_update, name="screen_update", pause=True
)
return self._update_timer
@@ -131,7 +131,7 @@ class Screen(Widget):
# The Screen is idle - a good opportunity to invoke the scheduled callbacks
await self._invoke_and_clear_callbacks()
def _on_update(self) -> None:
def _on_timer_update(self) -> None:
"""Called by the _update_timer."""
# Render widgets together
if self._dirty_widgets:
@@ -228,7 +228,7 @@ class Screen(Widget):
async def on_resize(self, event: events.Resize) -> None:
event.stop()
async def _on_mouse_move(self, event: events.MouseMove) -> None:
async def _handle_mouse_move(self, event: events.MouseMove) -> None:
try:
if self.app.mouse_captured:
widget = self.app.mouse_captured
@@ -265,7 +265,7 @@ class Screen(Widget):
elif isinstance(event, events.MouseMove):
event.style = self.get_style_at(event.screen_x, event.screen_y)
await self._on_mouse_move(event)
await self._handle_mouse_move(event)
elif isinstance(event, events.MouseEvent):
try:

View File

@@ -135,6 +135,16 @@ class Widget(DOMNode):
show_vertical_scrollbar = Reactive(False, layout=True)
show_horizontal_scrollbar = Reactive(False, layout=True)
@property
def allow_vertical_scroll(self) -> bool:
"""Check if vertical scroll is permitted."""
return self.is_scrollable and self.show_vertical_scrollbar
@property
def allow_horizontal_scroll(self) -> bool:
"""Check if horizontal scroll is permitted."""
return self.is_scrollable and self.show_horizontal_scrollbar
def _arrange(self, size: Size) -> DockArrangeResult:
"""Arrange children.
@@ -938,15 +948,13 @@ class Widget(DOMNode):
def watch(self, attribute_name, callback: Callable[[Any], Awaitable[None]]) -> None:
watch(self, attribute_name, callback)
def _render_styled(self) -> RenderableType:
def post_render(self, renderable: RenderableType) -> RenderableType:
"""Applies style attributes to the default renderable.
Returns:
RenderableType: A new renderable.
"""
renderable = self.render()
if isinstance(renderable, str):
renderable = Text.from_markup(renderable)
@@ -1002,7 +1010,8 @@ class Widget(DOMNode):
def _render_content(self) -> None:
"""Render all lines."""
width, height = self.size
renderable = self._render_styled()
renderable = self.render()
renderable = self.post_render(renderable)
options = self.console.options.update_dimensions(width, height).update(
highlight=False
)
@@ -1155,7 +1164,7 @@ class Widget(DOMNode):
assert self.parent
self.parent.refresh(layout=True)
def on_mount(self, event: events.Mount) -> None:
def _on_mount(self, event: events.Mount) -> None:
widgets = list(self.compose())
if widgets:
self.mount(*widgets)
@@ -1196,12 +1205,12 @@ class Widget(DOMNode):
break
def on_mouse_scroll_down(self, event) -> None:
if self.is_scrollable:
if self.allow_vertical_scroll:
if self.scroll_down(animate=False):
event.stop()
def on_mouse_scroll_up(self, event) -> None:
if self.is_scrollable:
if self.allow_vertical_scroll:
if self.scroll_up(animate=False):
event.stop()

View File

@@ -349,9 +349,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
Lines: A list of segments per line.
"""
if hover:
style += self.component_styles["datatable--highlight"].node.rich_style
style += self.get_component_styles("datatable--highlight").rich_style
if cursor:
style += self.component_styles["datatable--cursor"].node.rich_style
style += self.get_component_styles("datatable--cursor").rich_style
cell_key = (row_index, column_index, style, cursor, hover)
if cell_key not in self._cell_render_cache:
style += Style.from_meta({"row": row_index, "column": column_index})
@@ -394,7 +394,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
render_cell = self._render_cell
if self.fixed_columns:
fixed_style = self.component_styles["datatable--fixed"].node.rich_style
fixed_style = self.get_component_styles("datatable--fixed").rich_style
fixed_style += Style.from_meta({"fixed": True})
fixed_row = [
render_cell(row_index, column.index, fixed_style, column.width)[line_no]
@@ -404,13 +404,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
fixed_row = []
if row_index == -1:
row_style = self.component_styles["datatable--header"].node.rich_style
row_style = self.get_component_styles("datatable--header").rich_style
else:
if self.zebra_stripes:
component_row_style = (
"datatable--odd-row" if row_index % 2 else "datatable--even-row"
)
row_style = self.component_styles[component_row_style].node.rich_style
row_style = self.get_component_styles(component_row_style).rich_style
else:
row_style = base_style

View File

@@ -106,7 +106,7 @@ class DirectoryTree(TreeControl[DirEntry]):
return icon_label
async def on_mount(self, event: events.Mount) -> None:
await self.load_directory(self.root)
self.call_later(self.load_directory, self.root)
async def load_directory(self, node: TreeNode[DirEntry]):
path = node.data.path

View File

@@ -12,10 +12,41 @@ from ..widget import Widget
@rich.repr.auto
class Footer(Widget):
CSS = """
Footer {
background: $accent;
color: $text-accent;
dock: bottom;
height: 1;
}
Footer > .footer--highlight {
background: $accent-darken-1;
color: $text-accent-darken-1;
}
Footer > .footer--highlight-key {
background: $secondary;
color: $text-secondary;
text-style: bold;
}
Footer > .footer--key {
text-style: bold;
background: $accent-darken-2;
color: $text-accent-darken-2;
}
"""
COMPONENT_CLASSES = {
"footer--description",
"footer--key",
"footer--highlight",
"footer--highlight-key",
}
def __init__(self) -> None:
self.keys: list[tuple[str, str]] = []
super().__init__()
self.layout_size = 1
self._key_text: Text | None = None
highlight_key: Reactive[str | None] = Reactive(None)
@@ -37,13 +68,19 @@ class Footer(Widget):
def make_key_text(self) -> Text:
"""Create text containing all the keys."""
base_style = self.rich_style
text = Text(
style="white on dark_green",
style=self.rich_style,
no_wrap=True,
overflow="ellipsis",
justify="left",
end="",
)
highlight_style = self.get_component_styles("footer--highlight").rich_style
highlight_key_style = self.get_component_styles(
"footer--highlight-key"
).rich_style
key_style = self.get_component_styles("footer--key").rich_style
for binding in self.app.bindings.shown_keys:
key_display = (
binding.key.upper()
@@ -52,13 +89,19 @@ class Footer(Widget):
)
hovered = self.highlight_key == binding.key
key_text = Text.assemble(
(f" {key_display} ", "reverse" if hovered else "default on default"),
f" {binding.description} ",
(f" {key_display} ", highlight_key_style if hovered else key_style),
(
f" {binding.description} ",
highlight_style if hovered else base_style,
),
meta={"@click": f"app.press('{binding.key}')", "key": binding.key},
)
text.append_text(key_text)
return text
def post_render(self, renderable):
return renderable
def render(self) -> RenderableType:
if self._key_text is None:
self._key_text = self.make_key_text()

View File

@@ -1,78 +1,103 @@
from __future__ import annotations
from datetime import datetime
from logging import getLogger
from rich.console import RenderableType
from rich.panel import Panel
from rich.repr import Result
from rich.style import StyleType, Style
from rich.table import Table
from rich.text import Text
from .. import events
from ..reactive import watch, Reactive
from ..widget import Widget
from ..reactive import Reactive, watch
log = getLogger("rich")
class HeaderIcon(Widget):
"""Display an 'icon' on the left of the header."""
CSS = """
HeaderIcon {
dock: left;
padding: 0 1;
width: 10;
content-align: left middle;
}
"""
icon = Reactive("")
def render(self):
return self.icon
class HeaderClock(Widget):
"""Display a clock on the right of the header."""
CSS = """
HeaderClock {
dock: right;
width: auto;
padding: 0 1;
background: $secondary-background-lighten-1;
color: $text-secondary-background;
opacity: 85%;
content-align: center middle;
}
"""
def on_mount(self) -> None:
self.set_interval(1, callback=self.refresh)
def render(self):
return Text(datetime.now().time().strftime("%X"))
class HeaderTitle(Widget):
"""Display the title / subtitle in the header."""
CSS = """
HeaderTitle {
content-align: center middle;
width: 100%;
}
"""
text: Reactive[str] = Reactive("Hello World")
sub_text = Reactive("Test")
def render(self) -> Text:
text = Text(self.text, no_wrap=True, overflow="ellipsis")
if self.sub_text:
text.append(f" - {self.sub_text}", "dim")
return text
class Header(Widget):
def __init__(
self,
*,
tall: bool = True,
style: StyleType = "white on dark_green",
clock: bool = True,
) -> None:
super().__init__()
self.tall = tall
self.style = style
self.clock = clock
"""A header widget with icon and clock."""
tall: Reactive[bool] = Reactive(True, layout=True)
style: Reactive[StyleType] = Reactive("white on blue")
clock: Reactive[bool] = Reactive(True)
title: Reactive[str] = Reactive("")
sub_title: Reactive[str] = Reactive("")
CSS = """
Header {
dock: top;
width: 100%;
background: $secondary-background;
color: $text-secondary-background;
height: 1;
}
Header.tall {
height: 3;
}
"""
@property
def full_title(self) -> str:
return f"{self.title} - {self.sub_title}" if self.sub_title else self.title
async def on_click(self, event):
self.toggle_class("tall")
def __rich_repr__(self) -> Result:
yield from super().__rich_repr__()
yield "title", self.title
def on_mount(self) -> None:
def set_title(title: str) -> None:
self.query_one(HeaderTitle).text = title
async def watch_tall(self, tall: bool) -> None:
self.layout_size = 3 if tall else 1
def get_clock(self) -> str:
return datetime.now().time().strftime("%X")
def render(self) -> RenderableType:
header_table = Table.grid(padding=(0, 1), expand=True)
header_table.style = self.style
header_table.add_column(justify="left", ratio=0, width=8)
header_table.add_column("title", justify="center", ratio=1)
header_table.add_column("clock", justify="right", width=8)
header_table.add_row(
"🐞", self.full_title, self.get_clock() if self.clock else ""
)
header: RenderableType
header = Panel(header_table, style=self.style) if self.tall else header_table
return header
async def on_mount(self, event: events.Mount) -> None:
self.set_interval(1.0, callback=self.refresh)
async def set_title(title: str) -> None:
self.title = title
async def set_sub_title(sub_title: str) -> None:
self.sub_title = sub_title
def set_sub_title(sub_title: str) -> None:
self.query_one(HeaderTitle).sub_text = sub_title
watch(self.app, "title", set_title)
watch(self.app, "sub_title", set_sub_title)
self.add_class("tall")
async def on_click(self, event: events.Click) -> None:
self.tall = not self.tall
def compose(self):
yield HeaderIcon()
yield HeaderTitle()
yield HeaderClock()

View File

@@ -267,7 +267,7 @@ class TreeControl(Generic[NodeDataType], Widget, can_focus=True):
return None
def render(self) -> RenderableType:
self._tree.guide_style = self.component_styles["tree--guides"].node.rich_style
self._tree.guide_style = self._component_styles["tree--guides"].node.rich_style
return self._tree
def render_node(self, node: TreeNode[NodeDataType]) -> RenderableType:

View File

@@ -1,14 +1,11 @@
from textual.dom import DOMNode
from textual.widget import Widget
def test_query():
class Widget(DOMNode):
class View(Widget):
pass
class View(DOMNode):
pass
class App(DOMNode):
class App(Widget):
pass
app = App()
@@ -52,7 +49,15 @@ def test_query():
assert list(app.query("View#main")) == [main_view]
assert list(app.query("#widget1")) == [widget1]
assert list(app.query("#widget2")) == [widget2]
assert list(app.query("Widget.float")) == [sidebar, tooltip, helpbar]
assert list(app.query("Widget.float").results(Widget)) == [
sidebar,
tooltip,
helpbar,
]
assert list(app.query("Widget.float").results(View)) == []
assert list(app.query("Widget.float.transient")) == [tooltip]
assert list(app.query("App > View")) == [main_view, help_view]

View File

@@ -164,7 +164,7 @@ class AppTest(App):
screen.refresh(repaint=repaint, layout=layout)
# We also have to make sure we have at least one dirty widget, or `screen._on_update()` will early return:
screen._dirty_widgets.add(screen)
screen._on_update()
screen._on_timer_update()
await let_asyncio_process_some_events()