Skip to content

Commit f54b995

Browse files
eakmanrqnewtonapple
authored andcommitted
feat: add snowflake grant support (#5433)
1 parent ced6cf8 commit f54b995

8 files changed

Lines changed: 564 additions & 171 deletions

File tree

sqlmesh/core/engine_adapter/_typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@
3131

3232
QueryOrDF = t.Union[Query, DF]
3333
GrantsConfig = t.Dict[str, t.List[str]]
34+
DCL = t.TypeVar("DCL", exp.Grant, exp.Revoke)

sqlmesh/core/engine_adapter/postgres.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818

1919
if t.TYPE_CHECKING:
2020
from sqlmesh.core._typing import TableName
21-
from sqlmesh.core.engine_adapter._typing import DF, GrantsConfig, QueryOrDF
22-
23-
DCL = t.TypeVar("DCL", exp.Grant, exp.Revoke)
21+
from sqlmesh.core.engine_adapter._typing import DCL, DF, GrantsConfig, QueryOrDF
2422

2523
logger = logging.getLogger(__name__)
2624

@@ -38,7 +36,7 @@ class PostgresEngineAdapter(
3836
HAS_VIEW_BINDING = True
3937
CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog")
4038
SUPPORTS_REPLACE_TABLE = False
41-
MAX_IDENTIFIER_LENGTH = 63
39+
MAX_IDENTIFIER_LENGTH: t.Optional[int] = 63
4240
SUPPORTS_QUERY_EXECUTION_TRACKING = True
4341
SCHEMA_DIFFER_KWARGS = {
4442
"parameterized_type_defaults": {

sqlmesh/core/engine_adapter/risingwave.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class RisingwaveEngineAdapter(PostgresEngineAdapter):
3232
SUPPORTS_MATERIALIZED_VIEWS = True
3333
SUPPORTS_TRANSACTIONS = False
3434
MAX_IDENTIFIER_LENGTH = None
35+
SUPPORTS_GRANTS = False
3536

3637
def columns(
3738
self, table_name: TableName, include_pseudo_columns: bool = False

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,14 @@
3434
import pandas as pd
3535

3636
from sqlmesh.core._typing import SchemaName, SessionProperties, TableName
37-
from sqlmesh.core.engine_adapter._typing import DF, Query, QueryOrDF, SnowparkSession
37+
from sqlmesh.core.engine_adapter._typing import (
38+
DCL,
39+
DF,
40+
GrantsConfig,
41+
Query,
42+
QueryOrDF,
43+
SnowparkSession,
44+
)
3845
from sqlmesh.core.node import IntervalUnit
3946

4047

@@ -74,6 +81,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi
7481
MANAGED_TABLE_KIND = "DYNAMIC TABLE"
7582
SNOWPARK = "snowpark"
7683
SUPPORTS_QUERY_EXECUTION_TRACKING = True
84+
SUPPORTS_GRANTS = True
7785

7886
@contextlib.contextmanager
7987
def session(self, properties: SessionProperties) -> t.Iterator[None]:
@@ -128,6 +136,118 @@ def snowpark(self) -> t.Optional[SnowparkSession]:
128136
def catalog_support(self) -> CatalogSupport:
129137
return CatalogSupport.FULL_SUPPORT
130138

139+
@staticmethod
140+
def _grant_object_kind(table_type: DataObjectType) -> str:
141+
if table_type == DataObjectType.VIEW:
142+
return "VIEW"
143+
if table_type == DataObjectType.MATERIALIZED_VIEW:
144+
return "MATERIALIZED VIEW"
145+
if table_type == DataObjectType.MANAGED_TABLE:
146+
return "DYNAMIC TABLE"
147+
return "TABLE"
148+
149+
def _get_current_schema(self) -> str:
150+
"""Returns the current default schema for the connection."""
151+
result = self.fetchone("SELECT CURRENT_SCHEMA()")
152+
if not result or not result[0]:
153+
raise SQLMeshError("Unable to determine current schema")
154+
return str(result[0])
155+
156+
def _dcl_grants_config_expr(
157+
self,
158+
dcl_cmd: t.Type[DCL],
159+
table: exp.Table,
160+
grant_config: GrantsConfig,
161+
table_type: DataObjectType = DataObjectType.TABLE,
162+
) -> t.List[exp.Expression]:
163+
expressions: t.List[exp.Expression] = []
164+
if not grant_config:
165+
return expressions
166+
167+
object_kind = self._grant_object_kind(table_type)
168+
for privilege, principals in grant_config.items():
169+
for principal in principals:
170+
args: t.Dict[str, t.Any] = {
171+
"privileges": [exp.GrantPrivilege(this=exp.Var(this=privilege))],
172+
"securable": table.copy(),
173+
"principals": [principal],
174+
}
175+
176+
if object_kind:
177+
args["kind"] = exp.Var(this=object_kind)
178+
179+
expressions.append(dcl_cmd(**args)) # type: ignore[arg-type]
180+
181+
return expressions
182+
183+
def _apply_grants_config_expr(
184+
self,
185+
table: exp.Table,
186+
grant_config: GrantsConfig,
187+
table_type: DataObjectType = DataObjectType.TABLE,
188+
) -> t.List[exp.Expression]:
189+
return self._dcl_grants_config_expr(exp.Grant, table, grant_config, table_type)
190+
191+
def _revoke_grants_config_expr(
192+
self,
193+
table: exp.Table,
194+
grant_config: GrantsConfig,
195+
table_type: DataObjectType = DataObjectType.TABLE,
196+
) -> t.List[exp.Expression]:
197+
return self._dcl_grants_config_expr(exp.Revoke, table, grant_config, table_type)
198+
199+
def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig:
200+
schema_identifier = table.args.get("db") or normalize_identifiers(
201+
exp.to_identifier(self._get_current_schema(), quoted=True), dialect=self.dialect
202+
)
203+
catalog_identifier = table.args.get("catalog")
204+
if not catalog_identifier:
205+
current_catalog = self.get_current_catalog()
206+
if not current_catalog:
207+
raise SQLMeshError("Unable to determine current catalog for fetching grants")
208+
catalog_identifier = normalize_identifiers(
209+
exp.to_identifier(current_catalog, quoted=True), dialect=self.dialect
210+
)
211+
catalog_identifier.set("quoted", True)
212+
table_identifier = table.args.get("this")
213+
214+
grant_expr = (
215+
exp.select("privilege_type", "grantee")
216+
.from_(
217+
exp.table_(
218+
"TABLE_PRIVILEGES",
219+
db="INFORMATION_SCHEMA",
220+
catalog=catalog_identifier,
221+
)
222+
)
223+
.where(
224+
exp.and_(
225+
exp.column("table_schema").eq(exp.Literal.string(schema_identifier.this)),
226+
exp.column("table_name").eq(exp.Literal.string(table_identifier.this)), # type: ignore
227+
exp.column("grantor").eq(exp.func("CURRENT_ROLE")),
228+
exp.column("grantee").neq(exp.func("CURRENT_ROLE")),
229+
)
230+
)
231+
)
232+
233+
results = self.fetchall(grant_expr)
234+
235+
grants_dict: GrantsConfig = {}
236+
for privilege_raw, grantee_raw in results:
237+
if privilege_raw is None or grantee_raw is None:
238+
continue
239+
240+
privilege = str(privilege_raw)
241+
grantee = str(grantee_raw)
242+
if not privilege or not grantee:
243+
continue
244+
245+
grantees = grants_dict.setdefault(privilege, [])
246+
if grantee not in grantees:
247+
grantees.append(grantee)
248+
249+
return grants_dict
250+
131251
def _create_catalog(self, catalog_name: exp.Identifier) -> None:
132252
props = exp.Properties(
133253
expressions=[exp.SchemaCommentProperty(this=exp.Literal.string(c.SQLMESH_MANAGED))]

tests/core/engine_adapter/integration/__init__.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
import sys
66
import typing as t
77
import time
8+
from contextlib import contextmanager
89

910
import pandas as pd # noqa: TID253
1011
import pytest
1112
from sqlglot import exp, parse_one
13+
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1214

1315
from sqlmesh import Config, Context, EngineAdapter
1416
from sqlmesh.core.config import load_config_from_paths
@@ -744,6 +746,55 @@ def upsert_sql_model(self, model_definition: str) -> t.Tuple[Context, SqlModel]:
744746
self._context.upsert_model(model)
745747
return self._context, model
746748

749+
def _get_create_user_or_role(self, username: str, password: t.Optional[str] = None) -> str:
750+
password = password or random_id()
751+
if self.dialect == "postgres":
752+
return f"CREATE USER \"{username}\" WITH PASSWORD '{password}'"
753+
if self.dialect == "snowflake":
754+
return f"CREATE ROLE {username}"
755+
raise ValueError(f"User creation not supported for dialect: {self.dialect}")
756+
757+
def _create_user_or_role(self, username: str, password: t.Optional[str] = None) -> None:
758+
create_user_sql = self._get_create_user_or_role(username, password)
759+
self.engine_adapter.execute(create_user_sql)
760+
761+
@contextmanager
762+
def create_users_or_roles(self, *role_names: str) -> t.Iterator[t.Dict[str, str]]:
763+
created_users = []
764+
roles = {}
765+
766+
try:
767+
for role_name in role_names:
768+
user_name = normalize_identifiers(
769+
self.add_test_suffix(f"test_{role_name}"), dialect=self.dialect
770+
).sql(dialect=self.dialect)
771+
password = random_id()
772+
self._create_user_or_role(user_name, password)
773+
created_users.append(user_name)
774+
roles[role_name] = user_name
775+
776+
yield roles
777+
778+
finally:
779+
for user_name in created_users:
780+
self._cleanup_user_or_role(user_name)
781+
782+
def _cleanup_user_or_role(self, user_name: str) -> None:
783+
"""Helper function to clean up a PostgreSQL user and all their dependencies."""
784+
try:
785+
if self.dialect == "postgres":
786+
self.engine_adapter.execute(f"""
787+
SELECT pg_terminate_backend(pid)
788+
FROM pg_stat_activity
789+
WHERE usename = '{user_name}' AND pid <> pg_backend_pid()
790+
""")
791+
self.engine_adapter.execute(f'DROP OWNED BY "{user_name}"')
792+
self.engine_adapter.execute(f'DROP USER IF EXISTS "{user_name}"')
793+
elif self.dialect == "snowflake":
794+
self.engine_adapter.execute(f"DROP ROLE IF EXISTS {user_name}")
795+
except Exception:
796+
pass
797+
747798

748799
def wait_until(fn: t.Callable[..., bool], attempts=3, wait=5) -> None:
749800
current_attempt = 0

0 commit comments

Comments
 (0)