From 0b077e3db083a046b4a9fbbc57b76e129d50c0e7 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 9 Jun 2024 13:53:06 +0200 Subject: [PATCH] tui core is working --- pyproject.toml | 3 +- sksmithy/_static/tui.tcss | 20 ++++++- sksmithy/tui/_components.py | 108 ++++++++++++++++++++++++------------ sksmithy/tui/tui.py | 5 +- 4 files changed, 96 insertions(+), 40 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4649f93..83e4270 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "sklearn-smithy" -version = "0.0.10" +version = "0.1.0" description = "Toolkit to forge scikit-learn compatible estimators." requires-python = ">=3.10" @@ -19,6 +19,7 @@ keywords = [ "python", "cli", "webui", + "tui", "data-science", "machine-learning", "scikit-learn" diff --git a/sksmithy/_static/tui.tcss b/sksmithy/_static/tui.tcss index e1799f4..e2c9aa4 100644 --- a/sksmithy/_static/tui.tcss +++ b/sksmithy/_static/tui.tcss @@ -51,13 +51,31 @@ Input.-valid:focus { border: tall $success; } +ForgeRow { + grid-size: 5 2; + grid-gutter: 1; + grid-rows: 1fr 1fr 1fr 2fr; + grid-columns: 1fr; + min-height: 40vh; +} + +ForgeButton { + row-span: 2; +} + +DestinationFile { + column-span: 2; + row-span: 2; + height: 100%; +} + + .container { height: auto; width: auto; min-height: 10vh; } - .label { height: 3; content-align: right middle; diff --git a/sksmithy/tui/_components.py b/sksmithy/tui/_components.py index 0f15b2c..056eac2 100644 --- a/sksmithy/tui/_components.py +++ b/sksmithy/tui/_components.py @@ -1,9 +1,10 @@ import sys +from pathlib import Path from result import Err, Ok from textual import on from textual.app import ComposeResult -from textual.containers import Container, Horizontal +from textual.containers import Container, Grid, Horizontal from textual.widgets import Button, Input, Select, Static, Switch from sksmithy._models import EstimatorType @@ -14,6 +15,7 @@ PROMPT_LINEAR, PROMPT_NAME, PROMPT_OPTIONAL, + PROMPT_OUTPUT, PROMPT_PREDICT_PROBA, PROMPT_REQUIRED, PROMPT_SAMPLE_WEIGHT, @@ -31,6 +33,10 @@ class Prompt(Static): pass +class ForgeRow(Grid): + pass + + class Name(Container): """Name input component.""" @@ -48,8 +54,8 @@ def on_input_change(self: Self, event: Input.Changed) -> None: timeout=5, ) else: - ... - # TODO: Update filename component + output_file: Input = self.app.query_one("#output_file") + output_file.value = f"{event.value.lower()}.py" class Estimator(Container): @@ -94,18 +100,23 @@ def on_input_change(self: Self, event: Input.Submitted) -> None: timeout=5, ) - # optional: Input = self.app.query_one("#optional").value or "" - # if optional and event.value (duplicates_result := check_duplicates( - # event.value.split(","), - # optional.split(",") - # )): - - # self.notify( - # message=duplicates_result, - # title="Duplicate Parameter", - # severity="error", - # timeout=5, - # ) + optional: Input = self.app.query_one("#optional").value or "" + if ( + optional + and event.value + and ( + duplicates_result := check_duplicates( + event.value.split(","), + optional.split(","), + ) + ) + ): + self.notify( + message=duplicates_result, + title="Duplicate Parameter", + severity="error", + timeout=5, + ) class Optional(Container): @@ -125,18 +136,23 @@ def on_optional_change(self: Self, event: Input.Submitted) -> None: timeout=5, ) - # required: Input = self.app.query_one("#required").value or "" - # if required and event.value and (duplicates_result := check_duplicates( - # required.split(","), - # event.value.split(","), - # )): - - # self.notify( - # message=duplicates_result, - # title="Duplicate Parameter", - # severity="error", - # timeout=5, - # ) + required: Input = self.app.query_one("#required").value or "" + if ( + required + and event.value + and ( + duplicates_result := check_duplicates( + required.split(","), + event.value.split(","), + ) + ) + ): + self.notify( + message=duplicates_result, + title="Duplicate Parameter", + severity="error", + timeout=5, + ) class SampleWeight(Container): @@ -189,7 +205,17 @@ def compose(self: Self) -> ComposeResult: ) +class DestinationFile(Container): + """Destination file input component.""" + + def compose(self: Self) -> ComposeResult: + yield Prompt(PROMPT_OUTPUT, classes="label", id="output_prompt") + yield Input(placeholder="mightyestimator.py", id="output_file") + + class ForgeButton(Container): + """forge button component.""" + def compose(self: Self) -> ComposeResult: yield Button.success( label="Forge ⚒️", @@ -197,7 +223,7 @@ def compose(self: Self) -> ComposeResult: ) @on(Button.Pressed, "#forge_btn") - def on_button_pressed(self, event: Button.Pressed) -> None: + def on_button_pressed(self: Self, _: Button.Pressed) -> None: errors = [] name_input: str = self.app.query_one("#name").value @@ -210,6 +236,8 @@ def on_button_pressed(self, event: Button.Pressed) -> None: predict_proba: bool = self.app.query_one("#predict_proba").value decision_function: bool = self.app.query_one("#decision_function").value + output_file: str = self.app.query_one("#output_file").value + match name_parser(name_input): case Ok(name): pass @@ -240,9 +268,12 @@ def on_button_pressed(self, event: Button.Pressed) -> None: if required_is_valid and optional_is_valid and (msg_duplicated_params := check_duplicates(required, optional)): errors.append(msg_duplicated_params) + if not output_file: + errors.append("Outfile file cannot be empty") + if errors: self.notify( - message="\n".join(errors), + message="\n".join([f"- {e}" for e in errors]), title="Invalid inputs", severity="error", timeout=5, @@ -261,16 +292,21 @@ def on_button_pressed(self, event: Button.Pressed) -> None: tags=None, ) - # destination_file = Path(output_file) - # destination_file.parent.mkdir(parents=True, exist_ok=True) + destination_file = Path(output_file) + destination_file.parent.mkdir(parents=True, exist_ok=True) + + with destination_file.open(mode="w") as destination: + destination.write(forged_template) - # with destination_file.open(mode="w") as destination: - # destination.write(forged_template) + self.notify( + message=f"Template forged at {destination_file}", + title="Success!", + severity="information", + timeout=5, + ) - # TODO: Validate inputs - # TODO: Render - self.app.exit() +forge_row = ForgeRow(Static(), Static(), ForgeButton(), DestinationFile(), Static(), Static(), id="forge_row") # class Version(Static): # def render(self: Self) -> RenderableType: diff --git a/sksmithy/tui/tui.py b/sksmithy/tui/tui.py index 5fb00a9..57915c2 100644 --- a/sksmithy/tui/tui.py +++ b/sksmithy/tui/tui.py @@ -10,13 +10,13 @@ from sksmithy.tui._components import ( DecisionFunction, Estimator, - ForgeButton, Linear, Name, Optional, PredictProba, Required, SampleWeight, + forge_row, ) if sys.version_info >= (3, 11): # pragma: no cover @@ -61,7 +61,8 @@ def compose(self: Self) -> ComposeResult: Horizontal(SampleWeight(), Linear()), Horizontal(PredictProba(), DecisionFunction()), Rule(), - ForgeButton(), + forge_row, + Rule(), ) # yield Sidebar(classes="-hidden") yield Footer()