This commit is contained in:
Will McGugan
2022-01-09 15:34:15 +00:00
parent ddbaf8f04a
commit 72d7b5915e
2 changed files with 35 additions and 8 deletions

View File

@@ -32,12 +32,12 @@ class DOMQuery:
self,
node: DOMNode | None = None,
selector: str | None = None,
nodes: Iterable[DOMNode] | None = None,
nodes: list[DOMNode] | None = None,
) -> None:
self._nodes: list[DOMNode] = []
if nodes is not None:
self._nodes = list(nodes)
self._nodes = nodes
elif node is not None:
self._nodes = list(node.walk_children())
else:
@@ -70,8 +70,24 @@ class DOMQuery:
DOMQuery: New DOM Query.
"""
selector_set = parse_selectors(selector)
query = DOMQuery()
query._nodes = [_node for _node in self._nodes if match(selector_set, _node)]
query = DOMQuery(
nodes=[_node for _node in self._nodes if match(selector_set, _node)]
)
return query
def exclude(self, selector: str) -> DOMQuery:
"""Exclude nodes that match a given selector.
Args:
selector (str): A CSS selector.
Returns:
DOMQuery: New DOM query.
"""
selector_set = parse_selectors(selector)
query = DOMQuery(
nodes=[_node for _node in self._nodes if not match(selector_set, _node)]
)
return query
def first(self) -> DOMNode:
@@ -83,17 +99,20 @@ class DOMQuery:
# TODO: Better response to empty query than an IndexError
return self._nodes[0]
def add_class(self, *class_names: str) -> None:
def add_class(self, *class_names: str) -> DOMQuery:
"""Add the given class name(s) to nodes."""
for node in self._nodes:
node.add_class(*class_names)
return self
def remove_class(self, *class_names: str) -> None:
def remove_class(self, *class_names: str) -> DOMQuery:
"""Remove the given class names from the nodes."""
for node in self._nodes:
node.remove_class(*class_names)
return self
def toggle_class(self, *class_names: str) -> None:
def toggle_class(self, *class_names: str) -> DOMQuery:
"""Toggle the given class names from matched nodes."""
for node in self._nodes:
node.toggle_class(*class_names)
return self

View File

@@ -226,7 +226,15 @@ class DOMNode(MessagePump):
if node.children:
push(iter(node.children))
def query(self, selector: str) -> DOMQuery:
def query(self, selector: str | None = None) -> DOMQuery:
"""Get a DOM query.
Args:
selector (str, optional): A CSS selector or `None` for all nodes. Defaults to None.
Returns:
DOMQuery: A query object.
"""
from .css.query import DOMQuery
return DOMQuery(self, selector)