Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integer array type conversion in array_dist to compute hamming distance #191

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions src/hnsw.c
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,26 @@ static float4 array_dist(ArrayType *a, ArrayType *b, usearch_metric_kind_t metri
elog(ERROR, "expected equally sized arrays but got arrays with dimensions %d and %d", a_dim, b_dim);
}

float4 *ax = (float4 *)ARR_DATA_PTR(a);
float4 *bx = (float4 *)ARR_DATA_PTR(b);
float4 result;
bool is_int_array = (metric_kind == usearch_metric_hamming_k);

return usearch_dist(ax, bx, metric_kind, a_dim, usearch_scalar_f32_k);
if(is_int_array) {
int32 *ax_int = (int32 *)ARR_DATA_PTR(a);
int32 *bx_int = (int32 *)ARR_DATA_PTR(b);

// calling usearch_scalar_f32_k here even though it's an integer array is fine
// the hamming distance in usearch actually ignores the scalar type
// and it will get casted appropriately in usearch even with this scalar type
result = usearch_dist(ax_int, bx_int, metric_kind, a_dim, usearch_scalar_f32_k);

} else {
float4 *ax = (float4 *)ARR_DATA_PTR(a);
float4 *bx = (float4 *)ARR_DATA_PTR(b);

result = usearch_dist(ax, bx, metric_kind, a_dim, usearch_scalar_f32_k);
}

return result;
}

static float8 vector_dist(Vector *a, Vector *b, usearch_metric_kind_t metric_kind)
Expand Down Expand Up @@ -330,7 +346,7 @@ Datum hamming_dist(PG_FUNCTION_ARGS)
{
ArrayType *a = PG_GETARG_ARRAYTYPE_P(0);
ArrayType *b = PG_GETARG_ARRAYTYPE_P(1);
PG_RETURN_INT32(array_dist(a, b, usearch_metric_hamming_k));
PG_RETURN_INT32((int32)array_dist(a, b, usearch_metric_hamming_k));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is returning the distance between vectors and not the vectors themselves. Why should this distance be an integer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hamming distance is always an integer and we also pass it to PG_RETURN_INT32 which returns an integer in postgres. So I added this explicit cast since array_dist returns a float4

}

PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_l2sq_dist);
Expand Down Expand Up @@ -371,36 +387,32 @@ HnswColumnType GetIndexColumnType(Relation index)
}

/*
* Given vector data and vector type, convert it to a float4 array
* Given vector data and vector type, read it as either a float4 or int32 array and return as void*
*/
float4 *DatumGetSizedFloatArray(Datum datum, HnswColumnType type, int dimensions)
void *DatumGetSizedArray(Datum datum, HnswColumnType type, int dimensions)
{
if(type == VECTOR) {
Vector *vector = DatumGetVector(datum);
if(vector->dim != dimensions) {
elog(ERROR, "Expected vector with dimension %d, got %d", dimensions, vector->dim);
}
return vector->x;
return (void *)vector->x;
} else if(type == REAL_ARRAY) {
ArrayType *array = DatumGetArrayTypePCopy(datum);
int array_dim = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array));
if(array_dim != dimensions) {
elog(ERROR, "Expected real array with dimension %d, got %d", dimensions, array_dim);
}
return (float4 *)ARR_DATA_PTR(array);
return (void *)((float4 *)ARR_DATA_PTR(array));
} else if(type == INT_ARRAY) {
ArrayType *array = DatumGetArrayTypePCopy(datum);
int array_dim = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array));
if(array_dim != dimensions) {
elog(ERROR, "Expected int array with dimension %d, got %d", dimensions, array_dim);
}
int *intArray = (int *)ARR_DATA_PTR(array);
float4 *floatArray = (float4 *)palloc(sizeof(float) * array_dim);
for(int i = 0; i < array_dim; i++) {
floatArray[ i ] = (float)intArray[ i ];
}
// todo:: free this array
return floatArray;

int32 *intArray = (int32 *)ARR_DATA_PTR(array);
return (void *)intArray;
} else {
elog(ERROR, "Unsupported type");
}
Expand Down
2 changes: 1 addition & 1 deletion src/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ PGDLLEXPORT Datum cos_dist(PG_FUNCTION_ARGS);

HnswColumnType GetColumnTypeFromOid(Oid oid);
HnswColumnType GetIndexColumnType(Relation index);
float4 *DatumGetSizedFloatArray(Datum datum, HnswColumnType type, int dimensions);
void *DatumGetSizedArray(Datum datum, HnswColumnType type, int dimensions);

#define LDB_UNUSED(x) (void)(x)

