From 123ac81dd71aa9dbf990ca8a6cc30672951e9c24 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Wed, 18 Dec 2024 17:39:08 +0200 Subject: [PATCH] Refactor to handle nested jinja and preserve original sql text --- sqlmesh/core/dialect.py | 107 ++++++++++++++++++++----------- sqlmesh/core/model/definition.py | 8 ++- 2 files changed, 74 insertions(+), 41 deletions(-) diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 2709b9c3b..9db8de283 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -789,7 +789,7 @@ def _is_virtual_statement_end(tokens: t.List[Token], pos: int) -> bool: return _is_command_statement(ON_VIRTUAL_UPDATE_END, tokens, pos) -def virtual_statement(statement: exp.Expression) -> VirtualUpdateStatement: +def virtual_statement(statement: t.List[exp.Expression]) -> VirtualUpdateStatement: return VirtualUpdateStatement(this=statement) @@ -842,9 +842,10 @@ def parse( while pos < total: token = tokens[pos] if _is_virtual_statement_end(tokens, pos): - pos += 2 + chunks[-1][0].append(token) virtual = False chunks.append(([], ChunkType.SQL)) + pos += 2 elif _is_jinja_end(tokens, pos) or ( chunks[-1][1] == ChunkType.SQL and token.token_type == TokenType.SEMICOLON @@ -856,17 +857,14 @@ def parse( # Jinja end statement chunks[-1][0].append(token) pos += 2 - if virtual and tokens[pos] != ON_VIRTUAL_UPDATE_END: - # This is required for nested Jinja statements that precede - # SQL statements within an ON_VIRTUAL_UPDATE block - chunks.append( - ( - [Token(TokenType.VAR, text=ON_VIRTUAL_UPDATE_BEGIN)], - ChunkType.VIRTUAL_STATEMENT, - ) + chunks.append( + ( + [], + ChunkType.VIRTUAL_STATEMENT + if virtual and tokens[pos] != ON_VIRTUAL_UPDATE_END + else ChunkType.SQL, ) - else: - chunks.append(([], ChunkType.SQL)) + ) elif _is_jinja_query_begin(tokens, pos): chunks.append(([token], ChunkType.JINJA_QUERY)) pos += 2 @@ -874,7 +872,7 @@ def parse( chunks.append( ( [token], - ChunkType.VIRTUAL_JINJA_STATEMENT if virtual else ChunkType.JINJA_STATEMENT, + ChunkType.JINJA_STATEMENT, ) ) pos += 2 @@ -889,32 +887,65 @@ def parse( parser = dialect.parser() expressions: t.List[exp.Expression] = [] - for chunk, chunk_type in chunks: - if chunk_type == ChunkType.SQL: - parsed_expressions: t.List[t.Optional[exp.Expression]] = ( - parser.parse(chunk, sql) if into is None else parser.parse_into(into, chunk, sql) - ) - for expression in parsed_expressions: - if expression: - expression.meta["sql"] = parser._find_sql(chunk[0], chunk[-1]) - expressions.append(expression) - elif chunk_type == ChunkType.VIRTUAL_STATEMENT: - sql_chunk = chunk[1:-1] - for expression in parser.parse(sql_chunk, sql): - if expression: - expression.meta["sql"] = expression.sql(dialect=dialect) - expressions.append(virtual_statement(expression)) + def parse_sql_chunk(chunk: t.List[Token]) -> t.List[exp.Expression]: + parsed_expressions: t.List[t.Optional[exp.Expression]] = ( + parser.parse(chunk, sql) if into is None else parser.parse_into(into, chunk, sql) + ) + expressions = [] + for expression in parsed_expressions: + if expression: + expression.meta["sql"] = parser._find_sql(chunk[0], chunk[-1]) + expressions.append(expression) + return expressions + + def parse_jinja_chunk(chunk: t.List[Token]) -> exp.Expression: + start, *_, end = chunk + segment = sql[start.end + 2 : end.start - 1] + factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement + expression = factory(segment.strip()) + meta_sql = sql[start.start : end.end + 1] + expression.meta["sql"] = meta_sql + return expression + + def parse_virtual_statement( + chunks: t.List[t.Tuple[t.List[Token], ChunkType]], pos: int + ) -> t.Tuple[t.List[exp.Expression], int]: + # For virtual statements we need to handle both SQL and Jinja nested blocks within the chunk + virtual_update_statements = [] + start = chunks[pos][0][0].start + + while chunks[pos - 1][0] == [] or ( + chunks[pos - 1][0] and chunks[pos - 1][0][-1].text != ON_VIRTUAL_UPDATE_END + ): + chunk, chunk_type = chunks[pos] + if chunk_type == ChunkType.JINJA_STATEMENT: + virtual_update_statements.append(parse_jinja_chunk(chunk)) + else: + virtual_update_statements.extend( + parse_sql_chunk(chunk[int(chunk[0].text == "ON_VIRTUAL_UPDATE_BEGIN") : -1]) + ) + pos += 1 + + if virtual_update_statements: + statement = virtual_statement(virtual_update_statements) + end = chunk[-1].end + 1 + statement.meta["sql"] = sql[start:end] + return [statement], pos + + return [], pos + + pos = 0 + total_chunks = len(chunks) + while pos < total_chunks: + chunk, chunk_type = chunks[pos] + if chunk_type == ChunkType.VIRTUAL_STATEMENT: + virtual_expression, pos = parse_virtual_statement(chunks, pos) + expressions.extend(virtual_expression) + elif chunk_type == ChunkType.SQL: + expressions.extend(parse_sql_chunk(chunk)) else: - start, *_, end = chunk - segment = sql[start.end + 2 : end.start - 1] - factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement - expression = factory(segment.strip()) - expression.meta["sql"] = sql[start.start : end.end + 1] - expressions.append( - virtual_statement(expression) - if chunk_type == ChunkType.VIRTUAL_JINJA_STATEMENT - else expression - ) + expressions.append(parse_jinja_chunk(chunk)) + pos += 1 return expressions diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index db4cd3115..02d9884ae 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -2113,9 +2113,11 @@ def _split_sql_model_statements( loaded_audit = load_audit([expr, expressions[idx + 1]], dialect=dialect) assert isinstance(loaded_audit, ModelAudit) inline_audits[loaded_audit.name] = loaded_audit - idx += 1 + idx += 2 elif isinstance(expr, d.VirtualUpdateStatement): - on_virtual_update.append(expr.this) + for statement in expr.this: + on_virtual_update.append(statement) + idx += 1 else: if ( isinstance(expr, (exp.Query, d.JinjaQuery)) @@ -2127,7 +2129,7 @@ def _split_sql_model_statements( ): query_positions.append((expr, idx)) sql_statements.append(expr) - idx += 1 + idx += 1 if not query_positions: return None, sql_statements, [], on_virtual_update, inline_audits