Some efficiency improvements in tree-sitter highlighting

This commit is contained in:
Darren Burns
2025-03-12 17:22:56 +00:00
parent b079299c63
commit 7265a52f1c

View File

@@ -81,48 +81,60 @@ class HighlightMap:
BLOCK_SIZE = 50
def __init__(self, text_area_widget: widgets.TextArea):
self.text_area_widget: widgets.TextArea = text_area_widget
self.uncovered_lines: dict[int, range] = {}
def __init__(self, text_area: TextArea):
self.text_area: TextArea = text_area
"""The text area associated with this highlight map."""
# A mapping from line index to a list of Highlight instances.
self._highlights: LineToHighlightsMap = defaultdict(list)
self.reset()
self._highlighted_blocks: set[int] = set()
"""The set of blocks that have been highlighted, identified by the start line index of the block.
(0 represents the first block, 50 the second, 100 the third, etc. - assuming a block size of 50)
"""
self._highlights: dict[int, list[Highlight]] = defaultdict(list)
"""A mapping from line index to a list of Highlight instances."""
def reset(self) -> None:
"""Reset so that future lookups rebuild the highlight map."""
self._highlights.clear()
line_count = self.document.line_count
uncovered_lines = self.uncovered_lines
uncovered_lines.clear()
i = end_range = 0
for i in range(0, line_count, self.BLOCK_SIZE):
end_range = min(i + self.BLOCK_SIZE, line_count)
line_range = range(i, end_range)
uncovered_lines.update({j: line_range for j in line_range})
if end_range < line_count:
line_range = range(i, line_count)
uncovered_lines.update({j: line_range for j in line_range})
self._highlighted_blocks.clear()
@property
def document(self) -> DocumentBase:
"""The text document being highlighted."""
return self.text_area_widget.document
return self.text_area.document
def __getitem__(self, idx: int) -> list[text_area.Highlight]:
if idx in self.uncovered_lines:
self._build_part_of_highlight_map(self.uncovered_lines[idx])
return self._highlights[idx]
def __getitem__(self, index: int) -> list[Highlight]:
start, end = self._get_block_boundaries(index, self.BLOCK_SIZE)
if start not in self._highlighted_blocks:
self._highlighted_blocks.add(start)
self._build_part_of_highlight_map(range(start, end))
return self._highlights[index]
def _get_block_boundaries(self, index: int, block_size: int) -> tuple[int, int]:
"""Get the start and end of the block for the given index.
The start is inclusive and the end is exclusive.
Args:
index: The line index to get we want to know the block range for..
block_size: The size of the bucket.
Returns:
A tuple containing the start and end of the block.
"""
block_index = index // block_size
start = block_index * block_size
end = (block_index + 1) * block_size
return (start, end)
def _build_part_of_highlight_map(self, line_range: range) -> None:
"""Build part of the highlight map."""
highlights = self._highlights
for line_index in line_range:
self.uncovered_lines.pop(line_index)
start_point = (line_range[0], 0)
end_point = (line_range[-1] + 1, 0)
captures = self.document.query_syntax_tree(
self.text_area_widget._highlight_query,
self.text_area._highlight_query,
start_point=start_point,
end_point=end_point,
)
@@ -140,8 +152,9 @@ class HighlightMap:
)
# Add the middle lines - entire row of this node is highlighted
middle_highlight = (0, None, highlight_name)
for node_row in range(node_start_row + 1, node_end_row):
highlights[node_row].append((0, None, highlight_name))
highlights[node_row].append(middle_highlight)
# Add the last line of the node range
highlights[node_end_row].append(
@@ -157,16 +170,16 @@ class HighlightMap:
# to be sorted in ascending order of ``a``. When two highlights have the same
# value of ``a`` then the one with the larger a--b range comes first, with ``None``
# being considered larger than any number.
def sort_key(hl) -> tuple[int, int, int]:
a, b, _ = hl
max_range_ind = 1
def sort_key(highlight: Highlight) -> tuple[int, int, int]:
a, b, _ = highlight
max_range_index = 1
if b is None:
max_range_ind = 0
max_range_index = 0
b = a
return a, max_range_ind, a - b
return a, max_range_index, a - b
for line_index in line_range:
line_highlights = highlights.get(line_index, []).sort(key=sort_key)
highlights.get(line_index, []).sort(key=sort_key)
@dataclass