Skip to content

Commit

Permalink
eureka
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Jun 9, 2024
1 parent 0b077e3 commit 8f8b656
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 161 deletions.
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,16 @@ Website = "https://sklearn-smithy.streamlit.app/"

[project.optional-dependencies]
streamlit = ["streamlit>=1.34.0"]
textual = ["textual>=0.65.0"]

all = [
"streamlit>=1.34.0",
"textual>=0.65.0",
]

[project.scripts]
smith = "sksmithy.__main__:cli"
smith-tui = "sksmithy.tui.__main__:forge_tui"

[tool.hatch.build.targets.sdist]
only-include = ["sksmithy"]
Expand Down Expand Up @@ -105,4 +112,6 @@ omit = [
"sksmithy/_arguments.py",
"sksmithy/_logger.py",
"sksmithy/_prompts.py",
]
"sksmithy/tui/__init__.py",
"sksmithy/tui/__main__.py",
]
12 changes: 12 additions & 0 deletions sksmithy/_static/description.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Writing scikit-learn compatible estimators might be harder than expected.

While everyone knows about the `fit` and `predict`, there are other behaviours, methods and attributes that scikit-learn might be expecting from your estimator depending on:

- The type of estimator you're writing.
- The signature of the estimator.
- The signature of the `.fit(...)` method.

Scikit-learn Smithy to the rescue: this tool aims to help you crafting your own estimator by asking a few questions about it, and then generating the boilerplate code.

In this way you will be able to fully focus on the core implementation logic, and not on nitty-gritty details of the
scikit-learn API.
74 changes: 41 additions & 33 deletions sksmithy/_static/tui.tcss
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
.container {
height: auto;
width: auto;
min-height: 10vh;
}

.label {
height: 3;
content-align: right middle;
width: auto;
}

Screen {
align: center middle;
}
Expand Down Expand Up @@ -69,44 +81,40 @@ DestinationFile {
height: 100%;
}


