diff --git a/docs/api/directory_tree.md b/docs/api/directory_tree.md
new file mode 100644
index 000000000..f9d26e0e0
--- /dev/null
+++ b/docs/api/directory_tree.md
@@ -0,0 +1 @@
+::: textual.widgets.DirectoryTree
diff --git a/docs/api/tree.md b/docs/api/tree.md
new file mode 100644
index 000000000..73f20ee30
--- /dev/null
+++ b/docs/api/tree.md
@@ -0,0 +1 @@
+::: textual.widgets.Tree
diff --git a/docs/api/tree_node.md b/docs/api/tree_node.md
new file mode 100644
index 000000000..ad122443e
--- /dev/null
+++ b/docs/api/tree_node.md
@@ -0,0 +1 @@
+::: textual.widgets.TreeNode
diff --git a/docs/examples/widgets/directory_tree.py b/docs/examples/widgets/directory_tree.py
new file mode 100644
index 000000000..e0c14a92c
--- /dev/null
+++ b/docs/examples/widgets/directory_tree.py
@@ -0,0 +1,12 @@
+from textual.app import App, ComposeResult
+from textual.widgets import DirectoryTree
+
+
+class DirectoryTreeApp(App):
+ def compose(self) -> ComposeResult:
+ yield DirectoryTree("./")
+
+
+if __name__ == "__main__":
+ app = DirectoryTreeApp()
+ app.run()
diff --git a/docs/examples/widgets/tree.py b/docs/examples/widgets/tree.py
new file mode 100644
index 000000000..7b6ff27d7
--- /dev/null
+++ b/docs/examples/widgets/tree.py
@@ -0,0 +1,18 @@
+from textual.app import App, ComposeResult
+from textual.widgets import Tree
+
+
+class TreeApp(App):
+ def compose(self) -> ComposeResult:
+ tree: Tree = Tree("Dune")
+ tree.root.expand()
+ characters = tree.root.add("Characters", expand=True)
+ characters.add_leaf("Paul")
+ characters.add_leaf("Jessica")
+ characters.add_leaf("Channi")
+ yield tree
+
+
+if __name__ == "__main__":
+ app = TreeApp()
+ app.run()
diff --git a/docs/widgets/directory_tree.md b/docs/widgets/directory_tree.md
new file mode 100644
index 000000000..2c5e327c1
--- /dev/null
+++ b/docs/widgets/directory_tree.md
@@ -0,0 +1,36 @@
+# DirectoryTree
+
+A tree control to navigate the contents of your filesystem.
+
+- [x] Focusable
+- [ ] Container
+
+
+## Example
+
+The example below creates a simple tree to navigate the current working directory.
+
+```python
+--8<-- "docs/examples/widgets/directory_tree.py"
+```
+
+## Events
+
+| Event | Default handler | Description |
+| ------------------- | --------------------------------- | --------------------------------------- |
+| `Tree.FileSelected` | `on_directory_tree_file_selected` | Sent when the user selects a file node. |
+
+
+## Reactive Attributes
+
+| Name | Type | Default | Description |
+| ------------- | ------ | ------- | ----------------------------------------------- |
+| `show_root` | `bool` | `True` | Show the root node. |
+| `show_guides` | `bool` | `True` | Show guide lines between levels. |
+| `guide_depth` | `int` | `4` | Amount of indentation between parent and child. |
+
+
+## See Also
+
+* [Tree][textual.widgets.DirectoryTree] code reference
+* [Tree][textual.widgets.Tree] code reference
diff --git a/docs/widgets/tree.md b/docs/widgets/tree.md
new file mode 100644
index 000000000..d87e1a966
--- /dev/null
+++ b/docs/widgets/tree.md
@@ -0,0 +1,46 @@
+# Tree
+
+A tree control widget.
+
+- [x] Focusable
+- [ ] Container
+
+
+## Example
+
+The example below creates a simple tree.
+
+=== "Output"
+
+ ```{.textual path="docs/examples/widgets/tree.py"}
+ ```
+
+=== "tree.py"
+
+ ```python
+ --8<-- "docs/examples/widgets/tree.py"
+ ```
+
+A each tree widget has a "root" attribute which is an instance of a [TreeNode][textual.widgets.TreeNode]. Call [add()][textual.widgets.TreeNode.add] or [add_leaf()][textual.widgets.TreeNode.add_leaf] to add new nodes underneath the root. Both these methods return a TreeNode for the child, so you can add more levels.
+
+## Events
+
+| Event | Default handler | Description |
+| -------------------- | ------------------------ | ------------------------------------------------ |
+| `Tree.NodeSelected` | `on_tree_node_selected` | Sent when the user selects a tree node. |
+| `Tree.NodeExpanded` | `on_tree_node_expanded` | Sent when the user expands a node in the tree. |
+| `Tree.NodeCollapsed` | `on_tree_node_collapsed` | Sent when the user collapsed a node in the tree. |
+
+## Reactive Attributes
+
+| Name | Type | Default | Description |
+| ------------- | ------ | ------- | ----------------------------------------------- |
+| `show_root` | `bool` | `True` | Show the root node. |
+| `show_guides` | `bool` | `True` | Show guide lines between levels. |
+| `guide_depth` | `int` | `4` | Amount of indentation between parent and child. |
+
+
+## See Also
+
+* [Tree][textual.widgets.Tree] code reference
+* [TreeNode][textual.widgets.TreeNode] code reference
diff --git a/docs/widgets/tree_control.md b/docs/widgets/tree_control.md
deleted file mode 100644
index 1155acfcc..000000000
--- a/docs/widgets/tree_control.md
+++ /dev/null
@@ -1 +0,0 @@
-# TreeControl
diff --git a/examples/json_tree.py b/examples/json_tree.py
new file mode 100644
index 000000000..d844556bb
--- /dev/null
+++ b/examples/json_tree.py
@@ -0,0 +1,79 @@
+import json
+
+from rich.text import Text
+
+from textual.app import App, ComposeResult
+from textual.widgets import Header, Footer, Tree, TreeNode
+
+
+class TreeApp(App):
+
+ BINDINGS = [
+ ("a", "add", "Add node"),
+ ("c", "clear", "Clear"),
+ ("t", "toggle_root", "Toggle root"),
+ ]
+
+ def compose(self) -> ComposeResult:
+ yield Header()
+ yield Footer()
+ yield Tree("Root")
+
+ @classmethod
+ def add_json(cls, node: TreeNode, json_data: object) -> None:
+ """Adds JSON data to a node.
+
+ Args:
+ node (TreeNode): A Tree node.
+ json_data (object): An object decoded from JSON.
+ """
+
+ from rich.highlighter import ReprHighlighter
+
+ highlighter = ReprHighlighter()
+
+ def add_node(name: str, node: TreeNode, data: object) -> None:
+ if isinstance(data, dict):
+ node._label = Text(f"{{}} {name}")
+ for key, value in data.items():
+ new_node = node.add("")
+ add_node(key, new_node, value)
+ elif isinstance(data, list):
+ node._label = Text(f"[] {name}")
+ for index, value in enumerate(data):
+ new_node = node.add("")
+ add_node(str(index), new_node, value)
+ else:
+ node._allow_expand = False
+ if name:
+ label = Text.assemble(
+ Text.from_markup(f"[b]{name}[/b]="), highlighter(repr(data))
+ )
+ else:
+ label = Text(repr(data))
+ node._label = label
+
+ add_node("JSON", node, json_data)
+
+ def on_mount(self) -> None:
+ with open("food.json") as data_file:
+ self.json_data = json.load(data_file)
+
+ def action_add(self) -> None:
+ tree = self.query_one(Tree)
+ json_node = tree.root.add("JSON")
+ self.add_json(json_node, self.json_data)
+ tree.root.expand()
+
+ def action_clear(self) -> None:
+ tree = self.query_one(Tree)
+ tree.clear()
+
+ def action_toggle_root(self) -> None:
+ tree = self.query_one(Tree)
+ tree.show_root = not tree.show_root
+
+
+if __name__ == "__main__":
+ app = TreeApp()
+ app.run()
diff --git a/examples/tree.py b/examples/tree.py
deleted file mode 100644
index 688a505d7..000000000
--- a/examples/tree.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import json
-
-from textual.app import App, ComposeResult
-from textual.widgets import Header, Footer, Tree, DirectoryTree
-
-
-with open("food.json") as data_file:
- data = json.load(data_file)
-
-from rich import print
-
-print(data)
-
-
-class TreeApp(App):
-
- BINDINGS = [
- ("a", "add", "Add node"),
- ("c", "clear", "Clear"),
- ("t", "toggle_root", "Toggle root"),
- ]
-
- def compose(self) -> ComposeResult:
- yield Header()
- yield Footer()
- yield DirectoryTree("../")
-
- def action_add(self) -> None:
- tree = self.query_one(Tree)
-
- json_node = tree.root.add("JSON")
- tree.root.expand()
- tree.add_json(json_node, data)
-
- def action_clear(self) -> None:
- tree = self.query_one(Tree)
- tree.clear()
-
- def action_toggle_root(self) -> None:
- tree = self.query_one(Tree)
- tree.show_root = not tree.show_root
-
-
-if __name__ == "__main__":
- app = TreeApp()
- app.run()
diff --git a/mkdocs.yml b/mkdocs.yml
index 332298e5d..24edb147b 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -89,16 +89,17 @@ nav:
- "styles/visibility.md"
- "styles/width.md"
- Widgets:
- - "widgets/index.md"
- "widgets/button.md"
- "widgets/checkbox.md"
- "widgets/data_table.md"
+ - "widgets/directory_tree.md"
- "widgets/footer.md"
- "widgets/header.md"
+ - "widgets/index.md"
- "widgets/input.md"
- "widgets/label.md"
- "widgets/static.md"
- - "widgets/tree_control.md"
+ - "widgets/tree.md"
- API:
- "api/index.md"
- "api/app.md"
diff --git a/src/textual/widgets/__init__.py b/src/textual/widgets/__init__.py
index e856b2bbf..4cf014383 100644
--- a/src/textual/widgets/__init__.py
+++ b/src/textual/widgets/__init__.py
@@ -8,22 +8,23 @@ from ..case import camel_to_snake
# but also to the `__init__.pyi` file in this same folder - otherwise text editors and type checkers won't
# be able to "see" them.
if typing.TYPE_CHECKING:
- from ..widget import Widget
+
from ._button import Button
from ._checkbox import Checkbox
from ._data_table import DataTable
from ._directory_tree import DirectoryTree
from ._footer import Footer
from ._header import Header
+ from ._input import Input
from ._label import Label
from ._placeholder import Placeholder
from ._pretty import Pretty
from ._static import Static
- from ._input import Input
from ._text_log import TextLog
from ._tree import Tree
- from ._tree_control import TreeControl
+ from ._tree_node import TreeNode
from ._welcome import Welcome
+ from ..widget import Widget
__all__ = [
"Button",
@@ -39,7 +40,7 @@ __all__ = [
"Static",
"TextLog",
"Tree",
- "TreeControl",
+ "TreeNode",
"Welcome",
]
diff --git a/src/textual/widgets/__init__.pyi b/src/textual/widgets/__init__.pyi
index 7530aef42..1d6b5c920 100644
--- a/src/textual/widgets/__init__.pyi
+++ b/src/textual/widgets/__init__.pyi
@@ -12,5 +12,5 @@ from ._static import Static as Static
from ._input import Input as Input
from ._text_log import TextLog as TextLog
from ._tree import Tree as Tree
-from ._tree_control import TreeControl as TreeControl
+from ._tree_node import TreeNode as TreeNode
from ._welcome import Welcome as Welcome
diff --git a/src/textual/widgets/_tree.py b/src/textual/widgets/_tree.py
index 5de425494..83f9a3f67 100644
--- a/src/textual/widgets/_tree.py
+++ b/src/textual/widgets/_tree.py
@@ -56,6 +56,8 @@ class _TreeLine:
@rich.repr.auto
class TreeNode(Generic[TreeDataType]):
+ """An object that represents a "node" in a tree control."""
+
def __init__(
self,
tree: Tree[TreeDataType],
@@ -90,8 +92,9 @@ class TreeNode(Generic[TreeDataType]):
self._selected_ = False
self._updates += 1
+ @property
def line(self) -> int:
- """Get the line number for this node, or -1 if it is not displayed."""
+ """int: Get the line number for this node, or -1 if it is not displayed."""
return self._line
@property
@@ -116,9 +119,33 @@ class TreeNode(Generic[TreeDataType]):
@property
def id(self) -> NodeID:
- """Get the node ID."""
+ """NodeID: Get the node ID."""
return self._id
+ @property
+ def is_expanded(self) -> bool:
+ """bool: Check if the node is expanded."""
+ return self._expanded
+
+ @property
+ def is_last(self) -> bool:
+ """bool: Check if this is the last child."""
+ if self._parent is None:
+ return True
+ return bool(
+ self._parent._children and self._parent._children[-1] == self,
+ )
+
+ @property
+ def allow_expand(self) -> bool:
+ """bool: Check if the node is allowed to expand."""
+ return self._allow_expand
+
+ @allow_expand.setter
+ def allow_expand(self, allow_expand: bool) -> None:
+ self._allow_expand = allow_expand
+ self._updates += 1
+
def expand(self) -> None:
"""Expand a node (show its children)."""
self._expanded = True
@@ -147,30 +174,6 @@ class TreeNode(Generic[TreeDataType]):
text_label = self._tree.process_label(label)
self._label = text_label
- @property
- def is_expanded(self) -> bool:
- """bool: Check if the node is expanded."""
- return self._expanded
-
- @property
- def is_last(self) -> bool:
- """bool: Check if this is the last child."""
- if self._parent is None:
- return True
- return bool(
- self._parent._children and self._parent._children[-1] == self,
- )
-
- @property
- def allow_expand(self) -> bool:
- """bool: Check if the node is allowed to expand."""
- return self._allow_expand
-
- @allow_expand.setter
- def allow_expand(self, allow_expand: bool) -> bool:
- self._allow_expand = allow_expand
- self._updates += 1
-
def add(
self,
label: TextType,
@@ -199,6 +202,21 @@ class TreeNode(Generic[TreeDataType]):
self._tree._invalidate()
return node
+ def add_leaf(
+ self, label: TextType, data: TreeDataType | None = None
+ ) -> TreeNode[TreeDataType]:
+ """Add a 'leaf' node (a node that can not expand).
+
+ Args:
+ label (TextType): Label for the node.
+ data (TreeDataType | None, optional): Optional data. Defaults to None.
+
+ Returns:
+ TreeNode[TreeDataType]: New node.
+ """
+ node = self.add(label, data, expand=False, allow_expand=False)
+ return node
+
class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
@@ -451,36 +469,6 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
else:
return line.node
- def add_json(self, node: TreeNode, json_data: object) -> None:
-
- from rich.highlighter import ReprHighlighter
-
- highlighter = ReprHighlighter()
-
- def add_node(name: str, node: TreeNode, data: object) -> None:
- if isinstance(data, dict):
- node._label = Text(f"{{}} {name}")
- for key, value in data.items():
- new_node = node.add("")
- add_node(key, new_node, value)
- elif isinstance(data, list):
- node._label = Text(f"[] {name}")
- for index, value in enumerate(data):
- new_node = node.add("")
- add_node(str(index), new_node, value)
- else:
- node._allow_expand = False
- if name:
- label = Text.assemble(
- Text.from_markup(f"[b]{name}[/b]="), highlighter(repr(data))
- )
- else:
- label = Text(repr(data))
- node._label = label
-
- add_node("JSON", node, json_data)
- self._invalidate()
-
def validate_cursor_line(self, value: int) -> int:
return clamp(value, 0, len(self._tree_lines) - 1)
@@ -642,10 +630,11 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
width = self.size.width
self.virtual_size = Size(width, len(lines))
- if self.cursor_node is not None:
- self.cursor_line = self.cursor_node._line
- if self.cursor_line >= len(lines):
- self.cursor_line = -1
+ if self.cursor_line != -1:
+ if self.cursor_node is not None:
+ self.cursor_line = self.cursor_node._line
+ if self.cursor_line >= len(lines):
+ self.cursor_line = -1
self.refresh()
def render_line(self, y: int) -> list[Segment]:
@@ -813,7 +802,8 @@ class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
def action_cursor_down(self) -> None:
if self.cursor_line == -1:
self.cursor_line = 0
- self.cursor_line += 1
+ else:
+ self.cursor_line += 1
self.scroll_to_line(self.cursor_line)
def action_page_down(self) -> None:
diff --git a/src/textual/widgets/_tree_control.py b/src/textual/widgets/_tree_control.py
deleted file mode 100644
index c471e0686..000000000
--- a/src/textual/widgets/_tree_control.py
+++ /dev/null
@@ -1,427 +0,0 @@
-from __future__ import annotations
-
-
-from typing import ClassVar, Generic, Iterator, NewType, TypeVar
-
-import rich.repr
-from rich.console import RenderableType
-from rich.style import Style, NULL_STYLE
-from rich.text import Text, TextType
-from rich.tree import Tree
-
-from ..geometry import Region, Size
-from .. import events
-from ..reactive import Reactive
-from .._types import MessageTarget
-from ..widgets import Static
-from ..message import Message
-from .. import messages
-
-
-NodeID = NewType("NodeID", int)
-
-
-NodeDataType = TypeVar("NodeDataType")
-EventNodeDataType = TypeVar("EventNodeDataType")
-
-
-@rich.repr.auto
-class TreeNode(Generic[NodeDataType]):
- def __init__(
- self,
- parent: TreeNode[NodeDataType] | None,
- node_id: NodeID,
- control: TreeControl,
- tree: Tree,
- label: TextType,
- data: NodeDataType,
- ) -> None:
- self.parent = parent
- self.id = node_id
- self._control = control
- self._tree = tree
- self.label = label
- self.data = data
- self.loaded = False
- self._expanded = False
- self._empty = False
- self._tree.expanded = False
- self.children: list[TreeNode] = []
-
- def __rich_repr__(self) -> rich.repr.Result:
- yield "id", self.id
- yield "label", self.label
- yield "data", self.data
-
- @property
- def control(self) -> TreeControl:
- return self._control
-
- @property
- def empty(self) -> bool:
- return self._empty
-
- @property
- def expanded(self) -> bool:
- return self._expanded
-
- @property
- def is_cursor(self) -> bool:
- return self.control.cursor == self.id and self.control.show_cursor
-
- @property
- def tree(self) -> Tree:
- return self._tree
-
- @property
- def next_node(self) -> TreeNode[NodeDataType] | None:
- """The next node in the tree, or None if at the end."""
-
- if self.expanded and self.children:
- return self.children[0]
- else:
-
- sibling = self.next_sibling
- if sibling is not None:
- return sibling
-
- node = self
- while True:
- if node.parent is None:
- return None
- sibling = node.parent.next_sibling
- if sibling is not None:
- return sibling
- else:
- node = node.parent
-
- @property
- def previous_node(self) -> TreeNode[NodeDataType] | None:
- """The previous node in the tree, or None if at the end."""
-
- sibling = self.previous_sibling
- if sibling is not None:
-
- def last_sibling(node) -> TreeNode[NodeDataType]:
- if node.expanded and node.children:
- return last_sibling(node.children[-1])
- else:
- return (
- node.children[-1] if (node.children and node.expanded) else node
- )
-
- return last_sibling(sibling)
-
- if self.parent is None:
- return None
- return self.parent
-
- @property
- def next_sibling(self) -> TreeNode[NodeDataType] | None:
- """The next sibling, or None if last sibling."""
- if self.parent is None:
- return None
- iter_siblings = iter(self.parent.children)
- try:
- for node in iter_siblings:
- if node is self:
- return next(iter_siblings)
- except StopIteration:
- pass
- return None
-
- @property
- def previous_sibling(self) -> TreeNode[NodeDataType] | None:
- """Previous sibling or None if first sibling."""
- if self.parent is None:
- return None
- iter_siblings = iter(self.parent.children)
- sibling: TreeNode[NodeDataType] | None = None
-
- for node in iter_siblings:
- if node is self:
- return sibling
- sibling = node
- return None
-
- def expand(self, expanded: bool = True) -> None:
- self._expanded = expanded
- self._tree.expanded = expanded
- self._control.refresh(layout=True)
-
- def toggle(self) -> None:
- self.expand(not self._expanded)
-
- def add(self, label: TextType, data: NodeDataType) -> None:
- self._control.add(self.id, label, data=data)
- self._control.refresh(layout=True)
- self._empty = False
-
- def __rich__(self) -> RenderableType:
- return self._control.render_node(self)
-
-
-class TreeControl(Generic[NodeDataType], Static, can_focus=True):
- DEFAULT_CSS = """
- TreeControl {
- color: $text;
- height: auto;
- width: 100%;
- link-style: not underline;
- }
-
- TreeControl > .tree--guides {
- color: $success;
- }
-
- TreeControl > .tree--guides-highlight {
- color: $success;
- text-style: uu;
- }
-
- TreeControl > .tree--guides-cursor {
- color: $secondary;
- text-style: bold;
- }
-
- TreeControl > .tree--labels {
- color: $text;
- }
-
- TreeControl > .tree--cursor {
- background: $secondary;
- color: $text;
- }
-
- """
-
- COMPONENT_CLASSES: ClassVar[set[str]] = {
- "tree--guides",
- "tree--guides-highlight",
- "tree--guides-cursor",
- "tree--labels",
- "tree--cursor",
- }
-
- class NodeSelected(Generic[EventNodeDataType], Message, bubble=False):
- def __init__(
- self, sender: MessageTarget, node: TreeNode[EventNodeDataType]
- ) -> None:
- self.node = node
- super().__init__(sender)
-
- def __init__(
- self,
- label: TextType,
- data: NodeDataType,
- *,
- name: str | None = None,
- id: str | None = None,
- classes: str | None = None,
- ) -> None:
- super().__init__(name=name, id=id, classes=classes)
- self.data = data
-
- self.node_id = NodeID(0)
- self.nodes: dict[NodeID, TreeNode[NodeDataType]] = {}
- self._tree = Tree(label)
-
- self.root: TreeNode[NodeDataType] = TreeNode(
- None, self.node_id, self, self._tree, label, data
- )
-
- self._tree.label = self.root
- self.nodes[NodeID(self.node_id)] = self.root
-
- self.auto_links = False
-
- hover_node: Reactive[NodeID | None] = Reactive(None)
- cursor: Reactive[NodeID] = Reactive(NodeID(0))
- cursor_line: Reactive[int] = Reactive(0)
- show_cursor: Reactive[bool] = Reactive(False)
-
- def watch_cursor_line(self, value: int) -> None:
- line_region = Region(0, value, self.size.width, 1)
- self.emit_no_wait(messages.ScrollToRegion(self, line_region))
-
- def get_content_height(self, container: Size, viewport: Size, width: int) -> int:
- def get_size(tree: Tree) -> int:
- return 1 + sum(
- get_size(child) if child.expanded else 1 for child in tree.children
- )
-
- size = get_size(self._tree)
- return size
-
- def add(
- self,
- node_id: NodeID,
- label: TextType,
- data: NodeDataType,
- ) -> None:
-
- parent = self.nodes[node_id]
- self.node_id = NodeID(self.node_id + 1)
- child_tree = parent._tree.add(label)
- child_tree.guide_style = self._guide_style
- child_node: TreeNode[NodeDataType] = TreeNode(
- parent, self.node_id, self, child_tree, label, data
- )
- parent.children.append(child_node)
- child_tree.label = child_node
- self.nodes[self.node_id] = child_node
-
- self.refresh(layout=True)
-
- def find_cursor(self) -> int | None:
- """Find the line location for the cursor node."""
-
- node_id = self.cursor
- line = 0
-
- stack: list[Iterator[TreeNode[NodeDataType]]]
- stack = [iter([self.root])]
-
- pop = stack.pop
- push = stack.append
- while stack:
- iter_children = pop()
- try:
- node = next(iter_children)
- except StopIteration:
- continue
- else:
- if node.id == node_id:
- return line
- line += 1
- push(iter_children)
- if node.children and node.expanded:
- push(iter(node.children))
- return None
-
- def render(self) -> RenderableType:
- guide_style = self._guide_style
-
- def update_guide_style(tree: Tree) -> None:
- tree.guide_style = guide_style
- for child in tree.children:
- if child.expanded:
- update_guide_style(child)
-
- update_guide_style(self._tree)
- if self.hover_node is not None:
- hover = self.nodes.get(self.hover_node)
- if hover is not None:
- hover._tree.guide_style = self._highlight_guide_style
- if self.cursor is not None and self.show_cursor:
- cursor = self.nodes.get(self.cursor)
- if cursor is not None:
- cursor._tree.guide_style = self._cursor_guide_style
- return self._tree
-
- def render_node(self, node: TreeNode[NodeDataType]) -> RenderableType:
- label_style = self.get_component_styles("tree--labels").rich_style
- label = (
- Text(node.label, no_wrap=True, style=label_style, overflow="ellipsis")
- if isinstance(node.label, str)
- else node.label
- )
- if node.id == self.hover_node:
- label.stylize("underline")
- label.apply_meta({"@click": f"click_label({node.id})", "tree_node": node.id})
- return label
-
- def action_click_label(self, node_id: NodeID) -> None:
- node = self.nodes[node_id]
- self.cursor = node.id
- self.cursor_line = self.find_cursor() or 0
- self.show_cursor = True
- self.post_message_no_wait(self.NodeSelected(self, node))
-
- def on_mount(self) -> None:
- self._tree.guide_style = self._guide_style
-
- @property
- def _guide_style(self) -> Style:
- return self.get_component_rich_style("tree--guides")
-
- @property
- def _highlight_guide_style(self) -> Style:
- return self.get_component_rich_style("tree--guides-highlight")
-
- @property
- def _cursor_guide_style(self) -> Style:
- return self.get_component_rich_style("tree--guides-cursor")
-
- def on_mouse_move(self, event: events.MouseMove) -> None:
- self.hover_node = event.style.meta.get("tree_node")
-
- def key_down(self, event: events.Key) -> None:
- event.stop()
- self.cursor_down()
-
- def key_up(self, event: events.Key) -> None:
- event.stop()
- self.cursor_up()
-
- def key_pagedown(self) -> None:
- assert self.parent is not None
- height = self.container_viewport.height
-
- cursor = self.cursor
- cursor_line = self.cursor_line
- for _ in range(height):
- cursor_node = self.nodes[cursor]
- next_node = cursor_node.next_node
- if next_node is not None:
- cursor_line += 1
- cursor = next_node.id
- self.cursor = cursor
- self.cursor_line = cursor_line
-
- def key_pageup(self) -> None:
- assert self.parent is not None
- height = self.container_viewport.height
- cursor = self.cursor
- cursor_line = self.cursor_line
- for _ in range(height):
- cursor_node = self.nodes[cursor]
- previous_node = cursor_node.previous_node
- if previous_node is not None:
- cursor_line -= 1
- cursor = previous_node.id
- self.cursor = cursor
- self.cursor_line = cursor_line
-
- def key_home(self) -> None:
- self.cursor_line = 0
- self.cursor = NodeID(0)
-
- def key_end(self) -> None:
- self.cursor = self.nodes[NodeID(0)].children[-1].id
- self.cursor_line = self.find_cursor() or 0
-
- def key_enter(self, event: events.Key) -> None:
- cursor_node = self.nodes[self.cursor]
- event.stop()
- self.post_message_no_wait(self.NodeSelected(self, cursor_node))
-
- def cursor_down(self) -> None:
- if not self.show_cursor:
- self.show_cursor = True
- return
- cursor_node = self.nodes[self.cursor]
- next_node = cursor_node.next_node
- if next_node is not None:
- self.cursor_line += 1
- self.cursor = next_node.id
-
- def cursor_up(self) -> None:
- if not self.show_cursor:
- self.show_cursor = True
- return
- cursor_node = self.nodes[self.cursor]
- previous_node = cursor_node.previous_node
- if previous_node is not None:
- self.cursor_line -= 1
- self.cursor = previous_node.id
diff --git a/src/textual/widgets/_tree_node.py b/src/textual/widgets/_tree_node.py
new file mode 100644
index 000000000..e6c57fb61
--- /dev/null
+++ b/src/textual/widgets/_tree_node.py
@@ -0,0 +1 @@
+from ._tree import TreeNode as TreeNode
diff --git a/tests/snapshot_tests/__snapshots__/test_snapshots.ambr b/tests/snapshot_tests/__snapshots__/test_snapshots.ambr
index 354a78d0f..977ab4286 100644
--- a/tests/snapshot_tests/__snapshots__/test_snapshots.ambr
+++ b/tests/snapshot_tests/__snapshots__/test_snapshots.ambr
@@ -6792,6 +6792,162 @@
'''
# ---
+# name: test_tree_example
+ '''
+
+
+ '''
+# ---
# name: test_vertical_layout
'''