Expand Down
7 changes: 4 additions & 3 deletions src/hnsw/build.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,17 @@ static void AddTupleToUsearchIndex(ItemPointer tid, Datum *values, HnswBuildStat
usearch_error_t error = NULL;
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[ 0 ]));
usearch_scalar_kind_t usearch_scalar;
float4 *vector = DatumGetSizedFloatArray(value, buildstate->columnType, buildstate->dimensions);

void *vector = DatumGetSizedArray(value, buildstate->columnType, buildstate->dimensions);
switch(buildstate->columnType) {
case REAL_ARRAY:
case VECTOR:
usearch_scalar = usearch_scalar_f32_k;
break;
case INT_ARRAY:
// q:: I think in this case we need to do a type conversion from int to float
// before passing the buffer to usearch
// this is fine, since we only use integer arrays with hamming distance metric
// and hamming distance in usearch doesn't care about scalar type
// also, usearch will appropriately cast integer arrays even with this scalar type
usearch_scalar = usearch_scalar_f32_k;
break;
default:
Expand Down
2 changes: 1 addition & 1 deletion src/hnsw/insert.c
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ bool ldb_aminsert(Relation index,
assert(!error);

datum = PointerGetDatum(PG_DETOAST_DATUM(values[ 0 ]));
float4 *vector = DatumGetSizedFloatArray(datum, insertstate->columnType, opts.dimensions);
void *vector = DatumGetSizedArray(datum, insertstate->columnType, opts.dimensions);

#if LANTERNDB_COPYNODES
// currently not fully ported to the latest changes
Expand Down
8 changes: 4 additions & 4 deletions src/hnsw/scan.c
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ bool ldb_amgettuple(IndexScanDesc scan, ScanDirection dir)
if(scanstate->first) {
int num_returned;
Datum value;
float4 *vec;
void *vec;
usearch_error_t error = NULL;
int k = ldb_hnsw_init_k;

Expand All @@ -183,7 +183,7 @@ bool ldb_amgettuple(IndexScanDesc scan, ScanDirection dir)
Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value)));
Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value)));

vec = DatumGetSizedFloatArray(value, scanstate->columnType, scanstate->dimensions);
vec = DatumGetSizedArray(value, scanstate->columnType, scanstate->dimensions);

if(scanstate->distances == NULL) {
scanstate->distances = palloc(k * sizeof(float));
Expand All @@ -209,7 +209,7 @@ bool ldb_amgettuple(IndexScanDesc scan, ScanDirection dir)
if(scanstate->current == scanstate->count) {
int num_returned;
Datum value;
float4 *vec;
void *vec;
usearch_error_t error = NULL;
int k = scanstate->count * 2;
int index_size = usearch_size(scanstate->usearch_index, &error);
Expand All @@ -221,7 +221,7 @@ bool ldb_amgettuple(IndexScanDesc scan, ScanDirection dir)

value = scan->orderByData->sk_argument;

vec = DatumGetSizedFloatArray(value, scanstate->columnType, scanstate->dimensions);
vec = DatumGetSizedArray(value, scanstate->columnType, scanstate->dimensions);

/* double k and reallocate arrays to account for increased size */
scanstate->distances = repalloc(scanstate->distances, k * sizeof(float));
Expand Down
21 changes: 20 additions & 1 deletion test/expected/hnsw_dist_func.out
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ INFO: inserted 0 elements
INFO: done saving 0 vectors
INSERT INTO small_world_l2 SELECT id, v FROM small_world;
INSERT INTO small_world_cos SELECT id, v FROM small_world;
INSERT INTO small_world_ham SELECT id, v FROM small_world;
INSERT INTO small_world_ham SELECT id, ARRAY[CAST(v[1] AS INTEGER), CAST(v[2] AS INTEGER), CAST(v[3] AS INTEGER)] FROM small_world;
SET enable_seqscan = false;
-- Verify that the distance functions work (check distances)
SELECT ROUND(l2sq_dist(v, '{0,1,0}')::numeric, 2) FROM small_world_l2 ORDER BY v <-> '{0,1,0}';
Expand Down Expand Up @@ -220,3 +220,22 @@ WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}' LIMIT 1) SELECT id, COUNT
ERROR: Operator <-> can only be used inside of an index
WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}') SELECT id FROM t UNION SELECT id FROM t;
ERROR: Operator <-> can only be used inside of an index
-- Check that hamming distance query results are sorted correctly
CREATE TABLE extra_small_world_ham (
id SERIAL PRIMARY KEY,
v INT[2]
);
INSERT INTO extra_small_world_ham (v) VALUES ('{0,0}'), ('{1,1}'), ('{2,2}'), ('{3,3}');
CREATE INDEX ON extra_small_world_ham USING hnsw (v dist_hamming_ops) WITH (dim=2);
INFO: done init usearch index
INFO: inserted 4 elements
INFO: done saving 4 vectors
SELECT ROUND(hamming_dist(v, '{0,0}')::numeric, 2) FROM extra_small_world_ham ORDER BY v <-> '{0,0}';
round
-------
0.00
2.00
2.00
4.00
(4 rows)

