keyboard control of cursor

This commit is contained in:
Will McGugan
2022-06-21 17:49:28 +01:00
parent a9be6aa32a
commit b2ed540c50
3 changed files with 178 additions and 15 deletions

View File

@@ -47,7 +47,13 @@ class TableApp(App):
height = 1 height = 1
row = [f"row [b]{n}[/b] col [i]{c}[/i]" for c in range(6)] row = [f"row [b]{n}[/b] col [i]{c}[/i]" for c in range(6)]
if n == 10: if n == 10:
row[1] = Syntax(CODE, "python", line_numbers=True, indent_guides=True) row[1] = Syntax(
CODE,
"python",
theme="ansi_dark",
line_numbers=True,
indent_guides=True,
)
height = 13 height = 13
if n == 30: if n == 30:

View File

@@ -624,6 +624,67 @@ class Widget(DOMNode):
return any(scrolls) return any(scrolls)
def scroll_to_region(self, region: Region, *, animate: bool = True) -> bool:
"""Scrolls a given region in to view.
Args:
region (Region): A region that should be visible.
animate (bool, optional): Enable animation. Defaults to True.
Returns:
bool: True if the window was scrolled.
"""
scroll_x, scroll_y = self.scroll_offset
width, height = self.region.size
container_region = Region(scroll_x, scroll_y, width, height)
if region in container_region:
# Widget is visible, nothing to do
return False
(
container_left,
container_top,
container_right,
container_bottom,
) = container_region.corners
(
child_left,
child_top,
child_right,
child_bottom,
) = region.corners
delta_x = 0
delta_y = 0
if not (
(container_right >= child_left > container_left)
and (container_right >= child_right > container_left)
):
delta_x = min(
child_left - container_left,
child_left - (container_right - region.width),
key=abs,
)
if not (
(container_bottom >= child_top > container_top)
and (container_bottom >= child_bottom > container_top)
):
delta_y = min(
child_top - container_top,
child_top - (container_bottom - region.height),
key=abs,
)
scrolled = self.scroll_relative(
delta_x or None, delta_y or None, animate=abs(delta_y) != 1, duration=0.2
)
return scrolled
def __init_subclass__( def __init_subclass__(
cls, cls,
can_focus: bool = True, can_focus: bool = True,

View File

@@ -2,7 +2,8 @@ from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import chain from itertools import chain
from typing import Callable, ClassVar, Generic, TypeVar, cast import sys
from typing import ClassVar, Generic, TypeVar, cast
from rich.console import RenderableType from rich.console import RenderableType
from rich.padding import Padding from rich.padding import Padding
@@ -11,16 +12,25 @@ from rich.segment import Segment
from rich.style import Style from rich.style import Style
from rich.text import Text, TextType from rich.text import Text, TextType
from .. import events
from .._cache import LRUCache from .._cache import LRUCache
from .._segment_tools import line_crop from .._segment_tools import line_crop
from .._types import Lines from .._types import Lines
from ..geometry import Region, Size from ..geometry import clamp, Region, Size
from ..reactive import Reactive from ..reactive import Reactive
from .._profile import timer from .._profile import timer
from ..scroll_view import ScrollView from ..scroll_view import ScrollView
from ..widget import Widget from ..widget import Widget
from .. import messages from .. import messages
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
CursorType = Literal["cell", "row", "column"]
CELL: CursorType = "cell"
CellType = TypeVar("CellType") CellType = TypeVar("CellType")
@@ -44,6 +54,7 @@ class Column:
class Row: class Row:
index: int index: int
height: int height: int
y: int
cell_renderables: list[RenderableType] = field(default_factory=list) cell_renderables: list[RenderableType] = field(default_factory=list)
@@ -82,6 +93,11 @@ class DataTable(ScrollView, Generic[CellType]):
background: $primary 10%; background: $primary 10%;
} }
DataTable > .datatable--cursor {
background: $secondary;
color: $text-secondary;
}
.-dark-mode DataTable > .datatable--even-row { .-dark-mode DataTable > .datatable--even-row {
background: $primary 15%; background: $primary 15%;
} }
@@ -98,6 +114,7 @@ class DataTable(ScrollView, Generic[CellType]):
"datatable--odd-row", "datatable--odd-row",
"datatable--even-row", "datatable--even-row",
"datatable--highlight", "datatable--highlight",
"datatable--cursor",
} }
def __init__( def __init__(
@@ -123,11 +140,17 @@ class DataTable(ScrollView, Generic[CellType]):
self._line_cache: LRUCache[tuple[int, int, int, int], list[Segment]] self._line_cache: LRUCache[tuple[int, int, int, int], list[Segment]]
self._line_cache = LRUCache(1000) self._line_cache = LRUCache(1000)
self._line_no = 0
show_header = Reactive(True) show_header = Reactive(True)
fixed_rows = Reactive(0) fixed_rows = Reactive(0)
fixed_columns = Reactive(1) fixed_columns = Reactive(0)
zebra_stripes = Reactive(False) zebra_stripes = Reactive(False)
header_height = Reactive(1) header_height = Reactive(1)
show_cursor = Reactive(True)
cursor_type = Reactive(CELL)
cursor_row = Reactive(0)
cursor_column = Reactive(0)
def _clear_caches(self) -> None: def _clear_caches(self) -> None:
self._row_render_cache.clear() self._row_render_cache.clear()
@@ -151,6 +174,12 @@ class DataTable(ScrollView, Generic[CellType]):
def watch_zebra_stripes(self, zebra_stripes: bool) -> None: def watch_zebra_stripes(self, zebra_stripes: bool) -> None:
self._clear_caches() self._clear_caches()
def validate_cursor_row(self, value: int) -> int:
return clamp(value, 0, self.row_count - 1)
def validate_cursor_column(self, value: int) -> int:
return clamp(value, self.fixed_columns, len(self.columns) - 1)
def _update_dimensions(self) -> None: def _update_dimensions(self) -> None:
"""Called to recalculate the virtual (scrollable) size.""" """Called to recalculate the virtual (scrollable) size."""
total_width = sum(column.width for column in self.columns) total_width = sum(column.width for column in self.columns)
@@ -159,6 +188,16 @@ class DataTable(ScrollView, Generic[CellType]):
len(self._y_offsets) + (self.header_height if self.show_header else 0), len(self._y_offsets) + (self.header_height if self.show_header else 0),
) )
def _get_cursor_region(self, row_index: int, column_index: int) -> Region:
row = self.rows[row_index]
x = sum(column.width for column in self.columns[:column_index])
width = self.columns[column_index].width
height = row.height
y = row.y
if self.show_header:
y += self.header_height
return Region(x, y, width, height)
def add_column(self, label: TextType, *, width: int = 10) -> None: def add_column(self, label: TextType, *, width: int = 10) -> None:
"""Add a column to the table. """Add a column to the table.
@@ -179,12 +218,13 @@ class DataTable(ScrollView, Generic[CellType]):
""" """
row_index = self.row_count row_index = self.row_count
self.data[row_index] = list(cells) self.data[row_index] = list(cells)
self.rows[row_index] = Row(row_index, height=height) self.rows[row_index] = Row(row_index, height, self._line_no)
for line_no in range(height): for line_no in range(height):
self._y_offsets.append((row_index, line_no)) self._y_offsets.append((row_index, line_no))
self.row_count += 1 self.row_count += 1
self._line_no += height
self._update_dimensions() self._update_dimensions()
self.refresh() self.refresh()
@@ -210,7 +250,12 @@ class DataTable(ScrollView, Generic[CellType]):
return [default_cell_formatter(datum) or empty for datum in data] return [default_cell_formatter(datum) or empty for datum in data]
def _render_cell( def _render_cell(
self, row_index: int, column_index: int, style: Style, width: int self,
row_index: int,
column_index: int,
style: Style,
width: int,
cursor: bool = False,
) -> Lines: ) -> Lines:
"""Render the given cell. """Render the given cell.
@@ -223,6 +268,8 @@ class DataTable(ScrollView, Generic[CellType]):
Returns: Returns:
Lines: A list of segments per line. Lines: A list of segments per line.
""" """
if cursor:
style += self.component_styles["datatable--cursor"].node.rich_style
cell_key = (row_index, column_index, style) cell_key = (row_index, column_index, style)
if cell_key not in self._cell_render_cache: if cell_key not in self._cell_render_cache:
style += Style.from_meta({"row": row_index, "column": column_index}) style += Style.from_meta({"row": row_index, "column": column_index})
@@ -239,7 +286,7 @@ class DataTable(ScrollView, Generic[CellType]):
return self._cell_render_cache[cell_key] return self._cell_render_cache[cell_key]
def _render_row( def _render_row(
self, row_index: int, line_no: int, base_style: Style self, row_index: int, line_no: int, base_style: Style, cursor: int = -1
) -> tuple[Lines, Lines]: ) -> tuple[Lines, Lines]:
"""Render a row in to lines for each cell. """Render a row in to lines for each cell.
@@ -281,7 +328,13 @@ class DataTable(ScrollView, Generic[CellType]):
row_style = base_style row_style = base_style
scrollable_row = [ scrollable_row = [
render_cell(row_index, column.index, row_style, column.width)[line_no] render_cell(
row_index,
column.index,
row_style,
column.width,
cursor=cursor == column.index,
)[line_no]
for column in self.columns for column in self.columns
] ]
@@ -319,7 +372,7 @@ class DataTable(ScrollView, Generic[CellType]):
list[Segment]: List of segments for rendering. list[Segment]: List of segments for rendering.
""" """
width = self.content_region.width width = self.region.width
cache_key = (y, x1, x2, width) cache_key = (y, x1, x2, width)
if cache_key in self._line_cache: if cache_key in self._line_cache:
@@ -327,7 +380,14 @@ class DataTable(ScrollView, Generic[CellType]):
row_index, line_no = self._get_offsets(y) row_index, line_no = self._get_offsets(y)
fixed, scrollable = self._render_row(row_index, line_no, base_style) fixed, scrollable = self._render_row(
row_index,
line_no,
base_style,
cursor=self.cursor_column
if (self.show_cursor and self.cursor_row == row_index)
else -1,
)
fixed_width = sum(column.width for column in self.columns[: self.fixed_columns]) fixed_width = sum(column.width for column in self.columns[: self.fixed_columns])
fixed_line: list[Segment] = list(chain.from_iterable(fixed)) if fixed else [] fixed_line: list[Segment] = list(chain.from_iterable(fixed)) if fixed else []
@@ -335,11 +395,11 @@ class DataTable(ScrollView, Generic[CellType]):
segments = fixed_line + line_crop(scrollable_line, x1 + fixed_width, x2, width) segments = fixed_line + line_crop(scrollable_line, x1 + fixed_width, x2, width)
remaining_width = width - (fixed_width + min(width, (x2 - x1 + fixed_width))) # remaining_width = width - (fixed_width + min(width, (x2 - x1 + fixed_width)))
if remaining_width > 0: # if remaining_width > 0:
segments.append(Segment(" " * remaining_width, base_style)) # segments.append(Segment(" " * remaining_width, base_style))
elif remaining_width < 0: # elif remaining_width < 0:
segments = Segment.adjust_line_length(segments, width, style=base_style) segments = Segment.adjust_line_length(segments, width, style=base_style)
simplified_segments = list(Segment.simplify(segments)) simplified_segments = list(Segment.simplify(segments))
@@ -382,3 +442,39 @@ class DataTable(ScrollView, Generic[CellType]):
def on_mouse_move(self, event): def on_mouse_move(self, event):
print(self.get_style_at(event.x, event.y).meta) print(self.get_style_at(event.x, event.y).meta)
async def on_key(self, event) -> None:
await self.dispatch_key(event)
def _scroll_cursor_in_to_view(self) -> None:
region = self._get_cursor_region(self.cursor_row, self.cursor_column)
print("CURSOR", region)
self.scroll_to_region(region)
def key_down(self, event: events.Key):
self.cursor_row += 1
self._clear_caches()
event.stop()
event.prevent_default()
self._scroll_cursor_in_to_view()
def key_up(self, event: events.Key):
self.cursor_row -= 1
self._clear_caches()
event.stop()
event.prevent_default()
self._scroll_cursor_in_to_view()
def key_right(self, event: events.Key):
self.cursor_column += 1
self._clear_caches()
event.stop()
event.prevent_default()
self._scroll_cursor_in_to_view()
def key_left(self, event: events.Key):
self.cursor_column -= 1
self._clear_caches()
event.stop()
event.prevent_default()
self._scroll_cursor_in_to_view()