.container {
Sidebar {
width: 80;
height: auto;
width: auto;
min-height: 10vh;
}
background: $panel;
transition: offset 200ms in_out_cubic;
layer: overlay;

.label {
height: 3;
content-align: right middle;
width: auto;
}

Sidebar:focus-within {
offset: 0 0 !important;
}

Sidebar.-hidden {
offset-x: -100%;
}

Sidebar Title {
background: $boost;
color: $secondary;
padding: 2 0 1 0;
border-right: vkey $background;
dock: top;
text-align: center;
text-style: bold;
}

# Sidebar {
# width: 80;
# background: $panel;
# transition: offset 200ms in_out_cubic;
# layer: overlay;

# }

# Sidebar:focus-within {
# offset: 0 0 !important;
# }

# Sidebar.-hidden {
# offset-x: -100%;
# }
OptionGroup {
background: $boost;
color: $text;
height: 1fr;
border-right: vkey $background;
}

# Sidebar Title {
# background: $boost;
# color: $secondary;
# padding: 2 0 1 0;
# border-right: vkey $background;
# dock: top;
# text-align: center;
# text-style: bold;
# }
Message {
margin: 0 1;
}
3 changes: 3 additions & 0 deletions sksmithy/tui/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from sksmithy.tui._tui import ForgeTUI

__all__ = ("ForgeTUI",)
11 changes: 11 additions & 0 deletions sksmithy/tui/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from sksmithy.tui._tui import ForgeTUI


def forge_tui() -> None:
"""Entrypoint function."""
tui = ForgeTUI()
tui.run()


if __name__ == "__main__":
forge_tui()
93 changes: 56 additions & 37 deletions sksmithy/tui/_components.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import sys
from importlib import metadata, resources
from pathlib import Path

from result import Err, Ok
from rich.console import RenderableType
from textual import on
from textual.app import ComposeResult
from textual.containers import Container, Grid, Horizontal
from textual.containers import Container, Grid, Horizontal, ScrollableContainer
from textual.widgets import Button, Input, Select, Static, Switch

from sksmithy._models import EstimatorType
Expand All @@ -29,11 +31,10 @@
from typing_extensions import Self


class Prompt(Static):
pass
SIDEBAR_MSG: str = (resources.files("sksmithy") / "_static" / "description.md").read_text()


class ForgeRow(Grid):
class Prompt(Static):
pass


Expand All @@ -46,15 +47,15 @@ def compose(self: Self) -> ComposeResult:

@on(Input.Changed, "#name")
def on_input_change(self: Self, event: Input.Changed) -> None:
if not event.validation_result.is_valid:
if not event.validation_result.is_valid: # type: ignore[union-attr]
self.notify(
message=event.validation_result.failure_descriptions[0],
message=event.validation_result.failure_descriptions[0], # type: ignore[union-attr]
title="Invalid Name",
severity="error",
timeout=5,
)
else:
output_file: Input = self.app.query_one("#output_file")
output_file = self.app.query_one("#output_file", Input)
output_file.value = f"{event.value.lower()}.py"


Expand All @@ -70,9 +71,9 @@ def compose(self: Self) -> ComposeResult:

@on(Select.Changed, "#estimator")
def on_select_change(self: Self, event: Select.Changed) -> None:
linear: Switch = self.app.query_one("#linear")
predict_proba: Switch = self.app.query_one("#predict_proba")
decision_function: Switch = self.app.query_one("#decision_function")
linear = self.app.query_one("#linear", Switch)
predict_proba = self.app.query_one("#predict_proba", Switch)
decision_function = self.app.query_one("#decision_function", Switch)

linear.disabled = event.value not in {"classifier", "regressor"}
predict_proba.disabled = event.value not in {"classifier", "outlier"}
Expand All @@ -92,15 +93,15 @@ def compose(self: Self) -> ComposeResult:

@on(Input.Submitted, "#required")
def on_input_change(self: Self, event: Input.Submitted) -> None:
if not event.validation_result.is_valid:
if not event.validation_result.is_valid: # type: ignore[union-attr]
self.notify(
message="\n".join(event.validation_result.failure_descriptions),
message="\n".join(event.validation_result.failure_descriptions), # type: ignore[union-attr]
title="Invalid Parameter",
severity="error",
timeout=5,
)

optional: Input = self.app.query_one("#optional").value or ""
optional = self.app.query_one("#optional", Input).value or ""
if (
optional
and event.value
Expand Down Expand Up @@ -128,15 +129,15 @@ def compose(self: Self) -> ComposeResult:

@on(Input.Submitted, "#optional")
def on_optional_change(self: Self, event: Input.Submitted) -> None:
if not event.validation_result.is_valid:
if not event.validation_result.is_valid: # type: ignore[union-attr]
self.notify(
message="\n".join(event.validation_result.failure_descriptions),
message="\n".join(event.validation_result.failure_descriptions), # type: ignore[union-attr]
title="Invalid Parameter",
severity="error",
timeout=5,
)

required: Input = self.app.query_one("#required").value or ""
required = self.app.query_one("#required", Input).value or ""
if (
required
and event.value
Expand Down Expand Up @@ -178,7 +179,7 @@ def compose(self: Self) -> ComposeResult:

@on(Switch.Changed, "#linear")
def on_switch_changed(self: Self, event: Switch.Changed) -> None:
decision_function: Switch = self.app.query_one("#decision_function")
decision_function = self.app.query_one("#decision_function", Switch)
decision_function.disabled = event.value
decision_function.value = decision_function.value and (not decision_function.disabled)

Expand Down Expand Up @@ -223,20 +224,20 @@ def compose(self: Self) -> ComposeResult:
)

@on(Button.Pressed, "#forge_btn")
def on_button_pressed(self: Self, _: Button.Pressed) -> None:
def on_forge(self: Self, _: Button.Pressed) -> None: # noqa: C901
errors = []

name_input: str = self.app.query_one("#name").value
estimator: str | None = self.app.query_one("#estimator").value
required_params: str = self.app.query_one("#required").value
optional_params: str = self.app.query_one("#optional").value
name_input = self.app.query_one("#name", Input).value
estimator = self.app.query_one("#estimator", Select).value
required_params = self.app.query_one("#required", Input).value
optional_params = self.app.query_one("#optional", Input).value

sample_weight: bool = self.app.query_one("#linear").value
linear: bool = self.app.query_one("#linear").value
predict_proba: bool = self.app.query_one("#predict_proba").value
decision_function: bool = self.app.query_one("#decision_function").value
sample_weight = self.app.query_one("#linear", Switch).value
linear = self.app.query_one("#linear", Switch).value
predict_proba = self.app.query_one("#predict_proba", Switch).value
decision_function = self.app.query_one("#decision_function", Switch).value

output_file: str = self.app.query_one("#output_file").value
output_file = self.app.query_one("#output_file", Input).value

match name_parser(name_input):
case Ok(name):
Expand All @@ -248,7 +249,7 @@ def on_button_pressed(self: Self, _: Button.Pressed) -> None:
case str(v):
estimator_type = EstimatorType(v)
case Select.BLANK:
errors.append("Estimator cannot be None")
errors.append("Estimator cannot be None!")

match params_parser(required_params):
case Ok(required):
Expand All @@ -269,7 +270,7 @@ def on_button_pressed(self: Self, _: Button.Pressed) -> None:
errors.append(msg_duplicated_params)

if not output_file:
errors.append("Outfile file cannot be empty")
errors.append("Outfile file cannot be empty!")

if errors:
self.notify(
Expand Down Expand Up @@ -306,15 +307,33 @@ def on_button_pressed(self: Self, _: Button.Pressed) -> None:
)


class ForgeRow(Grid):
"""Row grid for forge."""



forge_row = ForgeRow(Static(), Static(), ForgeButton(), DestinationFile(), Static(), Static(), id="forge_row")

# class Version(Static):
# def render(self: Self) -> RenderableType:
# return f"Version: [b]{version('sklearn-smithy')}"

class Title(Static):
pass


# class Sidebar(Container):
# def compose(self: Self) -> ComposeResult:
# yield Title("Description")
# yield Container(MarkdownViewer(SIDEBAR_MSG))
# yield Version()
class OptionGroup(ScrollableContainer):
pass


class Message(Static):
pass


class Version(Static):
def render(self: Self) -> RenderableType:
return f"Version: [b]{metadata.version('sklearn-smithy')}"


class Sidebar(Container):
def compose(self: Self) -> ComposeResult:
yield Title("Description")
yield OptionGroup(Message(SIDEBAR_MSG), Version())
yield Version()
Loading

0 comments on commit 8f8b656

Please sign in to comment.