Merge branch 'main' into alt-compose

This commit is contained in:
Will McGugan
2023-02-21 10:46:45 +00:00
committed by GitHub
42 changed files with 1563 additions and 539 deletions

View File

@@ -5,12 +5,31 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/) The format is based on [Keep a Changelog](http://keepachangelog.com/)
and this project adheres to [Semantic Versioning](http://semver.org/). and this project adheres to [Semantic Versioning](http://semver.org/).
## Unreleased
## [0.12.0] - Unreleased
### Added
- Added `App.batch_update` https://github.com/Textualize/textual/pull/1832
- Added horizontal rule to Markdown https://github.com/Textualize/textual/pull/1832
- Added `Widget.disabled` https://github.com/Textualize/textual/pull/1785
### Changed ### Changed
- Scrolling by page now adds to current position.
- Markdown lists have been polished: a selection of bullets, better alignment of numbers, style tweaks https://github.com/Textualize/textual/pull/1832
- Added alternative method of composing Widgets https://github.com/Textualize/textual/pull/1847 - Added alternative method of composing Widgets https://github.com/Textualize/textual/pull/1847
### Removed
- Removed `screen.visible_widgets` and `screen.widgets`
### Fixed
- Numbers in a descendant-combined selector no longer cause an error https://github.com/Textualize/textual/issues/1836
## [0.11.1] - 2023-02-17 ## [0.11.1] - 2023-02-17
### Fixed ### Fixed

View File

@@ -28,12 +28,12 @@
Screen { Screen {
layers: ruler; layers: ruler;
overflow: hidden;
} }
Ruler { Ruler {
layer: ruler; layer: ruler;
dock: right; dock: right;
overflow: hidden;
width: 1; width: 1;
background: $accent; background: $accent;
} }

View File

@@ -315,6 +315,8 @@ The `background: green` is only applied to the Button underneath the mouse curso
Here are some other pseudo classes: Here are some other pseudo classes:
- `:disabled` Matches widgets which are in a disabled state.
- `:enabled` Matches widgets which are in an enabled state.
- `:focus` Matches widgets which have input focus. - `:focus` Matches widgets which have input focus.
- `:focus-within` Matches widgets with a focused a child widget. - `:focus-within` Matches widgets with a focused a child widget.

View File

@@ -8,9 +8,9 @@ Input {
} }
#results { #results {
width: auto; width: 100%;
min-height: 100%; height: auto;
padding: 0 1;
} }
#results-container { #results-container {

View File

@@ -7,11 +7,10 @@ try:
except ImportError: except ImportError:
raise ImportError("Please install httpx with 'pip install httpx' ") raise ImportError("Please install httpx with 'pip install httpx' ")
from rich.markdown import Markdown
from textual.app import App, ComposeResult from textual.app import App, ComposeResult
from textual.containers import Content from textual.containers import Content
from textual.widgets import Input, Static from textual.widgets import Input, Markdown
class DictionaryApp(App): class DictionaryApp(App):
@@ -36,17 +35,22 @@ class DictionaryApp(App):
asyncio.create_task(self.lookup_word(message.value)) asyncio.create_task(self.lookup_word(message.value))
else: else:
# Clear the results # Clear the results
self.query_one("#results", Static).update() await self.query_one("#results", Markdown).update("")
async def lookup_word(self, word: str) -> None: async def lookup_word(self, word: str) -> None:
"""Looks up a word.""" """Looks up a word."""
url = f"https://api.dictionaryapi.dev/api/v2/entries/en/{word}" url = f"https://api.dictionaryapi.dev/api/v2/entries/en/{word}"
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
results = (await client.get(url)).json() response = await client.get(url)
try:
results = response.json()
except Exception:
self.query_one("#results", Static).update(response.text)
return
if word == self.query_one(Input).value: if word == self.query_one(Input).value:
markdown = self.make_word_markdown(results) markdown = self.make_word_markdown(results)
self.query_one("#results", Static).update(Markdown(markdown)) await self.query_one("#results", Markdown).update(markdown)
def make_word_markdown(self, results: object) -> str: def make_word_markdown(self, results: object) -> str:
"""Convert the results in to markdown.""" """Convert the results in to markdown."""

View File

@@ -42,6 +42,32 @@ Two tildes indicates strikethrough, e.g. `~~cross out~~` render ~~cross out~~.
Inline code is indicated by backticks. e.g. `import this`. Inline code is indicated by backticks. e.g. `import this`.
## Lists
1. Lists can be ordered
2. Lists can be unordered
- I must not fear.
- Fear is the mind-killer.
- Fear is the little-death that brings total obliteration.
- I will face my fear.
- I will permit it to pass over me and through me.
- And when it has gone past, I will turn the inner eye to see its path.
- Where the fear has gone there will be nothing. Only I will remain.
### Longer list
1. **Duke Leto I Atreides**, head of House Atreides
2. **Lady Jessica**, Bene Gesserit and concubine of Leto, and mother of Paul and Alia
3. **Paul Atreides**, son of Leto and Jessica
4. **Alia Atreides**, daughter of Leto and Jessica
5. **Gurney Halleck**, troubadour warrior of House Atreides
6. **Thufir Hawat**, Mentat and Master of Assassins of House Atreides
7. **Duncan Idaho**, swordmaster of House Atreides
8. **Dr. Wellington Yueh**, Suk doctor of House Atreides
9. **Leto**, first son of Paul and Chani who dies as a toddler
10. **Esmar Tuek**, a smuggler on Arrakis
11. **Staban Tuek**, son of Esmar
## Fences ## Fences
Fenced code blocks are introduced with three back-ticks and the optional parser. Here we are rendering the code in a sub-widget with syntax highlighting and indent guides. Fenced code blocks are introduced with three back-ticks and the optional parser. Here we are rendering the code in a sub-widget with syntax highlighting and indent guides.

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
"""Simple version of 5x5, developed for/with Textual.""" """Simple version of 5x5, developed for/with Textual."""
from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, cast from typing import TYPE_CHECKING, cast
@@ -192,8 +192,7 @@ class Game(Screen):
Args: Args:
playable (bool): Should the game currently be playable? playable (bool): Should the game currently be playable?
""" """
for cell in self.query(GameCell): self.query_one(GameGrid).disabled = not playable
cell.disabled = not playable
def cell(self, row: int, col: int) -> GameCell: def cell(self, row: int, col: int) -> GameCell:
"""Get the cell at a given location. """Get the cell at a given location.

View File

