Skip to content

Commit

Permalink
Refactor to handle nested jinja and preserve original sql text
Browse files Browse the repository at this point in the history
  • Loading branch information
themisvaltinos committed Dec 18, 2024
1 parent b0b8fed commit 123ac81
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 41 deletions.
107 changes: 69 additions & 38 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ def _is_virtual_statement_end(tokens: t.List[Token], pos: int) -> bool:
return _is_command_statement(ON_VIRTUAL_UPDATE_END, tokens, pos)


def virtual_statement(statement: exp.Expression) -> VirtualUpdateStatement:
def virtual_statement(statement: t.List[exp.Expression]) -> VirtualUpdateStatement:
return VirtualUpdateStatement(this=statement)


Expand Down Expand Up @@ -842,9 +842,10 @@ def parse(
while pos < total:
token = tokens[pos]
if _is_virtual_statement_end(tokens, pos):
pos += 2
chunks[-1][0].append(token)
virtual = False
chunks.append(([], ChunkType.SQL))
pos += 2
elif _is_jinja_end(tokens, pos) or (
chunks[-1][1] == ChunkType.SQL
and token.token_type == TokenType.SEMICOLON
Expand All @@ -856,25 +857,22 @@ def parse(
# Jinja end statement
chunks[-1][0].append(token)
pos += 2
if virtual and tokens[pos] != ON_VIRTUAL_UPDATE_END:
# This is required for nested Jinja statements that precede
# SQL statements within an ON_VIRTUAL_UPDATE block
chunks.append(
(
[Token(TokenType.VAR, text=ON_VIRTUAL_UPDATE_BEGIN)],
ChunkType.VIRTUAL_STATEMENT,
)
chunks.append(
(
[],
ChunkType.VIRTUAL_STATEMENT
if virtual and tokens[pos] != ON_VIRTUAL_UPDATE_END
else ChunkType.SQL,
)
else:
chunks.append(([], ChunkType.SQL))
)
elif _is_jinja_query_begin(tokens, pos):
chunks.append(([token], ChunkType.JINJA_QUERY))
pos += 2
elif _is_jinja_statement_begin(tokens, pos):
chunks.append(
(
[token],
ChunkType.VIRTUAL_JINJA_STATEMENT if virtual else ChunkType.JINJA_STATEMENT,
ChunkType.JINJA_STATEMENT,
)
)
pos += 2
Expand All @@ -889,32 +887,65 @@ def parse(
parser = dialect.parser()
expressions: t.List[exp.Expression] = []

for chunk, chunk_type in chunks:
if chunk_type == ChunkType.SQL:
parsed_expressions: t.List[t.Optional[exp.Expression]] = (
parser.parse(chunk, sql) if into is None else parser.parse_into(into, chunk, sql)
)
for expression in parsed_expressions:
if expression:
expression.meta["sql"] = parser._find_sql(chunk[0], chunk[-1])
expressions.append(expression)
elif chunk_type == ChunkType.VIRTUAL_STATEMENT:
sql_chunk = chunk[1:-1]
for expression in parser.parse(sql_chunk, sql):
if expression:
expression.meta["sql"] = expression.sql(dialect=dialect)
expressions.append(virtual_statement(expression))
def parse_sql_chunk(chunk: t.List[Token]) -> t.List[exp.Expression]:
parsed_expressions: t.List[t.Optional[exp.Expression]] = (
parser.parse(chunk, sql) if into is None else parser.parse_into(into, chunk, sql)
)
expressions = []
for expression in parsed_expressions:
if expression:
expression.meta["sql"] = parser._find_sql(chunk[0], chunk[-1])
expressions.append(expression)
return expressions

def parse_jinja_chunk(chunk: t.List[Token]) -> exp.Expression:
start, *_, end = chunk
segment = sql[start.end + 2 : end.start - 1]
factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement
expression = factory(segment.strip())
meta_sql = sql[start.start : end.end + 1]
expression.meta["sql"] = meta_sql
return expression

def parse_virtual_statement(
chunks: t.List[t.Tuple[t.List[Token], ChunkType]], pos: int
) -> t.Tuple[t.List[exp.Expression], int]:
# For virtual statements we need to handle both SQL and Jinja nested blocks within the chunk
virtual_update_statements = []
start = chunks[pos][0][0].start

while chunks[pos - 1][0] == [] or (
chunks[pos - 1][0] and chunks[pos - 1][0][-1].text != ON_VIRTUAL_UPDATE_END
):
chunk, chunk_type = chunks[pos]
if chunk_type == ChunkType.JINJA_STATEMENT:
virtual_update_statements.append(parse_jinja_chunk(chunk))
else:
virtual_update_statements.extend(
parse_sql_chunk(chunk[int(chunk[0].text == "ON_VIRTUAL_UPDATE_BEGIN") : -1])
)
pos += 1

if virtual_update_statements:
statement = virtual_statement(virtual_update_statements)
end = chunk[-1].end + 1
statement.meta["sql"] = sql[start:end]
return [statement], pos

return [], pos

pos = 0
total_chunks = len(chunks)
while pos < total_chunks:
chunk, chunk_type = chunks[pos]
if chunk_type == ChunkType.VIRTUAL_STATEMENT:
virtual_expression, pos = parse_virtual_statement(chunks, pos)
expressions.extend(virtual_expression)
elif chunk_type == ChunkType.SQL:
expressions.extend(parse_sql_chunk(chunk))
else:
start, *_, end = chunk
segment = sql[start.end + 2 : end.start - 1]
factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement
expression = factory(segment.strip())
expression.meta["sql"] = sql[start.start : end.end + 1]
expressions.append(
virtual_statement(expression)
if chunk_type == ChunkType.VIRTUAL_JINJA_STATEMENT
else expression
)
expressions.append(parse_jinja_chunk(chunk))
pos += 1

return expressions

Expand Down
8 changes: 5 additions & 3 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2113,9 +2113,11 @@ def _split_sql_model_statements(
loaded_audit = load_audit([expr, expressions[idx + 1]], dialect=dialect)
assert isinstance(loaded_audit, ModelAudit)
inline_audits[loaded_audit.name] = loaded_audit
idx += 1
idx += 2
elif isinstance(expr, d.VirtualUpdateStatement):
on_virtual_update.append(expr.this)
for statement in expr.this:
on_virtual_update.append(statement)
idx += 1
else:
if (
isinstance(expr, (exp.Query, d.JinjaQuery))
Expand All @@ -2127,7 +2129,7 @@ def _split_sql_model_statements(
):
query_positions.append((expr, idx))
sql_statements.append(expr)
idx += 1
idx += 1

if not query_positions:
return None, sql_statements, [], on_virtual_update, inline_audits
Expand Down

0 comments on commit 123ac81

Please sign in to comment.