Skip to content

Commit

Permalink
fixed mount issue
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Jun 10, 2024
1 parent e1c92ac commit 2acd23f
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 37 deletions.
3 changes: 0 additions & 3 deletions sksmithy/tui/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,6 @@ class ForgeRow(Grid):
"""Row grid for forge."""


forge_row = ForgeRow(Static(), Static(), ForgeButton(), DestinationFile(), Static(), Static(), id="forge_row")


class Title(Static):
pass

Expand Down
24 changes: 20 additions & 4 deletions sksmithy/tui/_tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@
from textual.app import App, ComposeResult
from textual.containers import Container, Horizontal, ScrollableContainer
from textual.reactive import reactive
from textual.widgets import Button, Footer, Header, Rule
from textual.widgets import Button, Footer, Header, Rule, Static

from sksmithy.tui._components import (
DecisionFunction,
DestinationFile,
Estimator,
ForgeButton,
ForgeRow,
Linear,
Name,
Optional,
PredictProba,
Required,
SampleWeight,
Sidebar,
forge_row,
)

if sys.version_info >= (3, 11): # pragma: no cover
Expand All @@ -41,6 +43,13 @@ class ForgeTUI(App):

show_sidebar = reactive(False) # noqa: FBT003

def on_mount(self: Self) -> None:
"""Compose on mount.
Q: is this needed?
"""
self.compose()

def compose(self: Self) -> ComposeResult:
"""Create child widgets for the app."""
yield Container(
Expand All @@ -51,7 +60,14 @@ def compose(self: Self) -> ComposeResult:
Horizontal(SampleWeight(), Linear()),
Horizontal(PredictProba(), DecisionFunction()),
Rule(),
forge_row,
ForgeRow(
Static(),
Static(),
ForgeButton(),
DestinationFile(),
Static(),
Static(),
),
Rule(),
),
Sidebar(classes="-hidden"),
Expand Down Expand Up @@ -80,6 +96,6 @@ def action_forge(self: Self) -> None:
forge_btn.press()


if __name__ == "__main__":
if __name__ == "__main__": # pragma: no cover
tui = ForgeTUI()
tui.run()
2 changes: 1 addition & 1 deletion sksmithy/tui/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class _BaseValidator(Validator):
@staticmethod
def parser(value: str) -> Result[str | list[str], str]:
def parser(value: str) -> Result[str | list[str], str]: # pragma: no cover
raise NotImplementedError

def validate(self: Self, value: str) -> ValidationResult:
Expand Down
6 changes: 0 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from streamlit.testing.v1 import AppTest

from sksmithy._models import EstimatorType
from sksmithy.tui import ForgeTUI


@pytest.fixture(params=["MightyEstimator"])
Expand Down Expand Up @@ -73,8 +72,3 @@ def tags(request: pytest.FixtureRequest) -> list[str] | None:
@pytest.fixture()
def app() -> AppTest:
return AppTest.from_file("sksmithy/app.py", default_timeout=10)


@pytest.fixture()
def tui() -> ForgeTUI:
return ForgeTUI()
44 changes: 21 additions & 23 deletions tests/test_tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,11 @@ async def test_name(name_: str, err_msg: str) -> None:
if notifications:
assert notifications[0].message == err_msg

await pilot.pause()
await pilot.exit(0)


async def test_estimator_interaction(estimator: EstimatorType) -> None:
"""Test that all toggle components interact correctly with the selected estimator."""
app = ForgeTUI()
async with app.run_test(size=None) as pilot:
await pilot.pause()
pilot.app.query_one("#estimator", Select).value = estimator.value
await pilot.pause()

Expand All @@ -72,9 +68,6 @@ async def test_estimator_interaction(estimator: EstimatorType) -> None:
await pilot.pause()
assert pilot.app.query_one("#decision_function", Switch).disabled

await pilot.pause()
await pilot.exit(0)


async def test_valid_params() -> None:
"""Test required and optional params interaction."""
Expand All @@ -95,16 +88,34 @@ async def test_valid_params() -> None:
notifications = list(pilot.app._notifications) # noqa: SLF001
assert not notifications


@pytest.mark.parametrize(("required_", "optional_"), [("a,b", "a"), ("a", "a,b")])
async def test_duplicated_params(required_: str, optional_: str) -> None:
app = ForgeTUI()
msg = "The following parameters are duplicated between required and optional: {'a'}"

async with app.run_test(size=None) as pilot:
required_comp = pilot.app.query_one("#required", Input)
optional_comp = pilot.app.query_one("#optional", Input)

required_comp.value = required_
optional_comp.value = optional_

await required_comp.action_submit()
await optional_comp.action_submit()
await pilot.pause()

forge_btn = pilot.app.query_one("#forge-btn", Button)
forge_btn.action_press()
await pilot.pause()
await pilot.exit(0)

assert all(msg in n.message for n in pilot.app._notifications) # noqa: SLF001


async def test_forge_raise() -> None:
"""Test forge button and all of its interactions."""
app = ForgeTUI()
async with app.run_test(size=None) as pilot:
await pilot.pause()

required_comp = pilot.app.query_one("#required", Input)
optional_comp = pilot.app.query_one("#optional", Input)

Expand All @@ -130,18 +141,13 @@ async def test_forge_raise() -> None:
assert "Found repeated parameters!" in m3
assert "The following parameters are invalid python identifiers: ('b b',)" in m3

await pilot.pause()
await pilot.exit(0)


async def test_forge(tmp_path: Path) -> None:
"""Test forge button and all of its interactions."""
app = ForgeTUI()
name = "MightyEstimator"
estimator = "classifier"
async with app.run_test(size=None) as pilot:
await pilot.pause()

name_comp = pilot.app.query_one("#name", Input)
estimator_comp = pilot.app.query_one("#estimator", Select)
await pilot.pause()
Expand All @@ -167,13 +173,5 @@ async def test_forge(tmp_path: Path) -> None:
assert f"Template forged at {output_file!s}" in notification.message
assert output_file.exists()

await pilot.pause()
await pilot.exit(0)


def test_bindings() -> None: ...


def test_duplicated_params() -> None:
# values: ("a,b", "a", "The following parameters are duplicated between required and optional: {'a'}"),
...

0 comments on commit 2acd23f

Please sign in to comment.