Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pipeline.jl: refactor to make API uniform #80

Merged
merged 3 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 54 additions & 47 deletions src/pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

using .FmtSQL: fmt_join, fmt_read, fmt_select

export create_tbl, set_tbl_col
export create_tbl, tbl_select, as_table

# default options (for now)
_read_opts = pairs((header = true, skip = 1))
Expand Down Expand Up @@ -192,26 +192,41 @@
return getproperty(base, on)
end

function _set_tbl_col_impl(
con::DB,
source::String,
idx::Vector,
vals::Vector;
on::Symbol,
col::Symbol,
opts...,
)
df = DF.DataFrame([idx, vals], [on, col])
tmp_tbl = "t_col_$(col)"
register_data_frame(con, df, tmp_tbl)
# FIXME: should be fill=error (currently not implemented)
res = create_tbl(con, source, tmp_tbl; on = [on], cols = [col], fill = false, opts...)
unregister_data_frame(con, tmp_tbl)
return res
"""
as_table(op::Function, con::DB, name::String, args...)

Temporarily "import" a Julia object into a DuckDB session. It does it
by first creating a `DataFrame`. `args...` are passed on to the
`DataFrame` constructor as is. It is registered with the DuckDB
connection `con` as the table `name`. This function can be used with
a `do`-block like this:

julia> as_table(con, "mytbl", (;col=collect(1:5))) do con, name
DD.query(con, "SELECT col, col+2 as 'shift_2' FROM '\$name'")
end |> DataFrame
5×2 DataFrame
Row │ col shift_2
│ Int64? Int64?
─────┼─────────────────
1 │ 1 3
2 │ 2 4
3 │ 3 5
4 │ 4 6
5 │ 5 7
suvayu marked this conversation as resolved.
Show resolved Hide resolved

"""
function as_table(op::Function, con::DB, name::String, args...)
df = DF.DataFrame(args...)
register_data_frame(con, df, name)
try
op(con, name)
finally
unregister_data_frame(con, name)
end
end

"""
set_tbl_col(
create_tbl(
con::DB,
source::String,
cols::Dict{Symbol,Vector{T}};
Expand All @@ -233,7 +248,7 @@
All other options behave as the two source version of `create_tbl`.

"""
function set_tbl_col(
function create_tbl(
con::DB,
source::String,
cols::Dict{Symbol, Vector{T}};
Expand All @@ -256,10 +271,9 @@
end

idx = _get_index(con, source, on)
vals = first(values(cols))
if length(idx) != length(vals)
if !all(length(idx) .== map(length, values(cols)))
msg = "Length of index column and values are different\n"
_cols = [idx, vals]
_cols = [idx, values(cols)...]
data =
[get.(_cols, i, "-") for i in 1:maximum(length, _cols)] |>
Iterators.flatten |>
Expand All @@ -268,21 +282,14 @@
msg *= pretty_table(String, data; header = ["index", "value"])
throw(DimensionMismatch(msg))
end
_set_tbl_col_impl(
con,
source,
idx,
vals;
on = on,
col = first(keys(cols)),
name = name,
tmp = tmp,
show = show,
)
col_names = keys(cols) |> collect
as_table(con, "t_$(join(col_names, '_'))", merge(cols, Dict(on => idx))) do con, tname
create_tbl(con, source, tname; on = [on], cols = col_names, fill = false, name, tmp, show)
end
end

"""
set_tbl_col(
create_tbl(
con::DB,
source::String,
cols::Dict{Symbol, T};
Expand All @@ -304,7 +311,7 @@
function.

"""
function set_tbl_col(
function create_tbl(
con::DB,
source::String,
cols::Dict{Symbol, T};
Expand All @@ -328,18 +335,18 @@
return _create_tbl_impl(con, query; name = name, tmp = tmp, show = show)
end

function set_tbl_col(
con::DB,
source::String;
on::Symbol,
col::Symbol,
name::String,
apply::Function,
tmp::Bool = false,
show::Bool = false,
) end

function select(
# function create_tbl(
# con::DB,
# source::String;
# on::Symbol,
# col::Symbol,
# name::String,
# apply::Function,
# tmp::Bool = false,
# show::Bool = false,
# ) end
suvayu marked this conversation as resolved.
Show resolved Hide resolved

function tbl_select(

Check warning on line 349 in src/pipeline.jl

View check run for this annotation

Codecov / codecov/patch

src/pipeline.jl#L349

Added line #L349 was not covered by tests
con::DB,
source::String,
expression::String;
Expand Down
21 changes: 8 additions & 13 deletions test/test-pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,40 +226,35 @@ end
@testset "w/ vector" begin
con = DBInterface.connect(DB)
df_exp = DF.DataFrame(CSV.File(csv_copy; header = 2))
df_res = TIO.set_tbl_col(con, csv_path, Dict(:investable => df_exp.investable); opts...)
df_res = TIO.create_tbl(con, csv_path, Dict(:investable => df_exp.investable); opts...)
# NOTE: row order is different, join to determine equality
cmp = join_cmp(df_exp, df_res, ["name", "investable"]; on = :name)
investable = cmp[!, [c for c in propertynames(cmp) if occursin("investable", String(c))]]
@test isequal.(investable[!, 1], investable[!, 2]) |> all

# stupid Julia! grow up!
args = [con, csv_path, Dict(:investable => df_exp.investable[2:end])]
@test_throws DimensionMismatch TIO.set_tbl_col(args...; opts...)
@test_throws DimensionMismatch TIO.create_tbl(args...; opts...)
if (VERSION.major >= 1) && (VERSION.minor >= 8)
@test_throws r"Length.+different" TIO.set_tbl_col(args...; opts...)
@test_throws r"index.+value" TIO.set_tbl_col(args...; opts...)
@test_throws r"Length.+different" TIO.create_tbl(args...; opts...)
@test_throws r"index.+value" TIO.create_tbl(args...; opts...)
end
end

@testset "w/ constant" begin
con = DBInterface.connect(DB)
df_res = TIO.set_tbl_col(con, csv_path, Dict(:investable => true); opts...)
df_res = TIO.create_tbl(con, csv_path, Dict(:investable => true); opts...)
@test df_res.investable |> all

table_name = TIO.set_tbl_col(con, csv_path, Dict(:investable => true); on = :name)
table_name = TIO.create_tbl(con, csv_path, Dict(:investable => true); on = :name)
@test "assets_data" == table_name
end

@testset "w/ constant after filtering" begin
con = DBInterface.connect(DB)
where_clause = TIO.FmtSQL.@where_(lifetime in 25:50, name % "Valhalla_%")
df_res = TIO.set_tbl_col(
con,
csv_path,
Dict(:investable => true);
opts...,
where_ = where_clause,
)
df_res =
TIO.create_tbl(con, csv_path, Dict(:investable => true); opts..., where_ = where_clause)
@test shape(df_res) == shape(df_org)
df_res =
filter(row -> 25 <= row.lifetime <= 50 && startswith(row.name, "Valhalla_"), df_res)
Expand Down
Loading