diff --git a/docs/concepts/macros/macro_variables.md b/docs/concepts/macros/macro_variables.md index 0774dd3fa..3fab383b5 100644 --- a/docs/concepts/macros/macro_variables.md +++ b/docs/concepts/macros/macro_variables.md @@ -128,6 +128,6 @@ SQLMesh provides two other predefined variables used to modify model behavior ba * 'evaluating' - The model query logic is being evaluated. * 'testing' - The model query logic is being evaluated in the context of a unit test. * @gateway - A string value containing the name of the current [gateway](../../guides/connections.md). -* @this_model - A string value containing the name of the physical table the model view selects from. Typically used to create [generic audits](../audits.md#generic-audits). +* @this_model - A string value containing the name of the physical table the model view selects from. Typically used to create [generic audits](../audits.md#generic-audits). In the case of [on_virtual_update statements](../models/sql_models.md#optional-on-virtual-update-statements) it contains the qualified view name instead. * Can be used in model definitions when SQLGlot cannot fully parse a statement and you need to reference the model's underlying physical table directly. * Can be passed as an argument to macros that access or interact with the underlying physical table. diff --git a/docs/concepts/models/python_models.md b/docs/concepts/models/python_models.md index 3b0faade3..a66a5649f 100644 --- a/docs/concepts/models/python_models.md +++ b/docs/concepts/models/python_models.md @@ -164,6 +164,41 @@ def execute( context.fetchdf("CREATE INDEX idx ON example.pre_post_statements (id);") ``` +## Optional on-virtual-update statements + +The optional on-virtual-update statements allow you to execute SQL commands after the completion of the [Virtual Update](#virtual-update). + +These can be used, for example, to grant privileges on views of the virtual layer. + +Similar to pre/post-statements you can set the `on_virtual_update` argument in the `@model` decorator to a list of SQL strings, SQLGlot expressions, or macro calls. + +``` python linenums="1" hl_lines="8" +@model( + "db.test_model", + kind="full", + columns={ + "id": "int", + "name": "text", + }, + on_virtual_update=["GRANT SELECT ON VIEW @this_model TO ROLE dev_role"], +) +def execute( + context: ExecutionContext, + start: datetime, + end: datetime, + execution_time: datetime, + **kwargs: t.Any, +) -> pd.DataFrame: + + return pd.DataFrame([ + {"id": 1, "name": "name"} + ]) +``` + +!!! note + + Table resolution for these statements occurs at the virtual layer. This means that table names, including `@this_model` macro, are resolved to their qualified view names. For instance, when running the plan in an environment named `dev`, `db.test_model` and `@this_model` would resolve to `db__dev.test_model` and not to the physical table name. + ## Dependencies In order to fetch data from an upstream model, you first get the table name using `context`'s `resolve_table` method. This returns the appropriate table name for the current runtime [environment](../environments.md): diff --git a/docs/concepts/models/seed_models.md b/docs/concepts/models/seed_models.md index d1970f958..ec3e6a128 100644 --- a/docs/concepts/models/seed_models.md +++ b/docs/concepts/models/seed_models.md @@ -194,3 +194,32 @@ ALTER SESSION SET TIMEZONE = 'UTC'; -- These are post-statements ALTER SESSION SET TIMEZONE = 'PST'; ``` + +## On-virtual-update statements + +Seed models also support on-virtual-update statements, which are executed after the completion of the [Virtual Update](#virtual-update). + +These must be enclosed within an `ON_VIRTUAL_UPDATE_BEGIN;` ...; `ON_VIRTUAL_UPDATE_END;` block: + +```sql linenums="1" hl_lines="8-13" +MODEL ( + name test_db.national_holidays, + kind SEED ( + path 'national_holidays.csv' + ) +); + +ON_VIRTUAL_UPDATE_BEGIN; +GRANT SELECT ON VIEW @this_model TO ROLE dev_role; +JINJA_STATEMENT_BEGIN; +GRANT SELECT ON VIEW {{ this_model }} TO ROLE admin_role; +JINJA_END; +ON_VIRTUAL_UPDATE_END; +``` + + +[Jinja expressions](../macros/jinja_macros.md) can also be used within them, as demonstrated in the example above. These expressions must be properly nested within a `JINJA_STATEMENT_BEGIN;` and `JINJA_END;` block. + +!!! note + + Table resolution for these statements occurs at the virtual layer. This means that table names, including `@this_model` macro, are resolved to their qualified view names. For instance, when running the plan in an environment named `dev`, `db.customers` and `@this_model` would resolve to `db__dev.customers` and not to the physical table name. \ No newline at end of file diff --git a/docs/concepts/models/sql_models.md b/docs/concepts/models/sql_models.md index d5f6d910f..56e2a5955 100644 --- a/docs/concepts/models/sql_models.md +++ b/docs/concepts/models/sql_models.md @@ -10,6 +10,7 @@ The SQL-based definition of SQL models is the most common one, and consists of t * Optional pre-statements * A single query * Optional post-statements +* Optional on-virtual-update-statements These models are designed to look and feel like you're simply using SQL, but they can be customized for advanced use cases. @@ -90,6 +91,38 @@ MODEL ( Note that the SQL command `UNCACHE TABLE countries` inside the `@IF()` macro does **not** end with a semi-colon. Instead, the semi-colon comes after the `@IF()` macro's closing parenthesis. +### Optional on-virtual-update statements + +The optional on-virtual-update statements allow you to execute SQL commands after the completion of the [Virtual Update](#virtual-update). + +These can be used, for example, to grant privileges on views of the virtual layer. + +These SQL statements must be enclosed within an `ON_VIRTUAL_UPDATE_BEGIN;` ...; `ON_VIRTUAL_UPDATE_END;` block like this: + +```sql linenums="1" hl_lines="10-15" +MODEL ( + name db.customers, + kind FULL +); + +SELECT + r.id::INT +FROM raw.restaurants AS r; + +ON_VIRTUAL_UPDATE_BEGIN; +GRANT SELECT ON VIEW @this_model TO ROLE role_name; +JINJA_STATEMENT_BEGIN; +GRANT SELECT ON VIEW {{ this_model }} TO ROLE admin; +JINJA_END; +ON_VIRTUAL_UPDATE_END; +``` + +[Jinja expressions](../macros/jinja_macros.md) can also be used within them, as demonstrated in the example above. These expressions must be properly nested within a `JINJA_STATEMENT_BEGIN;` and `JINJA_END;` block. + +!!! note + + Table resolution for these statements occurs at the virtual layer. This means that table names, including `@this_model` macro, are resolved to their qualified view names. For instance, when running the plan in an environment named `dev`, `db.customers` and `@this_model` would resolve to `db__dev.customers` and not to the physical table name. + ### The model query The model must contain a standalone query, which can be a single `SELECT` expression, or multiple `SELECT` expressions combined with the `UNION`, `INTERSECT`, or `EXCEPT` operators. The result of this query will be used to populate the model's table or view. @@ -98,7 +131,7 @@ The model must contain a standalone query, which can be a single `SELECT` expres The Python-based definition of SQL models consists of a single python function, decorated with SQLMesh's `@model` [decorator](https://wiki.python.org/moin/PythonDecorators). The decorator is required to have the `is_sql` keyword argument set to `True` to distinguish it from [Python models](./python_models.md) that return DataFrame instances. -This function's return value serves as the model's query, and it must be either a SQL string or a [SQLGlot expression](https://github.com/tobymao/sqlglot/blob/main/sqlglot/expressions.py). The `@model` decorator is used to define the model's [metadata](#MODEL-DDL) and, optionally its pre/post-statements that are also in the form of SQL strings or SQLGlot expressions. +This function's return value serves as the model's query, and it must be either a SQL string or a [SQLGlot expression](https://github.com/tobymao/sqlglot/blob/main/sqlglot/expressions.py). The `@model` decorator is used to define the model's [metadata](#MODEL-DDL) and, optionally its pre/post-statements or on-virtual-update-statements that are also in the form of SQL strings or SQLGlot expressions. Defining a SQL model using Python can be beneficial in cases where its query is too complex to express cleanly in SQL, for example due to having many dynamic components that would require heavy use of [macros](../macros/overview/). Since Python-based models generate SQL, they support the same features as regular SQL models, such as column-level [lineage](../glossary/#lineage). @@ -120,6 +153,7 @@ from sqlmesh.core.macros import MacroEvaluator kind="FULL", pre_statements=["CACHE TABLE countries AS SELECT * FROM raw.countries"], post_statements=["UNCACHE TABLE countries"], + on_virtual_update=["GRANT SELECT ON VIEW @this_model TO ROLE dev_role"], ) def entrypoint(evaluator: MacroEvaluator) -> str | exp.Expression: return ( @@ -139,7 +173,7 @@ One could also define this model by simply returning a string that contained the The `@model` decorator is the Python equivalent of the `MODEL` DDL. -In addition to model metadata and configuration information, one can also set the keyword arguments `pre_statements` and `post_statements` to a list of SQL strings and/or SQLGlot expressions to define the pre/post-statements of the model, respectively. +In addition to model metadata and configuration information, one can also set the keyword arguments `pre_statements`, `post_statements` and `on_virtual_update` to a list of SQL strings and/or SQLGlot expressions to define the pre/post-statements and on-virtual-update-statements of the model, respectively. !!! note diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 7b5ac2655..5af4a0e7f 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -61,6 +61,10 @@ class JinjaStatement(Jinja): pass +class VirtualUpdateStatement(exp.Expression): + arg_types = {"expressions": True} + + class ModelKind(exp.Expression): arg_types = {"this": True, "expressions": False} @@ -772,6 +776,8 @@ def _is_command_statement(command: str, tokens: t.List[Token], pos: int) -> bool JINJA_QUERY_BEGIN = "JINJA_QUERY_BEGIN" JINJA_STATEMENT_BEGIN = "JINJA_STATEMENT_BEGIN" JINJA_END = "JINJA_END" +ON_VIRTUAL_UPDATE_BEGIN = "ON_VIRTUAL_UPDATE_BEGIN" +ON_VIRTUAL_UPDATE_END = "ON_VIRTUAL_UPDATE_END" def _is_jinja_statement_begin(tokens: t.List[Token], pos: int) -> bool: @@ -794,10 +800,24 @@ def jinja_statement(statement: str) -> JinjaStatement: return JinjaStatement(this=exp.Literal.string(statement.strip())) +def _is_virtual_statement_begin(tokens: t.List[Token], pos: int) -> bool: + return _is_command_statement(ON_VIRTUAL_UPDATE_BEGIN, tokens, pos) + + +def _is_virtual_statement_end(tokens: t.List[Token], pos: int) -> bool: + return _is_command_statement(ON_VIRTUAL_UPDATE_END, tokens, pos) + + +def virtual_statement(statements: t.List[exp.Expression]) -> VirtualUpdateStatement: + return VirtualUpdateStatement(expressions=statements) + + class ChunkType(Enum): JINJA_QUERY = auto() JINJA_STATEMENT = auto() SQL = auto() + VIRTUAL_STATEMENT = auto() + VIRTUAL_JINJA_STATEMENT = auto() def parse_one( @@ -837,9 +857,15 @@ def parse( total = len(tokens) pos = 0 + virtual = False while pos < total: token = tokens[pos] - if _is_jinja_end(tokens, pos) or ( + if _is_virtual_statement_end(tokens, pos): + chunks[-1][0].append(token) + virtual = False + chunks.append(([], ChunkType.SQL)) + pos += 2 + elif _is_jinja_end(tokens, pos) or ( chunks[-1][1] == ChunkType.SQL and token.token_type == TokenType.SEMICOLON and pos < total - 1 @@ -850,13 +876,24 @@ def parse( # Jinja end statement chunks[-1][0].append(token) pos += 2 - chunks.append(([], ChunkType.SQL)) + chunks.append( + ( + [], + ChunkType.VIRTUAL_STATEMENT + if virtual and tokens[pos] != ON_VIRTUAL_UPDATE_END + else ChunkType.SQL, + ) + ) elif _is_jinja_query_begin(tokens, pos): chunks.append(([token], ChunkType.JINJA_QUERY)) pos += 2 elif _is_jinja_statement_begin(tokens, pos): chunks.append(([token], ChunkType.JINJA_STATEMENT)) pos += 2 + elif _is_virtual_statement_begin(tokens, pos): + chunks.append(([token], ChunkType.VIRTUAL_STATEMENT)) + pos += 2 + virtual = True else: chunks[-1][0].append(token) pos += 1 @@ -864,22 +901,68 @@ def parse( parser = dialect.parser() expressions: t.List[exp.Expression] = [] - for chunk, chunk_type in chunks: - if chunk_type == ChunkType.SQL: - parsed_expressions: t.List[t.Optional[exp.Expression]] = ( - parser.parse(chunk, sql) if into is None else parser.parse_into(into, chunk, sql) - ) - for expression in parsed_expressions: - if expression: + def parse_sql_chunk(chunk: t.List[Token], meta_sql: bool = True) -> t.List[exp.Expression]: + parsed_expressions: t.List[t.Optional[exp.Expression]] = ( + parser.parse(chunk, sql) if into is None else parser.parse_into(into, chunk, sql) + ) + expressions = [] + for expression in parsed_expressions: + if expression: + if meta_sql: expression.meta["sql"] = parser._find_sql(chunk[0], chunk[-1]) - expressions.append(expression) - else: - start, *_, end = chunk - segment = sql[start.end + 2 : end.start - 1] - factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement - expression = factory(segment.strip()) + expressions.append(expression) + return expressions + + def parse_jinja_chunk(chunk: t.List[Token], meta_sql: bool = True) -> exp.Expression: + start, *_, end = chunk + segment = sql[start.end + 2 : end.start - 1] + factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement + expression = factory(segment.strip()) + if meta_sql: expression.meta["sql"] = sql[start.start : end.end + 1] - expressions.append(expression) + return expression + + def parse_virtual_statement( + chunks: t.List[t.Tuple[t.List[Token], ChunkType]], pos: int + ) -> t.Tuple[t.List[exp.Expression], int]: + # For virtual statements we need to handle both SQL and Jinja nested blocks within the chunk + virtual_update_statements = [] + start = chunks[pos][0][0].start + + while ( + chunks[pos - 1][0] == [] or chunks[pos - 1][0][-1].text.upper() != ON_VIRTUAL_UPDATE_END + ): + chunk, chunk_type = chunks[pos] + if chunk_type == ChunkType.JINJA_STATEMENT: + virtual_update_statements.append(parse_jinja_chunk(chunk, False)) + else: + virtual_update_statements.extend( + parse_sql_chunk( + chunk[int(chunk[0].text.upper() == ON_VIRTUAL_UPDATE_BEGIN) : -1], False + ), + ) + pos += 1 + + if virtual_update_statements: + statements = virtual_statement(virtual_update_statements) + end = chunk[-1].end + 1 + statements.meta["sql"] = sql[start:end] + return [statements], pos + + return [], pos + + pos = 0 + total_chunks = len(chunks) + while pos < total_chunks: + chunk, chunk_type = chunks[pos] + if chunk_type == ChunkType.VIRTUAL_STATEMENT: + virtual_expression, pos = parse_virtual_statement(chunks, pos) + expressions.extend(virtual_expression) + elif chunk_type == ChunkType.SQL: + expressions.extend(parse_sql_chunk(chunk)) + else: + expressions.append(parse_jinja_chunk(chunk)) + pos += 1 return expressions diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index 47ae1f73b..be5268ca3 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -287,6 +287,7 @@ def depends_on(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) -> t.Optional[ "expressions_", "pre_statements_", "post_statements_", + "on_virtual_update_", "unique_key", mode="before", check_fields=False, diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py index cff43e974..61313cb98 100644 --- a/sqlmesh/core/model/decorator.py +++ b/sqlmesh/core/model/decorator.py @@ -135,7 +135,7 @@ def model( **self.kwargs, } - for key in ("pre_statements", "post_statements"): + for key in ("pre_statements", "post_statements", "on_virtual_update"): statements = common_kwargs.get(key) if statements: common_kwargs[key] = [ diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 4207d1780..df0c50d28 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -134,6 +134,9 @@ class _Model(ModelMeta, frozen=True): post_statements_: t.Optional[t.List[exp.Expression]] = Field( default=None, alias="post_statements" ) + on_virtual_update_: t.Optional[t.List[exp.Expression]] = Field( + default=None, alias="on_virtual_update" + ) _expressions_validator = expression_validator @@ -418,6 +421,30 @@ def render_post_statements( **kwargs, ) + def render_on_virtual_update( + self, + *, + start: t.Optional[TimeLike] = None, + end: t.Optional[TimeLike] = None, + execution_time: t.Optional[TimeLike] = None, + snapshots: t.Optional[t.Dict[str, Snapshot]] = None, + expand: t.Iterable[str] = tuple(), + deployability_index: t.Optional[DeployabilityIndex] = None, + engine_adapter: t.Optional[EngineAdapter] = None, + **kwargs: t.Any, + ) -> t.List[exp.Expression]: + return self._render_statements( + self.on_virtual_update, + start=start, + end=end, + execution_time=execution_time, + snapshots=snapshots, + expand=expand, + deployability_index=deployability_index, + engine_adapter=engine_adapter, + **kwargs, + ) + def render_audit_query( self, audit: Audit, @@ -501,10 +528,18 @@ def pre_statements(self) -> t.List[exp.Expression]: def post_statements(self) -> t.List[exp.Expression]: return self.post_statements_ or [] + @property + def on_virtual_update(self) -> t.List[exp.Expression]: + return self.on_virtual_update_ or [] + @property def macro_definitions(self) -> t.List[d.MacroDef]: """All macro definitions from the list of expressions.""" - return [s for s in self.pre_statements + self.post_statements if isinstance(s, d.MacroDef)] + return [ + s + for s in self.pre_statements + self.post_statements + self.on_virtual_update + if isinstance(s, d.MacroDef) + ] def _render_statements( self, @@ -1030,6 +1065,9 @@ def _additional_metadata(self) -> t.List[str]: if self._is_metadata_statement(statement): additional_metadata.append(gen(statement)) + for statement in self.on_virtual_update: + additional_metadata.append(gen(statement)) + return additional_metadata def _is_metadata_statement(self, statement: exp.Expression) -> bool: @@ -1098,6 +1136,7 @@ class SqlModel(_Model): query: The main query representing the model. pre_statements: The list of SQL statements that precede the model's query. post_statements: The list of SQL statements that follow after the model's query. + on_virtual_update: The list of SQL statements to be executed after the virtual update. """ query: t.Union[exp.Query, d.JinjaQuery, d.MacroFunc] @@ -1159,6 +1198,7 @@ def render_definition( result.extend(self.pre_statements) result.append(self.query) result.extend(self.post_statements) + result.extend(self.on_virtual_update) return result @property @@ -1733,7 +1773,7 @@ def load_sql_based_model( rendered_meta = rendered_meta_exprs[0] # Extract the query and any pre/post statements - query_or_seed_insert, pre_statements, post_statements, inline_audits = ( + query_or_seed_insert, pre_statements, post_statements, on_virtual_update, inline_audits = ( _split_sql_model_statements(expressions[1:], path, dialect=dialect) ) @@ -1776,6 +1816,7 @@ def load_sql_based_model( common_kwargs = dict( pre_statements=pre_statements, post_statements=post_statements, + on_virtual_update=on_virtual_update, defaults=defaults, path=path, module_path=module_path, @@ -2027,6 +2068,8 @@ def _create_model( statements.append(kwargs["query"]) if "post_statements" in kwargs: statements.extend(kwargs["post_statements"]) + if "on_virtual_update" in kwargs: + statements.extend(kwargs["on_virtual_update"]) jinja_macro_references, used_variables = extract_macro_references_and_variables( *(gen(e) for e in statements) @@ -2116,6 +2159,7 @@ def _split_sql_model_statements( t.Optional[exp.Expression], t.List[exp.Expression], t.List[exp.Expression], + t.List[exp.Expression], UniqueKeyDict[str, ModelAudit], ]: """Extracts the SELECT query from a sequence of expressions. @@ -2134,6 +2178,7 @@ def _split_sql_model_statements( query_positions = [] sql_statements = [] + on_virtual_update = [] inline_audits: UniqueKeyDict[str, ModelAudit] = UniqueKeyDict("inline_audits") idx = 0 @@ -2146,6 +2191,10 @@ def _split_sql_model_statements( assert isinstance(loaded_audit, ModelAudit) inline_audits[loaded_audit.name] = loaded_audit idx += 2 + elif isinstance(expr, d.VirtualUpdateStatement): + for statement in expr.expressions: + on_virtual_update.append(statement) + idx += 1 else: if ( isinstance(expr, (exp.Query, d.JinjaQuery)) @@ -2160,13 +2209,13 @@ def _split_sql_model_statements( idx += 1 if not query_positions: - return None, sql_statements, [], inline_audits + return None, sql_statements, [], on_virtual_update, inline_audits elif len(query_positions) > 1: raise_config_error("Only one SELECT query is allowed per model", path) query, pos = query_positions[0] - return query, sql_statements[:pos], sql_statements[pos + 1 :], inline_audits + return query, sql_statements[:pos], sql_statements[pos + 1 :], on_virtual_update, inline_audits def _resolve_session_properties( diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index 8893cb77d..64f36547d 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -25,7 +25,7 @@ from sqlmesh.core.notification_target import ( NotificationTarget, ) -from sqlmesh.core.snapshot.definition import Interval +from sqlmesh.core.snapshot.definition import Interval, to_view_mapping from sqlmesh.core.plan.definition import EvaluatablePlan from sqlmesh.core.scheduler import Scheduler from sqlmesh.core.snapshot import ( @@ -45,6 +45,7 @@ from sqlmesh.schedulers.airflow.mwaa_client import MWAAClient from sqlmesh.utils.errors import PlanError, SQLMeshError from sqlmesh.utils.dag import DAG +from sqlmesh.utils.date import now logger = logging.getLogger(__name__) @@ -314,6 +315,7 @@ def _update_views( environment.naming_info, deployability_index=deployability_index, on_complete=lambda s: self.console.update_promotion_progress(s, True), + snapshots=snapshots, ) if promotion_result.removed_environment_naming_info: self._demote_snapshots( @@ -322,6 +324,7 @@ def _update_views( promotion_result.removed_environment_naming_info, on_complete=lambda s: self.console.update_promotion_progress(s, False), ) + self.state_sync.finalize(environment) completed = True finally: @@ -332,12 +335,23 @@ def _promote_snapshots( plan: EvaluatablePlan, target_snapshots: t.Iterable[Snapshot], environment_naming_info: EnvironmentNamingInfo, + snapshots: t.Dict[SnapshotId, Snapshot], deployability_index: t.Optional[DeployabilityIndex] = None, on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None, ) -> None: self.snapshot_evaluator.promote( target_snapshots, - environment_naming_info, + start=plan.start, + end=plan.end, + execution_time=plan.execution_time or now(), + snapshots=snapshots, + table_mapping=to_view_mapping( + snapshots.values(), + environment_naming_info, + default_catalog=self.default_catalog, + dialect=self.snapshot_evaluator.adapter.dialect, + ), + environment_naming_info=environment_naming_info, deployability_index=deployability_index, on_complete=on_complete, ) diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index c16c0bd45..26a906998 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -27,6 +27,7 @@ Interval, expand_range, get_next_model_interval_start, + parent_snapshots_by_name, ) from sqlmesh.core.state_sync import StateSync from sqlmesh.utils import format_exception @@ -182,10 +183,7 @@ def evaluate( """ validate_date_range(start, end) - snapshots = { - self.snapshots[p_sid].name: self.snapshots[p_sid] for p_sid in snapshot.parents - } - snapshots[snapshot.name] = snapshot + snapshots = parent_snapshots_by_name(snapshot, self.snapshots) is_deployable = deployability_index.is_deployable(snapshot) diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index ceae32e85..31b5773c5 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -1662,6 +1662,21 @@ def to_table_mapping( } +def to_view_mapping( + snapshots: t.Iterable[Snapshot], + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str] = None, + dialect: t.Optional[str] = None, +) -> t.Dict[str, str]: + return { + snapshot.name: snapshot.display_name( + environment_naming_info, default_catalog=default_catalog, dialect=dialect + ) + for snapshot in snapshots + if snapshot.is_model + } + + def has_paused_forward_only( targets: t.Iterable[SnapshotIdLike], snapshots: t.Union[t.List[Snapshot], t.Dict[SnapshotId, Snapshot]], @@ -2020,6 +2035,16 @@ def apply_auto_restatements( ] +def parent_snapshots_by_name( + snapshot: Snapshot, snapshots: t.Dict[SnapshotId, Snapshot] +) -> t.Dict[str, Snapshot]: + parent_snapshots_by_name = { + snapshots[p_sid].name: snapshots[p_sid] for p_sid in snapshot.parents + } + parent_snapshots_by_name[snapshot.name] = snapshot + return parent_snapshots_by_name + + def _contiguous_intervals(intervals: Intervals) -> t.List[Intervals]: """Given a list of intervals with gaps, returns a list of sequences of contiguous intervals.""" contiguous_intervals = [] diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 2983a180f..6c9225fc0 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -60,6 +60,7 @@ SnapshotInfoLike, SnapshotTableCleanupTask, ) +from sqlmesh.core.snapshot.definition import parent_snapshots_by_name from sqlmesh.utils import random_id from sqlmesh.utils.concurrency import ( concurrent_apply_to_snapshots, @@ -203,6 +204,11 @@ def promote( target_snapshots: t.Iterable[Snapshot], environment_naming_info: EnvironmentNamingInfo, deployability_index: t.Optional[DeployabilityIndex] = None, + start: t.Optional[TimeLike] = None, + end: t.Optional[TimeLike] = None, + execution_time: t.Optional[TimeLike] = None, + snapshots: t.Optional[t.Dict[SnapshotId, Snapshot]] = None, + table_mapping: t.Optional[t.Dict[str, str]] = None, on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None, ) -> None: """Promotes the given collection of snapshots in the target environment by replacing a corresponding @@ -229,9 +235,14 @@ def promote( target_snapshots, lambda s: self._promote_snapshot( s, - environment_naming_info, - deployability_index, # type: ignore - on_complete, + start=start, + end=end, + execution_time=execution_time, + snapshots=snapshots, + table_mapping=table_mapping, + environment_naming_info=environment_naming_info, + deployability_index=deployability_index, # type: ignore + on_complete=on_complete, ), self.ddl_concurrent_tasks, ) @@ -721,17 +732,12 @@ def _create_snapshot( if not snapshot.is_model: return - parent_snapshots_by_name = { - snapshots[p_sid].name: snapshots[p_sid] for p_sid in snapshot.parents - } - parent_snapshots_by_name[snapshot.name] = snapshot - deployability_index = deployability_index or DeployabilityIndex.all_deployable() adapter = self._get_adapter(snapshot.model.gateway) common_render_kwargs: t.Dict[str, t.Any] = dict( engine_adapter=adapter, - snapshots=parent_snapshots_by_name, + snapshots=parent_snapshots_by_name(snapshot, snapshots), runtime_stage=RuntimeStage.CREATING, ) pre_post_render_kwargs = dict( @@ -820,18 +826,13 @@ def _migrate_snapshot( if not needs_migration: return - parent_snapshots_by_name = { - snapshots[p_sid].name: snapshots[p_sid] for p_sid in snapshot.parents - } - parent_snapshots_by_name[snapshot.name] = snapshot - tmp_table_name = snapshot.table_name(is_deployable=False) target_table_name = snapshot.table_name() _evaluation_strategy(snapshot, adapter).migrate( target_table_name=target_table_name, source_table_name=tmp_table_name, snapshot=snapshot, - snapshots=parent_snapshots_by_name, + snapshots=parent_snapshots_by_name(snapshot, snapshots), allow_destructive_snapshots=allow_destructive_snapshots, ) @@ -841,6 +842,11 @@ def _promote_snapshot( environment_naming_info: EnvironmentNamingInfo, deployability_index: DeployabilityIndex, on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]], + start: t.Optional[TimeLike] = None, + end: t.Optional[TimeLike] = None, + execution_time: t.Optional[TimeLike] = None, + snapshots: t.Optional[t.Dict[SnapshotId, Snapshot]] = None, + table_mapping: t.Optional[t.Dict[str, str]] = None, ) -> None: if snapshot.is_model: adapter = self.adapter @@ -854,6 +860,16 @@ def _promote_snapshot( model=snapshot.model, environment=environment_naming_info.name, ) + render_kwargs: t.Dict[str, t.Any] = dict( + start=start, + end=end, + execution_time=execution_time, + engine_adapter=adapter, + snapshots=snapshots, + deployability_index=deployability_index, + table_mapping=table_mapping, + ) + adapter.execute(snapshot.model.render_on_virtual_update(**render_kwargs)) if on_complete is not None: on_complete(snapshot) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 68d91c1a8..a4cf7973d 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -986,6 +986,40 @@ def test_seed_pre_statements_only(): assert not model.post_statements +def test_seed_on_virtual_update_statements(): + expressions = d.parse( + """ + MODEL ( + name db.seed, + kind SEED ( + path '../seeds/waiter_names.csv', + batch_size 100, + ) + ); + + JINJA_STATEMENT_BEGIN; + CREATE TABLE x{{ 1 + 1 }}; + JINJA_END; + + ON_VIRTUAL_UPDATE_BEGIN; + JINJA_STATEMENT_BEGIN; + GRANT SELECT ON VIEW {{ this_model }} TO ROLE dev_role; + JINJA_END; + DROP TABLE x2; + ON_VIRTUAL_UPDATE_END; + + """ + ) + + model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) + + assert model.pre_statements == [d.jinja_statement("CREATE TABLE x{{ 1 + 1 }};")] + assert model.on_virtual_update == [ + d.jinja_statement("GRANT SELECT ON VIEW {{ this_model }} TO ROLE dev_role;"), + *d.parse("DROP TABLE x2;"), + ] + + def test_seed_model_custom_types(tmp_path): model_csv_path = (tmp_path / "model.csv").absolute() @@ -6689,3 +6723,161 @@ def dummy_model_entry(evaluator: MacroEvaluator) -> exp.Select: ) assert isinstance(context._get_engine_adapter("duckdb"), DuckDBEngineAdapter) assert len(context._engine_adapters) == 2 + + +def test_model_on_virtual_update(make_snapshot: t.Callable): + # Macro to test resolution within virtual statement + @macro() + def resolve_parent_name(evaluator, name): + return evaluator.resolve_table(name.name) + + virtual_update_statements = """ + CREATE OR REPLACE VIEW test_view FROM demo_db.table; + GRANT SELECT ON VIEW @this_model TO ROLE owner_name; + JINJA_STATEMENT_BEGIN; + GRANT SELECT ON VIEW {{this_model}} TO ROLE admin; + JINJA_END; + GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name; + @resolve_parent_name('parent'); + GRANT SELECT ON VIEW demo_db.table /* sqlglot.meta replace=false */ TO ROLE admin; + """ + + expressions = d.parse( + f""" + MODEL ( + name demo_db.table, + owner owner_name, + ); + + SELECT id from parent; + + on_virtual_update_begin; + + {virtual_update_statements} + + on_virtual_update_end; + + """ + ) + + parent_expressions = d.parse( + """ + MODEL ( + name parent, + ); + + SELECT 1 from id; + + ON_VIRTUAL_UPDATE_BEGIN; + JINJA_STATEMENT_BEGIN; + GRANT SELECT ON VIEW {{this_model}} TO ROLE admin; + JINJA_END; + ON_VIRTUAL_UPDATE_END; + + """ + ) + + model = load_sql_based_model(expressions) + parent = load_sql_based_model(parent_expressions) + + parent_snapshot = make_snapshot(parent) + parent_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + version = parent_snapshot.version + + assert model.on_virtual_update == d.parse(virtual_update_statements) + + assert parent.on_virtual_update == d.parse( + "JINJA_STATEMENT_BEGIN; GRANT SELECT ON VIEW {{this_model}} TO ROLE admin; JINJA_END;" + ) + + table_mapping = {'"demo_db"."table"': "demo_db__dev.table"} + snapshots = {'"parent"': parent_snapshot} + + rendered_statements = model._render_statements( + model.on_virtual_update, snapshots=snapshots, table_mapping=table_mapping + ) + + assert len(rendered_statements) == 6 + assert ( + rendered_statements[0].sql() + == 'CREATE OR REPLACE VIEW "test_view" AS SELECT * FROM "demo_db__dev"."table" AS "table" /* demo_db.table */' + ) + assert ( + rendered_statements[1].sql() + == 'GRANT SELECT ON VIEW "demo_db__dev"."table" /* demo_db.table */ TO ROLE "owner_name"' + ) + assert ( + rendered_statements[3].sql() + == "GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name" + ) + assert rendered_statements[4].sql() == f'"sqlmesh__default"."parent__{version}"' + + # When replace=false the table should remain as is + assert ( + rendered_statements[5].sql() + == 'GRANT SELECT ON VIEW "demo_db"."table" /* sqlglot.meta replace=false */ TO ROLE "admin"' + ) + + rendered_parent_statements = model._render_statements( + parent.on_virtual_update, snapshots=snapshots, table_mapping=table_mapping + ) + assert ( + rendered_statements[2].sql() + == rendered_parent_statements[0].sql() + == 'GRANT SELECT ON VIEW "demo_db__dev"."table" /* demo_db.table */ TO ROLE "admin"' + ) + + +def test_python_model_on_virtual_update(): + macros = """ + {% macro index_name(v) %}{{ v }}{% endmacro %} + """ + + jinja_macros = JinjaMacroRegistry() + jinja_macros.add_macros(MacroExtractor().extract(macros)) + + @model( + "db.test_model", + kind="full", + columns={"id": "string", "name": "string"}, + on_virtual_update=[ + "JINJA_STATEMENT_BEGIN;\nCREATE INDEX {{index_name('id_index')}} ON db.test_model(id);\nJINJA_END;", + parse_one("GRANT SELECT ON VIEW @this_model TO ROLE dev_role;"), + "GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE db TO ROLE dev_role;", + ], + ) + def model_with_virtual_statements(context, **kwargs): + return pd.DataFrame( + [ + { + "id": context.var("1"), + "name": context.var("var"), + } + ] + ) + + python_model = model.get_registry()["db.test_model"].model( + module_path=Path("."), path=Path("."), dialect="duckdb", jinja_macros=jinja_macros + ) + + assert len(jinja_macros.root_macros) == 1 + assert len(python_model.jinja_macros.root_macros) == 1 + assert "index_name" in python_model.jinja_macros.root_macros + assert len(python_model.on_virtual_update) == 3 + + rendered_statements = python_model._render_statements( + python_model.on_virtual_update, table_mapping={'"db"."test_model"': "db.test_model"} + ) + + assert ( + rendered_statements[0].sql() + == 'CREATE INDEX "id_index" ON "db"."test_model" /* db.test_model */("id" NULLS LAST)' + ) + assert ( + rendered_statements[1].sql() + == 'GRANT SELECT ON VIEW "db"."test_model" /* db.test_model */ TO ROLE "dev_role"' + ) + assert ( + rendered_statements[2].sql() + == "GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE db TO ROLE dev_role" + ) diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 10ee3de66..bb0c1d948 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -48,6 +48,7 @@ SnapshotEvaluator, SnapshotTableCleanupTask, ) +from sqlmesh.core.snapshot.definition import to_view_mapping from sqlmesh.core.snapshot.evaluator import CustomMaterialization from sqlmesh.utils.concurrency import NodeExecutionFailedError from sqlmesh.utils.date import to_timestamp @@ -2714,6 +2715,140 @@ def model_with_statements(context, **kwargs): assert post_calls[0].sql(dialect="postgres") == expected_call +def test_on_virtual_update_statements(mocker: MockerFixture, adapter_mock, make_snapshot): + evaluator = SnapshotEvaluator(adapter_mock) + + model = load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + kind FULL, + dialect postgres, + ); + + SELECT a::int FROM tbl; + + CREATE INDEX IF NOT EXISTS test_idx ON test_schema.test_model(a); + + ON_VIRTUAL_UPDATE_BEGIN; + JINJA_STATEMENT_BEGIN; + GRANT SELECT ON VIEW test_schema.test_model TO ROLE admin; + JINJA_END; + GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name; + ON_VIRTUAL_UPDATE_END; + + """ + ), + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable()) + + snapshots = {snapshot.name: snapshot} + environment_naming_info = EnvironmentNamingInfo(name="test_env") + evaluator.promote( + [snapshot], + start="2020-01-01", + end="2020-01-01", + execution_time="2020-01-01", + snapshots=snapshots, + environment_naming_info=environment_naming_info, + table_mapping=to_view_mapping( + snapshots.values(), + environment_naming_info, + ), + ) + + call_args = adapter_mock.execute.call_args_list + post_calls = call_args[1][0][0] + assert len(post_calls) == 1 + assert ( + post_calls[0].sql(dialect="postgres") + == f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" /* test_schema.test_model */("a")' + ) + + on_virtual_update_calls = call_args[2][0][0] + assert ( + on_virtual_update_calls[0].sql(dialect="postgres") + == 'GRANT SELECT ON VIEW "test_schema__test_env"."test_model" /* test_schema.test_model */ TO ROLE "admin"' + ) + assert ( + on_virtual_update_calls[1].sql(dialect="postgres") + == "GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name" + ) + + +def test_on_virtual_update_python_model_macro(mocker: MockerFixture, adapter_mock, make_snapshot): + evaluator = SnapshotEvaluator(adapter_mock) + + @macro() + def create_index_2( + evaluator: MacroEvaluator, + index_name: str, + model_name: str, + column: str, + ): + return f"CREATE INDEX IF NOT EXISTS {index_name} ON {model_name}({column});" + + @model( + "db.test_model_3", + kind="full", + columns={"id": "string", "name": "string"}, + on_virtual_update=["@CREATE_INDEX_2('idx', 'db.test_model_3', id)"], + ) + def model_with_statements(context, **kwargs): + return pd.DataFrame( + [ + { + "id": context.var("1"), + "name": context.var("var"), + } + ] + ) + + python_model = model.get_registry()["db.test_model_3"].model( + module_path=Path("."), + path=Path("."), + macros=macro.get_registry(), + dialect="postgres", + ) + + assert len(python_model.python_env) == 3 + assert len(python_model.on_virtual_update) == 1 + assert isinstance(python_model.python_env["create_index_2"], Executable) + assert isinstance(python_model.on_virtual_update[0], MacroFunc) + + snapshot = make_snapshot(python_model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable()) + + snapshots = {snapshot.name: snapshot} + environment_naming_info = EnvironmentNamingInfo(name="prod") + evaluator.promote( + [snapshot], + start="2020-01-01", + end="2020-01-01", + execution_time="2020-01-01", + snapshots=snapshots, + environment_naming_info=environment_naming_info, + table_mapping=to_view_mapping( + snapshots.values(), + environment_naming_info, + ), + ) + + call_args = adapter_mock.execute.call_args_list + on_virtual_update_call = call_args[2][0][0][0] + assert ( + on_virtual_update_call.sql(dialect="postgres") + == 'CREATE INDEX IF NOT EXISTS "idx" ON "db"."test_model_3" /* db.test_model_3 */("id")' + ) + + def test_evaluate_incremental_by_partition(mocker: MockerFixture, make_snapshot, adapter_mock): model = SqlModel( name="test_schema.test_model", @@ -3519,7 +3654,7 @@ def model_with_statements(context, **kwargs): assert view_args[0][0][0] == "db__test_env.multi_engine_test_model" # For the pre/post statements verify the model-specific gateway was used - engine_adapters["default"].execute.assert_not_called() + engine_adapters["default"].execute.assert_called_once() assert len(engine_adapters["secondary"].execute.call_args_list) == 2 # Validate that the get_catalog_type method was called only on the secondary engine from the macro evaluator