Measuring string cells correctly

This commit is contained in:
Darren Burns
2023-02-01 17:10:59 +00:00
parent fd4e13c988
commit 23a34030cd
3 changed files with 20 additions and 16 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from rich.cells import cell_len
from rich.console import Console, RenderableType
from rich.protocol import rich_cast
@@ -22,6 +23,9 @@ def measure(
Returns:
Width in cells
"""
if isinstance(renderable, str):
return cell_len(renderable)
width = default
renderable = rich_cast(renderable)
get_console_width = getattr(renderable, "__rich_measure__", None)

View File

@@ -524,8 +524,6 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
update_width: Whether to resize the column width to accommodate
for the new cell content.
"""
value = Text.from_markup(value) if isinstance(value, str) else value
self.data[row_key][column_key] = value
self._update_count += 1
@@ -752,6 +750,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
label_width = measure(console, column.label, 1)
content_width = column.content_width
cell_value = self.data[row_key][column_key]
new_content_width = measure(console, cell_value, 1)
if new_content_width < content_width:
@@ -866,15 +865,13 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
of its current location in the DataTable (it could have moved after being added
due to sorting or insertion/deletion of other columns).
"""
text_label = Text.from_markup(label) if isinstance(label, str) else label
column_key = ColumnKey(key)
column_index = len(self.columns)
content_width = measure(self.app.console, text_label, 1)
content_width = measure(self.app.console, label, 1)
if width is None:
column = Column(
column_key,
text_label,
label,
content_width,
content_width=content_width,
auto_width=True,
@@ -882,7 +879,7 @@ class DataTable(ScrollView, Generic[CellType], can_focus=True):
else:
column = Column(
column_key,
text_label,
label,
width,
content_width=content_width,
)

View File

@@ -178,7 +178,6 @@ async def test_add_columns():
assert len(table.columns) == 3
# TODO: Ensure we can use the key to retrieve the column.
async def test_add_columns_user_defined_keys():
app = DataTableApp()
async with app.run_test():
@@ -223,7 +222,7 @@ async def test_column_labels() -> None:
table = app.query_one(DataTable)
table.add_columns("1", "2", "3")
actual_labels = [col.label for col in table.columns.values()]
expected_labels = [Text("1"), Text("2"), Text("3")]
expected_labels = ["1", "2", "3"]
assert actual_labels == expected_labels
@@ -294,13 +293,17 @@ async def test_get_value_at_exception():
table.get_value_at(Coordinate(9999, 0))
# async def test_update_cell_cell_exists():
# app = DataTableApp()
# async with app.run_test():
# table = app.query_one(DataTable)
# table.add_column("A", key="A")
# table.add_row("1", key="1")
# assert table.get_cell_value()
async def test_update_cell_cell_exists():
app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
table.add_column("A", key="A")
table.add_row("1", key="1")
table.update_cell("1", "A", "NEW_VALUE")
assert table.get_cell_value("1", "A") == "NEW_VALUE"
# TODO: Test update coordinate
def test_key_equals_equivalent_string():