|
| 1 | +# Copyright 2025 Collate |
| 2 | +# Licensed under the Collate Community License, Version 1.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +""" |
| 13 | +Athena profiler interface with struct column support. |
| 14 | +
|
| 15 | +Athena (Presto SQL) accesses struct fields via dot notation: |
| 16 | + SELECT "my_struct"."my_field" FROM table |
| 17 | +
|
| 18 | +This interface flattens STRUCT columns into their leaf fields |
| 19 | +so they can be profiled individually, and patches the Athena compiler |
| 20 | +to quote each dot-separated segment individually. |
| 21 | +""" |
| 22 | +from typing import List, Optional |
| 23 | + |
| 24 | +from pyathena.sqlalchemy.compiler import AthenaStatementCompiler |
| 25 | +from sqlalchemy import Column |
| 26 | +from sqlalchemy.sql.compiler import SQLCompiler |
| 27 | + |
| 28 | +from metadata.generated.schema.entity.data.table import Column as OMColumn |
| 29 | +from metadata.generated.schema.entity.data.table import ColumnName, DataType |
| 30 | +from metadata.generated.schema.entity.services.databaseService import ( |
| 31 | + DatabaseServiceType, |
| 32 | +) |
| 33 | +from metadata.profiler.interface.sqlalchemy.profiler_interface import ( |
| 34 | + SQAProfilerInterface, |
| 35 | +) |
| 36 | +from metadata.profiler.orm.converter.base import build_orm_col |
| 37 | + |
| 38 | + |
| 39 | +def _visit_column_with_struct_quoting(self, column, *args, **kwargs): |
| 40 | + """Compile column references, quoting each segment for struct fields. |
| 41 | +
|
| 42 | + For struct leaf columns (whose names contain dots like "address.street"), |
| 43 | + each segment is quoted individually so the SQL output is: |
| 44 | + "address"."street" |
| 45 | + instead of the default "address.street" (single identifier). |
| 46 | +
|
| 47 | + This handles reserved words in struct field names correctly. |
| 48 | + """ |
| 49 | + result = SQLCompiler.visit_column(self, column, *args, **kwargs) |
| 50 | + col_name = str(getattr(column, "name", "")) |
| 51 | + if "." in col_name: |
| 52 | + col_idx = result.find(col_name) |
| 53 | + if col_idx >= 0: |
| 54 | + prefix = result[:col_idx] |
| 55 | + segments = col_name.split(".") |
| 56 | + quoted = ".".join(self.preparer.quote_identifier(seg) for seg in segments) |
| 57 | + result = prefix + quoted |
| 58 | + return result |
| 59 | + |
| 60 | + |
| 61 | +class AthenaProfilerInterface(SQAProfilerInterface): |
| 62 | + """Athena profiler interface with struct column flattening""" |
| 63 | + |
| 64 | + def __init__(self, service_connection_config, **kwargs): |
| 65 | + super().__init__(service_connection_config=service_connection_config, **kwargs) |
| 66 | + AthenaStatementCompiler.visit_column = _visit_column_with_struct_quoting |
| 67 | + |
| 68 | + def _get_struct_columns( |
| 69 | + self, columns: Optional[List[OMColumn]], parent: str |
| 70 | + ) -> List[Column]: |
| 71 | + """Recursively flatten struct children into leaf columns. |
| 72 | +
|
| 73 | + Column names are set to plain dot notation (e.g. "address.street") |
| 74 | + for OM API profile matching. The compiler patch in __init__ handles |
| 75 | + quoting each segment at SQL generation time. |
| 76 | +
|
| 77 | + Args: |
| 78 | + columns: child columns of a STRUCT column |
| 79 | + parent: dot-separated parent path (e.g. "address" or "address.geo") |
| 80 | +
|
| 81 | + Returns: |
| 82 | + list of SQLAlchemy Column objects for each leaf field |
| 83 | + """ |
| 84 | + columns_list = [] |
| 85 | + for col in columns or []: |
| 86 | + if col.dataType != DataType.STRUCT: |
| 87 | + col.name = ColumnName(f"{parent}.{col.name.root}") |
| 88 | + sqa_col = build_orm_col( |
| 89 | + idx=1, |
| 90 | + col=col, |
| 91 | + table_service_type=DatabaseServiceType.Athena, |
| 92 | + _quote=False, |
| 93 | + ) |
| 94 | + sqa_col._set_parent( # pylint: disable=protected-access |
| 95 | + self.table.__table__ |
| 96 | + ) |
| 97 | + columns_list.append(sqa_col) |
| 98 | + else: |
| 99 | + cols = self._get_struct_columns( |
| 100 | + col.children, f"{parent}.{col.name.root}" |
| 101 | + ) |
| 102 | + columns_list.extend(cols) |
| 103 | + return columns_list |
| 104 | + |
| 105 | + def get_columns(self) -> List[Column]: |
| 106 | + """Get columns from table, flattening STRUCT columns into leaf fields.""" |
| 107 | + columns = [] |
| 108 | + for idx, column_obj in enumerate(self.table_entity.columns): |
| 109 | + if column_obj.dataType == DataType.STRUCT: |
| 110 | + columns.extend( |
| 111 | + self._get_struct_columns(column_obj.children, column_obj.name.root) |
| 112 | + ) |
| 113 | + else: |
| 114 | + col = build_orm_col(idx, column_obj, DatabaseServiceType.Athena) |
| 115 | + col._set_parent( # pylint: disable=protected-access |
| 116 | + self.table.__table__ |
| 117 | + ) |
| 118 | + columns.append(col) |
| 119 | + return columns |
0 commit comments