@@ -128,4 +128,4 @@ def arrange(
placements.extend(layout_placements) placements.extend(layout_placements)
return placements, arrange_widgets, scroll_spacing return DockArrangeResult(placements, arrange_widgets, scroll_spacing)

View File

@@ -167,6 +167,7 @@ class Compositor:
def __init__(self) -> None: def __init__(self) -> None:
# A mapping of Widget on to its "render location" (absolute position / depth) # A mapping of Widget on to its "render location" (absolute position / depth)
self.map: CompositorMap = {} self.map: CompositorMap = {}
self._full_map: CompositorMap | None = None
self._layers: list[tuple[Widget, MapGeometry]] | None = None self._layers: list[tuple[Widget, MapGeometry]] | None = None
# All widgets considered in the arrangement # All widgets considered in the arrangement
@@ -241,29 +242,27 @@ class Compositor:
size: Size of the area to be filled. size: Size of the area to be filled.
Returns: Returns:
Hidden shown and resized widgets. Hidden, shown, and resized widgets.
""" """
self._cuts = None self._cuts = None
self._layers = None self._layers = None
self._layers_visible = None self._layers_visible = None
self._visible_widgets = None self._visible_widgets = None
self._full_map = None
self.root = parent self.root = parent
self.size = size self.size = size
# Keep a copy of the old map because we're going to compare it with the update # Keep a copy of the old map because we're going to compare it with the update
old_map = self.map.copy() old_map = self.map
old_widgets = old_map.keys() old_widgets = old_map.keys()
map, widgets = self._arrange_root(parent, size) map, widgets = self._arrange_root(parent, size)
new_widgets = map.keys()
# Newly visible widgets new_widgets = map.keys()
shown_widgets = new_widgets - old_widgets
# Newly hidden widgets
hidden_widgets = old_widgets - new_widgets
# Replace map and widgets # Replace map and widgets
self.map = map self.map = map
self._full_map = map
self.widgets = widgets self.widgets = widgets
# Contains widgets + geometry for every widget that changed (added, removed, or updated) # Contains widgets + geometry for every widget that changed (added, removed, or updated)
@@ -272,13 +271,7 @@ class Compositor:
# Widgets in both new and old # Widgets in both new and old
common_widgets = old_widgets & new_widgets common_widgets = old_widgets & new_widgets
# Widgets with changed size # Mark dirty regions.
resized_widgets = {
widget
for widget, (region, *_) in changes
if (widget in common_widgets and old_map[widget].region[2:] != region[2:])
}
screen_region = size.region screen_region = size.region
if screen_region not in self._dirty_regions: if screen_region not in self._dirty_regions:
regions = { regions = {
@@ -291,12 +284,80 @@ class Compositor:
} }
self._dirty_regions.update(regions) self._dirty_regions.update(regions)
resized_widgets = {
widget
for widget, (region, *_) in changes
if (widget in common_widgets and old_map[widget].region[2:] != region[2:])
}
# Newly visible widgets
shown_widgets = new_widgets - old_widgets
# Newly hidden widgets
hidden_widgets = self.widgets - widgets
return ReflowResult( return ReflowResult(
hidden=hidden_widgets, hidden=hidden_widgets,
shown=shown_widgets, shown=shown_widgets,
resized=resized_widgets, resized=resized_widgets,
) )
def reflow_visible(self, parent: Widget, size: Size) -> set[Widget]:
"""Reflow only the visible children.
This is a fast-path for scrolling.
Args:
parent: The root widget.
size: Size of the area to be filled.
Returns:
Set of widgets that were exposed by the scroll.
"""
self._cuts = None
self._layers = None
self._layers_visible = None
self._visible_widgets = None
self._full_map = None
self.root = parent
self.size = size
# Keep a copy of the old map because we're going to compare it with the update
old_map = self.map
map, widgets = self._arrange_root(parent, size, visible_only=True)
exposed_widgets = map.keys() - old_map.keys()
# Replace map and widgets
self.map = map
self.widgets = widgets
# Contains widgets + geometry for every widget that changed (added, removed, or updated)
changes = map.items() ^ old_map.items()
# Mark dirty regions.
screen_region = size.region
if screen_region not in self._dirty_regions:
regions = {
region
for region in (
map_geometry.clip.intersection(map_geometry.region)
for _, map_geometry in changes
)
if region
}
self._dirty_regions.update(regions)
return exposed_widgets
@property
def full_map(self) -> CompositorMap:
"""Lazily built compositor map that covers all widgets."""
if self.root is None or not self.map:
return {}
if self._full_map is None:
map, widgets = self._arrange_root(self.root, self.size, visible_only=False)
self._full_map = map
return self._full_map
@property @property
def visible_widgets(self) -> dict[Widget, tuple[Region, Region]]: def visible_widgets(self) -> dict[Widget, tuple[Region, Region]]:
"""Get a mapping of widgets on to region and clip. """Get a mapping of widgets on to region and clip.
@@ -322,9 +383,9 @@ class Compositor:
return self._visible_widgets return self._visible_widgets
def _arrange_root( def _arrange_root(
self, root: Widget, size: Size self, root: Widget, size: Size, visible_only: bool = True
) -> tuple[CompositorMap, set[Widget]]: ) -> tuple[CompositorMap, set[Widget]]:
"""Arrange a widgets children based on its layout attribute. """Arrange a widget's children based on its layout attribute.
Args: Args:
root: Top level widget. root: Top level widget.
@@ -337,6 +398,7 @@ class Compositor:
map: CompositorMap = {} map: CompositorMap = {}
widgets: set[Widget] = set() widgets: set[Widget] = set()
add_new_widget = widgets.add
layer_order: int = 0 layer_order: int = 0
def add_widget( def add_widget(
@@ -362,7 +424,7 @@ class Compositor:
visible = visibility == "visible" visible = visibility == "visible"
if visible: if visible:
widgets.add(widget) add_new_widget(widget)
styles_offset = widget.styles.offset styles_offset = widget.styles.offset
layout_offset = ( layout_offset = (
styles_offset.resolve(region.size, clip.size) styles_offset.resolve(region.size, clip.size)
@@ -389,69 +451,75 @@ class Compositor:
if widget.is_container: if widget.is_container:
# Arrange the layout # Arrange the layout
placements, arranged_widgets, spacing = widget._arrange( arrange_result = widget._arrange(child_region.size)
child_region.size arranged_widgets = arrange_result.widgets
) spacing = arrange_result.spacing
widgets.update(arranged_widgets) widgets.update(arranged_widgets)
if placements: if visible_only:
# An offset added to all placements placements = arrange_result.get_visible_placements(
placement_offset = container_region.offset container_size.region + widget.scroll_offset
placement_scroll_offset = ( )
placement_offset - widget.scroll_offset else:
placements = arrange_result.placements
total_region = total_region.union(arrange_result.total_region)
# An offset added to all placements
placement_offset = container_region.offset
placement_scroll_offset = placement_offset - widget.scroll_offset
_layers = widget.layers
layers_to_index = {
layer_name: index for index, layer_name in enumerate(_layers)
}
get_layer_index = layers_to_index.get
# Add all the widgets
for sub_region, margin, sub_widget, z, fixed in reversed(
placements
):
# Combine regions with children to calculate the "virtual size"
if fixed:
widget_region = sub_region + placement_offset
else:
total_region = total_region.union(
sub_region.grow(spacing + margin)
)
widget_region = sub_region + placement_scroll_offset
widget_order = (
*order,
get_layer_index(sub_widget.layer, 0),
z,
layer_order,
) )
_layers = widget.layers add_widget(
layers_to_index = { sub_widget,
layer_name: index sub_region,
for index, layer_name in enumerate(_layers) widget_region,
} widget_order,
get_layer_index = layers_to_index.get layer_order,
sub_clip,
visible,
)
# Add all the widgets layer_order -= 1
for sub_region, margin, sub_widget, z, fixed in reversed(
placements
):
# Combine regions with children to calculate the "virtual size"
if fixed:
widget_region = sub_region + placement_offset
else:
total_region = total_region.union(
sub_region.grow(spacing + margin)
)
widget_region = sub_region + placement_scroll_offset
widget_order = (
*order,
get_layer_index(sub_widget.layer, 0),
z,
layer_order,
)
add_widget(
sub_widget,
sub_region,
widget_region,
widget_order,
layer_order,
sub_clip,
visible,
)
layer_order -= 1
if visible: if visible:
# Add any scrollbars # Add any scrollbars
for chrome_widget, chrome_region in widget._arrange_scrollbars( if any(widget.scrollbars_enabled):
container_region for chrome_widget, chrome_region in widget._arrange_scrollbars(
): container_region
map[chrome_widget] = _MapGeometry( ):
chrome_region + layout_offset, map[chrome_widget] = _MapGeometry(
order, chrome_region + layout_offset,
clip, order,
container_size, clip,
container_size, container_size,
chrome_region, container_size,
) chrome_region,
)
map[widget] = _MapGeometry( map[widget] = _MapGeometry(
region + layout_offset, region + layout_offset,
@@ -519,7 +587,10 @@ class Compositor:
try: try:
return self.map[widget].region.offset return self.map[widget].region.offset
except KeyError: except KeyError:
raise errors.NoWidget("Widget is not in layout") try:
return self.full_map[widget].region.offset
except KeyError:
raise errors.NoWidget("Widget is not in layout")
def get_widget_at(self, x: int, y: int) -> tuple[Widget, Region]: def get_widget_at(self, x: int, y: int) -> tuple[Widget, Region]:
"""Get the widget under a given coordinate. """Get the widget under a given coordinate.
@@ -601,10 +672,15 @@ class Compositor:
Widget's composition information. Widget's composition information.
""" """
if self.root is None or not self.map:
raise errors.NoWidget("Widget is not in layout")
try: try:
region = self.map[widget] region = self.map[widget]
except KeyError: except KeyError:
raise errors.NoWidget("Widget is not in layout") try:
return self.full_map[widget]
except KeyError:
raise errors.NoWidget("Widget is not in layout")
else: else:
return region return region
@@ -788,6 +864,7 @@ class Compositor:
widget: Widget to update. widget: Widget to update.
""" """
self._full_map = None
regions: list[Region] = [] regions: list[Region] = []
add_region = regions.append add_region = regions.append
get_widget = self.visible_widgets.__getitem__ get_widget = self.visible_widgets.__getitem__

View File

@@ -1,8 +1,10 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, NamedTuple from typing import TYPE_CHECKING, ClassVar, NamedTuple
from ._spatial_map import SpatialMap
from .geometry import Region, Size, Spacing from .geometry import Region, Size, Spacing
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -11,7 +13,55 @@ if TYPE_CHECKING:
from .widget import Widget from .widget import Widget
ArrangeResult: TypeAlias = "tuple[list[WidgetPlacement], set[Widget]]" ArrangeResult: TypeAlias = "tuple[list[WidgetPlacement], set[Widget]]"
DockArrangeResult: TypeAlias = "tuple[list[WidgetPlacement], set[Widget], Spacing]"
@dataclass
class DockArrangeResult:
placements: list[WidgetPlacement]
"""A `WidgetPlacement` for every widget to describe it's location on screen."""
widgets: set[Widget]
"""A set of widgets in the arrangement."""
spacing: Spacing
"""Shared spacing around the widgets."""
_spatial_map: SpatialMap[WidgetPlacement] | None = None
@property
def spatial_map(self) -> SpatialMap[WidgetPlacement]:
"""A lazy-calculated spatial map."""
if self._spatial_map is None:
self._spatial_map = SpatialMap()
self._spatial_map.insert(
(
placement.region.grow(placement.margin),
placement.fixed,
placement,
)
for placement in self.placements
)
return self._spatial_map
@property
def total_region(self) -> Region:
"""The total area occupied by the arrangement.
Returns:
A Region.
"""
return self.spatial_map.total_region
def get_visible_placements(self, region: Region) -> list[WidgetPlacement]:
"""Get the placements visible within the given region.
Args:
region: A region.
Returns:
Set of placements.
"""
visible_placements = self.spatial_map.get_values_in_region(region)
return visible_placements
class WidgetPlacement(NamedTuple): class WidgetPlacement(NamedTuple):
@@ -61,7 +111,7 @@ class Layout(ABC):
width = 0 width = 0
else: else:
# Use a size of 0, 0 to ignore relative sizes, since those are flexible anyway # Use a size of 0, 0 to ignore relative sizes, since those are flexible anyway
placements, _, _ = widget._arrange(Size(0, 0)) placements = widget._arrange(Size(0, 0)).placements
width = max( width = max(
[ [
placement.region.right + placement.margin.right placement.region.right + placement.margin.right
@@ -89,7 +139,7 @@ class Layout(ABC):
height = 0 height = 0
else: else:
# Use a height of zero to ignore relative heights # Use a height of zero to ignore relative heights
placements, _, _ = widget._arrange(Size(width, 0)) placements = widget._arrange(Size(width, 0)).placements
height = max( height = max(
[ [
placement.region.bottom + placement.margin.bottom placement.region.bottom + placement.margin.bottom

103
src/textual/_spatial_map.py Normal file
View File

@@ -0,0 +1,103 @@
from __future__ import annotations
from collections import defaultdict
from itertools import product
from typing import Generic, Iterable, TypeVar
from typing_extensions import TypeAlias
from .geometry import Region
ValueType = TypeVar("ValueType")
GridCoordinate: TypeAlias = "tuple[int, int]"
class SpatialMap(Generic[ValueType]):
"""A spatial map allows for data to be associated with rectangular regions
in Euclidean space, and efficiently queried.
When the SpatialMap is populated, a reference to each value is placed into one or
more buckets associated with a regular grid that covers 2D space.
The SpatialMap is able to quickly retrieve the values under a given "window" region
by combining the values in the grid squares under the visible area.
"""
def __init__(self, grid_width: int = 100, grid_height: int = 20) -> None:
"""Create a spatial map with the given grid size.
Args:
grid_width: Width of a grid square.
grid_height: Height of a grid square.
"""
self._grid_size = (grid_width, grid_height)
self.total_region = Region()
self._map: defaultdict[GridCoordinate, list[ValueType]] = defaultdict(list)
self._fixed: list[ValueType] = []
def _region_to_grid_coordinates(self, region: Region) -> Iterable[GridCoordinate]:
"""Get the grid squares under a region.
Args:
region: A region.
Returns:
Iterable of grid coordinates (tuple of 2 values).
"""
# (x1, y1) is the coordinate of the top left cell
# (x2, y2) is the coordinate of the bottom right cell
x1, y1, width, height = region
x2 = x1 + width - 1
y2 = y1 + height - 1
grid_width, grid_height = self._grid_size
return product(
range(x1 // grid_width, x2 // grid_width + 1),
range(y1 // grid_height, y2 // grid_height + 1),
)
def insert(
self, regions_and_values: Iterable[tuple[Region, bool, ValueType]]
) -> None:
"""Insert values into the Spatial map.
Values are associated with their region in Euclidean space, and a boolean that
indicates fixed regions. Fixed regions don't scroll and are always visible.
Args:
regions_and_values: An iterable of (REGION, FIXED, VALUE).
"""
append_fixed = self._fixed.append
get_grid_list = self._map.__getitem__
_region_to_grid = self._region_to_grid_coordinates
total_region = self.total_region
for region, fixed, value in regions_and_values:
total_region = total_region.union(region)
if fixed:
append_fixed(value)
else:
for grid in _region_to_grid(region):
get_grid_list(grid).append(value)
self.total_region = total_region
def get_values_in_region(self, region: Region) -> list[ValueType]:
"""Get a superset of all the values that intersect with a given region.
Note that this may return false positives.
Args:
region: A region.
Returns:
Values under the region.
"""
results: list[ValueType] = self._fixed.copy()
add_results = results.extend
get_grid_values = self._map.get
for grid_coordinate in self._region_to_grid_coordinates(region):
grid_values = get_grid_values(grid_coordinate)
if grid_values is not None:
add_results(grid_values)
unique_values = list(dict.fromkeys(results))
return unique_values

View File

@@ -11,7 +11,12 @@ import unicodedata
import warnings import warnings
from asyncio import Task from asyncio import Task
from concurrent.futures import Future from concurrent.futures import Future
from contextlib import asynccontextmanager, redirect_stderr, redirect_stdout from contextlib import (
asynccontextmanager,
contextmanager,
redirect_stderr,
redirect_stdout,
)
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from pathlib import Path, PurePath from pathlib import Path, PurePath
@@ -22,6 +27,7 @@ from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Generator,
Generic, Generic,
Iterable, Iterable,
List, List,
@@ -242,6 +248,11 @@ class App(Generic[ReturnType], DOMNode):
background: $background; background: $background;
color: $text; color: $text;
} }
*:disabled {
opacity: 0.6;
text-opacity: 0.8;
}
""" """
SCREENS: dict[str, Screen | Callable[[], Screen]] = {} SCREENS: dict[str, Screen | Callable[[], Screen]] = {}
@@ -415,6 +426,7 @@ class App(Generic[ReturnType], DOMNode):
self._screenshot: str | None = None self._screenshot: str | None = None
self._dom_lock = asyncio.Lock() self._dom_lock = asyncio.Lock()
self._dom_ready = False self._dom_ready = False
self._batch_count = 0
self.set_class(self.dark, "-dark-mode") self.set_class(self.dark, "-dark-mode")
@property @property
@@ -430,6 +442,30 @@ class App(Generic[ReturnType], DOMNode):
except ScreenError: except ScreenError:
return () return ()
@contextmanager
def batch_update(self) -> Generator[None, None, None]:
"""Suspend all repaints until the end of the batch."""
self._begin_batch()
try:
yield
finally:
self._end_batch()
def _begin_batch(self) -> None:
"""Begin a batch update."""
self._batch_count += 1
def _end_batch(self) -> None:
"""End a batch update."""
self._batch_count -= 1
assert self._batch_count >= 0, "This won't happen if you use `batch_update`"
if not self._batch_count:
try:
self.screen.check_idle()
except ScreenStackError:
pass
self.check_idle()
def animate( def animate(
self, self,
attribute: str, attribute: str,
@@ -1508,28 +1544,29 @@ class App(Generic[ReturnType], DOMNode):
if inspect.isawaitable(ready_result): if inspect.isawaitable(ready_result):
await ready_result await ready_result
try: with self.batch_update():
try: try:
await self._dispatch_message(events.Compose(sender=self)) try:
await self._dispatch_message(events.Mount(sender=self)) await self._dispatch_message(events.Compose(sender=self))
await self._dispatch_message(events.Mount(sender=self))
finally:
self._mounted_event.set()
Reactive._initialize_object(self)
self.stylesheet.update(self)
self.refresh()
await self.animator.start()
except Exception:
await self.animator.stop()
raise
finally: finally:
self._mounted_event.set() self._running = True
await self._ready()
Reactive._initialize_object(self) await invoke_ready_callback()
self.stylesheet.update(self)
self.refresh()
await self.animator.start()
except Exception:
await self.animator.stop()
raise
finally:
self._running = True
await self._ready()
await invoke_ready_callback()
try: try:
await self._process_messages_loop() await self._process_messages_loop()
@@ -1615,11 +1652,12 @@ class App(Generic[ReturnType], DOMNode):
raise TypeError( raise TypeError(
f"{self!r} compose() returned an invalid response; {error}" f"{self!r} compose() returned an invalid response; {error}"
) from error ) from error
await self.mount_all(widgets) await self.mount_all(widgets)
def _on_idle(self) -> None: def _on_idle(self) -> None:
"""Perform actions when there are no messages in the queue.""" """Perform actions when there are no messages in the queue."""
if self._require_stylesheet_update: if self._require_stylesheet_update and not self._batch_count:
nodes: set[DOMNode] = { nodes: set[DOMNode] = {
child child
for node in self._require_stylesheet_update for node in self._require_stylesheet_update
@@ -1782,6 +1820,7 @@ class App(Generic[ReturnType], DOMNode):
await child._close_messages() await child._close_messages()
async def _shutdown(self) -> None: async def _shutdown(self) -> None:
self._begin_update() # Prevents any layout / repaint while shutting down
driver = self._driver driver = self._driver
self._running = False self._running = False
if driver is not None: if driver is not None:
@@ -1799,6 +1838,7 @@ class App(Generic[ReturnType], DOMNode):
self._writer_thread.stop() self._writer_thread.stop()
async def _on_exit_app(self) -> None: async def _on_exit_app(self) -> None:
self._begin_batch() # Prevent repaint / layout while shutting down
await self._message_queue.put(None) await self._message_queue.put(None)
def refresh(self, *, repaint: bool = True, layout: bool = False) -> None: def refresh(self, *, repaint: bool = True, layout: bool = False) -> None:
@@ -1907,7 +1947,6 @@ class App(Generic[ReturnType], DOMNode):
# Handle input events that haven't been forwarded # Handle input events that haven't been forwarded
# If the event has been forwarded it may have bubbled up back to the App # If the event has been forwarded it may have bubbled up back to the App
if isinstance(event, events.Compose): if isinstance(event, events.Compose):
self.log(event)
screen = Screen(id="_default") screen = Screen(id="_default")
self._register(self, screen) self._register(self, screen)
self._screen_stack.append(screen) self._screen_stack.append(screen)

View File

@@ -68,7 +68,6 @@ class ColorsApp(App):
content.mount(ColorsView()) content.mount(ColorsView())
def on_button_pressed(self, event: Button.Pressed) -> None: def on_button_pressed(self, event: Button.Pressed) -> None:
self.bell()
self.query(ColorGroup).remove_class("-active") self.query(ColorGroup).remove_class("-active")
group = self.query_one(f"#group-{event.button.id}", ColorGroup) group = self.query_one(f"#group-{event.button.id}", ColorGroup)
group.add_class("-active") group.add_class("-active")

View File

@@ -75,7 +75,7 @@ expect_selector_continue = Expect(
selector_id=r"\#[a-zA-Z_\-][a-zA-Z0-9_\-]*", selector_id=r"\#[a-zA-Z_\-][a-zA-Z0-9_\-]*",
selector_class=r"\.[a-zA-Z_\-][a-zA-Z0-9_\-]*", selector_class=r"\.[a-zA-Z_\-][a-zA-Z0-9_\-]*",
selector_universal=r"\*", selector_universal=r"\*",
selector=r"[a-zA-Z_\-]+", selector=IDENTIFIER,
combinator_child=">", combinator_child=">",
new_selector=r",", new_selector=r",",
declaration_set_start=r"\{", declaration_set_start=r"\{",

View File

@@ -872,6 +872,16 @@ class DOMNode(MessagePump):
else: else:
self.remove_class(*class_names) self.remove_class(*class_names)
def _update_styles(self) -> None:
"""Request an update of this node's styles.
Should be called whenever CSS classes / pseudo classes change.
"""
try:
self.app.update_styles(self)
except NoActiveAppError:
pass
def add_class(self, *class_names: str) -> None: def add_class(self, *class_names: str) -> None:
"""Add class names to this Node. """Add class names to this Node.
@@ -884,10 +894,7 @@ class DOMNode(MessagePump):
self._classes.update(class_names) self._classes.update(class_names)
if old_classes == self._classes: if old_classes == self._classes:
return return
try: self._update_styles()
self.app.update_styles(self)
except NoActiveAppError:
pass
def remove_class(self, *class_names: str) -> None: def remove_class(self, *class_names: str) -> None:
"""Remove class names from this Node. """Remove class names from this Node.
@@ -900,10 +907,7 @@ class DOMNode(MessagePump):
self._classes.difference_update(class_names) self._classes.difference_update(class_names)
if old_classes == self._classes: if old_classes == self._classes:
return return
try: self._update_styles()
self.app.update_styles(self)
except NoActiveAppError:
pass
def toggle_class(self, *class_names: str) -> None: def toggle_class(self, *class_names: str) -> None:
"""Toggle class names on this Node. """Toggle class names on this Node.
@@ -916,10 +920,7 @@ class DOMNode(MessagePump):
self._classes.symmetric_difference_update(class_names) self._classes.symmetric_difference_update(class_names)
if old_classes == self._classes: if old_classes == self._classes:
return return
try: self._update_styles()
self.app.update_styles(self)
except NoActiveAppError:
pass
def has_pseudo_class(self, *class_names: str) -> bool: def has_pseudo_class(self, *class_names: str) -> bool:
"""Check for pseudo classes (such as hover, focus etc) """Check for pseudo classes (such as hover, focus etc)

View File

@@ -45,12 +45,24 @@ class Update(Message, verbose=True):
@rich.repr.auto @rich.repr.auto
class Layout(Message, verbose=True): class Layout(Message, verbose=True):
"""Sent by Textual when a layout is required."""
def can_replace(self, message: Message) -> bool: def can_replace(self, message: Message) -> bool:
return isinstance(message, Layout) return isinstance(message, Layout)
@rich.repr.auto
class UpdateScroll(Message, verbose=True):
"""Sent by Textual when a scroll update is required."""
def can_replace(self, message: Message) -> bool:
return isinstance(message, UpdateScroll)
@rich.repr.auto @rich.repr.auto
class InvokeLater(Message, verbose=True, bubble=False): class InvokeLater(Message, verbose=True, bubble=False):
"""Sent by Textual to invoke a callback."""
def __init__(self, sender: MessagePump, callback: CallbackType) -> None: def __init__(self, sender: MessagePump, callback: CallbackType) -> None:
self.callback = callback self.callback = callback
super().__init__(sender) super().__init__(sender)

View File

@@ -143,24 +143,25 @@ class Reactive(Generic[ReactiveType]):
self.name = name self.name = name
# The internal name where the attribute's value is stored # The internal name where the attribute's value is stored
self.internal_name = f"_reactive_{name}" self.internal_name = f"_reactive_{name}"
self.compute_name = f"compute_{name}"
default = self._default default = self._default
setattr(owner, f"_default_{name}", default) setattr(owner, f"_default_{name}", default)
def __get__(self, obj: Reactable, obj_type: type[object]) -> ReactiveType: def __get__(self, obj: Reactable, obj_type: type[object]) -> ReactiveType:
_rich_traceback_omit = True internal_name = self.internal_name
if not hasattr(obj, internal_name):
self._initialize_reactive(obj, self.name)
self._initialize_reactive(obj, self.name) if hasattr(obj, self.compute_name):
value: ReactiveType
value: ReactiveType old_value = getattr(obj, internal_name)
compute_method = getattr(self, f"compute_{self.name}", None) _rich_traceback_omit = True
if compute_method is not None: value = getattr(obj, self.compute_name)()
old_value = getattr(obj, self.internal_name) setattr(obj, internal_name, value)
value = getattr(obj, f"compute_{self.name}")()
setattr(obj, self.internal_name, value)
self._check_watchers(obj, self.name, old_value) self._check_watchers(obj, self.name, old_value)
return value
else: else:
value = getattr(obj, self.internal_name) return getattr(obj, internal_name)
return value
def __set__(self, obj: Reactable, value: ReactiveType) -> None: def __set__(self, obj: Reactable, value: ReactiveType) -> None:
_rich_traceback_omit = True _rich_traceback_omit = True

View File

@@ -80,16 +80,6 @@ class Screen(Widget):
) )
return self._update_timer return self._update_timer
@property
def widgets(self) -> list[Widget]:
"""Get all widgets."""
return list(self._compositor.map.keys())
@property
def visible_widgets(self) -> list[Widget]:
"""Get a list of visible widgets."""
return list(self._compositor.visible_widgets)
def render(self) -> RenderableType: def render(self) -> RenderableType:
background = self.styles.background background = self.styles.background
if background.is_transparent: if background.is_transparent:
@@ -159,11 +149,7 @@ class Screen(Widget):
@property @property
def focus_chain(self) -> list[Widget]: def focus_chain(self) -> list[Widget]:
"""Get widgets that may receive focus, in focus order. """A list of widgets that may receive focus, in focus order."""
Returns:
List of Widgets in focus order.
"""
widgets: list[Widget] = [] widgets: list[Widget] = []
add_widget = widgets.append add_widget = widgets.append
stack: list[Iterator[Widget]] = [iter(self.focusable_children)] stack: list[Iterator[Widget]] = [iter(self.focusable_children)]
@@ -177,7 +163,7 @@ class Screen(Widget):
else: else:
if node.is_container and node.can_focus_children: if node.is_container and node.can_focus_children:
push(iter(node.focusable_children)) push(iter(node.focusable_children))
if node.can_focus: if node.focusable:
add_widget(node) add_widget(node)
return widgets return widgets
@@ -314,7 +300,7 @@ class Screen(Widget):
# It may have been made invisible # It may have been made invisible
# Move to a sibling if possible # Move to a sibling if possible
for sibling in widget.visible_siblings: for sibling in widget.visible_siblings:
if sibling not in avoiding and sibling.can_focus: if sibling not in avoiding and sibling.focusable:
self.set_focus(sibling) self.set_focus(sibling)
break break
else: else:
@@ -351,7 +337,7 @@ class Screen(Widget):
self.focused.post_message_no_wait(events.Blur(self)) self.focused.post_message_no_wait(events.Blur(self))
self.focused = None self.focused = None
self.log.debug("focus was removed") self.log.debug("focus was removed")
elif widget.can_focus: elif widget.focusable:
if self.focused != widget: if self.focused != widget:
if self.focused is not None: if self.focused is not None:
# Blur currently focused widget # Blur currently focused widget
@@ -368,13 +354,18 @@ class Screen(Widget):
# Check for any widgets marked as 'dirty' (needs a repaint) # Check for any widgets marked as 'dirty' (needs a repaint)
event.prevent_default() event.prevent_default()
if self.is_current: if not self.app._batch_count and self.is_current:
async with self.app._dom_lock: async with self.app._dom_lock:
if self.is_current: if self.is_current:
if self._layout_required: if self._layout_required:
self._refresh_layout() self._refresh_layout()
self._layout_required = False self._layout_required = False
self._scroll_required = False
self._dirty_widgets.clear() self._dirty_widgets.clear()
elif self._scroll_required:
self._refresh_layout(scroll=True)
self._scroll_required = False
if self._repaint_required: if self._repaint_required:
self._dirty_widgets.clear() self._dirty_widgets.clear()
self._dirty_widgets.add(self) self._dirty_widgets.add(self)
@@ -423,7 +414,9 @@ class Screen(Widget):
self._callbacks.append(callback) self._callbacks.append(callback)
self.check_idle() self.check_idle()
def _refresh_layout(self, size: Size | None = None, full: bool = False) -> None: def _refresh_layout(
self, size: Size | None = None, full: bool = False, scroll: bool = False
) -> None:
"""Refresh the layout (can change size and positions of widgets).""" """Refresh the layout (can change size and positions of widgets)."""
size = self.outer_size if size is None else size size = self.outer_size if size is None else size
if not size: if not size:
@@ -431,35 +424,64 @@ class Screen(Widget):
self._compositor.update_widgets(self._dirty_widgets) self._compositor.update_widgets(self._dirty_widgets)
self.update_timer.pause() self.update_timer.pause()
ResizeEvent = events.Resize
try: try:
hidden, shown, resized = self._compositor.reflow(self, size) if scroll:
Hide = events.Hide exposed_widgets = self._compositor.reflow_visible(self, size)
Show = events.Show if exposed_widgets:
layers = self._compositor.layers
for widget in hidden: for widget, (
widget.post_message_no_wait(Hide(self)) region,
_order,
_clip,
virtual_size,
container_size,
_,
) in layers:
if widget in exposed_widgets:
if widget._size_updated(
region.size,
virtual_size,
container_size,
layout=False,
):
widget.post_message_no_wait(
ResizeEvent(
self,
region.size,
virtual_size,
container_size,
)
)
else:
hidden, shown, resized = self._compositor.reflow(self, size)
Hide = events.Hide
Show = events.Show
# We want to send a resize event to widgets that were just added or change since last layout for widget in hidden:
send_resize = shown | resized widget.post_message_no_wait(Hide(self))
ResizeEvent = events.Resize
layers = self._compositor.layers # We want to send a resize event to widgets that were just added or change since last layout
for widget, ( send_resize = shown | resized
region,
_order,
_clip,
virtual_size,
container_size,
_,
) in layers:
widget._size_updated(region.size, virtual_size, container_size)
if widget in send_resize:
widget.post_message_no_wait(
ResizeEvent(self, region.size, virtual_size, container_size)
)
for widget in shown: layers = self._compositor.layers
widget.post_message_no_wait(Show(self)) for widget, (
region,
_order,
_clip,
virtual_size,
container_size,
_,
) in layers:
widget._size_updated(region.size, virtual_size, container_size)
if widget in send_resize:
widget.post_message_no_wait(
ResizeEvent(self, region.size, virtual_size, container_size)
)
for widget in shown:
widget.post_message_no_wait(Show(self))
except Exception as error: except Exception as error:
self.app._handle_exception(error) self.app._handle_exception(error)
@@ -484,6 +506,12 @@ class Screen(Widget):
self._layout_required = True self._layout_required = True
self.check_idle() self.check_idle()
async def _on_update_scroll(self, message: messages.UpdateScroll) -> None:
message.stop()
message.prevent_default()
self._scroll_required = True
self.check_idle()
def _screen_resized(self, size: Size): def _screen_resized(self, size: Size):
"""Called by App when the screen is resized.""" """Called by App when the screen is resized."""
self._refresh_layout(size, full=True) self._refresh_layout(size, full=True)
@@ -547,7 +575,7 @@ class Screen(Widget):
except errors.NoWidget: except errors.NoWidget:
self.set_focus(None) self.set_focus(None)
else: else:
if isinstance(event, events.MouseUp) and widget.can_focus: if isinstance(event, events.MouseUp) and widget.focusable:
if self.focused is not widget: if self.focused is not widget:
self.set_focus(widget) self.set_focus(widget)
event.stop() event.stop()

View File

@@ -69,14 +69,18 @@ class ScrollView(Widget):
return self.virtual_size.height return self.virtual_size.height
def _size_updated( def _size_updated(
self, size: Size, virtual_size: Size, container_size: Size self, size: Size, virtual_size: Size, container_size: Size, layout: bool = True
) -> None: ) -> bool:
"""Called when size is updated. """Called when size is updated.
Args: Args:
size: New size. size: New size.
virtual_size: New virtual size. virtual_size: New virtual size.
container_size: New container size. container_size: New container size.
layout: Perform layout if required.
Returns:
True if anything changed, or False if nothing changed.
""" """
if self._size != size or container_size != container_size: if self._size != size or container_size != container_size:
self.refresh() self.refresh()
@@ -90,6 +94,9 @@ class ScrollView(Widget):
self._container_size = size - self.styles.gutter.totals self._container_size = size - self.styles.gutter.totals
self._scroll_update(virtual_size) self._scroll_update(virtual_size)
self.scroll_to(self.scroll_x, self.scroll_y, animate=False) self.scroll_to(self.scroll_x, self.scroll_y, animate=False)
return True
else:
return False
def render(self) -> RenderableType: def render(self) -> RenderableType:
"""Render the scrollable region (if `render_lines` is not implemented). """Render the scrollable region (if `render_lines` is not implemented).

View File

@@ -112,8 +112,14 @@ class ScrollBarRender:
if window_size and size and virtual_size and size != virtual_size: if window_size and size and virtual_size and size != virtual_size:
step_size = virtual_size / size step_size = virtual_size / size
thumb_size = window_size / step_size * len_bars
if thumb_size < len_bars:
virtual_size += step_size
step_size = virtual_size / size
start = int(position / step_size * len_bars) start = int(position / step_size * len_bars)
end = start + max(len_bars, int(ceil(window_size / step_size * len_bars))) end = start + max(len_bars, ceil(thumb_size))
start_index, start_bar = divmod(max(0, start), len_bars) start_index, start_bar = divmod(max(0, start), len_bars)
end_index, end_bar = divmod(max(0, end), len_bars) end_index, end_bar = divmod(max(0, end), len_bars)
@@ -246,6 +252,7 @@ class ScrollBar(Widget):
yield "thickness", self.thickness yield "thickness", self.thickness
def render(self) -> RenderableType: def render(self) -> RenderableType:
assert self.parent is not None
styles = self.parent.styles styles = self.parent.styles
if self.grabbed: if self.grabbed:
background = styles.scrollbar_background_active background = styles.scrollbar_background_active
@@ -258,11 +265,25 @@ class ScrollBar(Widget):
color = styles.scrollbar_color color = styles.scrollbar_color
color = background + color color = background + color
scrollbar_style = Style.from_color(color.rich_color, background.rich_color) scrollbar_style = Style.from_color(color.rich_color, background.rich_color)
return self._render_bar(scrollbar_style)
def _render_bar(self, scrollbar_style: Style) -> RenderableType:
"""Get a renderable for the scrollbar with given style.
Args:
scrollbar_style: Scrollbar style.
Returns:
Scrollbar renderable.
"""
window_size = (
self.window_size if self.window_size < self.window_virtual_size else 0
)
virtual_size = self.window_virtual_size
return self.renderer( return self.renderer(
virtual_size=self.window_virtual_size, virtual_size=ceil(virtual_size),
window_size=( window_size=ceil(window_size),
self.window_size if self.window_size < self.window_virtual_size else 0
),
position=self.position, position=self.position,
thickness=self.thickness, thickness=self.thickness,
vertical=self.vertical, vertical=self.vertical,
@@ -311,19 +332,31 @@ class ScrollBar(Widget):
x: float | None = None x: float | None = None
y: float | None = None y: float | None = None
if self.vertical: if self.vertical:
size = self.size.height
virtual_size = self.window_virtual_size
step_size = virtual_size / size
thumb_size = self.window_size / step_size
if thumb_size < 1:
virtual_size = ceil(virtual_size + step_size)
y = round( y = round(
self.grabbed_position self.grabbed_position
+ ( + (
(event.screen_y - self.grabbed.y) (event.screen_y - self.grabbed.y)
* (self.window_virtual_size / self.window_size) * (virtual_size / self.window_size)
) )
) )
else: else:
size = self.size.width
virtual_size = self.window_virtual_size
step_size = virtual_size / size
thumb_size = self.window_size / step_size
if thumb_size < 1:
virtual_size = ceil(virtual_size + step_size)
x = round( x = round(
self.grabbed_position self.grabbed_position
+ ( + (
(event.screen_x - self.grabbed.x) (event.screen_x - self.grabbed.x)
* (self.window_virtual_size / self.window_size) * (virtual_size / self.window_size)
) )
) )
await self.post_message(ScrollTo(self, x=x, y=y)) await self.post_message(ScrollTo(self, x=x, y=y))

View File

@@ -41,6 +41,7 @@ from ._animator import DEFAULT_EASING, Animatable, BoundAnimator, EasingFunction
from ._arrange import DockArrangeResult, arrange from ._arrange import DockArrangeResult, arrange
from ._asyncio import create_task from ._asyncio import create_task
from ._compose import compose from ._compose import compose
from ._cache import FIFOCache
from ._context import active_app from ._context import active_app
from ._easing import DEFAULT_SCROLL_EASING from ._easing import DEFAULT_SCROLL_EASING
from ._layout import Layout from ._layout import Layout
@@ -228,6 +229,8 @@ class Widget(DOMNode):
"""Rich renderable may shrink.""" """Rich renderable may shrink."""
auto_links = Reactive(True) auto_links = Reactive(True)
"""Widget will highlight links automatically.""" """Widget will highlight links automatically."""
disabled = Reactive(False)
"""The disabled state of the widget. `True` if disabled, `False` if not."""
hover_style: Reactive[Style] = Reactive(Style, repaint=False) hover_style: Reactive[Style] = Reactive(Style, repaint=False)
highlight_link_id: Reactive[str] = Reactive("") highlight_link_id: Reactive[str] = Reactive("")
@@ -238,11 +241,13 @@ class Widget(DOMNode):
name: str | None = None, name: str | None = None,
id: str | None = None, id: str | None = None,
classes: str | None = None, classes: str | None = None,
disabled: bool = False,
) -> None: ) -> None:
self._size = Size(0, 0) self._size = Size(0, 0)
self._container_size = Size(0, 0) self._container_size = Size(0, 0)
self._layout_required = False self._layout_required = False
self._repaint_required = False self._repaint_required = False
self._scroll_required = False
self._default_layout = VerticalLayout() self._default_layout = VerticalLayout()
self._animate: BoundAnimator | None = None self._animate: BoundAnimator | None = None
self.highlight_style: Style | None = None self.highlight_style: Style | None = None
@@ -262,8 +267,9 @@ class Widget(DOMNode):
self._content_width_cache: tuple[object, int] = (None, 0) self._content_width_cache: tuple[object, int] = (None, 0)
self._content_height_cache: tuple[object, int] = (None, 0) self._content_height_cache: tuple[object, int] = (None, 0)
self._arrangement_cache_key: tuple[Size, int] = (Size(), -1) self._arrangement_cache: FIFOCache[
self._cached_arrangement: DockArrangeResult | None = None tuple[Size, int], DockArrangeResult
] = FIFOCache(4)
self._styles_cache = StylesCache() self._styles_cache = StylesCache()
self._rich_style_cache: dict[str, tuple[Style, Style]] = {} self._rich_style_cache: dict[str, tuple[Style, Style]] = {}
@@ -280,6 +286,7 @@ class Widget(DOMNode):
raise WidgetError("A widget can't be its own parent") raise WidgetError("A widget can't be its own parent")
self._add_children(*children) self._add_children(*children)
self.disabled = disabled
virtual_size = Reactive(Size(0, 0), layout=True) virtual_size = Reactive(Size(0, 0), layout=True)
auto_width = Reactive(True) auto_width = Reactive(True)
@@ -495,14 +502,11 @@ class Widget(DOMNode):
assert self.is_container assert self.is_container
cache_key = (size, self._nodes._updates) cache_key = (size, self._nodes._updates)
if ( cached_result = self._arrangement_cache.get(cache_key)
self._arrangement_cache_key == cache_key if cached_result is not None:
and self._cached_arrangement is not None return cached_result
):
return self._cached_arrangement
self._arrangement_cache_key = cache_key arrangement = self._arrangement_cache[cache_key] = arrange(
arrangement = self._cached_arrangement = arrange(
self, self._nodes, size, self.screen.size self, self._nodes, size, self.screen.size
) )
@@ -510,7 +514,7 @@ class Widget(DOMNode):
def _clear_arrangement_cache(self) -> None: def _clear_arrangement_cache(self) -> None:
"""Clear arrangement cache, forcing a new arrange operation.""" """Clear arrangement cache, forcing a new arrange operation."""
self._cached_arrangement = None self._arrangement_cache.clear()
def _get_virtual_dom(self) -> Iterable[Widget]: def _get_virtual_dom(self) -> Iterable[Widget]:
"""Get widgets not part of the DOM. """Get widgets not part of the DOM.
@@ -1195,6 +1199,20 @@ class Widget(DOMNode):
""" """
return self.virtual_region.grow(self.styles.margin) return self.virtual_region.grow(self.styles.margin)
@property
def _self_or_ancestors_disabled(self) -> bool:
"""Is this widget or any of its ancestors disabled?"""
return any(
node.disabled
for node in self.ancestors_with_self
if isinstance(node, Widget)
)
@property
def focusable(self) -> bool:
"""Can this widget currently receive focus?"""
return self.can_focus and not self._self_or_ancestors_disabled
@property @property
def focusable_children(self) -> list[Widget]: def focusable_children(self) -> list[Widget]:
"""Get the children which may be focused. """Get the children which may be focused.
@@ -1732,7 +1750,7 @@ class Widget(DOMNode):
""" """
return self.scroll_to( return self.scroll_to(
y=self.scroll_target_y - self.container_size.height, y=self.scroll_y - self.container_size.height,
animate=animate, animate=animate,
speed=speed, speed=speed,
duration=duration, duration=duration,
@@ -1764,7 +1782,7 @@ class Widget(DOMNode):
""" """
return self.scroll_to( return self.scroll_to(
y=self.scroll_target_y + self.container_size.height, y=self.scroll_y + self.container_size.height,
animate=animate, animate=animate,
speed=speed, speed=speed,
duration=duration, duration=duration,
@@ -1798,7 +1816,7 @@ class Widget(DOMNode):
if speed is None and duration is None: if speed is None and duration is None:
duration = 0.3 duration = 0.3
return self.scroll_to( return self.scroll_to(
x=self.scroll_target_x - self.container_size.width, x=self.scroll_x - self.container_size.width,
animate=animate, animate=animate,
speed=speed, speed=speed,
duration=duration, duration=duration,
@@ -1832,7 +1850,7 @@ class Widget(DOMNode):
if speed is None and duration is None: if speed is None and duration is None:
duration = 0.3 duration = 0.3
return self.scroll_to( return self.scroll_to(
x=self.scroll_target_x + self.container_size.width, x=self.scroll_x + self.container_size.width,
animate=animate, animate=animate,
speed=speed, speed=speed,
duration=duration, duration=duration,
@@ -2102,6 +2120,14 @@ class Widget(DOMNode):
Names of the pseudo classes. Names of the pseudo classes.
""" """
node = self
while isinstance(node, Widget):
if node.disabled:
yield "disabled"
break
node = node._parent
else:
yield "enabled"
if self.mouse_over: if self.mouse_over:
yield "hover" yield "hover"
if self.has_focus: if self.has_focus:
@@ -2149,21 +2175,29 @@ class Widget(DOMNode):
def watch_mouse_over(self, value: bool) -> None: def watch_mouse_over(self, value: bool) -> None:
"""Update from CSS if mouse over state changes.""" """Update from CSS if mouse over state changes."""
if self._has_hover_style: if self._has_hover_style:
self.app.update_styles(self) self._update_styles()
def watch_has_focus(self, value: bool) -> None: def watch_has_focus(self, value: bool) -> None:
"""Update from CSS if has focus state changes.""" """Update from CSS if has focus state changes."""
self.app.update_styles(self) self._update_styles()
def watch_disabled(self) -> None:
"""Update the styles of the widget and its children when disabled is toggled."""
self._update_styles()
def _size_updated( def _size_updated(
self, size: Size, virtual_size: Size, container_size: Size self, size: Size, virtual_size: Size, container_size: Size, layout: bool = True
) -> None: ) -> bool:
"""Called when the widget's size is updated. """Called when the widget's size is updated.
Args: Args:
size: Screen size. size: Screen size.
virtual_size: Virtual (scrollable) size. virtual_size: Virtual (scrollable) size.
container_size: Container size (size of parent). container_size: Container size (size of parent).
layout: Perform layout if required.
Returns:
True if anything changed, or False if nothing changed.
""" """
if ( if (
self._size != size self._size != size
@@ -2171,11 +2205,16 @@ class Widget(DOMNode):
or self._container_size != container_size or self._container_size != container_size
): ):
self._size = size self._size = size
self.virtual_size = virtual_size if layout:
self.virtual_size = virtual_size
else:
self._reactive_virtual_size = virtual_size
self._container_size = container_size self._container_size = container_size
if self.is_scrollable: if self.is_scrollable:
self._scroll_update(virtual_size) self._scroll_update(virtual_size)
self.refresh() return True
else:
return False
def _scroll_update(self, virtual_size: Size) -> None: def _scroll_update(self, virtual_size: Size) -> None:
"""Update scrollbars visibility and dimensions. """Update scrollbars visibility and dimensions.
@@ -2286,7 +2325,7 @@ class Widget(DOMNode):
def _refresh_scroll(self) -> None: def _refresh_scroll(self) -> None:
"""Refreshes the scroll position.""" """Refreshes the scroll position."""
self._layout_required = True self._scroll_required = True
self.check_idle() self.check_idle()
def refresh( def refresh(
@@ -2313,8 +2352,7 @@ class Widget(DOMNode):
repaint: Repaint the widget (will call render() again). Defaults to True. repaint: Repaint the widget (will call render() again). Defaults to True.
layout: Also layout widgets in the view. Defaults to False. layout: Also layout widgets in the view. Defaults to False.
""" """
if layout and not self._layout_required:
if layout:
self._layout_required = True self._layout_required = True
for ancestor in self.ancestors: for ancestor in self.ancestors:
if not isinstance(ancestor, Widget): if not isinstance(ancestor, Widget):
@@ -2395,6 +2433,9 @@ class Widget(DOMNode):
except NoScreen: except NoScreen:
pass pass
else: else:
if self._scroll_required:
self._scroll_required = False
screen.post_message_no_wait(messages.UpdateScroll(self))
if self._repaint_required: if self._repaint_required:
self._repaint_required = False self._repaint_required = False
screen.post_message_no_wait(messages.Update(self, self)) screen.post_message_no_wait(messages.Update(self, self))
@@ -2443,6 +2484,18 @@ class Widget(DOMNode):
""" """
self.app.capture_mouse(None) self.app.capture_mouse(None)
def check_message_enabled(self, message: Message) -> bool:
# Do the normal checking and get out if that fails.
if not super().check_message_enabled(message):
return False
# Otherwise, if this is a mouse event, the widget receiving the
# event must not be disabled at this moment.
return (
not self._self_or_ancestors_disabled
if isinstance(message, (events.MouseEvent, events.Enter, events.Leave))
else True
)
async def broker_event(self, event_name: str, event: events.Event) -> bool: async def broker_event(self, event_name: str, event: events.Event) -> bool:
return await self.app._broker_event(event_name, event, default_namespace=self) return await self.app._broker_event(event_name, event, default_namespace=self)
@@ -2501,11 +2554,11 @@ class Widget(DOMNode):
def _on_descendant_blur(self, event: events.DescendantBlur) -> None: def _on_descendant_blur(self, event: events.DescendantBlur) -> None:
if self._has_focus_within: if self._has_focus_within:
self.app.update_styles(self) self._update_styles()
def _on_descendant_focus(self, event: events.DescendantBlur) -> None: def _on_descendant_focus(self, event: events.DescendantBlur) -> None:
if self._has_focus_within: if self._has_focus_within:
self.app.update_styles(self) self._update_styles()
def _on_mouse_scroll_down(self, event: events.MouseScrollDown) -> None: def _on_mouse_scroll_down(self, event: events.MouseScrollDown) -> None:
if event.ctrl or event.shift: if event.ctrl or event.shift:

View File

@@ -39,11 +39,6 @@ class Button(Static, can_focus=True):
text-style: bold; text-style: bold;
} }
Button.-disabled {
opacity: 0.4;
text-opacity: 0.7;
}
Button:focus { Button:focus {
text-style: bold reverse; text-style: bold reverse;
} }
@@ -156,9 +151,6 @@ class Button(Static, can_focus=True):
variant = reactive("default") variant = reactive("default")
"""The variant name for the button.""" """The variant name for the button."""
disabled = reactive(False)
"""The disabled state of the button; `True` if disabled, `False` if not."""
class Pressed(Message, bubble=True): class Pressed(Message, bubble=True):
"""Event sent when a `Button` is pressed. """Event sent when a `Button` is pressed.
@@ -176,45 +168,35 @@ class Button(Static, can_focus=True):
def __init__( def __init__(
self, self,
label: TextType | None = None, label: TextType | None = None,
disabled: bool = False,
variant: ButtonVariant = "default", variant: ButtonVariant = "default",
*, *,
name: str | None = None, name: str | None = None,
id: str | None = None, id: str | None = None,
classes: str | None = None, classes: str | None = None,
disabled: bool = False,
): ):
"""Create a Button widget. """Create a Button widget.
Args: Args:
label: The text that appears within the button. label: The text that appears within the button.
disabled: Whether the button is disabled or not.
variant: The variant of the button. variant: The variant of the button.
name: The name of the button. name: The name of the button.
id: The ID of the button in the DOM. id: The ID of the button in the DOM.
classes: The CSS classes of the button. classes: The CSS classes of the button.
disabled: Whether the button is disabled or not.
""" """
super().__init__(name=name, id=id, classes=classes) super().__init__(name=name, id=id, classes=classes, disabled=disabled)
if label is None: if label is None:
label = self.css_identifier_styled label = self.css_identifier_styled
self.label = self.validate_label(label) self.label = self.validate_label(label)
self.disabled = disabled
if disabled:
self.add_class("-disabled")
self.variant = self.validate_variant(variant) self.variant = self.validate_variant(variant)
def __rich_repr__(self) -> rich.repr.Result: def __rich_repr__(self) -> rich.repr.Result:
yield from super().__rich_repr__() yield from super().__rich_repr__()
yield "variant", self.variant, "default" yield "variant", self.variant, "default"
yield "disabled", self.disabled, False
def watch_mouse_over(self, value: bool) -> None:
"""Update from CSS if mouse over state changes."""
if self._has_hover_style and not self.disabled:
self.app.update_styles(self)
def validate_variant(self, variant: str) -> str: def validate_variant(self, variant: str) -> str:
if variant not in _VALID_BUTTON_VARIANTS: if variant not in _VALID_BUTTON_VARIANTS:
@@ -227,10 +209,6 @@ class Button(Static, can_focus=True):
self.remove_class(f"-{old_variant}") self.remove_class(f"-{old_variant}")
self.add_class(f"-{variant}") self.add_class(f"-{variant}")
def watch_disabled(self, disabled: bool) -> None:
self.set_class(disabled, "-disabled")
self.can_focus = not disabled
def validate_label(self, label: RenderableType) -> RenderableType: def validate_label(self, label: RenderableType) -> RenderableType:
"""Parse markup for self.label""" """Parse markup for self.label"""
if isinstance(label, str): if isinstance(label, str):
@@ -272,11 +250,11 @@ class Button(Static, can_focus=True):
def success( def success(
cls, cls,
label: TextType | None = None, label: TextType | None = None,
disabled: bool = False,
*, *,
name: str | None = None, name: str | None = None,
id: str | None = None, id: str | None = None,
classes: str | None = None, classes: str | None = None,
disabled: bool = False,
) -> Button: ) -> Button:
"""Utility constructor for creating a success Button variant. """Utility constructor for creating a success Button variant.
@@ -292,22 +270,22 @@ class Button(Static, can_focus=True):
""" """
return Button( return Button(
label=label, label=label,
disabled=disabled,
variant="success", variant="success",
name=name, name=name,
id=id, id=id,
classes=classes, classes=classes,
disabled=disabled,
) )
@classmethod @classmethod
def warning( def warning(
cls, cls,
label: TextType | None = None, label: TextType | None = None,
disabled: bool = False,
*, *,
name: str | None = None, name: str | None = None,
id: str | None = None, id: str | None = None,
classes: str | None = None, classes: str | None = None,
disabled: bool = False,
) -> Button: ) -> Button:
"""Utility constructor for creating a warning Button variant. """Utility constructor for creating a warning Button variant.
@@ -323,22 +301,22 @@ class Button(Static, can_focus=True):
""" """
return Button( return Button(
label=label, label=label,
disabled=disabled,
variant="warning", variant="warning",
name=name, name=name,
id=id, id=id,
classes=classes, classes=classes,
disabled=disabled,
) )
@classmethod @classmethod
def error( def error(
cls, cls,
label: TextType | None = None, label: TextType | None = None,
disabled: bool = False,
*, *,
name: str | None = None, name: str | None = None,
id: str | None = None, id: str | None = None,
classes: str | None = None, classes: str | None = None,
disabled: bool = False,
) -> Button: ) -> Button:
"""Utility constructor for creating an error Button variant. """Utility constructor for creating an error Button variant.
@@ -354,9 +332,9 @@ class Button(Static, can_focus=True):
""" """
return Button( return Button(
label=label, label=label,
disabled=disabled,
variant="error", variant="error",
name=name, name=name,
id=id, id=id,
classes=classes, classes=classes,
disabled=disabled,
) )

View File

@@ -473,8 +473,9 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
name: str | None = None, name: str | None = None,
id: str | None = None, id: str | None = None,
classes: str | None = None, classes: str | None = None,
disabled: bool = False,
) -> None: ) -> None:
super().__init__(name=name, id=id, classes=classes) super().__init__(name=name, id=id, classes=classes, disabled=disabled)
self._data: dict[RowKey, dict[ColumnKey, CellType]] = {} self._data: dict[RowKey, dict[ColumnKey, CellType]] = {}
"""Contains the cells of the table, indexed by row key and column key. """Contains the cells of the table, indexed by row key and column key.
The final positioning of a cell on screen cannot be determined solely by this The final positioning of a cell on screen cannot be determined solely by this

View File

@@ -29,6 +29,7 @@ class DirectoryTree(Tree[DirEntry]):
name: The name of the widget, or None for no name. Defaults to None. name: The name of the widget, or None for no name. Defaults to None.
id: The ID of the widget in the DOM, or None for no ID. Defaults to None. id: The ID of the widget in the DOM, or None for no ID. Defaults to None.
classes: A space-separated list of classes, or None for no classes. Defaults to None. classes: A space-separated list of classes, or None for no classes. Defaults to None.
disabled: Whether the directory tree is disabled or not.
""" """
COMPONENT_CLASSES: ClassVar[set[str]] = { COMPONENT_CLASSES: ClassVar[set[str]] = {
@@ -87,6 +88,7 @@ class DirectoryTree(Tree[DirEntry]):
name: str | None = None, name: str | None = None,
id: str | None = None, id: str | None = None,
classes: str | None = None, classes: str | None = None,
disabled: bool = False,
) -> None: ) -> None:
self.path = path self.path = path
super().__init__( super().__init__(
@@ -95,6 +97,7 @@ class DirectoryTree(Tree[DirEntry]):
name=name, name=name,
id=id, id=id,
classes=classes, classes=classes,
disabled=disabled,
) )
def process_label(self, label: TextType): def process_label(self, label: TextType):

View File

@@ -110,9 +110,6 @@ class Input(Widget, can_focus=True):
height: 1; height: 1;
min-height: 1; min-height: 1;
} }
Input.-disabled {
opacity: 0.6;
}
Input:focus { Input:focus {
border: tall $accent; border: tall $accent;
} }
@@ -179,6 +176,7 @@ class Input(Widget, can_focus=True):
name: str | None = None, name: str | None = None,
id: str | None = None, id: str | None = None,
classes: str | None = None, classes: str | None = None,
disabled: bool = False,
) -> None: ) -> None:
"""Initialise the `Input` widget. """Initialise the `Input` widget.
@@ -190,8 +188,9 @@ class Input(Widget, can_focus=True):
name: Optional name for the input widget. name: Optional name for the input widget.
id: Optional ID for the widget. id: Optional ID for the widget.
classes: Optional initial classes for the widget. classes: Optional initial classes for the widget.
disabled: Whether the input is disabled or not.
""" """
super().__init__(name=name, id=id, classes=classes) super().__init__(name=name, id=id, classes=classes, disabled=disabled)
if value is not None: if value is not None:
self.value = value self.value = value
self.placeholder = placeholder self.placeholder = placeholder

View File

@@ -73,6 +73,7 @@ class ListView(Vertical, can_focus=True, can_focus_children=False):
name: str | None = None, name: str | None = None,
id: str | None = None, id: str | None = None,
classes: str | None = None, classes: str | None = None,
disabled: bool = False,
) -> None: ) -> None:
""" """
Args: Args:
@@ -81,8 +82,11 @@ class ListView(Vertical, can_focus=True, can_focus_children=False):
name: The name of the widget. name: The name of the widget.
id: The unique ID of the widget used in CSS/query selection. id: The unique ID of the widget used in CSS/query selection.
classes: The CSS classes of the widget. classes: The CSS classes of the widget.
disabled: Whether the ListView is disabled or not.
""" """
super().__init__(*children, name=name, id=id, classes=classes) super().__init__(
*children, name=name, id=id, classes=classes, disabled=disabled
)
self._index = initial_index self._index = initial_index
def on_mount(self) -> None: def on_mount(self) -> None:

View File

@@ -10,7 +10,7 @@ from rich.text import Text
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from ..app import ComposeResult from ..app import ComposeResult
from ..containers import Vertical from ..containers import Horizontal, Vertical
from ..message import Message from ..message import Message
from ..reactive import reactive, var from ..reactive import reactive, var
from ..widget import Widget from ..widget import Widget
@@ -198,6 +198,19 @@ class MarkdownH6(MarkdownHeader):
""" """
class MarkdownHorizontalRule(MarkdownBlock):
"""A horizontal rule."""
DEFAULT_CSS = """
MarkdownHorizontalRule {
border-bottom: heavy $primary;
height: 1;
padding-top: 1;
margin-bottom: 1;
}
"""
class MarkdownParagraph(MarkdownBlock): class MarkdownParagraph(MarkdownBlock):
"""A paragraph Markdown block.""" """A paragraph Markdown block."""
@@ -225,37 +238,83 @@ class MarkdownBlockQuote(MarkdownBlock):
""" """
class MarkdownBulletList(MarkdownBlock): class MarkdownList(MarkdownBlock):
DEFAULT_CSS = """
MarkdownList {
width: 1fr;
}
MarkdownList MarkdownList {
margin: 0;
padding-top: 0;
}
"""
class MarkdownBulletList(MarkdownList):
"""A Bullet list Markdown block.""" """A Bullet list Markdown block."""
DEFAULT_CSS = """ DEFAULT_CSS = """
MarkdownBulletList { MarkdownBulletList {
margin: 0; margin: 0 0 1 0;
padding: 0 0; padding: 0 0;
} }
MarkdownBulletList MarkdownBulletList { MarkdownBulletList Horizontal {
margin: 0; height: auto;
padding-top: 0; width: 1fr;
}
MarkdownBulletList Vertical {
height: auto;
width: 1fr;
} }
""" """
def compose(self) -> ComposeResult:
for block in self._blocks:
if isinstance(block, MarkdownListItem):
bullet = MarkdownBullet()
bullet.symbol = block.bullet
yield Horizontal(bullet, Vertical(*block._blocks))
self._blocks.clear()
class MarkdownOrderedList(MarkdownBlock):
class MarkdownOrderedList(MarkdownList):
"""An ordered list Markdown block.""" """An ordered list Markdown block."""
DEFAULT_CSS = """ DEFAULT_CSS = """
MarkdownOrderedList { MarkdownOrderedList {
margin: 0; margin: 0 0 1 0;
padding: 0 0; padding: 0 0;
} }
Markdown OrderedList MarkdownOrderedList { MarkdownOrderedList Horizontal {
margin: 0; height: auto;
padding-top: 0; width: 1fr;
}
MarkdownOrderedList Vertical {
height: auto;
width: 1fr;
} }
""" """
def compose(self) -> ComposeResult:
symbol_size = max(
len(block.bullet)
for block in self._blocks
if isinstance(block, MarkdownListItem)
)
for block in self._blocks:
if isinstance(block, MarkdownListItem):
bullet = MarkdownBullet()
bullet.symbol = block.bullet.rjust(symbol_size + 1)
yield Horizontal(bullet, Vertical(*block._blocks))
self._blocks.clear()
class MarkdownTable(MarkdownBlock): class MarkdownTable(MarkdownBlock):
"""A Table markdown Block.""" """A Table markdown Block."""
@@ -329,10 +388,12 @@ class MarkdownBullet(Widget):
DEFAULT_CSS = """ DEFAULT_CSS = """
MarkdownBullet { MarkdownBullet {
width: auto; width: auto;
color: $success;
text-style: bold;
} }
""" """
symbol = reactive("●​ ") symbol = reactive("●​")
"""The symbol for the bullet.""" """The symbol for the bullet."""
def render(self) -> Text: def render(self) -> Text:
@@ -359,13 +420,13 @@ class MarkdownListItem(MarkdownBlock):
self.bullet = bullet self.bullet = bullet
super().__init__() super().__init__()
def compose(self) -> ComposeResult:
bullet = MarkdownBullet()
bullet.symbol = self.bullet
yield bullet
yield Vertical(*self._blocks)
self._blocks.clear() class MarkdownOrderedListItem(MarkdownListItem):
pass
class MarkdownUnorderedListItem(MarkdownListItem):
pass
class MarkdownFence(MarkdownBlock): class MarkdownFence(MarkdownBlock):
@@ -439,6 +500,8 @@ class Markdown(Widget):
""" """
COMPONENT_CLASSES = {"em", "strong", "s", "code_inline"} COMPONENT_CLASSES = {"em", "strong", "s", "code_inline"}
BULLETS = ["", "", "", "", ""]
def __init__( def __init__(
self, self,
markdown: str | None = None, markdown: str | None = None,
@@ -501,7 +564,7 @@ class Markdown(Widget):
markdown = path.read_text(encoding="utf-8") markdown = path.read_text(encoding="utf-8")
except Exception: except Exception:
return False return False
await self.query("MarkdownBlock").remove()
await self.update(markdown) await self.update(markdown)
return True return True
@@ -524,6 +587,8 @@ class Markdown(Widget):
if token.type == "heading_open": if token.type == "heading_open":
block_id += 1 block_id += 1
stack.append(HEADINGS[token.tag](id=f"block{block_id}")) stack.append(HEADINGS[token.tag](id=f"block{block_id}"))
elif token.type == "hr":
output.append(MarkdownHorizontalRule())
elif token.type == "paragraph_open": elif token.type == "paragraph_open":
stack.append(MarkdownParagraph()) stack.append(MarkdownParagraph())
elif token.type == "blockquote_open": elif token.type == "blockquote_open":
@@ -533,9 +598,20 @@ class Markdown(Widget):
elif token.type == "ordered_list_open": elif token.type == "ordered_list_open":
stack.append(MarkdownOrderedList()) stack.append(MarkdownOrderedList())
elif token.type == "list_item_open": elif token.type == "list_item_open":
stack.append( if token.info:
MarkdownListItem(f"{token.info}. " if token.info else "") stack.append(MarkdownOrderedListItem(f"{token.info}. "))
) else:
item_count = sum(
1
for block in stack
if isinstance(block, MarkdownUnorderedListItem)
)
stack.append(
MarkdownUnorderedListItem(
self.BULLETS[item_count % len(self.BULLETS)]
)
)
elif token.type == "table_open": elif token.type == "table_open":
stack.append(MarkdownTable()) stack.append(MarkdownTable())
elif token.type == "tbody_open": elif token.type == "tbody_open":
@@ -565,6 +641,8 @@ class Markdown(Widget):
for child in token.children: for child in token.children:
if child.type == "text": if child.type == "text":
content.append(child.content, style_stack[-1]) content.append(child.content, style_stack[-1])
if child.type == "softbreak":
content.append(" ")
elif child.type == "code_inline": elif child.type == "code_inline":
content.append( content.append(
child.content, child.content,
@@ -627,7 +705,10 @@ class Markdown(Widget):
await self.post_message( await self.post_message(
Markdown.TableOfContentsUpdated(table_of_contents, sender=self) Markdown.TableOfContentsUpdated(table_of_contents, sender=self)
) )
await self.mount(*output) with self.app.batch_update():
await self.query("MarkdownBlock").remove()
await self.mount(*output)
self.refresh(layout=True)
class MarkdownTableOfContents(Widget, can_focus_children=True): class MarkdownTableOfContents(Widget, can_focus_children=True):

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from itertools import cycle from itertools import cycle
from rich.console import RenderableType
from typing_extensions import Literal from typing_extensions import Literal
from .. import events from .. import events
@@ -61,10 +62,10 @@ class Placeholder(Widget):
overflow: hidden; overflow: hidden;
color: $text; color: $text;
} }
Placeholder.-text { Placeholder.-text {
padding: 1; padding: 1;
} }
""" """
# Consecutive placeholders get assigned consecutive colors. # Consecutive placeholders get assigned consecutive colors.
@@ -73,7 +74,7 @@ class Placeholder(Widget):
variant: Reactive[PlaceholderVariant] = reactive("default") variant: Reactive[PlaceholderVariant] = reactive("default")
_renderables: dict[PlaceholderVariant, RenderResult] _renderables: dict[PlaceholderVariant, str]
@classmethod @classmethod
def reset_color_cycle(cls) -> None: def reset_color_cycle(cls) -> None:
@@ -119,7 +120,7 @@ class Placeholder(Widget):
while next(self._variants_cycle) != self.variant: while next(self._variants_cycle) != self.variant:
pass pass
def render(self) -> RenderResult: def render(self) -> RenderableType:
return self._renderables[self.variant] return self._renderables[self.variant]
def cycle_variant(self) -> None: def cycle_variant(self) -> None:
@@ -147,6 +148,6 @@ class Placeholder(Widget):
def on_resize(self, event: events.Resize) -> None: def on_resize(self, event: events.Resize) -> None:
"""Update the placeholder "size" variant with the new placeholder size.""" """Update the placeholder "size" variant with the new placeholder size."""
self._renderables["size"] = self._SIZE_RENDER_TEMPLATE.format(*self.size) self._renderables["size"] = self._SIZE_RENDER_TEMPLATE.format(*event.size)
if self.variant == "size": if self.variant == "size":
self.refresh(layout=True) self.refresh(layout=False)

View File

@@ -36,6 +36,7 @@ class Static(Widget, inherit_bindings=False):
name: Name of widget. Defaults to None. name: Name of widget. Defaults to None.
id: ID of Widget. Defaults to None. id: ID of Widget. Defaults to None.
classes: Space separated list of class names. Defaults to None. classes: Space separated list of class names. Defaults to None.
disabled: Whether the static is disabled or not.
""" """
DEFAULT_CSS = """ DEFAULT_CSS = """
@@ -56,8 +57,9 @@ class Static(Widget, inherit_bindings=False):
name: str | None = None, name: str | None = None,
id: str | None = None, id: str | None = None,
classes: str | None = None, classes: str | None = None,
disabled: bool = False,
) -> None: ) -> None:
super().__init__(name=name, id=id, classes=classes) super().__init__(name=name, id=id, classes=classes, disabled=disabled)
self.expand = expand self.expand = expand
self.shrink = shrink self.shrink = shrink
self.markup = markup self.markup = markup

View File

@@ -100,6 +100,7 @@ class Switch(Widget, can_focus=True):
name: str | None = None, name: str | None = None,
id: str | None = None, id: str | None = None,
classes: str | None = None, classes: str | None = None,
disabled: bool = False,
): ):
"""Initialise the switch. """Initialise the switch.
@@ -109,8 +110,9 @@ class Switch(Widget, can_focus=True):
name: The name of the switch. name: The name of the switch.
id: The ID of the switch in the DOM. id: The ID of the switch in the DOM.
classes: The CSS classes of the switch. classes: The CSS classes of the switch.
disabled: Whether the switch is disabled or not.
""" """
super().__init__(name=name, id=id, classes=classes) super().__init__(name=name, id=id, classes=classes, disabled=disabled)
if value: if value:
self.slider_pos = 1.0 self.slider_pos = 1.0
self._reactive_value = value self._reactive_value = value

View File

@@ -43,8 +43,9 @@ class TextLog(ScrollView, can_focus=True):
name: str | None = None, name: str | None = None,
id: str | None = None, id: str | None = None,
classes: str | None = None, classes: str | None = None,
disabled: bool = False,
) -> None: ) -> None:
super().__init__(name=name, id=id, classes=classes) super().__init__(name=name, id=id, classes=classes, disabled=disabled)
self.max_lines = max_lines self.max_lines = max_lines
self._start_line: int = 0 self._start_line: int = 0
self.lines: list[Strip] = [] self.lines: list[Strip] = []

View File

@@ -473,8 +473,9 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
name: str | None = None, name: str | None = None,
id: str | None = None, id: str | None = None,
classes: str | None = None, classes: str | None = None,
disabled: bool = False,
) -> None: ) -> None:
super().__init__(name=name, id=id, classes=classes) super().__init__(name=name, id=id, classes=classes, disabled=disabled)
text_label = self.process_label(label) text_label = self.process_label(label)

View File

@@ -8,7 +8,7 @@ from textual.css.parse import substitute_references
from textual.css.scalar import Scalar, Unit from textual.css.scalar import Scalar, Unit
from textual.css.stylesheet import Stylesheet, StylesheetParseError from textual.css.stylesheet import Stylesheet, StylesheetParseError
from textual.css.tokenize import tokenize from textual.css.tokenize import tokenize
from textual.css.tokenizer import ReferencedBy, Token from textual.css.tokenizer import ReferencedBy, Token, TokenError
from textual.css.transition import Transition from textual.css.transition import Transition
from textual.geometry import Spacing from textual.geometry import Spacing
from textual.layouts.vertical import VerticalLayout from textual.layouts.vertical import VerticalLayout
@@ -1189,3 +1189,40 @@ class TestParseTextAlign:
stylesheet = Stylesheet() stylesheet = Stylesheet()
stylesheet.add_source(css) stylesheet.add_source(css)
assert stylesheet.rules[0].styles.text_align == "start" assert stylesheet.rules[0].styles.text_align == "start"
class TestTypeNames:
def test_type_no_number(self):
stylesheet = Stylesheet()
stylesheet.add_source("TestType {}")
assert len(stylesheet.rules) == 1
def test_type_with_number(self):
stylesheet = Stylesheet()
stylesheet.add_source("TestType1 {}")
assert len(stylesheet.rules) == 1
def test_type_starts_with_number(self):
stylesheet = Stylesheet()
stylesheet.add_source("1TestType {}")
with pytest.raises(TokenError):
stylesheet.parse()
def test_combined_type_no_number(self):
for separator in " >,":
stylesheet = Stylesheet()
stylesheet.add_source(f"StartType {separator} TestType {{}}")
assert len(stylesheet.rules) == 1
def test_combined_type_with_number(self):
for separator in " >,":
stylesheet = Stylesheet()
stylesheet.add_source(f"StartType {separator} TestType1 {{}}")
assert len(stylesheet.rules) == 1
def test_combined_type_starts_with_number(self):
for separator in " >,":
stylesheet = Stylesheet()
stylesheet.add_source(f"StartType {separator} 1TestType {{}}")
with pytest.raises(TokenError):
stylesheet.parse()

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,84 @@
from textual.app import App, ComposeResult
from textual.containers import Vertical, Horizontal
from textual.widgets import (
Header,
Footer,
Button,
DataTable,
Input,
ListView,
ListItem,
Label,
Markdown,
MarkdownViewer,
Tree,
TextLog,
)
class WidgetDisableTestApp(App[None]):
CSS = """
Horizontal {
height: auto;
}
DataTable, ListView, Tree, TextLog {
height: 2;
}
Markdown, MarkdownViewer {
height: 1fr;
}
"""
@property
def data_table(self) -> DataTable:
data_table = DataTable[str]()
data_table.add_columns("Column 1", "Column 2", "Column 3", "Column 4")
data_table.add_rows(
[(str(n), str(n * 10), str(n * 100), str(n * 1000)) for n in range(100)]
)
return data_table
@property
def list_view(self) -> ListView:
return ListView(*[ListItem(Label(f"This is list item {n}")) for n in range(20)])
@property
def test_tree(self) -> Tree:
tree = Tree[None](label="This is a test tree")
for n in range(10):
tree.root.add_leaf(f"Leaf {n}")
tree.root.expand()
return tree
def compose(self) -> ComposeResult:
yield Header()
yield Vertical(
Horizontal(
Button(),
Button(variant="primary"),
Button(variant="success"),
Button(variant="warning"),
Button(variant="error"),
),
self.data_table,
self.list_view,
self.test_tree,
TextLog(),
Input(),
Input(placeholder="This is an empty input with a placeholder"),
Input("This is some text in an input"),
Markdown("# Hello, World!"),
MarkdownViewer("# Hello, World!"),
id="test-container",
)
yield Footer()
def on_mount(self) -> None:
self.query_one(TextLog).write("Hello, World!")
self.query_one("#test-container", Vertical).disabled = True
if __name__ == "__main__":
WidgetDisableTestApp().run()

View File

@@ -231,3 +231,7 @@ def test_auto_width_input(snap_compare):
def test_screen_switch(snap_compare): def test_screen_switch(snap_compare):
assert snap_compare(SNAPSHOT_APPS_DIR / "screen_switch.py", press=["a", "b"]) assert snap_compare(SNAPSHOT_APPS_DIR / "screen_switch.py", press=["a", "b"])
def test_disabled_widgets(snap_compare):
assert snap_compare(SNAPSHOT_APPS_DIR / "disable_widgets.py")

17
tests/test_app.py Normal file
View File

@@ -0,0 +1,17 @@
from textual.app import App
def test_batch_update():
"""Test `batch_update` context manager"""
app = App()
assert app._batch_count == 0 # Start at zero
with app.batch_update():
assert app._batch_count == 1 # Increments in context manager
with app.batch_update():
assert app._batch_count == 2 # Nested updates
assert app._batch_count == 1 # Exiting decrements
assert app._batch_count == 0 # Back to zero

View File

@@ -9,10 +9,10 @@ from textual.widget import Widget
def test_arrange_empty(): def test_arrange_empty():
container = Widget(id="container") container = Widget(id="container")
placements, widgets, spacing = arrange(container, [], Size(80, 24), Size(80, 24)) result = arrange(container, [], Size(80, 24), Size(80, 24))
assert placements == [] assert result.placements == []
assert widgets == set() assert result.widgets == set()
assert spacing == Spacing(0, 0, 0, 0) assert result.spacing == Spacing(0, 0, 0, 0)
def test_arrange_dock_top(): def test_arrange_dock_top():
@@ -22,17 +22,16 @@ def test_arrange_dock_top():
header.styles.dock = "top" header.styles.dock = "top"
header.styles.height = "1" header.styles.height = "1"
placements, widgets, spacing = arrange( result = arrange(container, [child, header], Size(80, 24), Size(80, 24))
container, [child, header], Size(80, 24), Size(80, 24)
) assert result.placements == [
assert placements == [
WidgetPlacement( WidgetPlacement(
Region(0, 0, 80, 1), Spacing(), header, order=TOP_Z, fixed=True Region(0, 0, 80, 1), Spacing(), header, order=TOP_Z, fixed=True
), ),
WidgetPlacement(Region(0, 1, 80, 23), Spacing(), child, order=0, fixed=False), WidgetPlacement(Region(0, 1, 80, 23), Spacing(), child, order=0, fixed=False),
] ]
assert widgets == {child, header} assert result.widgets == {child, header}
assert spacing == Spacing(1, 0, 0, 0) assert result.spacing == Spacing(1, 0, 0, 0)
def test_arrange_dock_left(): def test_arrange_dock_left():
@@ -42,17 +41,15 @@ def test_arrange_dock_left():
header.styles.dock = "left" header.styles.dock = "left"
header.styles.width = "10" header.styles.width = "10"
placements, widgets, spacing = arrange( result = arrange(container, [child, header], Size(80, 24), Size(80, 24))
container, [child, header], Size(80, 24), Size(80, 24) assert result.placements == [
)
assert placements == [
WidgetPlacement( WidgetPlacement(
Region(0, 0, 10, 24), Spacing(), header, order=TOP_Z, fixed=True Region(0, 0, 10, 24), Spacing(), header, order=TOP_Z, fixed=True
), ),
WidgetPlacement(Region(10, 0, 70, 24), Spacing(), child, order=0, fixed=False), WidgetPlacement(Region(10, 0, 70, 24), Spacing(), child, order=0, fixed=False),
] ]
assert widgets == {child, header} assert result.widgets == {child, header}
assert spacing == Spacing(0, 0, 0, 10) assert result.spacing == Spacing(0, 0, 0, 10)
def test_arrange_dock_right(): def test_arrange_dock_right():
@@ -62,17 +59,15 @@ def test_arrange_dock_right():
header.styles.dock = "right" header.styles.dock = "right"
header.styles.width = "10" header.styles.width = "10"
placements, widgets, spacing = arrange( result = arrange(container, [child, header], Size(80, 24), Size(80, 24))
container, [child, header], Size(80, 24), Size(80, 24) assert result.placements == [
)
assert placements == [
WidgetPlacement( WidgetPlacement(
Region(70, 0, 10, 24), Spacing(), header, order=TOP_Z, fixed=True Region(70, 0, 10, 24), Spacing(), header, order=TOP_Z, fixed=True
), ),
WidgetPlacement(Region(0, 0, 70, 24), Spacing(), child, order=0, fixed=False), WidgetPlacement(Region(0, 0, 70, 24), Spacing(), child, order=0, fixed=False),
] ]
assert widgets == {child, header} assert result.widgets == {child, header}
assert spacing == Spacing(0, 10, 0, 0) assert result.spacing == Spacing(0, 10, 0, 0)
def test_arrange_dock_bottom(): def test_arrange_dock_bottom():
@@ -82,17 +77,15 @@ def test_arrange_dock_bottom():
header.styles.dock = "bottom" header.styles.dock = "bottom"
header.styles.height = "1" header.styles.height = "1"
placements, widgets, spacing = arrange( result = arrange(container, [child, header], Size(80, 24), Size(80, 24))
container, [child, header], Size(80, 24), Size(80, 24) assert result.placements == [
)
assert placements == [
WidgetPlacement( WidgetPlacement(
Region(0, 23, 80, 1), Spacing(), header, order=TOP_Z, fixed=True Region(0, 23, 80, 1), Spacing(), header, order=TOP_Z, fixed=True
), ),
WidgetPlacement(Region(0, 0, 80, 23), Spacing(), child, order=0, fixed=False), WidgetPlacement(Region(0, 0, 80, 23), Spacing(), child, order=0, fixed=False),
] ]
assert widgets == {child, header} assert result.widgets == {child, header}
assert spacing == Spacing(0, 0, 1, 0) assert result.spacing == Spacing(0, 0, 1, 0)
def test_arrange_dock_badly(): def test_arrange_dock_badly():

84
tests/test_disabled.py Normal file
View File

@@ -0,0 +1,84 @@
"""Test Widget.disabled."""
from textual.app import App, ComposeResult
from textual.containers import Vertical
from textual.widgets import (
Button,
DataTable,
DirectoryTree,
Input,
ListView,
Markdown,
MarkdownViewer,
Switch,
TextLog,
Tree,
)
class DisableApp(App[None]):
"""Application for testing Widget.disabled."""
def compose(self) -> ComposeResult:
"""Compose the child widgets."""
yield Vertical(
Button(),
DataTable(),
DirectoryTree("."),
Input(),
ListView(),
Switch(),
TextLog(),
Tree("Test"),
Markdown(),
MarkdownViewer(),
id="test-container",
)
async def test_all_initially_enabled() -> None:
"""All widgets should start out enabled."""
async with DisableApp().run_test() as pilot:
assert all(
not node.disabled for node in pilot.app.screen.query("#test-container > *")
)
async def test_enabled_widgets_have_enabled_pseudo_class() -> None:
"""All enabled widgets should have the :enabled pseudoclass."""
async with DisableApp().run_test() as pilot:
assert all(
node.has_pseudo_class("enabled") and not node.has_pseudo_class("disabled")
for node in pilot.app.screen.query("#test-container > *")
)
async def test_all_individually_disabled() -> None:
"""Post-disable all widgets should report being disabled."""
async with DisableApp().run_test() as pilot:
for node in pilot.app.screen.query("Vertical > *"):
node.disabled = True
assert all(
node.disabled for node in pilot.app.screen.query("#test-container > *")
)
async def test_disabled_widgets_have_disabled_pseudo_class() -> None:
"""All disabled widgets should have the :disabled pseudoclass."""
async with DisableApp().run_test() as pilot:
for node in pilot.app.screen.query("#test-container > *"):
node.disabled = True
assert all(
node.has_pseudo_class("disabled") and not node.has_pseudo_class("enabled")
for node in pilot.app.screen.query("#test-container > *")
)
async def test_disable_via_container() -> None:
"""All child widgets should appear (to CSS) as disabled by a container being disabled."""
async with DisableApp().run_test() as pilot:
pilot.app.screen.query_one("#test-container", Vertical).disabled = True
assert all(
node.has_pseudo_class("disabled") and not node.has_pseudo_class("enabled")
for node in pilot.app.screen.query("#test-container > *")
)

View File

@@ -328,6 +328,33 @@ async def test_reactive_inheritance():
assert tertiary.baz == "baz" assert tertiary.baz == "baz"
async def test_compute():
"""Check compute method is called."""
class ComputeApp(App):
count = var(0)
count_double = var(0)
def __init__(self) -> None:
self.start = 0
super().__init__()
def compute_count_double(self) -> int:
return self.start + self.count * 2
app = ComputeApp()
async with app.run_test():
assert app.count_double == 0
app.count = 1
assert app.count_double == 2
assert app.count_double == 2
app.count = 2
assert app.count_double == 4
app.start = 10
assert app.count_double == 14
async def test_watch_compute(): async def test_watch_compute():
"""Check that watching a computed attribute works.""" """Check that watching a computed attribute works."""
@@ -347,7 +374,9 @@ async def test_watch_compute():
app = Calculator() app = Calculator()
async with app.run_test() as pilot: # Referencing the value calls compute
# Setting any reactive values calls compute
async with app.run_test():
assert app.show_ac is True assert app.show_ac is True
app.value = "1" app.value = "1"
assert app.show_ac is False assert app.show_ac is False
@@ -356,4 +385,4 @@ async def test_watch_compute():
app.numbers = "123" app.numbers = "123"
assert app.show_ac is False assert app.show_ac is False
assert watch_called == [True, False, True, False] assert watch_called == [True, True, False, False, True, True, False, False]

64
tests/test_spatial_map.py Normal file
View File

@@ -0,0 +1,64 @@
import pytest
from textual._spatial_map import SpatialMap
from textual.geometry import Region
@pytest.mark.parametrize(
"region,grid",
[
(
Region(0, 0, 10, 10),
[
(0, 0),
],
),
(
Region(10, 10, 10, 10),
[
(1, 1),
],
),
(
Region(0, 0, 11, 11),
[(0, 0), (0, 1), (1, 0), (1, 1)],
),
(
Region(5, 5, 15, 3),
[(0, 0), (1, 0)],
),
(
Region(5, 5, 2, 15),
[(0, 0), (0, 1)],
),
],
)
def test_region_to_grid(region, grid):
spatial_map = SpatialMap(10, 10)
assert list(spatial_map._region_to_grid_coordinates(region)) == grid
def test_get_values_in_region() -> None:
spatial_map: SpatialMap[str] = SpatialMap(20, 10)
spatial_map.insert(
[
(Region(10, 5, 5, 5), False, "foo"),
(Region(5, 20, 5, 5), False, "bar"),
(Region(0, 0, 40, 1), True, "title"),
]
)
assert spatial_map.get_values_in_region(Region(0, 0, 10, 5)) == [
"title",
"foo",
]
assert spatial_map.get_values_in_region(Region(0, 1, 10, 5)) == ["title", "foo"]
assert spatial_map.get_values_in_region(Region(0, 10, 10, 5)) == ["title"]
assert spatial_map.get_values_in_region(Region(0, 20, 10, 5)) == ["title", "bar"]
assert spatial_map.get_values_in_region(Region(5, 5, 50, 50)) == [
"title",
"foo",
"bar",
]

View File

@@ -26,21 +26,18 @@ class VisibleTester(App[None]):
async def test_visibility_changes() -> None: async def test_visibility_changes() -> None:
"""Test changing visibility via code and CSS.""" """Test changing visibility via code and CSS."""
async with VisibleTester().run_test() as pilot: async with VisibleTester().run_test() as pilot:
assert len(pilot.app.screen.visible_widgets) == 5
assert pilot.app.query_one("#keep").visible is True assert pilot.app.query_one("#keep").visible is True
assert pilot.app.query_one("#hide-via-code").visible is True assert pilot.app.query_one("#hide-via-code").visible is True
assert pilot.app.query_one("#hide-via-css").visible is True assert pilot.app.query_one("#hide-via-css").visible is True
pilot.app.query_one("#hide-via-code").styles.visibility = "hidden" pilot.app.query_one("#hide-via-code").styles.visibility = "hidden"
await pilot.pause(0) await pilot.pause(0)
assert len(pilot.app.screen.visible_widgets) == 4
assert pilot.app.query_one("#keep").visible is True assert pilot.app.query_one("#keep").visible is True
assert pilot.app.query_one("#hide-via-code").visible is False assert pilot.app.query_one("#hide-via-code").visible is False
assert pilot.app.query_one("#hide-via-css").visible is True assert pilot.app.query_one("#hide-via-css").visible is True
pilot.app.query_one("#hide-via-css").set_class(True, "hidden") pilot.app.query_one("#hide-via-css").set_class(True, "hidden")
await pilot.pause(0) await pilot.pause(0)
assert len(pilot.app.screen.visible_widgets) == 3
assert pilot.app.query_one("#keep").visible is True assert pilot.app.query_one("#keep").visible is True
assert pilot.app.query_one("#hide-via-code").visible is False assert pilot.app.query_one("#hide-via-code").visible is False
assert pilot.app.query_one("#hide-via-css").visible is False assert pilot.app.query_one("#hide-via-css").visible is False