diff --git a/src/textual/widgets/_directory_tree.py b/src/textual/widgets/_directory_tree.py index db06d89c3..dcca8e2b8 100644 --- a/src/textual/widgets/_directory_tree.py +++ b/src/textual/widgets/_directory_tree.py @@ -3,7 +3,7 @@ from __future__ import annotations import os from dataclasses import dataclass from pathlib import Path -from typing import Callable, ClassVar +from typing import ClassVar, Iterable from rich.style import Style from rich.text import Text, TextType @@ -90,11 +90,9 @@ class DirectoryTree(Tree[DirEntry]): id: str | None = None, classes: str | None = None, disabled: bool = False, - path_filter: Callable[[Path], bool] | None = None, ) -> None: str_path = os.fspath(path) self.path = str_path - self._path_filter = path_filter or (lambda _: True) super().__init__( str_path, data=DirEntry(str_path, True), @@ -152,12 +150,23 @@ class DirectoryTree(Tree[DirEntry]): text = Text.assemble(prefix, node_label) return text + def filter_paths(self, paths: Iterable[Path]) -> Iterable[Path]: + """Filter the paths before adding them to the tree. + + Args: + paths: The paths to be filtered. + + Returns: + The filtered paths. + """ + return paths + def load_directory(self, node: TreeNode[DirEntry]) -> None: assert node.data is not None dir_path = Path(node.data.path) node.data.loaded = True directory = sorted( - [entry for entry in dir_path.iterdir() if self._path_filter(entry)], + self.filter_paths(dir_path.iterdir()), key=lambda path: (not path.is_dir(), path.name.lower()), ) for path in directory: