add inherited bindings

This commit is contained in:
Will McGugan
2022-11-19 19:28:14 +00:00
parent 811dcd8eaf
commit b48a1402b8
5 changed files with 86 additions and 44 deletions

View File

@@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Support lazy-instantiated Screens (callables in App.SCREENS) https://github.com/Textualize/textual/pull/1185 - Support lazy-instantiated Screens (callables in App.SCREENS) https://github.com/Textualize/textual/pull/1185
- Display of keys in footer has more sensible defaults https://github.com/Textualize/textual/pull/1213 - Display of keys in footer has more sensible defaults https://github.com/Textualize/textual/pull/1213
- Add App.get_key_display, allowing custom key_display App-wide https://github.com/Textualize/textual/pull/1213 - Add App.get_key_display, allowing custom key_display App-wide https://github.com/Textualize/textual/pull/1213
- Added "inherited bindings" -- BINDINGS classvar will be merged with base classes, unless inherit_bindings is set to False
### Changed ### Changed

View File

@@ -458,7 +458,6 @@ class Compositor:
# Add top level (root) widget # Add top level (root) widget
add_widget(root, size.region, size.region, ((0,),), layer_order, size.region) add_widget(root, size.region, size.region, ((0,),), layer_order, size.region)
root.log(map)
return map, widgets return map, widgets
@property @property

View File

