diff --git a/sandbox/table.py b/sandbox/table.py index 5f013c338..545cd5f83 100644 --- a/sandbox/table.py +++ b/sandbox/table.py @@ -47,7 +47,13 @@ class TableApp(App): height = 1 row = [f"row [b]{n}[/b] col [i]{c}[/i]" for c in range(6)] if n == 10: - row[1] = Syntax(CODE, "python", line_numbers=True, indent_guides=True) + row[1] = Syntax( + CODE, + "python", + theme="ansi_dark", + line_numbers=True, + indent_guides=True, + ) height = 13 if n == 30: diff --git a/src/textual/widget.py b/src/textual/widget.py index 2100f70fd..f845b5da3 100644 --- a/src/textual/widget.py +++ b/src/textual/widget.py @@ -624,6 +624,67 @@ class Widget(DOMNode): return any(scrolls) + def scroll_to_region(self, region: Region, *, animate: bool = True) -> bool: + """Scrolls a given region in to view. + + Args: + region (Region): A region that should be visible. + animate (bool, optional): Enable animation. Defaults to True. + + Returns: + bool: True if the window was scrolled. + """ + + scroll_x, scroll_y = self.scroll_offset + width, height = self.region.size + container_region = Region(scroll_x, scroll_y, width, height) + + if region in container_region: + # Widget is visible, nothing to do + return False + + ( + container_left, + container_top, + container_right, + container_bottom, + ) = container_region.corners + ( + child_left, + child_top, + child_right, + child_bottom, + ) = region.corners + + delta_x = 0 + delta_y = 0 + + if not ( + (container_right >= child_left > container_left) + and (container_right >= child_right > container_left) + ): + delta_x = min( + child_left - container_left, + child_left - (container_right - region.width), + key=abs, + ) + + if not ( + (container_bottom >= child_top > container_top) + and (container_bottom >= child_bottom > container_top) + ): + delta_y = min( + child_top - container_top, + child_top - (container_bottom - region.height), + key=abs, + ) + + scrolled = self.scroll_relative( + delta_x or None, delta_y or None, animate=abs(delta_y) != 1, duration=0.2 + ) + + return scrolled + def __init_subclass__( cls, can_focus: bool = True, diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index d337ea92f..558f6d612 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -2,7 +2,8 @@ from __future__ import annotations from dataclasses import dataclass, field from itertools import chain -from typing import Callable, ClassVar, Generic, TypeVar, cast +import sys +from typing import ClassVar, Generic, TypeVar, cast from rich.console import RenderableType from rich.padding import Padding @@ -11,16 +12,25 @@ from rich.segment import Segment from rich.style import Style from rich.text import Text, TextType +from .. import events from .._cache import LRUCache from .._segment_tools import line_crop from .._types import Lines -from ..geometry import Region, Size +from ..geometry import clamp, Region, Size from ..reactive import Reactive from .._profile import timer from ..scroll_view import ScrollView from ..widget import Widget from .. import messages + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + +CursorType = Literal["cell", "row", "column"] +CELL: CursorType = "cell" CellType = TypeVar("CellType") @@ -44,6 +54,7 @@ class Column: class Row: index: int height: int + y: int cell_renderables: list[RenderableType] = field(default_factory=list) @@ -82,6 +93,11 @@ class DataTable(ScrollView, Generic[CellType]): background: $primary 10%; } + DataTable > .datatable--cursor { + background: $secondary; + color: $text-secondary; + } + .-dark-mode DataTable > .datatable--even-row { background: $primary 15%; } @@ -98,6 +114,7 @@ class DataTable(ScrollView, Generic[CellType]): "datatable--odd-row", "datatable--even-row", "datatable--highlight", + "datatable--cursor", } def __init__( @@ -123,11 +140,17 @@ class DataTable(ScrollView, Generic[CellType]): self._line_cache: LRUCache[tuple[int, int, int, int], list[Segment]] self._line_cache = LRUCache(1000) + self._line_no = 0 + show_header = Reactive(True) fixed_rows = Reactive(0) - fixed_columns = Reactive(1) + fixed_columns = Reactive(0) zebra_stripes = Reactive(False) header_height = Reactive(1) + show_cursor = Reactive(True) + cursor_type = Reactive(CELL) + cursor_row = Reactive(0) + cursor_column = Reactive(0) def _clear_caches(self) -> None: self._row_render_cache.clear() @@ -151,6 +174,12 @@ class DataTable(ScrollView, Generic[CellType]): def watch_zebra_stripes(self, zebra_stripes: bool) -> None: self._clear_caches() + def validate_cursor_row(self, value: int) -> int: + return clamp(value, 0, self.row_count - 1) + + def validate_cursor_column(self, value: int) -> int: + return clamp(value, self.fixed_columns, len(self.columns) - 1) + def _update_dimensions(self) -> None: """Called to recalculate the virtual (scrollable) size.""" total_width = sum(column.width for column in self.columns) @@ -159,6 +188,16 @@ class DataTable(ScrollView, Generic[CellType]): len(self._y_offsets) + (self.header_height if self.show_header else 0), ) + def _get_cursor_region(self, row_index: int, column_index: int) -> Region: + row = self.rows[row_index] + x = sum(column.width for column in self.columns[:column_index]) + width = self.columns[column_index].width + height = row.height + y = row.y + if self.show_header: + y += self.header_height + return Region(x, y, width, height) + def add_column(self, label: TextType, *, width: int = 10) -> None: """Add a column to the table. @@ -179,12 +218,13 @@ class DataTable(ScrollView, Generic[CellType]): """ row_index = self.row_count self.data[row_index] = list(cells) - self.rows[row_index] = Row(row_index, height=height) + self.rows[row_index] = Row(row_index, height, self._line_no) for line_no in range(height): self._y_offsets.append((row_index, line_no)) self.row_count += 1 + self._line_no += height self._update_dimensions() self.refresh() @@ -210,7 +250,12 @@ class DataTable(ScrollView, Generic[CellType]): return [default_cell_formatter(datum) or empty for datum in data] def _render_cell( - self, row_index: int, column_index: int, style: Style, width: int + self, + row_index: int, + column_index: int, + style: Style, + width: int, + cursor: bool = False, ) -> Lines: """Render the given cell. @@ -223,6 +268,8 @@ class DataTable(ScrollView, Generic[CellType]): Returns: Lines: A list of segments per line. """ + if cursor: + style += self.component_styles["datatable--cursor"].node.rich_style cell_key = (row_index, column_index, style) if cell_key not in self._cell_render_cache: style += Style.from_meta({"row": row_index, "column": column_index}) @@ -239,7 +286,7 @@ class DataTable(ScrollView, Generic[CellType]): return self._cell_render_cache[cell_key] def _render_row( - self, row_index: int, line_no: int, base_style: Style + self, row_index: int, line_no: int, base_style: Style, cursor: int = -1 ) -> tuple[Lines, Lines]: """Render a row in to lines for each cell. @@ -281,7 +328,13 @@ class DataTable(ScrollView, Generic[CellType]): row_style = base_style scrollable_row = [ - render_cell(row_index, column.index, row_style, column.width)[line_no] + render_cell( + row_index, + column.index, + row_style, + column.width, + cursor=cursor == column.index, + )[line_no] for column in self.columns ] @@ -319,7 +372,7 @@ class DataTable(ScrollView, Generic[CellType]): list[Segment]: List of segments for rendering. """ - width = self.content_region.width + width = self.region.width cache_key = (y, x1, x2, width) if cache_key in self._line_cache: @@ -327,7 +380,14 @@ class DataTable(ScrollView, Generic[CellType]): row_index, line_no = self._get_offsets(y) - fixed, scrollable = self._render_row(row_index, line_no, base_style) + fixed, scrollable = self._render_row( + row_index, + line_no, + base_style, + cursor=self.cursor_column + if (self.show_cursor and self.cursor_row == row_index) + else -1, + ) fixed_width = sum(column.width for column in self.columns[: self.fixed_columns]) fixed_line: list[Segment] = list(chain.from_iterable(fixed)) if fixed else [] @@ -335,11 +395,11 @@ class DataTable(ScrollView, Generic[CellType]): segments = fixed_line + line_crop(scrollable_line, x1 + fixed_width, x2, width) - remaining_width = width - (fixed_width + min(width, (x2 - x1 + fixed_width))) - if remaining_width > 0: - segments.append(Segment(" " * remaining_width, base_style)) - elif remaining_width < 0: - segments = Segment.adjust_line_length(segments, width, style=base_style) + # remaining_width = width - (fixed_width + min(width, (x2 - x1 + fixed_width))) + # if remaining_width > 0: + # segments.append(Segment(" " * remaining_width, base_style)) + # elif remaining_width < 0: + segments = Segment.adjust_line_length(segments, width, style=base_style) simplified_segments = list(Segment.simplify(segments)) @@ -382,3 +442,39 @@ class DataTable(ScrollView, Generic[CellType]): def on_mouse_move(self, event): print(self.get_style_at(event.x, event.y).meta) + + async def on_key(self, event) -> None: + await self.dispatch_key(event) + + def _scroll_cursor_in_to_view(self) -> None: + region = self._get_cursor_region(self.cursor_row, self.cursor_column) + print("CURSOR", region) + self.scroll_to_region(region) + + def key_down(self, event: events.Key): + self.cursor_row += 1 + self._clear_caches() + event.stop() + event.prevent_default() + self._scroll_cursor_in_to_view() + + def key_up(self, event: events.Key): + self.cursor_row -= 1 + self._clear_caches() + event.stop() + event.prevent_default() + self._scroll_cursor_in_to_view() + + def key_right(self, event: events.Key): + self.cursor_column += 1 + self._clear_caches() + event.stop() + event.prevent_default() + self._scroll_cursor_in_to_view() + + def key_left(self, event: events.Key): + self.cursor_column -= 1 + self._clear_caches() + event.stop() + event.prevent_default() + self._scroll_cursor_in_to_view()