Skip to content

Commit

Permalink
Fixed #34955 -- Made Concat() use || operator on PostgreSQL.
Browse files Browse the repository at this point in the history
This also avoids casting string based expressions in Concat() on
PostgreSQL.
  • Loading branch information
charettes authored and felixxm committed Nov 14, 2023
1 parent bdf30b9 commit 6364b6e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 15 deletions.
32 changes: 18 additions & 14 deletions django/db/models/functions/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class ConcatPair(Func):

function = "CONCAT"

def as_sqlite(self, compiler, connection, **extra_context):
def pipes_concat_sql(self, compiler, connection, **extra_context):
coalesced = self.coalesce()
return super(ConcatPair, coalesced).as_sql(
compiler,
Expand All @@ -83,19 +83,19 @@ def as_sqlite(self, compiler, connection, **extra_context):
**extra_context,
)

as_sqlite = pipes_concat_sql

def as_postgresql(self, compiler, connection, **extra_context):
copy = self.copy()
copy.set_source_expressions(
c = self.copy()
c.set_source_expressions(
[
Cast(expression, TextField())
for expression in copy.get_source_expressions()
expression
if isinstance(expression.output_field, (CharField, TextField))
else Cast(expression, TextField())
for expression in c.get_source_expressions()
]
)
return super(ConcatPair, copy).as_sql(
compiler,
connection,
**extra_context,
)
return c.pipes_concat_sql(compiler, connection, **extra_context)

def as_mysql(self, compiler, connection, **extra_context):
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
Expand Down Expand Up @@ -132,16 +132,20 @@ class Concat(Func):
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError("Concat must take at least two expressions")
paired = self._paired(expressions)
paired = self._paired(expressions, output_field=extra.get("output_field"))
super().__init__(paired, **extra)

def _paired(self, expressions):
def _paired(self, expressions, output_field):
# wrap pairs of expressions in successive concat functions
# exp = [a, b, c, d]
# -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
if len(expressions) == 2:
return ConcatPair(*expressions)
return ConcatPair(expressions[0], self._paired(expressions[1:]))
return ConcatPair(*expressions, output_field=output_field)
return ConcatPair(
expressions[0],
self._paired(expressions[1:], output_field=output_field),
output_field=output_field,
)


class Left(Func):
Expand Down
20 changes: 19 additions & 1 deletion tests/db_functions/text/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def test_mixed_char_text(self):
expected = article.title + " - " + article.text
self.assertEqual(expected.upper(), article.title_text)

@skipUnless(connection.vendor == "sqlite", "sqlite specific implementation detail.")
@skipUnless(
connection.vendor in ("sqlite", "postgresql"),
"SQLite and PostgreSQL specific implementation detail.",
)
def test_coalesce_idempotent(self):
pair = ConcatPair(V("a"), V("b"))
# Check nodes counts
Expand All @@ -89,3 +92,18 @@ def test_sql_generation_idempotency(self):
qs = Article.objects.annotate(description=Concat("title", V(": "), "summary"))
# Multiple compilations should not alter the generated query.
self.assertEqual(str(qs.query), str(qs.all().query))

def test_concat_non_str(self):
Author.objects.create(name="The Name", age=42)
with self.assertNumQueries(1) as ctx:
author = Author.objects.annotate(
name_text=Concat(
"name", V(":"), "alias", V(":"), "age", output_field=TextField()
),
).get()
self.assertEqual(author.name_text, "The Name::42")
# Only non-string columns are casted on PostgreSQL.
self.assertEqual(
ctx.captured_queries[0]["sql"].count("::text"),
1 if connection.vendor == "postgresql" else 0,
)

0 comments on commit 6364b6e

Please sign in to comment.