Skip to content

Commit

Permalink
tui core is working
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Jun 9, 2024
1 parent 0dc723d commit 0b077e3
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 40 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "sklearn-smithy"
version = "0.0.10"
version = "0.1.0"
description = "Toolkit to forge scikit-learn compatible estimators."
requires-python = ">=3.10"

Expand All @@ -19,6 +19,7 @@ keywords = [
"python",
"cli",
"webui",
"tui",
"data-science",
"machine-learning",
"scikit-learn"
Expand Down
20 changes: 19 additions & 1 deletion sksmithy/_static/tui.tcss
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,31 @@ Input.-valid:focus {
border: tall $success;
}

ForgeRow {
grid-size: 5 2;
grid-gutter: 1;
grid-rows: 1fr 1fr 1fr 2fr;
grid-columns: 1fr;
min-height: 40vh;
}

ForgeButton {
row-span: 2;
}

DestinationFile {
column-span: 2;
row-span: 2;
height: 100%;
}


.container {
height: auto;
width: auto;
min-height: 10vh;
}


.label {
height: 3;
content-align: right middle;
Expand Down
108 changes: 72 additions & 36 deletions sksmithy/tui/_components.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import sys
from pathlib import Path

from result import Err, Ok
from textual import on
from textual.app import ComposeResult
from textual.containers import Container, Horizontal
from textual.containers import Container, Grid, Horizontal
from textual.widgets import Button, Input, Select, Static, Switch

from sksmithy._models import EstimatorType
Expand All @@ -14,6 +15,7 @@
PROMPT_LINEAR,
PROMPT_NAME,
PROMPT_OPTIONAL,
PROMPT_OUTPUT,
PROMPT_PREDICT_PROBA,
PROMPT_REQUIRED,
PROMPT_SAMPLE_WEIGHT,
Expand All @@ -31,6 +33,10 @@ class Prompt(Static):
pass


class ForgeRow(Grid):
pass


class Name(Container):
"""Name input component."""

Expand All @@ -48,8 +54,8 @@ def on_input_change(self: Self, event: Input.Changed) -> None:
timeout=5,
)
else:
...
# TODO: Update filename component
output_file: Input = self.app.query_one("#output_file")
output_file.value = f"{event.value.lower()}.py"


class Estimator(Container):
Expand Down Expand Up @@ -94,18 +100,23 @@ def on_input_change(self: Self, event: Input.Submitted) -> None:
timeout=5,
)

# optional: Input = self.app.query_one("#optional").value or ""
# if optional and event.value (duplicates_result := check_duplicates(
# event.value.split(","),
# optional.split(",")
# )):

# self.notify(
# message=duplicates_result,
# title="Duplicate Parameter",
# severity="error",
# timeout=5,
# )
optional: Input = self.app.query_one("#optional").value or ""
if (
optional
and event.value
and (
duplicates_result := check_duplicates(
event.value.split(","),
optional.split(","),
)
)
):
self.notify(
message=duplicates_result,
title="Duplicate Parameter",
severity="error",
timeout=5,
)


class Optional(Container):
Expand All @@ -125,18 +136,23 @@ def on_optional_change(self: Self, event: Input.Submitted) -> None:
timeout=5,
)

# required: Input = self.app.query_one("#required").value or ""
# if required and event.value and (duplicates_result := check_duplicates(
# required.split(","),
# event.value.split(","),
# )):

# self.notify(
# message=duplicates_result,
# title="Duplicate Parameter",
# severity="error",
# timeout=5,
# )
required: Input = self.app.query_one("#required").value or ""
if (
required
and event.value
and (
duplicates_result := check_duplicates(
required.split(","),
event.value.split(","),
)
)
):
self.notify(
message=duplicates_result,
title="Duplicate Parameter",
severity="error",
timeout=5,
)


class SampleWeight(Container):
Expand Down Expand Up @@ -189,15 +205,25 @@ def compose(self: Self) -> ComposeResult:
)


class DestinationFile(Container):
"""Destination file input component."""

def compose(self: Self) -> ComposeResult:
yield Prompt(PROMPT_OUTPUT, classes="label", id="output_prompt")
yield Input(placeholder="mightyestimator.py", id="output_file")


class ForgeButton(Container):
"""forge button component."""

def compose(self: Self) -> ComposeResult:
yield Button.success(
label="Forge ⚒️",
id="forge_btn",
)

@on(Button.Pressed, "#forge_btn")
def on_button_pressed(self, event: Button.Pressed) -> None:
def on_button_pressed(self: Self, _: Button.Pressed) -> None:
errors = []

name_input: str = self.app.query_one("#name").value
Expand All @@ -210,6 +236,8 @@ def on_button_pressed(self, event: Button.Pressed) -> None:
predict_proba: bool = self.app.query_one("#predict_proba").value
decision_function: bool = self.app.query_one("#decision_function").value

output_file: str = self.app.query_one("#output_file").value

match name_parser(name_input):
case Ok(name):
pass
Expand Down Expand Up @@ -240,9 +268,12 @@ def on_button_pressed(self, event: Button.Pressed) -> None:
if required_is_valid and optional_is_valid and (msg_duplicated_params := check_duplicates(required, optional)):
errors.append(msg_duplicated_params)

if not output_file:
errors.append("Outfile file cannot be empty")

if errors:
self.notify(
message="\n".join(errors),
message="\n".join([f"- {e}" for e in errors]),
title="Invalid inputs",
severity="error",
timeout=5,
Expand All @@ -261,16 +292,21 @@ def on_button_pressed(self, event: Button.Pressed) -> None:
tags=None,
)

# destination_file = Path(output_file)
# destination_file.parent.mkdir(parents=True, exist_ok=True)
destination_file = Path(output_file)
destination_file.parent.mkdir(parents=True, exist_ok=True)

with destination_file.open(mode="w") as destination:
destination.write(forged_template)

# with destination_file.open(mode="w") as destination:
# destination.write(forged_template)
self.notify(
message=f"Template forged at {destination_file}",
title="Success!",
severity="information",
timeout=5,
)

# TODO: Validate inputs
# TODO: Render
self.app.exit()

forge_row = ForgeRow(Static(), Static(), ForgeButton(), DestinationFile(), Static(), Static(), id="forge_row")

# class Version(Static):
# def render(self: Self) -> RenderableType:
Expand Down
5 changes: 3 additions & 2 deletions sksmithy/tui/tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from sksmithy.tui._components import (
DecisionFunction,
Estimator,
ForgeButton,
Linear,
Name,
Optional,
PredictProba,
Required,
SampleWeight,
forge_row,
)

if sys.version_info >= (3, 11): # pragma: no cover
Expand Down Expand Up @@ -61,7 +61,8 @@ def compose(self: Self) -> ComposeResult:
Horizontal(SampleWeight(), Linear()),
Horizontal(PredictProba(), DecisionFunction()),
Rule(),
ForgeButton(),
forge_row,
Rule(),
)
# yield Sidebar(classes="-hidden")
yield Footer()
Expand Down

0 comments on commit 0b077e3

Please sign in to comment.