Skip to content

Commit

Permalink
fix(expr-common): Coerce to Decimal(20, 0) when combining UInt64 with…
Browse files Browse the repository at this point in the history
… signed integers (#14223)

* fix(expr-common): Coerce to Decimal(20, 0) when combining UInt64 with signed integers

Previously, when combining UInt64 with any signed integer, the resulting type would be Int64, which would result in lost information. Now, combining UInt64 with a signed integer results in a Decimal(20, 0), which is able to encode all (64-bit) integer types.

* test: Add sqllogictest for #14208

* refactor: Move unsigned integer and decimal coercion to coerce_numeric_type_to_decimal

* fix: Also handle unsigned integers when coercing to Decimal256

* fix: Coerce UInt64 and other unsigned integer to UInt64
  • Loading branch information
nuno-faria authored Jan 24, 2025
1 parent 49f95af commit e0f9f65
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 34 deletions.
58 changes: 33 additions & 25 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -777,29 +777,19 @@ pub fn binary_numeric_coercion(
(_, Float32) | (Float32, _) => Some(Float32),
// The following match arms encode the following logic: Given the two
// integral types, we choose the narrowest possible integral type that
// accommodates all values of both types. Note that some information
// loss is inevitable when we have a signed type and a `UInt64`, in
// which case we use `Int64`;i.e. the widest signed integral type.

// TODO: For i64 and u64, we can use decimal or float64
// Postgres has no unsigned type :(
// DuckDB v.0.10.0 has double (double precision floating-point number (8 bytes))
// for largest signed (signed sixteen-byte integer) and unsigned integer (unsigned sixteen-byte integer)
// accommodates all values of both types. Note that to avoid information
// loss when combining UInt64 with signed integers we use Decimal128(20, 0).
(UInt64, Int64 | Int32 | Int16 | Int8)
| (Int64 | Int32 | Int16 | Int8, UInt64) => Some(Decimal128(20, 0)),
(UInt64, _) | (_, UInt64) => Some(UInt64),
(Int64, _)
| (_, Int64)
| (UInt64, Int8)
| (Int8, UInt64)
| (UInt64, Int16)
| (Int16, UInt64)
| (UInt64, Int32)
| (Int32, UInt64)
| (UInt32, Int8)
| (Int8, UInt32)
| (UInt32, Int16)
| (Int16, UInt32)
| (UInt32, Int32)
| (Int32, UInt32) => Some(Int64),
(UInt64, _) | (_, UInt64) => Some(UInt64),
(Int32, _)
| (_, Int32)
| (UInt16, Int16)
Expand Down Expand Up @@ -928,16 +918,16 @@ pub fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result<DataType> {
}

/// Convert the numeric data type to the decimal data type.
/// Now, we just support the signed integer type and floating-point type.
/// We support signed and unsigned integer types and floating-point type.
fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
// This conversion rule is from spark
// https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127
match numeric_type {
Int8 => Some(Decimal128(3, 0)),
Int16 => Some(Decimal128(5, 0)),
Int32 => Some(Decimal128(10, 0)),
Int64 => Some(Decimal128(20, 0)),
Int8 | UInt8 => Some(Decimal128(3, 0)),
Int16 | UInt16 => Some(Decimal128(5, 0)),
Int32 | UInt32 => Some(Decimal128(10, 0)),
Int64 | UInt64 => Some(Decimal128(20, 0)),
// TODO if we convert the floating-point data to the decimal type, it maybe overflow.
Float32 => Some(Decimal128(14, 7)),
Float64 => Some(Decimal128(30, 15)),
Expand All @@ -946,16 +936,16 @@ fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> {
}

/// Convert the numeric data type to the decimal data type.
/// Now, we just support the signed integer type and floating-point type.
/// We support signed and unsigned integer types and floating-point type.
fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
// This conversion rule is from spark
// https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127
match numeric_type {
Int8 => Some(Decimal256(3, 0)),
Int16 => Some(Decimal256(5, 0)),
Int32 => Some(Decimal256(10, 0)),
Int64 => Some(Decimal256(20, 0)),
Int8 | UInt8 => Some(Decimal256(3, 0)),
Int16 | UInt16 => Some(Decimal256(5, 0)),
Int32 | UInt32 => Some(Decimal256(10, 0)),
Int64 | UInt64 => Some(Decimal256(20, 0)),
// TODO if we convert the floating-point data to the decimal type, it maybe overflow.
Float32 => Some(Decimal256(14, 7)),
Float64 => Some(Decimal256(30, 15)),
Expand Down Expand Up @@ -1994,6 +1984,18 @@ mod tests {
Operator::Gt,
DataType::UInt32
);
test_coercion_binary_rule!(
DataType::UInt64,
DataType::UInt8,
Operator::Eq,
DataType::UInt64
);
test_coercion_binary_rule!(
DataType::UInt64,
DataType::Int64,
Operator::Eq,
DataType::Decimal128(20, 0)
);
// numeric/decimal
test_coercion_binary_rule!(
DataType::Int64,
Expand Down Expand Up @@ -2025,6 +2027,12 @@ mod tests {
Operator::GtEq,
DataType::Decimal128(15, 3)
);
test_coercion_binary_rule!(
DataType::UInt64,
DataType::Decimal128(20, 0),
Operator::Eq,
DataType::Decimal128(20, 0)
);

// Binary
test_coercion_binary_rule!(
Expand Down
6 changes: 3 additions & 3 deletions datafusion/sqllogictest/test_files/coalesce.slt
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ select
----
2 Int64

# TODO: Got types (i64, u64), casting to decimal or double or even i128 if supported
query IT
# i64 and u64, cast to decimal128(20, 0)
query RT
select
coalesce(2, arrow_cast(3, 'UInt64')),
arrow_typeof(coalesce(2, arrow_cast(3, 'UInt64')));
----
2 Int64
2 Decimal128(20, 0)

statement ok
create table t1 (a bigint, b int, c int) as values (null, null, 1), (null, 2, null);
Expand Down
27 changes: 27 additions & 0 deletions datafusion/sqllogictest/test_files/insert.slt
Original file line number Diff line number Diff line change
Expand Up @@ -433,3 +433,30 @@ drop table test_column_defaults

statement error DataFusion error: Error during planning: Column reference is not allowed in the DEFAULT expression : Schema error: No field named a.
create table test_column_defaults(a int, b int default a+1)


# test inserting UInt64 and signed integers into a bigint unsigned column
statement ok
create table unsigned_bigint_test (v bigint unsigned)

query I
insert into unsigned_bigint_test values (10000000000000000000), (18446744073709551615)
----
2

query I
insert into unsigned_bigint_test values (10000000000000000001), (1), (10000000000000000002)
----
3

query I rowsort
select * from unsigned_bigint_test
----
1
10000000000000000000
10000000000000000001
10000000000000000002
18446744073709551615

statement ok
drop table unsigned_bigint_test
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/joins.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1828,7 +1828,7 @@ AS VALUES
('BB', 6, 1),
('BB', 6, 1);

query TII
query TIR
select col1, col2, coalesce(sum_col3, 0) as sum_col3
from (select distinct col2 from tbl) AS q1
cross join (select distinct col1 from tbl) AS q2
Expand Down
10 changes: 5 additions & 5 deletions datafusion/sqllogictest/test_files/union.slt
Original file line number Diff line number Diff line change
Expand Up @@ -413,23 +413,23 @@ explain SELECT c1, c9 FROM aggregate_test_100 UNION ALL SELECT c1, c3 FROM aggre
logical_plan
01)Sort: aggregate_test_100.c9 DESC NULLS FIRST, fetch=5
02)--Union
03)----Projection: aggregate_test_100.c1, CAST(aggregate_test_100.c9 AS Int64) AS c9
03)----Projection: aggregate_test_100.c1, CAST(aggregate_test_100.c9 AS Decimal128(20, 0)) AS c9
04)------TableScan: aggregate_test_100 projection=[c1, c9]
05)----Projection: aggregate_test_100.c1, CAST(aggregate_test_100.c3 AS Int64) AS c9
05)----Projection: aggregate_test_100.c1, CAST(aggregate_test_100.c3 AS Decimal128(20, 0)) AS c9
06)------TableScan: aggregate_test_100 projection=[c1, c3]
physical_plan
01)SortPreservingMergeExec: [c9@1 DESC], fetch=5
02)--UnionExec
03)----SortExec: TopK(fetch=5), expr=[c9@1 DESC], preserve_partitioning=[true]
04)------ProjectionExec: expr=[c1@0 as c1, CAST(c9@1 AS Int64) as c9]
04)------ProjectionExec: expr=[c1@0 as c1, CAST(c9@1 AS Decimal128(20, 0)) as c9]
05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c9], has_header=true
07)----SortExec: TopK(fetch=5), expr=[c9@1 DESC], preserve_partitioning=[true]
08)------ProjectionExec: expr=[c1@0 as c1, CAST(c3@1 AS Int64) as c9]
08)------ProjectionExec: expr=[c1@0 as c1, CAST(c3@1 AS Decimal128(20, 0)) as c9]
09)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
10)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c3], has_header=true

query TI
query TR
SELECT c1, c9 FROM aggregate_test_100 UNION ALL SELECT c1, c3 FROM aggregate_test_100 ORDER BY c9 DESC LIMIT 5
----
c 4268716378
Expand Down

0 comments on commit e0f9f65

Please sign in to comment.