@@ -22,7 +22,7 @@ from rich.tree import Tree
from ._context import NoActiveAppError from ._context import NoActiveAppError
from ._node_list import NodeList from ._node_list import NodeList
from .binding import Bindings, BindingType from .binding import Binding, Bindings, BindingType
from .color import BLACK, WHITE, Color from .color import BLACK, WHITE, Color
from .css._error_tools import friendly_list from .css._error_tools import friendly_list
from .css.constants import VALID_DISPLAY, VALID_VISIBILITY from .css.constants import VALID_DISPLAY, VALID_VISIBILITY
@@ -97,9 +97,16 @@ class DOMNode(MessagePump):
# True if this node inherits the CSS from the base class. # True if this node inherits the CSS from the base class.
_inherit_css: ClassVar[bool] = True _inherit_css: ClassVar[bool] = True
# True to inherit bindings from base class
_inherit_bindings: ClassVar[bool] = True
# List of names of base classes that inherit CSS # List of names of base classes that inherit CSS
_css_type_names: ClassVar[frozenset[str]] = frozenset() _css_type_names: ClassVar[frozenset[str]] = frozenset()
# Generated list of bindings
_merged_bindings: ClassVar[Bindings] | None = None
def __init__( def __init__(
self, self,
*, *,
@@ -127,7 +134,7 @@ class DOMNode(MessagePump):
self._auto_refresh: float | None = None self._auto_refresh: float | None = None
self._auto_refresh_timer: Timer | None = None self._auto_refresh_timer: Timer | None = None
self._css_types = {cls.__name__ for cls in self._css_bases(self.__class__)} self._css_types = {cls.__name__ for cls in self._css_bases(self.__class__)}
self._bindings = Bindings(self.BINDINGS) self._bindings = self._merged_bindings or Bindings()
self._has_hover_style: bool = False self._has_hover_style: bool = False
self._has_focus_within: bool = False self._has_focus_within: bool = False
@@ -152,12 +159,16 @@ class DOMNode(MessagePump):
"""Perform an automatic refresh (set with auto_refresh property).""" """Perform an automatic refresh (set with auto_refresh property)."""
self.refresh() self.refresh()
def __init_subclass__(cls, inherit_css: bool = True) -> None: def __init_subclass__(
cls, inherit_css: bool = True, inherit_bindings: bool = True
) -> None:
super().__init_subclass__() super().__init_subclass__()
cls._inherit_css = inherit_css cls._inherit_css = inherit_css
cls._inherit_bindings = inherit_bindings
css_type_names: set[str] = set() css_type_names: set[str] = set()
for base in cls._css_bases(cls): for base in cls._css_bases(cls):
css_type_names.add(base.__name__) css_type_names.add(base.__name__)
cls._merged_bindings = cls._merge_bindings()
cls._css_type_names = frozenset(css_type_names) cls._css_type_names = frozenset(css_type_names)
def get_component_styles(self, name: str) -> RenderStyles: def get_component_styles(self, name: str) -> RenderStyles:
@@ -205,6 +216,25 @@ class DOMNode(MessagePump):
else: else:
break break
@classmethod
def _merge_bindings(cls) -> Bindings:
"""Merge bindings from base classes.
Returns:
Bindings: Merged bindings.
"""
bindings: list[Bindings] = []
for base in reversed(cls.__mro__):
if issubclass(base, DOMNode):
if not base._inherit_bindings:
bindings.clear()
bindings.append(Bindings(base.BINDINGS))
keys = {}
for bindings_ in bindings:
keys.update(bindings_.keys)
return Bindings(keys.values())
def _post_register(self, app: App) -> None: def _post_register(self, app: App) -> None:
"""Called when the widget is registered """Called when the widget is registered

View File

@@ -174,6 +174,12 @@ class Widget(DOMNode):
BINDINGS = [ BINDINGS = [
Binding("up", "scroll_up", "Scroll Up", show=False), Binding("up", "scroll_up", "Scroll Up", show=False),
Binding("down", "scroll_down", "Scroll Down", show=False), Binding("down", "scroll_down", "Scroll Down", show=False),
Binding("left", "scroll_left", "Scroll Up", show=False),
Binding("right", "scroll_right", "Scroll Right", show=False),
Binding("home", "scroll_home", "Scroll Home", show=False),
Binding("end", "scroll_end", "Scroll End", show=False),
Binding("pageup", "page_up", "Page Up", show=False),
Binding("pagedown", "page_down", "Page Down", show=False),
] ]
DEFAULT_CSS = """ DEFAULT_CSS = """
@@ -1816,9 +1822,13 @@ class Widget(DOMNode):
can_focus: bool | None = None, can_focus: bool | None = None,
can_focus_children: bool | None = None, can_focus_children: bool | None = None,
inherit_css: bool = True, inherit_css: bool = True,
inherit_bindings: bool = True,
) -> None: ) -> None:
base = cls.__mro__[0] base = cls.__mro__[0]
super().__init_subclass__(inherit_css=inherit_css) super().__init_subclass__(
inherit_css=inherit_css,
inherit_bindings=inherit_bindings,
)
if issubclass(base, Widget): if issubclass(base, Widget):
cls.can_focus = base.can_focus if can_focus is None else can_focus cls.can_focus = base.can_focus if can_focus is None else can_focus
cls.can_focus_children = ( cls.can_focus_children = (
@@ -2345,53 +2355,21 @@ class Widget(DOMNode):
def _on_scroll_to_region(self, message: messages.ScrollToRegion) -> None: def _on_scroll_to_region(self, message: messages.ScrollToRegion) -> None:
self.scroll_to_region(message.region, animate=True) self.scroll_to_region(message.region, animate=True)
def _key_home(self) -> bool: def action_scroll_home(self) -> None:
if self._allow_scroll: if self._allow_scroll:
self.scroll_home() self.scroll_home()
return True
return False
def _key_end(self) -> bool: def action_scroll_end(self) -> None:
if self._allow_scroll: if self._allow_scroll:
self.scroll_end() self.scroll_end()
return True
return False
def _key_left(self) -> bool: def action_scroll_left(self) -> None:
if self.allow_horizontal_scroll: if self.allow_horizontal_scroll:
self.scroll_left() self.scroll_left()
return True
return False
def _key_right(self) -> bool: def action_scroll_right(self) -> None:
if self.allow_horizontal_scroll: if self.allow_horizontal_scroll:
self.scroll_right() self.scroll_right()
return True
return False
# def _key_down(self) -> bool:
# if self.allow_vertical_scroll:
# self.scroll_down()
# return True
# return False
# def _key_up(self) -> bool:
# if self.allow_vertical_scroll:
# self.scroll_up()
# return True
# return False
def _key_pagedown(self) -> bool:
if self.allow_vertical_scroll:
self.scroll_page_down()
return True
return False
def _key_pageup(self) -> bool:
if self.allow_vertical_scroll:
self.scroll_page_up()
return True
return False
def action_scroll_up(self) -> None: def action_scroll_up(self) -> None:
if self.allow_vertical_scroll: if self.allow_vertical_scroll:
@@ -2400,3 +2378,11 @@ class Widget(DOMNode):
def action_scroll_down(self) -> None: def action_scroll_down(self) -> None:
if self.allow_vertical_scroll: if self.allow_vertical_scroll:
self.scroll_down() self.scroll_down()
def action_page_down(self) -> None:
if self.allow_vertical_scroll:
self.scroll_page_down()
def action_page_up(self) -> None:
if self.allow_vertical_scroll:
self.scroll_page_up()

View File

@@ -26,7 +26,7 @@ NodeID = NewType("NodeID", int)
TreeDataType = TypeVar("TreeDataType") TreeDataType = TypeVar("TreeDataType")
EventTreeDataType = TypeVar("EventTreeDataType") EventTreeDataType = TypeVar("EventTreeDataType")
LineCacheKey: TypeAlias = tuple[int | tuple[int, ...], ...] LineCacheKey: TypeAlias = tuple[int | tuple, ...]
TOGGLE_STYLE = Style.from_meta({"toggle": True}) TOGGLE_STYLE = Style.from_meta({"toggle": True})
@@ -199,9 +199,9 @@ class TreeNode(Generic[TreeDataType]):
class Tree(Generic[TreeDataType], ScrollView, can_focus=True): class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
BINDINGS = [ BINDINGS = [
Binding("enter", "select_cursor", "Select", show=False),
Binding("up", "cursor_up", "Cursor Up", show=False), Binding("up", "cursor_up", "Cursor Up", show=False),
Binding("down", "cursor_down", "Cursor Down", show=False), Binding("down", "cursor_down", "Cursor Down", show=False),
Binding("enter", "select_cursor", "Select", show=False),
] ]
DEFAULT_CSS = """ DEFAULT_CSS = """
@@ -334,6 +334,10 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
"""TreeNode | Node: The currently selected node, or ``None`` if no selection.""" """TreeNode | Node: The currently selected node, or ``None`` if no selection."""
return self._cursor_node return self._cursor_node
@property
def last_line(self) -> int:
return len(self._tree_lines) - 1
def process_label(self, label: TextType): def process_label(self, label: TextType):
"""Process a str or Text in to a label. Maybe overridden in a subclass to change modify how labels are rendered. """Process a str or Text in to a label. Maybe overridden in a subclass to change modify how labels are rendered.
@@ -797,15 +801,37 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
def action_cursor_up(self) -> None: def action_cursor_up(self) -> None:
if self.cursor_line == -1: if self.cursor_line == -1:
self.cursor_line = len(self._tree_lines) - 1 self.cursor_line = self.last_line
else: else:
self.cursor_line -= 1 self.cursor_line -= 1
self.scroll_to_line(self.cursor_line) self.scroll_to_line(self.cursor_line)
def action_cursor_down(self) -> None: def action_cursor_down(self) -> None:
if self.cursor_line == -1:
self.cursor_line = 0
self.cursor_line += 1 self.cursor_line += 1
self.scroll_to_line(self.cursor_line) self.scroll_to_line(self.cursor_line)
def action_page_down(self) -> None:
if self.cursor_line == -1:
self.cursor_line = 0
self.cursor_line += self.scrollable_content_region.height - 1
self.scroll_to_line(self.cursor_line)
def action_page_up(self) -> None:
if self.cursor_line == -1:
self.cursor_line = self.last_line
self.cursor_line -= self.scrollable_content_region.height - 1
self.scroll_to_line(self.cursor_line)
def action_scroll_home(self) -> None:
self.cursor_line = 0
self.scroll_to_line(self.cursor_line)
def action_scroll_end(self) -> None:
self.cursor_line = self.last_line
self.scroll_to_line(self.cursor_line)
def action_select_cursor(self) -> None: def action_select_cursor(self) -> None:
try: try:
line = self._tree_lines[self.cursor_line] line = self._tree_lines[self.cursor_line]