docstrings and types

This commit is contained in:
Will McGugan
2023-02-17 18:11:55 +00:00
parent c0c49978bd
commit 0ac7eef4b5
2 changed files with 15 additions and 9 deletions

View File

@@ -4,17 +4,23 @@ from collections import defaultdict
from itertools import product
from typing import Generic, Iterable, TypeVar
from typing_extensions import TypeAlias
from .geometry import Region
ValueType = TypeVar("ValueType")
GridCoordinate: TypeAlias = "tuple[int, int]"
class SpatialMap(Generic[ValueType]):
"""A spatial map allows for data to be associated with rectangular regions
in Euclidean space, and efficiently queried.
When the SpatialMap is populated, a reference to each value is placed in a bucket associated
with a regular grid that covers 2D space.
When the SpatialMap is populated, a reference to each value is placed into one or
more buckets associated with a regular grid that covers 2D space.
The SpatialMap is able to quickly retrieve the values under a given "window" region
by combining the values in the grid squares under the visible area.
"""
@@ -27,17 +33,17 @@ class SpatialMap(Generic[ValueType]):
"""
self._grid_size = (grid_width, grid_height)
self.total_region = Region()
self._map: defaultdict[tuple[int, int], list[ValueType]] = defaultdict(list)
self._map: defaultdict[GridCoordinate, list[ValueType]] = defaultdict(list)
self._fixed: list[ValueType] = []
def _region_to_grid(self, region: Region) -> Iterable[tuple[int, int]]:
def _region_to_grid_coordinate(self, region: Region) -> Iterable[GridCoordinate]:
"""Get the grid squares under a region.
Args:
region: A region.
Returns:
Iterable of grid squares (tuple of 2 values).
Iterable of grid coordinates (tuple of 2 values).
"""
# (x1, y1) is the coordinate of the top left cell
# (x2, y2) is the coordinate of the bottom right cell
@@ -64,7 +70,7 @@ class SpatialMap(Generic[ValueType]):
"""
append_fixed = self._fixed.append
get_grid_list = self._map.__getitem__
_region_to_grid = self._region_to_grid
_region_to_grid = self._region_to_grid_coordinate
total_region = self.total_region
for region, fixed, value in regions_and_values:
total_region = total_region.union(region)
@@ -89,8 +95,8 @@ class SpatialMap(Generic[ValueType]):
results: list[ValueType] = self._fixed.copy()
add_results = results.extend
get_grid_values = self._map.get
for grid in self._region_to_grid(region):
grid_values = get_grid_values(grid)
for grid_coordinate in self._region_to_grid_coordinate(region):
grid_values = get_grid_values(grid_coordinate)
if grid_values is not None:
add_results(grid_values)
unique_values = list(dict.fromkeys(results))

View File

@@ -36,7 +36,7 @@ from textual.geometry import Region
def test_region_to_grid(region, grid):
spatial_map = SpatialMap(10, 10)
assert list(spatial_map._region_to_grid(region)) == grid
assert list(spatial_map._region_to_grid_coordinate(region)) == grid
def test_get_values_in_region() -> None: