Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
themisvaltinos committed Dec 18, 2024
1 parent 123ac81 commit 7d5d9e9
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 2 deletions.
4 changes: 2 additions & 2 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,10 +1013,10 @@ def _execute_virtual_statements(
adapter = self._get_adapter(snapshot.model_gateway)
snapshot_deps = {snapshots[p_sid].name: snapshots[p_sid] for p_sid in snapshot.parents}
snapshot_deps[snapshot.name] = snapshot
if virtual_statements := snapshot.model.on_virtual_update:
if on_virtual_update := snapshot.model.on_virtual_update:
adapter.execute(
snapshot.model._render_statements(
virtual_statements,
on_virtual_update,
start=start,
end=end,
execution_time=execution_time,
Expand Down
121 changes: 121 additions & 0 deletions tests/core/test_snapshot_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2618,6 +2618,127 @@ def model_with_statements(context, **kwargs):
assert post_calls[0].sql(dialect="postgres") == expected_call


def test_on_virtual_update_statements(mocker: MockerFixture, adapter_mock, make_snapshot):
evaluator = SnapshotEvaluator(adapter_mock)

model = load_sql_based_model(
d.parse(
"""
MODEL (
name test_schema.test_model,
kind FULL,
dialect postgres,
);
SELECT a::int FROM tbl;
CREATE INDEX IF NOT EXISTS test_idx ON test_schema.test_model(a);
ON_VIRTUAL_UPDATE_BEGIN;
JINJA_STATEMENT_BEGIN;
GRANT SELECT ON VIEW test_schema.test_model TO ROLE admin;
JINJA_END;
GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name;
ON_VIRTUAL_UPDATE_END;
"""
),
)

snapshot = make_snapshot(model)
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)

evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable())
evaluator._execute_virtual_statements(
[snapshot],
start="2020-01-01",
end="2020-01-01",
execution_time="2020-01-01",
snapshots={snapshot.name: snapshot},
environment_naming_info=EnvironmentNamingInfo(name="test_env"),
)

call_args = adapter_mock.execute.call_args_list
post_calls = call_args[1][0][0]
assert len(post_calls) == 1
assert (
post_calls[0].sql(dialect="postgres")
== f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" /* test_schema.test_model */("a")'
)

on_virtual_update_calls = call_args[2][0][0]
assert (
on_virtual_update_calls[0].sql(dialect="postgres")
== 'GRANT SELECT ON VIEW "test_schema__test_env"."test_model" /* test_schema.test_model */ TO ROLE "admin"'
)
assert (
on_virtual_update_calls[1].sql(dialect="postgres")
== "GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name"
)


def test_on_virtual_update_python_model_macro(mocker: MockerFixture, adapter_mock, make_snapshot):
evaluator = SnapshotEvaluator(adapter_mock)

@macro()
def create_index(
evaluator: MacroEvaluator,
index_name: str,
model_name: str,
column: str,
):
return f"CREATE INDEX IF NOT EXISTS {index_name} ON {model_name}({column});"

@model(
"db.test_model",
kind="full",
columns={"id": "string", "name": "string"},
on_virtual_update=["@CREATE_INDEX('idx', 'db.test_model', id)"],
)
def model_with_statements(context, **kwargs):
return pd.DataFrame(
[
{
"id": context.var("1"),
"name": context.var("var"),
}
]
)

python_model = model.get_registry()["db.test_model"].model(
module_path=Path("."),
path=Path("."),
macros=macro.get_registry(),
dialect="postgres",
)

assert len(python_model.python_env) == 3
assert len(python_model.on_virtual_update) == 1
assert isinstance(python_model.python_env["create_index"], Executable)
assert isinstance(python_model.on_virtual_update[0], MacroFunc)

snapshot = make_snapshot(python_model)
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)

evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable())

evaluator._execute_virtual_statements(
[snapshot],
start="2020-01-01",
end="2020-01-01",
execution_time="2020-01-01",
snapshots={snapshot.name: snapshot},
environment_naming_info=EnvironmentNamingInfo(name="prod"),
)

call_args = adapter_mock.execute.call_args_list
on_virtual_update_call = call_args[2][0][0][0]
assert (
on_virtual_update_call.sql(dialect="postgres")
== 'CREATE INDEX IF NOT EXISTS "idx" ON "db"."test_model" /* db.test_model */("id")'
)


def test_evaluate_incremental_by_partition(mocker: MockerFixture, make_snapshot, adapter_mock):
model = SqlModel(
name="test_schema.test_model",
Expand Down

0 comments on commit 7d5d9e9

Please sign in to comment.