keyboard control of cursor

This commit is contained in:
Will McGugan
2022-06-21 17:49:28 +01:00
parent a9be6aa32a
commit b2ed540c50
3 changed files with 178 additions and 15 deletions

View File

@@ -47,7 +47,13 @@ class TableApp(App):
height = 1
row = [f"row [b]{n}[/b] col [i]{c}[/i]" for c in range(6)]
if n == 10:
row[1] = Syntax(CODE, "python", line_numbers=True, indent_guides=True)
row[1] = Syntax(
CODE,
"python",
theme="ansi_dark",
line_numbers=True,
indent_guides=True,
)
height = 13
if n == 30:

View File

@@ -624,6 +624,67 @@ class Widget(DOMNode):
return any(scrolls)
def scroll_to_region(self, region: Region, *, animate: bool = True) -> bool:
"""Scrolls a given region in to view.
Args:
region (Region): A region that should be visible.
animate (bool, optional): Enable animation. Defaults to True.
Returns:
bool: True if the window was scrolled.
"""
scroll_x, scroll_y = self.scroll_offset
width, height = self.region.size
container_region = Region(scroll_x, scroll_y, width, height)
if region in container_region:
# Widget is visible, nothing to do
return False
(
container_left,
container_top,
container_right,
container_bottom,
) = container_region.corners
(
child_left,
child_top,
child_right,
child_bottom,
) = region.corners
delta_x = 0
delta_y = 0
if not (
(container_right >= child_left > container_left)
and (container_right >= child_right > container_left)
):
delta_x = min(
child_left - container_left,
child_left - (container_right - region.width),
key=abs,
)
if not (
(container_bottom >= child_top > container_top)
and (container_bottom >= child_bottom > container_top)
):
delta_y = min(
child_top - container_top,
child_top - (container_bottom - region.height),
key=abs,
)
scrolled = self.scroll_relative(
delta_x or None, delta_y or None, animate=abs(delta_y) != 1, duration=0.2
)
return scrolled
def __init_subclass__(
cls,
can_focus: bool = True,

View File

@@ -2,7 +2,8 @@ from __future__ import annotations
from dataclasses import dataclass, field
from itertools import chain
from typing import Callable, ClassVar, Generic, TypeVar, cast
import sys
from typing import ClassVar, Generic, TypeVar, cast
from rich.console import RenderableType
from rich.padding import Padding
@@ -11,16 +12,25 @@ from rich.segment import Segment
from rich.style import Style
from rich.text import Text, TextType
from .. import events
from .._cache import LRUCache
from .._segment_tools import line_crop
from .._types import Lines
from ..geometry import Region, Size
from ..geometry import clamp, Region, Size
from ..reactive import Reactive
from .._profile import timer
from ..scroll_view import ScrollView
from ..widget import Widget
from .. import messages
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
CursorType = Literal["cell", "row", "column"]
CELL: CursorType = "cell"
CellType = TypeVar("CellType")
@@ -44,6 +54,7 @@ class Column:
class Row:
index: int
height: int
y: int
cell_renderables: list[RenderableType] = field(default_factory=list)
@@ -82,6 +93,11 @@ class DataTable(ScrollView, Generic[CellType]):
background: $primary 10%;
}
DataTable > .datatable--cursor {
background: $secondary;
color: $text-secondary;
}
.-dark-mode DataTable > .datatable--even-row {
background: $primary 15%;
}
@@ -98,6 +114,7 @@ class DataTable(ScrollView, Generic[CellType]):
"datatable--odd-row",
"datatable--even-row",
"datatable--highlight",
"datatable--cursor",
}
def __init__(
@@ -123,11 +140,17 @@ class DataTable(ScrollView, Generic[CellType]):
self._line_cache: LRUCache[tuple[int, int, int, int], list[Segment]]
self._line_cache = LRUCache(1000)
self._line_no = 0
show_header = Reactive(True)
fixed_rows = Reactive(0)
fixed_columns = Reactive(1)
fixed_columns = Reactive(0)
zebra_stripes = Reactive(False)
header_height = Reactive(1)
show_cursor = Reactive(True)
cursor_type = Reactive(CELL)
cursor_row = Reactive(0)
cursor_column = Reactive(0)
def _clear_caches(self) -> None:
self._row_render_cache.clear()
@@ -151,6 +174,12 @@ class DataTable(ScrollView, Generic[CellType]):
def watch_zebra_stripes(self, zebra_stripes: bool) -> None:
self._clear_caches()
def validate_cursor_row(self, value: int) -> int:
return clamp(value, 0, self.row_count - 1)
def validate_cursor_column(self, value: int) -> int:
return clamp(value, self.fixed_columns, len(self.columns) - 1)
def _update_dimensions(self) -> None:
"""Called to recalculate the virtual (scrollable) size."""
total_width = sum(column.width for column in self.columns)
@@ -159,6 +188,16 @@ class DataTable(ScrollView, Generic[CellType]):
len(self._y_offsets) + (self.header_height if self.show_header else 0),
)
def _get_cursor_region(self, row_index: int, column_index: int) -> Region:
row = self.rows[row_index]
x = sum(column.width for column in self.columns[:column_index])
width = self.columns[column_index].width
height = row.height
y = row.y
if self.show_header:
y += self.header_height
return Region(x, y, width, height)
def add_column(self, label: TextType, *, width: int = 10) -> None:
"""Add a column to the table.
@@ -179,12 +218,13 @@ class DataTable(ScrollView, Generic[CellType]):
"""
row_index = self.row_count
self.data[row_index] = list(cells)
self.rows[row_index] = Row(row_index, height=height)
self.rows[row_index] = Row(row_index, height, self._line_no)
for line_no in range(height):
self._y_offsets.append((row_index, line_no))
self.row_count += 1
self._line_no += height
self._update_dimensions()
self.refresh()
@@ -210,7 +250,12 @@ class DataTable(ScrollView, Generic[CellType]):
return [default_cell_formatter(datum) or empty for datum in data]
def _render_cell(
self, row_index: int, column_index: int, style: Style, width: int
self,
row_index: int,
column_index: int,
style: Style,
width: int,
cursor: bool = False,
) -> Lines:
"""Render the given cell.
@@ -223,6 +268,8 @@ class DataTable(ScrollView, Generic[CellType]):
Returns:
Lines: A list of segments per line.
"""
if cursor:
style += self.component_styles["datatable--cursor"].node.rich_style
cell_key = (row_index, column_index, style)
if cell_key not in self._cell_render_cache:
style += Style.from_meta({"row": row_index, "column": column_index})
@@ -239,7 +286,7 @@ class DataTable(ScrollView, Generic[CellType]):
return self._cell_render_cache[cell_key]
def _render_row(
self, row_index: int, line_no: int, base_style: Style
self, row_index: int, line_no: int, base_style: Style, cursor: int = -1
) -> tuple[Lines, Lines]:
"""Render a row in to lines for each cell.
@@ -281,7 +328,13 @@ class DataTable(ScrollView, Generic[CellType]):
row_style = base_style
scrollable_row = [
render_cell(row_index, column.index, row_style, column.width)[line_no]
render_cell(
row_index,
column.index,
row_style,
column.width,
cursor=cursor == column.index,
)[line_no]
for column in self.columns
]
@@ -319,7 +372,7 @@ class DataTable(ScrollView, Generic[CellType]):
list[Segment]: List of segments for rendering.
"""
width = self.content_region.width
width = self.region.width
cache_key = (y, x1, x2, width)
if cache_key in self._line_cache:
@@ -327,7 +380,14 @@ class DataTable(ScrollView, Generic[CellType]):
row_index, line_no = self._get_offsets(y)
fixed, scrollable = self._render_row(row_index, line_no, base_style)
fixed, scrollable = self._render_row(
row_index,
line_no,
base_style,
cursor=self.cursor_column
if (self.show_cursor and self.cursor_row == row_index)
else -1,
)
fixed_width = sum(column.width for column in self.columns[: self.fixed_columns])
fixed_line: list[Segment] = list(chain.from_iterable(fixed)) if fixed else []
@@ -335,11 +395,11 @@ class DataTable(ScrollView, Generic[CellType]):
segments = fixed_line + line_crop(scrollable_line, x1 + fixed_width, x2, width)
remaining_width = width - (fixed_width + min(width, (x2 - x1 + fixed_width)))
if remaining_width > 0:
segments.append(Segment(" " * remaining_width, base_style))
elif remaining_width < 0:
segments = Segment.adjust_line_length(segments, width, style=base_style)
# remaining_width = width - (fixed_width + min(width, (x2 - x1 + fixed_width)))
# if remaining_width > 0:
# segments.append(Segment(" " * remaining_width, base_style))
# elif remaining_width < 0:
segments = Segment.adjust_line_length(segments, width, style=base_style)
simplified_segments = list(Segment.simplify(segments))
@@ -382,3 +442,39 @@ class DataTable(ScrollView, Generic[CellType]):
def on_mouse_move(self, event):
print(self.get_style_at(event.x, event.y).meta)
async def on_key(self, event) -> None:
await self.dispatch_key(event)
def _scroll_cursor_in_to_view(self) -> None:
region = self._get_cursor_region(self.cursor_row, self.cursor_column)
print("CURSOR", region)
self.scroll_to_region(region)
def key_down(self, event: events.Key):
self.cursor_row += 1
self._clear_caches()
event.stop()
event.prevent_default()
self._scroll_cursor_in_to_view()
def key_up(self, event: events.Key):
self.cursor_row -= 1
self._clear_caches()
event.stop()
event.prevent_default()
self._scroll_cursor_in_to_view()
def key_right(self, event: events.Key):
self.cursor_column += 1
self._clear_caches()
event.stop()
event.prevent_default()
self._scroll_cursor_in_to_view()
def key_left(self, event: events.Key):
self.cursor_column -= 1
self._clear_caches()
event.stop()
event.prevent_default()
self._scroll_cursor_in_to_view()