diff --git a/src/textual/screen.py b/src/textual/screen.py index 97881e668..c06c86936 100644 --- a/src/textual/screen.py +++ b/src/textual/screen.py @@ -169,9 +169,8 @@ class Screen(Widget): else: if node.is_container and node.can_focus_children: push(iter(node.focusable_children)) - else: - if node.can_focus: - add_widget(node) + if node.can_focus: + add_widget(node) return widgets diff --git a/tests/test_focus.py b/tests/test_focus.py index 811817b92..67a8b1a92 100644 --- a/tests/test_focus.py +++ b/tests/test_focus.py @@ -10,8 +10,11 @@ class Focusable(Widget, can_focus=True): class NonFocusable(Widget, can_focus=False, can_focus_children=False): pass +class ChildrenFocusableOnly(Widget, can_focus=False, can_focus_children=True): + pass -async def test_focus_chain(): + +def test_focus_chain(): app = App() app._set_active() app.push_screen(Screen()) @@ -27,13 +30,14 @@ async def test_focus_chain(): Focusable(Focusable(id="Paul"), id="container1"), NonFocusable(Focusable(id="Jessica"), id="container2"), Focusable(id="baz"), + ChildrenFocusableOnly(Focusable(id="child")), ) - focused = [widget.id for widget in screen.focus_chain] - assert focused == ["foo", "Paul", "baz"] + focus_chain = [widget.id for widget in screen.focus_chain] + assert focus_chain == ["foo", "container1", "Paul", "baz", "child"] -async def test_focus_next_and_previous(): +def test_focus_next_and_previous(): app = App() app._set_active() app.push_screen(Screen()) @@ -46,11 +50,16 @@ async def test_focus_next_and_previous(): Focusable(Focusable(id="Paul"), id="container1"), NonFocusable(Focusable(id="Jessica"), id="container2"), Focusable(id="baz"), + ChildrenFocusableOnly(Focusable(id="child")), ) assert screen.focus_next().id == "foo" + assert screen.focus_next().id == "container1" assert screen.focus_next().id == "Paul" assert screen.focus_next().id == "baz" + assert screen.focus_next().id == "child" + assert screen.focus_previous().id == "baz" assert screen.focus_previous().id == "Paul" + assert screen.focus_previous().id == "container1" assert screen.focus_previous().id == "foo"