19 changes: 0 additions & 19 deletions test/expected/hnsw_todo.out
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,6 @@ SELECT id, ROUND(l2sq_dist(vector_int, array[0,1,0])::numeric, 2) as dist
FROM small_world_l2
ORDER BY vector_int <-> array[0,1,0] LIMIT 7;
ERROR: Operator <-> can only be used inside of an index
-- this result is not sorted correctly
CREATE TABLE small_world_ham (
id SERIAL PRIMARY KEY,
v INT[2]
);
INSERT INTO small_world_ham (v) VALUES ('{0,0}'), ('{1,1}'), ('{2,2}'), ('{3,3}');
CREATE INDEX ON small_world_ham USING hnsw (v dist_hamming_ops) WITH (dim=2);
INFO: done init usearch index
INFO: inserted 4 elements
INFO: done saving 4 vectors
SELECT ROUND(hamming_dist(v, '{0,0}')::numeric, 2) FROM small_world_ham ORDER BY v <-> '{0,0}';
round
-------
0.00
2.00
4.00
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there still anything wrong with this query?
If not, it should be removed from hnsw_todo.sql (and moved to hnsw_dist_func.sql unless an equivalent test is already there)
If something is still wrong, please add a relevant comment.

2.00
(4 rows)

--- Test scenarious ---
-----------------------------------------
-- Case:
Expand Down
13 changes: 11 additions & 2 deletions test/sql/hnsw_dist_func.sql
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ CREATE INDEX ON small_world_ham USING hnsw (v dist_hamming_ops) WITH (dim=3);

INSERT INTO small_world_l2 SELECT id, v FROM small_world;
INSERT INTO small_world_cos SELECT id, v FROM small_world;
INSERT INTO small_world_ham SELECT id, v FROM small_world;
INSERT INTO small_world_ham SELECT id, ARRAY[CAST(v[1] AS INTEGER), CAST(v[2] AS INTEGER), CAST(v[3] AS INTEGER)] FROM small_world;

SET enable_seqscan = false;

Expand Down Expand Up @@ -88,4 +88,13 @@ SELECT 1 FROM test1 ORDER BY v <-> (SELECT '{1,3}'::real[]);
SELECT t2_results.id FROM test1 t1 JOIN LATERAL (SELECT t2.id FROM test2 t2 ORDER BY t1.v <-> t2.v LIMIT 1) t2_results ON TRUE;
WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}' LIMIT 1) SELECT DISTINCT id FROM t;
WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}' LIMIT 1) SELECT id, COUNT(*) FROM t GROUP BY 1;
WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}') SELECT id FROM t UNION SELECT id FROM t;
WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}') SELECT id FROM t UNION SELECT id FROM t;

-- Check that hamming distance query results are sorted correctly
CREATE TABLE extra_small_world_ham (
id SERIAL PRIMARY KEY,
v INT[2]
);
INSERT INTO extra_small_world_ham (v) VALUES ('{0,0}'), ('{1,1}'), ('{2,2}'), ('{3,3}');
CREATE INDEX ON extra_small_world_ham USING hnsw (v dist_hamming_ops) WITH (dim=2);
SELECT ROUND(hamming_dist(v, '{0,0}')::numeric, 2) FROM extra_small_world_ham ORDER BY v <-> '{0,0}';
9 changes: 0 additions & 9 deletions test/sql/hnsw_todo.sql
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@ SELECT id, ROUND(l2sq_dist(vector_int, array[0,1,0])::numeric, 2) as dist
FROM small_world_l2
ORDER BY vector_int <-> array[0,1,0] LIMIT 7;

-- this result is not sorted correctly
CREATE TABLE small_world_ham (
id SERIAL PRIMARY KEY,
v INT[2]
);
INSERT INTO small_world_ham (v) VALUES ('{0,0}'), ('{1,1}'), ('{2,2}'), ('{3,3}');
CREATE INDEX ON small_world_ham USING hnsw (v dist_hamming_ops) WITH (dim=2);
SELECT ROUND(hamming_dist(v, '{0,0}')::numeric, 2) FROM small_world_ham ORDER BY v <-> '{0,0}';

--- Test scenarious ---
-----------------------------------------
-- Case:
Expand Down