diff --git a/src/pipeline.jl b/src/pipeline.jl index 2266a2e..e152120 100644 --- a/src/pipeline.jl +++ b/src/pipeline.jl @@ -3,7 +3,7 @@ using DuckDB: DB, DBInterface, Stmt, register_data_frame, unregister_data_frame using .FmtSQL: fmt_join, fmt_read, fmt_select -export create_tbl, set_tbl_col +export create_tbl, tbl_select, as_table # default options (for now) _read_opts = pairs((header = true, skip = 1)) @@ -192,26 +192,51 @@ function _get_index(con::DB, source::String, on::Symbol) return getproperty(base, on) end -function _set_tbl_col_impl( - con::DB, - source::String, - idx::Vector, - vals::Vector; - on::Symbol, - col::Symbol, - opts..., -) - df = DF.DataFrame([idx, vals], [on, col]) - tmp_tbl = "t_col_$(col)" - register_data_frame(con, df, tmp_tbl) - # FIXME: should be fill=error (currently not implemented) - res = create_tbl(con, source, tmp_tbl; on = [on], cols = [col], fill = false, opts...) - unregister_data_frame(con, tmp_tbl) - return res +""" + as_table(op::Function, con::DB, name::String, args...) + +Temporarily "import" a Julia object into a DuckDB session. It does it +by first creating a `DataFrame`. `args...` are passed on to the +`DataFrame` constructor as is. It is registered with the DuckDB +connection `con` as the table `name`. This function can be used with +a `do`-block like this: + +```jldoctest +using DuckDB: DBInterface, DB, query +using DataFrames: DataFrame + +con = DBInterface.connect(DB) + +as_table(con, "mytbl", (;col=collect(1:5))) do con, name + query(con, "SELECT col, col+2 as 'shift_2' FROM '\$name'") +end |> DataFrame + +# output + +5×2 DataFrame + Row │ col shift_2 + │ Int64? Int64? +─────┼───────────────── + 1 │ 1 3 + 2 │ 2 4 + 3 │ 3 5 + 4 │ 4 6 + 5 │ 5 7 +``` + +""" +function as_table(op::Function, con::DB, name::String, args...) + df = DF.DataFrame(args...) + register_data_frame(con, df, name) + try + op(con, name) + finally + unregister_data_frame(con, name) + end end """ - set_tbl_col( + create_tbl( con::DB, source::String, cols::Dict{Symbol,Vector{T}}; @@ -233,7 +258,7 @@ created table is returned. All other options behave as the two source version of `create_tbl`. """ -function set_tbl_col( +function create_tbl( con::DB, source::String, cols::Dict{Symbol, Vector{T}}; @@ -256,10 +281,9 @@ function set_tbl_col( end idx = _get_index(con, source, on) - vals = first(values(cols)) - if length(idx) != length(vals) + if !all(length(idx) .== map(length, values(cols))) msg = "Length of index column and values are different\n" - _cols = [idx, vals] + _cols = [idx, values(cols)...] data = [get.(_cols, i, "-") for i in 1:maximum(length, _cols)] |> Iterators.flatten |> @@ -268,21 +292,14 @@ function set_tbl_col( msg *= pretty_table(String, data; header = ["index", "value"]) throw(DimensionMismatch(msg)) end - _set_tbl_col_impl( - con, - source, - idx, - vals; - on = on, - col = first(keys(cols)), - name = name, - tmp = tmp, - show = show, - ) + col_names = keys(cols) |> collect + as_table(con, "t_$(join(col_names, '_'))", merge(cols, Dict(on => idx))) do con, tname + create_tbl(con, source, tname; on = [on], cols = col_names, fill = false, name, tmp, show) + end end """ - set_tbl_col( + create_tbl( con::DB, source::String, cols::Dict{Symbol, T}; @@ -304,7 +321,7 @@ All other options and behaviour are same as the vector variant of this function. """ -function set_tbl_col( +function create_tbl( con::DB, source::String, cols::Dict{Symbol, T}; @@ -328,18 +345,18 @@ function set_tbl_col( return _create_tbl_impl(con, query; name = name, tmp = tmp, show = show) end -function set_tbl_col( - con::DB, - source::String; - on::Symbol, - col::Symbol, - name::String, - apply::Function, - tmp::Bool = false, - show::Bool = false, -) end - -function select( +# function create_tbl( +# con::DB, +# source::String; +# on::Symbol, +# col::Symbol, +# name::String, +# apply::Function, +# tmp::Bool = false, +# show::Bool = false, +# ) end + +function tbl_select( con::DB, source::String, expression::String; diff --git a/test/test-pipeline.jl b/test/test-pipeline.jl index 4ad61f6..90efd97 100644 --- a/test/test-pipeline.jl +++ b/test/test-pipeline.jl @@ -226,7 +226,7 @@ end @testset "w/ vector" begin con = DBInterface.connect(DB) df_exp = DF.DataFrame(CSV.File(csv_copy; header = 2)) - df_res = TIO.set_tbl_col(con, csv_path, Dict(:investable => df_exp.investable); opts...) + df_res = TIO.create_tbl(con, csv_path, Dict(:investable => df_exp.investable); opts...) # NOTE: row order is different, join to determine equality cmp = join_cmp(df_exp, df_res, ["name", "investable"]; on = :name) investable = cmp[!, [c for c in propertynames(cmp) if occursin("investable", String(c))]] @@ -234,32 +234,27 @@ end # stupid Julia! grow up! args = [con, csv_path, Dict(:investable => df_exp.investable[2:end])] - @test_throws DimensionMismatch TIO.set_tbl_col(args...; opts...) + @test_throws DimensionMismatch TIO.create_tbl(args...; opts...) if (VERSION.major >= 1) && (VERSION.minor >= 8) - @test_throws r"Length.+different" TIO.set_tbl_col(args...; opts...) - @test_throws r"index.+value" TIO.set_tbl_col(args...; opts...) + @test_throws r"Length.+different" TIO.create_tbl(args...; opts...) + @test_throws r"index.+value" TIO.create_tbl(args...; opts...) end end @testset "w/ constant" begin con = DBInterface.connect(DB) - df_res = TIO.set_tbl_col(con, csv_path, Dict(:investable => true); opts...) + df_res = TIO.create_tbl(con, csv_path, Dict(:investable => true); opts...) @test df_res.investable |> all - table_name = TIO.set_tbl_col(con, csv_path, Dict(:investable => true); on = :name) + table_name = TIO.create_tbl(con, csv_path, Dict(:investable => true); on = :name) @test "assets_data" == table_name end @testset "w/ constant after filtering" begin con = DBInterface.connect(DB) where_clause = TIO.FmtSQL.@where_(lifetime in 25:50, name % "Valhalla_%") - df_res = TIO.set_tbl_col( - con, - csv_path, - Dict(:investable => true); - opts..., - where_ = where_clause, - ) + df_res = + TIO.create_tbl(con, csv_path, Dict(:investable => true); opts..., where_ = where_clause) @test shape(df_res) == shape(df_org) df_res = filter(row -> 25 <= row.lifetime <= 50 && startswith(row.name, "Valhalla_"), df_res)