From 7d5d9e9ea2f26a00524ee800fcf3b252de21c2f5 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Wed, 18 Dec 2024 20:42:31 +0200 Subject: [PATCH] Add unit tests --- sqlmesh/core/snapshot/evaluator.py | 4 +- tests/core/test_snapshot_evaluator.py | 121 ++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 2 deletions(-) diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 89fa6926e..49531978d 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -1013,10 +1013,10 @@ def _execute_virtual_statements( adapter = self._get_adapter(snapshot.model_gateway) snapshot_deps = {snapshots[p_sid].name: snapshots[p_sid] for p_sid in snapshot.parents} snapshot_deps[snapshot.name] = snapshot - if virtual_statements := snapshot.model.on_virtual_update: + if on_virtual_update := snapshot.model.on_virtual_update: adapter.execute( snapshot.model._render_statements( - virtual_statements, + on_virtual_update, start=start, end=end, execution_time=execution_time, diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 35a0abf60..4ce431355 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -2618,6 +2618,127 @@ def model_with_statements(context, **kwargs): assert post_calls[0].sql(dialect="postgres") == expected_call +def test_on_virtual_update_statements(mocker: MockerFixture, adapter_mock, make_snapshot): + evaluator = SnapshotEvaluator(adapter_mock) + + model = load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + kind FULL, + dialect postgres, + ); + + SELECT a::int FROM tbl; + + CREATE INDEX IF NOT EXISTS test_idx ON test_schema.test_model(a); + + ON_VIRTUAL_UPDATE_BEGIN; + JINJA_STATEMENT_BEGIN; + GRANT SELECT ON VIEW test_schema.test_model TO ROLE admin; + JINJA_END; + GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name; + ON_VIRTUAL_UPDATE_END; + + """ + ), + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable()) + evaluator._execute_virtual_statements( + [snapshot], + start="2020-01-01", + end="2020-01-01", + execution_time="2020-01-01", + snapshots={snapshot.name: snapshot}, + environment_naming_info=EnvironmentNamingInfo(name="test_env"), + ) + + call_args = adapter_mock.execute.call_args_list + post_calls = call_args[1][0][0] + assert len(post_calls) == 1 + assert ( + post_calls[0].sql(dialect="postgres") + == f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" /* test_schema.test_model */("a")' + ) + + on_virtual_update_calls = call_args[2][0][0] + assert ( + on_virtual_update_calls[0].sql(dialect="postgres") + == 'GRANT SELECT ON VIEW "test_schema__test_env"."test_model" /* test_schema.test_model */ TO ROLE "admin"' + ) + assert ( + on_virtual_update_calls[1].sql(dialect="postgres") + == "GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name" + ) + + +def test_on_virtual_update_python_model_macro(mocker: MockerFixture, adapter_mock, make_snapshot): + evaluator = SnapshotEvaluator(adapter_mock) + + @macro() + def create_index( + evaluator: MacroEvaluator, + index_name: str, + model_name: str, + column: str, + ): + return f"CREATE INDEX IF NOT EXISTS {index_name} ON {model_name}({column});" + + @model( + "db.test_model", + kind="full", + columns={"id": "string", "name": "string"}, + on_virtual_update=["@CREATE_INDEX('idx', 'db.test_model', id)"], + ) + def model_with_statements(context, **kwargs): + return pd.DataFrame( + [ + { + "id": context.var("1"), + "name": context.var("var"), + } + ] + ) + + python_model = model.get_registry()["db.test_model"].model( + module_path=Path("."), + path=Path("."), + macros=macro.get_registry(), + dialect="postgres", + ) + + assert len(python_model.python_env) == 3 + assert len(python_model.on_virtual_update) == 1 + assert isinstance(python_model.python_env["create_index"], Executable) + assert isinstance(python_model.on_virtual_update[0], MacroFunc) + + snapshot = make_snapshot(python_model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable()) + + evaluator._execute_virtual_statements( + [snapshot], + start="2020-01-01", + end="2020-01-01", + execution_time="2020-01-01", + snapshots={snapshot.name: snapshot}, + environment_naming_info=EnvironmentNamingInfo(name="prod"), + ) + + call_args = adapter_mock.execute.call_args_list + on_virtual_update_call = call_args[2][0][0][0] + assert ( + on_virtual_update_call.sql(dialect="postgres") + == 'CREATE INDEX IF NOT EXISTS "idx" ON "db"."test_model" /* db.test_model */("id")' + ) + + def test_evaluate_incremental_by_partition(mocker: MockerFixture, make_snapshot, adapter_mock): model = SqlModel( name="test_schema.test_model",