diff --git a/scripts/sanitizers/run_sanitizers.sh b/scripts/sanitizers/run_sanitizers.sh index 1f0df5553..de35c3191 100755 --- a/scripts/sanitizers/run_sanitizers.sh +++ b/scripts/sanitizers/run_sanitizers.sh @@ -55,7 +55,14 @@ docker exec -i -u root lantern-sanitizers /bin/bash < #include "../hnsw/utils.h" +#include "op_rewrite.h" #include "plan_tree_walker.h" #include "utils.h" @@ -73,10 +74,12 @@ void ExecutorStart_hook_with_operator_check(QueryDesc *queryDesc, int eflags) if(oidList != NULL) { // oidList will be NULL if LanternDB extension is not fully initialized // e.g. in statements executed as a result of CREATE EXTENSION ... statement + ldb_rewrite_ops(queryDesc->plannedstmt->planTree, oidList, queryDesc->plannedstmt->rtable); validate_operator_usage(queryDesc->plannedstmt->planTree, oidList); ListCell *lc; foreach(lc, queryDesc->plannedstmt->subplans) { Plan *subplan = (Plan *)lfirst(lc); + ldb_rewrite_ops(subplan, oidList, queryDesc->plannedstmt->rtable); validate_operator_usage(subplan, oidList); } list_free(oidList); diff --git a/src/hooks/op_rewrite.c b/src/hooks/op_rewrite.c new file mode 100644 index 000000000..d0b40c907 --- /dev/null +++ b/src/hooks/op_rewrite.c @@ -0,0 +1,285 @@ +#include + +#include "op_rewrite.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "plan_tree_walker.h" +#include "utils.h" + +#if PG_VERSION_NUM < 120000 +#include +#include +#else +#include +#endif + +static Node *operator_rewriting_mutator(Node *node, void *ctx); + +void base_plan_mutator(Plan *plan, void *context) +{ + plan->lefttree = (Plan *)operator_rewriting_mutator((Node *)plan->lefttree, context); + plan->righttree = (Plan *)operator_rewriting_mutator((Node *)plan->righttree, context); + plan->initPlan = (List *)operator_rewriting_mutator((Node *)plan->initPlan, context); + // checking qual and target list at the end covers some edge cases, if you modify this leave them here + plan->qual = (List *)operator_rewriting_mutator((Node *)plan->qual, context); + plan->targetlist = (List *)operator_rewriting_mutator((Node *)plan->targetlist, context); +} + +// recursively descend the plan tree searching for expressions with the <-> operator that are part of a non-index scan +// src/include/nodes/plannodes.h and src/include/nodes/nodes.h contain relevant definitions +Node *plan_tree_mutator(Plan *plan, void *context) +{ + check_stack_depth(); + + switch(nodeTag(plan)) { + case T_SubqueryScan: + { + SubqueryScan *subqueryscan = (SubqueryScan *)plan; + base_plan_mutator(&(subqueryscan->scan.plan), context); + subqueryscan->subplan = (Plan *)operator_rewriting_mutator((Node *)subqueryscan->subplan, context); + return (Node *)subqueryscan; + } + case T_CteScan: + { + CteScan *ctescan = (CteScan *)plan; + base_plan_mutator(&(ctescan->scan.plan), context); + return (Node *)ctescan; + } +#if PG_VERSION_NUM < 160000 + case T_Join: + { + Join *join = (Join *)plan; + base_plan_mutator(&(join->plan), context); + join->joinqual = (List *)operator_rewriting_mutator((Node *)join->joinqual, context); + return (Node *)join; + } +#endif + case T_NestLoop: + { + NestLoop *nestloop = (NestLoop *)plan; + base_plan_mutator((Plan *)&(nestloop->join), context); + return (Node *)nestloop; + } + case T_Result: + { + Result *result = (Result *)plan; + base_plan_mutator(&(result->plan), context); + result->resconstantqual = operator_rewriting_mutator((Node *)result->resconstantqual, context); + return (Node *)result; + } + case T_Limit: + { + Limit *limit = (Limit *)plan; + base_plan_mutator(&(limit->plan), context); + limit->limitOffset = operator_rewriting_mutator((Node *)limit->limitOffset, context); + limit->limitCount = operator_rewriting_mutator((Node *)limit->limitCount, context); + return (Node *)limit; + } + case T_Append: + { + Append *append = (Append *)plan; + base_plan_mutator(&(append->plan), context); + append->appendplans = (List *)operator_rewriting_mutator((Node *)append->appendplans, context); + return (Node *)append; + } + // case T_IncrementalSort: // We will eventually support this + case T_Agg: + case T_Group: + case T_Sort: + case T_Unique: + case T_SetOp: + case T_Hash: + case T_HashJoin: + case T_WindowAgg: + case T_LockRows: + { + base_plan_mutator(plan, context); + return (Node *)plan; + } + case T_ModifyTable: // No order by when modifying a table (update/delete etc) + case T_BitmapAnd: // We do not provide a bitmap index + case T_BitmapOr: + case T_BitmapHeapScan: + case T_BitmapIndexScan: + case T_FunctionScan: // SELECT * FROM fn(x, y, z) + case T_ValuesScan: // VALUES (1), (2) + case T_Material: // https://stackoverflow.com/questions/31410030/ +#if PG_VERSION_NUM >= 140000 + case T_Memoize: // memoized inner loop must have an index to be memoized +#endif + case T_WorkTableScan: // temporary table, shouldn't have index + case T_ProjectSet: // "execute set returning functions" feels safe to exclude + case T_TableFuncScan: // scan of a function that returns a table, shouldn't have an index + case T_ForeignScan: // if the relation is foreign we can't determine if it has an index + default: + break; + } + return (Node *)plan; +} + +// To write syscache calls look for the 'static const struct cachedesc cacheinfo[]' in utils/cache/syscache.c +// These describe the different caches that will be initialized into SysCache and the keys they support in searches +// The anums tell you the table and the column that the key will be compared to this is afaict the only way to match +// them to SQL for example pg_am.oid -> Anum_pg_am_oid the keys must be in order but they need not all be included the +// comment next to the top label is the name of the #defined cacheid that you should use as your first argument you can +// destructure the tuple int a From_(table_name) with GETSTRUCT to pull individual rows out +static Oid get_func_id_from_index(Relation index) +{ + Oid hnswamoid = get_index_am_oid("hnsw", false); + if(index->rd_rel->relam != hnswamoid) return InvalidOid; + + // indclass is inaccessible on the form data + // https://www.postgresql.org/docs/current/system-catalog-declarations.html + bool isNull; + Oid idxopclassoid; + Datum classDatum = SysCacheGetAttr(INDEXRELID, index->rd_indextuple, Anum_pg_index_indclass, &isNull); + if(!isNull) { + oidvector *indclass = (oidvector *)DatumGetPointer(classDatum); + assert(indclass->dim1 == 1); + idxopclassoid = indclass->values[ 0 ]; + } else { + index_close(index, AccessShareLock); + elog(ERROR, "Failed to retrieve indclass oid from index class"); + } + + // SELECT * FROM pg_opclass WHERE opcmethod=hnswamoid AND opcname=dist_cos_ops + HeapTuple opclassTuple = SearchSysCache1(CLAOID, ObjectIdGetDatum(idxopclassoid)); + if(!HeapTupleIsValid(opclassTuple)) { + index_close(index, AccessShareLock); + elog(ERROR, "Failed to find operator class for key column"); + } + + Oid opclassOid = ((Form_pg_opclass)GETSTRUCT(opclassTuple))->opcfamily; + ReleaseSysCache(opclassTuple); + + // SELECT * FROM pg_amproc WHERE amprocfamily=opclassOid + // SearchSysCache1 is what we want and in fact it runs fine against release builds. However debug builds assert that + // AMPROCNUM takes only 1 arg which isn't true and so they fail. We therefore have to use SearchSysCacheList1 since + // it doesn't enforce this invariant. Ideally we would call SearchCatCache1 directly but postgres doesn't expose + // necessary constants + CatCList *opList = SearchSysCacheList1(AMPROCNUM, ObjectIdGetDatum(opclassOid)); + assert(opList->n_members == 1); + HeapTuple opTuple = &opList->members[ 0 ]->tuple; + if(!HeapTupleIsValid(opTuple)) { + index_close(index, AccessShareLock); + elog(ERROR, "Failed to find the function for operator class"); + } + Oid functionId = ((Form_pg_amproc)GETSTRUCT(opTuple))->amproc; + ReleaseCatCacheList(opList); + + return functionId; +} + +static Node *operator_rewriting_mutator(Node *node, void *ctx) +{ + OpRewriterContext *context = (OpRewriterContext *)ctx; + + if(node == NULL) return node; + + if(IsA(node, OpExpr)) { + OpExpr *opExpr = (OpExpr *)node; + if(list_member_oid(context->ldb_ops, opExpr->opno)) { + if(context->indices == NULL) { + return node; + } else { + ListCell *lc; + foreach(lc, context->indices) { + uintptr_t intermediate = (uintptr_t)lfirst(lc); + Oid indexid = (Oid)intermediate; + Relation index = index_open(indexid, AccessShareLock); + Oid indexfunc = get_func_id_from_index(index); + if(OidIsValid(indexfunc)) { + MemoryContext old = MemoryContextSwitchTo(MessageContext); + FuncExpr *fnExpr = makeNode(FuncExpr); + fnExpr->funcresulttype = opExpr->opresulttype; + fnExpr->funcretset = opExpr->opretset; + fnExpr->funccollid = opExpr->opcollid; + fnExpr->inputcollid = opExpr->inputcollid; + fnExpr->args = opExpr->args; + fnExpr->location = opExpr->location; + // operators can't take variadic arguments + fnExpr->funcvariadic = false; + // print it as a function + fnExpr->funcformat = COERCE_EXPLICIT_CALL; + fnExpr->funcid = indexfunc; + MemoryContextSwitchTo(old); + + index_close(index, AccessShareLock); + + return (Node *)fnExpr; + } + index_close(index, AccessShareLock); + } + return node; + } + } + } + + if(IsA(node, IndexScan) || IsA(node, IndexOnlyScan)) { + return node; + } + if(IsA(node, SeqScan) || IsA(node, SampleScan)) { + Scan *scan = (Scan *)node; + Plan *scanPlan = &scan->plan; + Oid rtrelid = scan->scanrelid; + RangeTblEntry *rte = rt_fetch(rtrelid, context->rtable); + Oid relid = rte->relid; + Relation rel = relation_open(relid, AccessShareLock); + if(rel->rd_indexvalid) { + context->indices = RelationGetIndexList(rel); + } + relation_close(rel, AccessShareLock); + + base_plan_mutator(scanPlan, context); + return (Node *)scan; + } + + if(IsA(node, List)) { + MemoryContext old = MemoryContextSwitchTo(MessageContext); + List *list = (List *)node; + List *ret = NIL; + ListCell *lc; + foreach(lc, list) { + ret = lappend(ret, operator_rewriting_mutator((Node *)lfirst(lc), ctx)); + } + MemoryContextSwitchTo(old); + return (Node *)ret; + } + + if(is_plan_node(node)) { + return (Node *)plan_tree_mutator((Plan *)node, ctx); + } else { + return expression_tree_mutator(node, operator_rewriting_mutator, ctx); + } +} + +bool ldb_rewrite_ops(Plan *plan, List *oidList, List *rtable) +{ + Node *node = (Node *)plan; + + OpRewriterContext context; + context.ldb_ops = oidList; + context.indices = NULL; + context.rtable = rtable; + + if(IsA(node, IndexScan) || IsA(node, IndexOnlyScan)) { + return false; + } + + operator_rewriting_mutator(node, (void *)&context); + return true; +} diff --git a/src/hooks/op_rewrite.h b/src/hooks/op_rewrite.h new file mode 100644 index 000000000..8db3a04e7 --- /dev/null +++ b/src/hooks/op_rewrite.h @@ -0,0 +1,15 @@ +#ifndef LDB_HOOKS_OP_REWRITE_H +#define LDB_HOOKS_OP_REWRITE_H + +#include +#include + +typedef struct OpRewriterContext +{ + List *ldb_ops; + List *indices; + List *rtable; +} OpRewriterContext; + +bool ldb_rewrite_ops(Plan *plan, List *oidList, List *rtable); +#endif diff --git a/src/hooks/plan_tree_walker.h b/src/hooks/plan_tree_walker.h index 7207badc1..03f885c43 100644 --- a/src/hooks/plan_tree_walker.h +++ b/src/hooks/plan_tree_walker.h @@ -17,4 +17,4 @@ static inline bool is_plan_node(Node *node) bool plan_tree_walker(Plan *plan, bool (*walker_func)(Node *node, void *context), void *context); -#endif // LDB_HOOKS_PLAN_TREE_WALKER_H \ No newline at end of file +#endif // LDB_HOOKS_PLAN_TREE_WALKER_H diff --git a/src/hooks/utils.c b/src/hooks/utils.c index 34ed1adbc..bf263cecb 100644 --- a/src/hooks/utils.c +++ b/src/hooks/utils.c @@ -24,4 +24,4 @@ List *ldb_get_operator_oids() list_free(nameList); return oidList; -} \ No newline at end of file +} diff --git a/src/hooks/utils.h b/src/hooks/utils.h index ea3516b3c..be89baaf4 100644 --- a/src/hooks/utils.h +++ b/src/hooks/utils.h @@ -7,4 +7,6 @@ List *ldb_get_operator_oids(); -#endif // LDB_HOOKS_UTILS_H \ No newline at end of file +List *ldb_get_operator_class_oids(Oid amId); + +#endif // LDB_HOOKS_UTILS_H diff --git a/test/expected/hnsw_create_expr.out b/test/expected/hnsw_create_expr.out index 7527d2080..690beb7bd 100644 --- a/test/expected/hnsw_create_expr.out +++ b/test/expected/hnsw_create_expr.out @@ -87,7 +87,10 @@ ERROR: data type text has no default operator class for access method "hnsw" -- This should result in error about multicolumn expressions support CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id), int_to_dynamic_binary_real_array(id)) WITH (M=2); ERROR: access method "hnsw" does not support multicolumn indexes --- This currently results in an error about using the operator outside of index --- This case should be fixed SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> '{0,0,0}'::REAL[] LIMIT 2; -ERROR: Operator <-> can only be used inside of an index + id +---- + 0 + 1 +(2 rows) + diff --git a/test/expected/hnsw_select.out b/test/expected/hnsw_select.out index 1fcc450a7..8b4b42fc7 100644 --- a/test/expected/hnsw_select.out +++ b/test/expected/hnsw_select.out @@ -204,6 +204,201 @@ SELECT has_index_scan('EXPLAIN WITH t AS (SELECT id FROM test1 ORDER BY v <-> '' t (1 row) +-- Validate <-> is replaced with the matching function when an index is present +set enable_seqscan = true; +set enable_indexscan = false; +EXPLAIN (COSTS false) SELECT * from small_world ORDER BY v <-> '{1,1,1}'; + QUERY PLAN +----------------------------------------------- + Sort + Sort Key: (l2sq_dist(v, '{1,1,1}'::real[])) + -> Seq Scan on small_world +(3 rows) + +SELECT * from small_world ORDER BY v <-> '{1,1,1}'; + id | b | v +-----+---+--------- + 111 | t | {1,1,1} + 101 | f | {1,0,1} + 110 | f | {1,1,0} + 011 | t | {0,1,1} + 100 | f | {1,0,0} + 001 | t | {0,0,1} + 010 | f | {0,1,0} + 000 | t | {0,0,0} +(8 rows) + +begin; +INSERT INTO test2 (v) VALUES ('{1,4}'); +INSERT INTO test2 (v) VALUES ('{2,4}'); +CREATE INDEX test2_cos ON test2 USING hnsw(v dist_cos_ops); +INFO: done init usearch index +INFO: inserted 3 elements +INFO: done saving 3 vectors +EXPLAIN (COSTS false) SELECT * from test2 ORDER BY v <-> '{1,4}'; + QUERY PLAN +-------------------------------------------- + Sort + Sort Key: (cos_dist(v, '{1,4}'::real[])) + -> Seq Scan on test2 +(3 rows) + +-- Some additional cases that trigger operator rewriting +-- SampleScan +EXPLAIN (COSTS false) SELECT * FROM small_world TABLESAMPLE BERNOULLI (20) ORDER BY v <-> '{1,1,1}' ASC; + QUERY PLAN +----------------------------------------------- + Sort + Sort Key: (l2sq_dist(v, '{1,1,1}'::real[])) + -> Sample Scan on small_world + Sampling: bernoulli ('20'::real) +(4 rows) + +-- can't compare direct equality here because it's random +SELECT results_match('EXPLAIN SELECT * FROM small_world TABLESAMPLE BERNOULLI (20) ORDER BY v <-> ''{1,1,1}'' ASC', + 'EXPLAIN SELECT * FROM small_world TABLESAMPLE BERNOULLI (20) ORDER BY l2sq_dist(v, ''{1,1,1}'') ASC'); + results_match +--------------- + t +(1 row) + +-- SetOpt/HashSetOp +EXPLAIN (COSTS false) (SELECT * FROM small_world ORDER BY v <-> '{1,0,1}' ASC ) EXCEPT (SELECT * FROM small_world ORDER by v <-> '{1,1,1}' ASC LIMIT 5); + QUERY PLAN +------------------------------------------------------------------------------------- + HashSetOp Except + -> Append + -> Subquery Scan on "*SELECT* 1" + -> Sort + Sort Key: (l2sq_dist(small_world.v, '{1,0,1}'::real[])) + -> Seq Scan on small_world + -> Subquery Scan on "*SELECT* 2" + -> Limit + -> Sort + Sort Key: (l2sq_dist(small_world_1.v, '{1,1,1}'::real[])) + -> Seq Scan on small_world small_world_1 +(11 rows) + +SELECT results_match('(SELECT * FROM small_world ORDER BY v <-> ''{1,0,1}'' ASC ) EXCEPT (SELECT * FROM small_world ORDER by v <-> ''{1,1,1}'' ASC LIMIT 5)', + '(SELECT * FROM small_world ORDER BY l2sq_dist(v, ''{1,0,1}'') ASC ) EXCEPT (SELECT * FROM small_world ORDER by l2sq_dist(v, ''{1,1,1}'') ASC LIMIT 5)'); + results_match +--------------- + t +(1 row) + +-- HashAggregate +EXPLAIN (COSTS false) SELECT v, COUNT(*) FROM small_world GROUP BY v ORDER BY v <-> '{1,1,1}'; + QUERY PLAN +----------------------------------------------- + Sort + Sort Key: (l2sq_dist(v, '{1,1,1}'::real[])) + -> HashAggregate + Group Key: v + -> Seq Scan on small_world +(5 rows) + +SELECT results_match('SELECT v, COUNT(*) FROM small_world GROUP BY v ORDER BY v <-> ''{1,1,1}''', + 'SELECT v, COUNT(*) FROM small_world GROUP BY v ORDER BY l2sq_dist(v, ''{1,1,1}'')'); + results_match +--------------- + t +(1 row) + +-- GroupBy this +EXPLAIN (COSTS false) SELECT * FROM small_world GROUP BY id, v, b ORDER BY v <-> '{1,1,1}'; + QUERY PLAN +----------------------------------------------- + Sort + Sort Key: (l2sq_dist(v, '{1,1,1}'::real[])) + -> HashAggregate + Group Key: id, v, b + -> Seq Scan on small_world +(5 rows) + +SELECT results_match('SELECT * FROM small_world GROUP BY id, v, b ORDER BY v <-> ''{1,1,1}''', + 'SELECT * FROM small_world GROUP BY id, v, b ORDER BY l2sq_dist(v, ''{1,1,1}'')'); + results_match +--------------- + t +(1 row) + +-- HashJoin/Hash +CREATE TABLE small_world_2 AS (SELECT * FROM small_world); +EXPLAIN (COSTS false) SELECT * FROM small_world JOIN small_world_2 using (v) ORDER BY v <-> '{1,1,1}'; + QUERY PLAN +----------------------------------------------------------- + Sort + Sort Key: (l2sq_dist(small_world.v, '{1,1,1}'::real[])) + -> Hash Join + Hash Cond: (small_world_2.v = small_world.v) + -> Seq Scan on small_world_2 + -> Hash + -> Seq Scan on small_world +(7 rows) + +SELECT results_match('SELECT * FROM small_world JOIN small_world_2 using (v) ORDER BY v <-> ''{1,1,1}''', + 'SELECT * FROM small_world JOIN small_world_2 using (v) ORDER BY l2sq_dist(v, ''{1,1,1}'')'); + results_match +--------------- + t +(1 row) + +-- MixedAggregate (this doesn't require additional logic, but I include it here as an example of generating the path) +EXPLAIN (COSTS false) SELECT v FROM small_world GROUP BY ROLLUP(v) ORDER BY v <-> '{1,1,1}'; + QUERY PLAN +----------------------------------------------- + Sort + Sort Key: (l2sq_dist(v, '{1,1,1}'::real[])) + -> MixedAggregate + Hash Key: v + Group Key: () + -> Seq Scan on small_world +(6 rows) + +SELECT results_match('SELECT v FROM small_world GROUP BY ROLLUP(v) ORDER BY v <-> ''{1,1,1}''', + 'SELECT v FROM small_world GROUP BY ROLLUP(v) ORDER BY l2sq_dist(v, ''{1,1,1}'')'); + results_match +--------------- + t +(1 row) + +-- WindowAgg +EXPLAIN (COSTS false) SELECT v, EVERY(b) OVER () FROM small_world ORDER BY v <-> '{1,1,1}'; + QUERY PLAN +----------------------------------------------- + Sort + Sort Key: (l2sq_dist(v, '{1,1,1}'::real[])) + -> WindowAgg + -> Seq Scan on small_world +(4 rows) + +SELECT results_match('SELECT v, EVERY(b) OVER () FROM small_world ORDER BY v <-> ''{1,1,1}''', + 'SELECT v, EVERY(b) OVER () FROM small_world ORDER BY l2sq_dist(v, ''{1,1,1}'')'); + results_match +--------------- + t +(1 row) + +-- LockRows +EXPLAIN (COSTS false) SELECT * FROM small_world ORDER BY v <-> '{1,1,1}' ASC FOR UPDATE; + QUERY PLAN +----------------------------------------------------- + LockRows + -> Sort + Sort Key: (l2sq_dist(v, '{1,1,1}'::real[])) + -> Seq Scan on small_world +(4 rows) + +SELECT results_match('SELECT * FROM small_world ORDER BY v <-> ''{1,1,1}'' ASC FOR UPDATE', + 'SELECT * FROM small_world ORDER BY l2sq_dist(v, ''{1,1,1}'') ASC FOR UPDATE'); + results_match +--------------- + t +(1 row) + +rollback; +set enable_indexscan = true; +set enable_seqscan = false; -- todo:: Verify joins work and still use index -- todo:: Verify incremental sorts work -- Validate index data structures diff --git a/test/expected/hnsw_todo.out b/test/expected/hnsw_todo.out index cb3a610cd..e65164191 100644 --- a/test/expected/hnsw_todo.out +++ b/test/expected/hnsw_todo.out @@ -38,7 +38,15 @@ EXPLAIN (COSTS FALSE) SELECT id, ROUND(l2sq_dist(vector_int, array[0,1,0])::numeric, 2) as dist FROM small_world_l2 ORDER BY vector_int <-> array[0,1,0] LIMIT 7; -ERROR: Operator <-> can only be used inside of an index + QUERY PLAN +----------------------------------------------------------------------- + Limit + -> Result + -> Sort + Sort Key: (l2sq_dist(vector_int, '{0,1,0}'::integer[])) + -> Seq Scan on small_world_l2 +(5 rows) + --- Test scenarious --- ----------------------------------------- -- Case: diff --git a/test/sql/hnsw_create_expr.sql b/test/sql/hnsw_create_expr.sql index cae3ad888..9ee5f4aac 100644 --- a/test/sql/hnsw_create_expr.sql +++ b/test/sql/hnsw_create_expr.sql @@ -83,6 +83,4 @@ CREATE INDEX ON test_table USING hnsw (int_to_string(id)) WITH (M=2); -- This should result in error about multicolumn expressions support CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id), int_to_dynamic_binary_real_array(id)) WITH (M=2); --- This currently results in an error about using the operator outside of index --- This case should be fixed SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> '{0,0,0}'::REAL[] LIMIT 2; diff --git a/test/sql/hnsw_select.sql b/test/sql/hnsw_select.sql index d0d0efecd..5bd9c90cf 100644 --- a/test/sql/hnsw_select.sql +++ b/test/sql/hnsw_select.sql @@ -71,6 +71,56 @@ SELECT has_index_scan('EXPLAIN (SELECT id FROM test1 ORDER BY v <-> ''{1,4}'') U -- Validate CTEs work and still use index SELECT has_index_scan('EXPLAIN WITH t AS (SELECT id FROM test1 ORDER BY v <-> ''{1,4}'') SELECT id FROM t UNION SELECT id FROM t'); +-- Validate <-> is replaced with the matching function when an index is present +set enable_seqscan = true; +set enable_indexscan = false; +EXPLAIN (COSTS false) SELECT * from small_world ORDER BY v <-> '{1,1,1}'; +SELECT * from small_world ORDER BY v <-> '{1,1,1}'; +begin; +INSERT INTO test2 (v) VALUES ('{1,4}'); +INSERT INTO test2 (v) VALUES ('{2,4}'); +CREATE INDEX test2_cos ON test2 USING hnsw(v dist_cos_ops); +EXPLAIN (COSTS false) SELECT * from test2 ORDER BY v <-> '{1,4}'; +-- Some additional cases that trigger operator rewriting +-- SampleScan +EXPLAIN (COSTS false) SELECT * FROM small_world TABLESAMPLE BERNOULLI (20) ORDER BY v <-> '{1,1,1}' ASC; +-- can't compare direct equality here because it's random +SELECT results_match('EXPLAIN SELECT * FROM small_world TABLESAMPLE BERNOULLI (20) ORDER BY v <-> ''{1,1,1}'' ASC', + 'EXPLAIN SELECT * FROM small_world TABLESAMPLE BERNOULLI (20) ORDER BY l2sq_dist(v, ''{1,1,1}'') ASC'); +-- SetOpt/HashSetOp +EXPLAIN (COSTS false) (SELECT * FROM small_world ORDER BY v <-> '{1,0,1}' ASC ) EXCEPT (SELECT * FROM small_world ORDER by v <-> '{1,1,1}' ASC LIMIT 5); +SELECT results_match('(SELECT * FROM small_world ORDER BY v <-> ''{1,0,1}'' ASC ) EXCEPT (SELECT * FROM small_world ORDER by v <-> ''{1,1,1}'' ASC LIMIT 5)', + '(SELECT * FROM small_world ORDER BY l2sq_dist(v, ''{1,0,1}'') ASC ) EXCEPT (SELECT * FROM small_world ORDER by l2sq_dist(v, ''{1,1,1}'') ASC LIMIT 5)'); +-- HashAggregate +EXPLAIN (COSTS false) SELECT v, COUNT(*) FROM small_world GROUP BY v ORDER BY v <-> '{1,1,1}'; +SELECT results_match('SELECT v, COUNT(*) FROM small_world GROUP BY v ORDER BY v <-> ''{1,1,1}''', + 'SELECT v, COUNT(*) FROM small_world GROUP BY v ORDER BY l2sq_dist(v, ''{1,1,1}'')'); +-- GroupBy this +EXPLAIN (COSTS false) SELECT * FROM small_world GROUP BY id, v, b ORDER BY v <-> '{1,1,1}'; +SELECT results_match('SELECT * FROM small_world GROUP BY id, v, b ORDER BY v <-> ''{1,1,1}''', + 'SELECT * FROM small_world GROUP BY id, v, b ORDER BY l2sq_dist(v, ''{1,1,1}'')'); +-- HashJoin/Hash +CREATE TABLE small_world_2 AS (SELECT * FROM small_world); +EXPLAIN (COSTS false) SELECT * FROM small_world JOIN small_world_2 using (v) ORDER BY v <-> '{1,1,1}'; +SELECT results_match('SELECT * FROM small_world JOIN small_world_2 using (v) ORDER BY v <-> ''{1,1,1}''', + 'SELECT * FROM small_world JOIN small_world_2 using (v) ORDER BY l2sq_dist(v, ''{1,1,1}'')'); +-- MixedAggregate (this doesn't require additional logic, but I include it here as an example of generating the path) +EXPLAIN (COSTS false) SELECT v FROM small_world GROUP BY ROLLUP(v) ORDER BY v <-> '{1,1,1}'; +SELECT results_match('SELECT v FROM small_world GROUP BY ROLLUP(v) ORDER BY v <-> ''{1,1,1}''', + 'SELECT v FROM small_world GROUP BY ROLLUP(v) ORDER BY l2sq_dist(v, ''{1,1,1}'')'); +-- WindowAgg +EXPLAIN (COSTS false) SELECT v, EVERY(b) OVER () FROM small_world ORDER BY v <-> '{1,1,1}'; +SELECT results_match('SELECT v, EVERY(b) OVER () FROM small_world ORDER BY v <-> ''{1,1,1}''', + 'SELECT v, EVERY(b) OVER () FROM small_world ORDER BY l2sq_dist(v, ''{1,1,1}'')'); +-- LockRows +EXPLAIN (COSTS false) SELECT * FROM small_world ORDER BY v <-> '{1,1,1}' ASC FOR UPDATE; +SELECT results_match('SELECT * FROM small_world ORDER BY v <-> ''{1,1,1}'' ASC FOR UPDATE', + 'SELECT * FROM small_world ORDER BY l2sq_dist(v, ''{1,1,1}'') ASC FOR UPDATE'); + +rollback; +set enable_indexscan = true; +set enable_seqscan = false; + -- todo:: Verify joins work and still use index -- todo:: Verify incremental sorts work diff --git a/test/sql/utils/common.sql b/test/sql/utils/common.sql index 2f93bb7e1..89ae94f35 100644 --- a/test/sql/utils/common.sql +++ b/test/sql/utils/common.sql @@ -63,3 +63,28 @@ BEGIN RETURN found; END; $$ LANGUAGE plpgsql; + +-- Determine if the two queries provided return the same results +-- At the moment this only works on queries that return rows with the same entries as one another +-- if you try to compare uneven numbers of columns or columns of different types it will generate an error +CREATE OR REPLACE FUNCTION results_match(left_query text, right_query text) RETURNS boolean AS $$ +DECLARE + left_cursor REFCURSOR; + left_row RECORD; + + right_cursor REFCURSOR; + right_row RECORD; +BEGIN + OPEN left_cursor FOR EXECUTE left_query; + OPEN right_cursor FOR EXECUTE right_query; + LOOP + FETCH NEXT FROM left_cursor INTO left_row; + FETCH NEXT FROM right_cursor INTO right_row; + IF left_row != right_row THEN + RETURN false; + ELSEIF left_row IS NULL AND right_row IS NULL THEN + RETURN true; + END IF; + END LOOP; +END; +$$ LANGUAGE plpgsql;