From 529d0a99228d0a06d0fac2a6e46395f80998cf22 Mon Sep 17 00:00:00 2001 From: Suvayu Ali Date: Tue, 22 Oct 2024 02:16:30 +0200 Subject: [PATCH 1/3] pipeline.jl: refactor to make API uniform Everything is now handled with function dispatch. Also add a nicer way to join with Julia objects: `as_table` that support `do`-blocks. --- src/pipeline.jl | 101 ++++++++++++++++++++++-------------------- test/test-pipeline.jl | 21 ++++----- 2 files changed, 62 insertions(+), 60 deletions(-) diff --git a/src/pipeline.jl b/src/pipeline.jl index 2266a2e..a00eaf1 100644 --- a/src/pipeline.jl +++ b/src/pipeline.jl @@ -3,7 +3,7 @@ using DuckDB: DB, DBInterface, Stmt, register_data_frame, unregister_data_frame using .FmtSQL: fmt_join, fmt_read, fmt_select -export create_tbl, set_tbl_col +export create_tbl, tbl_select, as_table # default options (for now) _read_opts = pairs((header = true, skip = 1)) @@ -192,26 +192,41 @@ function _get_index(con::DB, source::String, on::Symbol) return getproperty(base, on) end -function _set_tbl_col_impl( - con::DB, - source::String, - idx::Vector, - vals::Vector; - on::Symbol, - col::Symbol, - opts..., -) - df = DF.DataFrame([idx, vals], [on, col]) - tmp_tbl = "t_col_$(col)" - register_data_frame(con, df, tmp_tbl) - # FIXME: should be fill=error (currently not implemented) - res = create_tbl(con, source, tmp_tbl; on = [on], cols = [col], fill = false, opts...) - unregister_data_frame(con, tmp_tbl) - return res +""" + as_table(op::Function, con::DB, name::String, args...) + +Temporarily "import" a Julia object into a DuckDB session. It does it +by first creating a `DataFrame`. `args...` are passed on to the +`DataFrame` constructor as is. It is registered with the DuckDB +connection `con` as the table `name`. This function can be used with +a `do`-block like this: + + julia> as_table(con, "mytbl", (;col=collect(1:5))) do con, name + DD.query(con, "SELECT col, col+2 as 'shift_2' FROM '\$name'") + end |> DataFrame + 5×2 DataFrame + Row │ col shift_2 + │ Int64? Int64? + ─────┼───────────────── + 1 │ 1 3 + 2 │ 2 4 + 3 │ 3 5 + 4 │ 4 6 + 5 │ 5 7 + +""" +function as_table(op::Function, con::DB, name::String, args...) + df = DF.DataFrame(args...) + register_data_frame(con, df, name) + try + op(con, name) + finally + unregister_data_frame(con, name) + end end """ - set_tbl_col( + create_tbl( con::DB, source::String, cols::Dict{Symbol,Vector{T}}; @@ -233,7 +248,7 @@ created table is returned. All other options behave as the two source version of `create_tbl`. """ -function set_tbl_col( +function create_tbl( con::DB, source::String, cols::Dict{Symbol, Vector{T}}; @@ -256,10 +271,9 @@ function set_tbl_col( end idx = _get_index(con, source, on) - vals = first(values(cols)) - if length(idx) != length(vals) + if !all(length(idx) .== map(length, values(cols))) msg = "Length of index column and values are different\n" - _cols = [idx, vals] + _cols = [idx, values(cols)...] data = [get.(_cols, i, "-") for i in 1:maximum(length, _cols)] |> Iterators.flatten |> @@ -268,21 +282,14 @@ function set_tbl_col( msg *= pretty_table(String, data; header = ["index", "value"]) throw(DimensionMismatch(msg)) end - _set_tbl_col_impl( - con, - source, - idx, - vals; - on = on, - col = first(keys(cols)), - name = name, - tmp = tmp, - show = show, - ) + col_names = keys(cols) |> collect + as_table(con, "t_$(join(col_names, '_'))", merge(cols, Dict(on => idx))) do con, tname + create_tbl(con, source, tname; on = [on], cols = col_names, fill = false, name, tmp, show) + end end """ - set_tbl_col( + create_tbl( con::DB, source::String, cols::Dict{Symbol, T}; @@ -304,7 +311,7 @@ All other options and behaviour are same as the vector variant of this function. """ -function set_tbl_col( +function create_tbl( con::DB, source::String, cols::Dict{Symbol, T}; @@ -328,18 +335,18 @@ function set_tbl_col( return _create_tbl_impl(con, query; name = name, tmp = tmp, show = show) end -function set_tbl_col( - con::DB, - source::String; - on::Symbol, - col::Symbol, - name::String, - apply::Function, - tmp::Bool = false, - show::Bool = false, -) end - -function select( +# function create_tbl( +# con::DB, +# source::String; +# on::Symbol, +# col::Symbol, +# name::String, +# apply::Function, +# tmp::Bool = false, +# show::Bool = false, +# ) end + +function tbl_select( con::DB, source::String, expression::String; diff --git a/test/test-pipeline.jl b/test/test-pipeline.jl index 4ad61f6..90efd97 100644 --- a/test/test-pipeline.jl +++ b/test/test-pipeline.jl @@ -226,7 +226,7 @@ end @testset "w/ vector" begin con = DBInterface.connect(DB) df_exp = DF.DataFrame(CSV.File(csv_copy; header = 2)) - df_res = TIO.set_tbl_col(con, csv_path, Dict(:investable => df_exp.investable); opts...) + df_res = TIO.create_tbl(con, csv_path, Dict(:investable => df_exp.investable); opts...) # NOTE: row order is different, join to determine equality cmp = join_cmp(df_exp, df_res, ["name", "investable"]; on = :name) investable = cmp[!, [c for c in propertynames(cmp) if occursin("investable", String(c))]] @@ -234,32 +234,27 @@ end # stupid Julia! grow up! args = [con, csv_path, Dict(:investable => df_exp.investable[2:end])] - @test_throws DimensionMismatch TIO.set_tbl_col(args...; opts...) + @test_throws DimensionMismatch TIO.create_tbl(args...; opts...) if (VERSION.major >= 1) && (VERSION.minor >= 8) - @test_throws r"Length.+different" TIO.set_tbl_col(args...; opts...) - @test_throws r"index.+value" TIO.set_tbl_col(args...; opts...) + @test_throws r"Length.+different" TIO.create_tbl(args...; opts...) + @test_throws r"index.+value" TIO.create_tbl(args...; opts...) end end @testset "w/ constant" begin con = DBInterface.connect(DB) - df_res = TIO.set_tbl_col(con, csv_path, Dict(:investable => true); opts...) + df_res = TIO.create_tbl(con, csv_path, Dict(:investable => true); opts...) @test df_res.investable |> all - table_name = TIO.set_tbl_col(con, csv_path, Dict(:investable => true); on = :name) + table_name = TIO.create_tbl(con, csv_path, Dict(:investable => true); on = :name) @test "assets_data" == table_name end @testset "w/ constant after filtering" begin con = DBInterface.connect(DB) where_clause = TIO.FmtSQL.@where_(lifetime in 25:50, name % "Valhalla_%") - df_res = TIO.set_tbl_col( - con, - csv_path, - Dict(:investable => true); - opts..., - where_ = where_clause, - ) + df_res = + TIO.create_tbl(con, csv_path, Dict(:investable => true); opts..., where_ = where_clause) @test shape(df_res) == shape(df_org) df_res = filter(row -> 25 <= row.lifetime <= 50 && startswith(row.name, "Valhalla_"), df_res) From 8a6512566c593b0df2c61f8d2d8be52e8218a227 Mon Sep 17 00:00:00 2001 From: Suvayu Ali Date: Tue, 22 Oct 2024 10:57:10 +0200 Subject: [PATCH 2/3] pipeline.jl: convert code example to doctest --- src/pipeline.jl | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/src/pipeline.jl b/src/pipeline.jl index a00eaf1..65f64e1 100644 --- a/src/pipeline.jl +++ b/src/pipeline.jl @@ -201,18 +201,27 @@ by first creating a `DataFrame`. `args...` are passed on to the connection `con` as the table `name`. This function can be used with a `do`-block like this: - julia> as_table(con, "mytbl", (;col=collect(1:5))) do con, name - DD.query(con, "SELECT col, col+2 as 'shift_2' FROM '\$name'") - end |> DataFrame - 5×2 DataFrame - Row │ col shift_2 - │ Int64? Int64? - ─────┼───────────────── - 1 │ 1 3 - 2 │ 2 4 - 3 │ 3 5 - 4 │ 4 6 - 5 │ 5 7 +```jldoctest +using DuckDB: DBInterface, DB + +con = DBInterface.connect(DB) + +as_table(con, "mytbl", (;col=collect(1:5))) do con, name + DD.query(con, "SELECT col, col+2 as 'shift_2' FROM '\$name'") +end |> DataFrame + +# output + +5×2 DataFrame +Row │ col shift_2 + │ Int64? Int64? +─────┼───────────────── + 1 │ 1 3 + 2 │ 2 4 + 3 │ 3 5 + 4 │ 4 6 + 5 │ 5 7 +``` """ function as_table(op::Function, con::DB, name::String, args...) From 5fc3ba5eaa5c60a871ba95c9fdb3a705759a6ae6 Mon Sep 17 00:00:00 2001 From: Suvayu Ali Date: Tue, 22 Oct 2024 11:00:33 +0200 Subject: [PATCH 3/3] pipeline.jl: fix typo in `as_table` code example --- src/pipeline.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/pipeline.jl b/src/pipeline.jl index 65f64e1..e152120 100644 --- a/src/pipeline.jl +++ b/src/pipeline.jl @@ -202,18 +202,19 @@ connection `con` as the table `name`. This function can be used with a `do`-block like this: ```jldoctest -using DuckDB: DBInterface, DB +using DuckDB: DBInterface, DB, query +using DataFrames: DataFrame con = DBInterface.connect(DB) as_table(con, "mytbl", (;col=collect(1:5))) do con, name - DD.query(con, "SELECT col, col+2 as 'shift_2' FROM '\$name'") + query(con, "SELECT col, col+2 as 'shift_2' FROM '\$name'") end |> DataFrame # output 5×2 DataFrame -Row │ col shift_2 + Row │ col shift_2 │ Int64? Int64? ─────┼───────────────── 1 │ 1 3