Skip to content

Commit

Permalink
gettin there
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Jun 9, 2024
1 parent dfa4453 commit 0dc723d
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 54 deletions.
4 changes: 2 additions & 2 deletions sksmithy/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ class EstimatorType(str, Enum):
"""

ClassifierMixin = "classifier"
OutlierMixin = "outlier"
RegressorMixin = "regressor"
TransformerMixin = "transformer"
OutlierMixin = "outlier"
ClusterMixin = "cluster"
TransformerMixin = "transformer"
SelectorMixin = "feature-selector"


Expand Down
6 changes: 5 additions & 1 deletion sksmithy/_static/tui.tcss
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ Prompt {
Switch {
height: auto;
width: auto;
transition: offset 200ms;
}

Switch:disabled {
background: darkslategrey;
}

Input.-valid {
Expand All @@ -51,6 +54,7 @@ Input.-valid:focus {
.container {
height: auto;
width: auto;
min-height: 10vh;
}


Expand Down
127 changes: 119 additions & 8 deletions sksmithy/tui/_components.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import sys

from result import Err, Ok
from textual import on
from textual.app import ComposeResult
from textual.containers import Container, Horizontal
from textual.widgets import Input, Select, Static, Switch
from textual.widgets import Button, Input, Select, Static, Switch

from sksmithy._models import EstimatorType
from sksmithy._parsers import check_duplicates, name_parser, params_parser
from sksmithy._prompts import (
PROMPT_DECISION_FUNCTION,
PROMPT_ESTIMATOR,
Expand All @@ -16,6 +18,7 @@
PROMPT_REQUIRED,
PROMPT_SAMPLE_WEIGHT,
)
from sksmithy._utils import render_template
from sksmithy.tui._validators import NameValidator, ParamsValidator

if sys.version_info >= (3, 11): # pragma: no cover
Expand Down Expand Up @@ -44,6 +47,8 @@ def on_input_change(self: Self, event: Input.Changed) -> None:
severity="error",
timeout=5,
)
else:
...
# TODO: Update filename component


Expand Down Expand Up @@ -79,8 +84,8 @@ def compose(self: Self) -> ComposeResult:
yield Prompt(PROMPT_REQUIRED, classes="label")
yield Input(placeholder="alpha,beta", id="required", validators=[ParamsValidator()])

@on(Input.Changed, "#required")
def on_input_change(self: Self, event: Input.Changed) -> None:
@on(Input.Submitted, "#required")
def on_input_change(self: Self, event: Input.Submitted) -> None:
if not event.validation_result.is_valid:
self.notify(
message="\n".join(event.validation_result.failure_descriptions),
Expand All @@ -89,7 +94,18 @@ def on_input_change(self: Self, event: Input.Changed) -> None:
timeout=5,
)

# TODO: Add check for duplicates with optional
# optional: Input = self.app.query_one("#optional").value or ""
# if optional and event.value (duplicates_result := check_duplicates(
# event.value.split(","),
# optional.split(",")
# )):

# self.notify(
# message=duplicates_result,
# title="Duplicate Parameter",
# severity="error",
# timeout=5,
# )


class Optional(Container):
Expand All @@ -99,16 +115,28 @@ def compose(self: Self) -> ComposeResult:
yield Prompt(PROMPT_OPTIONAL, classes="label")
yield Input(placeholder="mu,sigma", id="optional", validators=[ParamsValidator()])

@on(Input.Changed, "#optional")
def on_input_change(self: Self, event: Input.Changed) -> None:
@on(Input.Submitted, "#optional")
def on_optional_change(self: Self, event: Input.Submitted) -> None:
if not event.validation_result.is_valid:
self.notify(
message="\n".join(event.validation_result.failure_descriptions),
title="Invalid Parameter",
severity="error",
timeout=5,
)
# TODO: Add check for duplicates with required

# required: Input = self.app.query_one("#required").value or ""
# if required and event.value and (duplicates_result := check_duplicates(
# required.split(","),
# event.value.split(","),
# )):

# self.notify(
# message=duplicates_result,
# title="Duplicate Parameter",
# severity="error",
# timeout=5,
# )


class SampleWeight(Container):
Expand All @@ -133,7 +161,7 @@ def compose(self: Self) -> ComposeResult:
)

@on(Switch.Changed, "#linear")
def on_switch_changed(self, event: Switch.Changed) -> None:
def on_switch_changed(self: Self, event: Switch.Changed) -> None:
decision_function: Switch = self.app.query_one("#decision_function")
decision_function.disabled = event.value
decision_function.value = decision_function.value and (not decision_function.disabled)
Expand Down Expand Up @@ -161,6 +189,89 @@ def compose(self: Self) -> ComposeResult:
)


class ForgeButton(Container):
def compose(self: Self) -> ComposeResult:
yield Button.success(
label="Forge ⚒️",
id="forge_btn",
)

@on(Button.Pressed, "#forge_btn")
def on_button_pressed(self, event: Button.Pressed) -> None:
errors = []

name_input: str = self.app.query_one("#name").value
estimator: str | None = self.app.query_one("#estimator").value
required_params: str = self.app.query_one("#required").value
optional_params: str = self.app.query_one("#optional").value

sample_weight: bool = self.app.query_one("#linear").value
linear: bool = self.app.query_one("#linear").value
predict_proba: bool = self.app.query_one("#predict_proba").value
decision_function: bool = self.app.query_one("#decision_function").value

match name_parser(name_input):
case Ok(name):
pass
case Err(name_error_msg):
errors.append(name_error_msg)

match estimator:
case str(v):
estimator_type = EstimatorType(v)
case Select.BLANK:
errors.append("Estimator cannot be None")

match params_parser(required_params):
case Ok(required):
required_is_valid = True
case Err(required_err_msg):
required_is_valid = False
errors.append(required_err_msg)

match params_parser(optional_params):
case Ok(optional):
optional_is_valid = True

case Err(optional_err_msg):
optional_is_valid = False
errors.append(optional_err_msg)

if required_is_valid and optional_is_valid and (msg_duplicated_params := check_duplicates(required, optional)):
errors.append(msg_duplicated_params)

if errors:
self.notify(
message="\n".join(errors),
title="Invalid inputs",
severity="error",
timeout=5,
)

else:
forged_template = render_template(
name=name,
estimator_type=estimator_type,
required=required,
optional=optional,
linear=linear,
sample_weight=sample_weight,
predict_proba=predict_proba,
decision_function=decision_function,
tags=None,
)

# destination_file = Path(output_file)
# destination_file.parent.mkdir(parents=True, exist_ok=True)

# with destination_file.open(mode="w") as destination:
# destination.write(forged_template)

# TODO: Validate inputs
# TODO: Render
self.app.exit()


# class Version(Static):
# def render(self: Self) -> RenderableType:
# return f"Version: [b]{version('sklearn-smithy')}"
Expand Down
8 changes: 0 additions & 8 deletions sksmithy/tui/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,3 @@ class ParamsValidator(_BaseValidator):
@staticmethod
def parser(value: str) -> Result[list[str], str]:
return params_parser(value)


# class DuplicateParamValidator(Validator):
# def validate(self, value: str) -> ValidationResult:
# required = self.required_
# optional = self.optional_
# result = check_duplicates(required, optional)
# return self.failure(result) if result else self.success()
38 changes: 3 additions & 35 deletions sksmithy/tui/tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from importlib import resources
from typing import ClassVar

from textual import on
from textual.app import App, ComposeResult
from textual.containers import Horizontal, ScrollableContainer
from textual.reactive import reactive
from textual.widgets import Button, Footer, Header, Input, Rule
from textual.widgets import Footer, Header, Rule

from sksmithy.tui._components import (
DecisionFunction,
Estimator,
ForgeButton,
Linear,
Name,
Optional,
Expand Down Expand Up @@ -61,7 +61,7 @@ def compose(self: Self) -> ComposeResult:
Horizontal(SampleWeight(), Linear()),
Horizontal(PredictProba(), DecisionFunction()),
Rule(),
Button(),
ForgeButton(),
)
# yield Sidebar(classes="-hidden")
yield Footer()
Expand All @@ -80,38 +80,6 @@ def action_toggle_dark(self: Self) -> None:
# self.screen.set_focus(None)
# sidebar.add_class("-hidden")

# @on(Input.Changed, "#name")
# def show_invalid_name(self: Self, event: Input.Changed) -> None:
# if not event.validation_result.is_valid:
# self.name_ = None
# self.notify(
# message=event.validation_result.failure_descriptions[0],
# title="Invalid Name",
# severity="error",
# timeout=5,
# )
# else:
# self.name_ = event.value

@on(Input.Changed, "#required,#optional")
def show_invalid_required(self: Self, event: Input.Changed) -> None:
if not event.validation_result.is_valid:
if event.input.id == "required":
self.required_ = None
else:
self.optional_ = None

self.notify(
message="\n".join(event.validation_result.failure_descriptions),
title="Invalid Parameter",
severity="error",
timeout=5,
)
elif event.input.id == "required":
self.required_ = event.value.split(",")
else:
self.optional_ = event.value.split(",")


if __name__ == "__main__":
tui = TUI()
Expand Down

0 comments on commit 0dc723d

Please sign in to comment.