|
1 | 1 | """Necessary, otherwise uncategorised backend functionality.""" |
2 | 2 |
|
| 3 | +from datetime import date, datetime |
| 4 | +from decimal import Decimal |
3 | 5 | import sys |
4 | 6 | from typing import Type |
5 | 7 |
|
| 8 | +from dataclasses import is_dataclass |
6 | 9 | from pydantic import BaseModel, create_model |
7 | 10 |
|
8 | 11 | from dve.core_engine.type_hints import Messages |
| 12 | +from dve.core_engine.backends.base.utilities import _get_non_heterogenous_type |
| 13 | + |
| 14 | +import polars as pl # type: ignore |
| 15 | +from polars.datatypes.classes import DataTypeClass as PolarsType |
| 16 | +from typing import Any, ClassVar, Dict, Set, Union |
9 | 17 |
|
10 | 18 | # We need to rely on a Python typing implementation detail in Python <= 3.7. |
11 | 19 | if sys.version_info[:2] <= (3, 7): |
12 | 20 | # Crimes against typing. |
13 | 21 | from typing import _GenericAlias # type: ignore |
14 | 22 |
|
15 | | - from typing_extensions import get_args, get_origin |
| 23 | + from typing_extensions import Annotated, get_args, get_origin, get_type_hints |
16 | 24 | else: |
17 | | - from typing import get_args, get_origin |
| 25 | + from typing import Annotated, get_args, get_origin, get_type_hints |
| 26 | + |
| 27 | +PYTHON_TYPE_TO_POLARS_TYPE: Dict[type, PolarsType] = { |
| 28 | + # issue with decimal conversion at the moment... |
| 29 | + str: pl.Utf8, # type: ignore |
| 30 | + int: pl.Int64, # type: ignore |
| 31 | + bool: pl.Boolean, # type: ignore |
| 32 | + float: pl.Float64, # type: ignore |
| 33 | + bytes: pl.Binary, # type: ignore |
| 34 | + date: pl.Date, # type: ignore |
| 35 | + datetime: pl.Datetime, # type: ignore |
| 36 | + Decimal: pl.Utf8, # type: ignore |
| 37 | +} |
| 38 | +"""A mapping of Python types to the equivalent Polars types.""" |
18 | 39 |
|
19 | 40 |
|
20 | 41 | def stringify_type(type_: type) -> type: |
@@ -61,3 +82,94 @@ def dedup_messages(messages: Messages) -> Messages: |
61 | 82 |
|
62 | 83 | """ |
63 | 84 | return list(dict.fromkeys(messages)) |
| 85 | + |
| 86 | +def get_polars_type_from_annotation(type_annotation: Any) -> PolarsType: |
| 87 | + """Get a polars type from a Python type annotation. |
| 88 | +
|
| 89 | + Supported types are any of the following (this definition is recursive): |
| 90 | + - Supported basic Python types. These are: |
| 91 | + * `str`: pl.Utf8 |
| 92 | + * `int`: pl.Int64 |
| 93 | + * `bool`: pl.Boolean |
| 94 | + * `float`: pl.Float64 |
| 95 | + * `bytes`: pl.Binary |
| 96 | + * `datetime.date`: pl.Date |
| 97 | + * `datetime.datetime`: pl.Datetime |
| 98 | + * `decimal.Decimal`: pl.Decimal with precision of 38 and scale of 18 |
| 99 | + - A list of supported types (e.g. `List[str]` or `typing.List[str]`). |
| 100 | + This will return a pl.List type (variable length) |
| 101 | + - A `typing.Optional` type or a `typing.Union` of the type and `None` (e.g. |
| 102 | + `typing.Optional[str]`, `typing.Union[List[str], None]`). This will remove the |
| 103 | + 'optional' wrapper and return the inner type |
| 104 | + - A subclass of `typing.TypedDict` with values typed using supported types. This |
| 105 | + will parse the value types as Polars types and return a Polars Struct. |
| 106 | + - A dataclass or `pydantic.main.ModelMetaClass` with values typed using supported types. |
| 107 | + This will parse the field types as Polars types and return a Polars Struct. |
| 108 | + - Any supported type, with a `typing_extensions.Annotated` wrapper. |
| 109 | + - A `decimal.Decimal` wrapped with `typing_extensions.Annotated` with a `DecimalConfig` |
| 110 | + indicating precision and scale. This will return a Polars Decimal |
| 111 | + with the specfied scale and precision. |
| 112 | + - A `pydantic.types.condecimal` created type. |
| 113 | +
|
| 114 | + Any `ClassVar` types within `TypedDict`s, dataclasses, or `pydantic` models will be |
| 115 | + ignored. |
| 116 | +
|
| 117 | + """ |
| 118 | + type_origin = get_origin(type_annotation) |
| 119 | + |
| 120 | + # An `Optional` or `Union` type, check to ensure non-heterogenity. |
| 121 | + if type_origin is Union: |
| 122 | + python_type = _get_non_heterogenous_type(get_args(type_annotation)) |
| 123 | + return get_polars_type_from_annotation(python_type) |
| 124 | + |
| 125 | + # Type hint is e.g. `List[str]`, check to ensure non-heterogenity. |
| 126 | + if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)): |
| 127 | + element_type = _get_non_heterogenous_type(get_args(type_annotation)) |
| 128 | + return pl.List(get_polars_type_from_annotation(element_type)) # type: ignore |
| 129 | + |
| 130 | + if type_origin is Annotated: |
| 131 | + python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable |
| 132 | + return get_polars_type_from_annotation(python_type) |
| 133 | + # Ensure that we have a concrete type at this point. |
| 134 | + if not isinstance(type_annotation, type): |
| 135 | + raise ValueError(f"Unsupported type annotation {type_annotation!r}") |
| 136 | + |
| 137 | + if ( |
| 138 | + # Type hint is a dict subclass, but not dict. Possibly a `TypedDict`. |
| 139 | + (issubclass(type_annotation, dict) and type_annotation is not dict) |
| 140 | + # Type hint is a dataclass. |
| 141 | + or is_dataclass(type_annotation) |
| 142 | + # Type hint is a `pydantic` model. |
| 143 | + or (type_origin is None and issubclass(type_annotation, BaseModel)) |
| 144 | + ): |
| 145 | + fields: Dict[str, PolarsType] = {} |
| 146 | + for field_name, field_annotation in get_type_hints(type_annotation).items(): |
| 147 | + # Technically non-string keys are disallowed, but people are bad. |
| 148 | + if not isinstance(field_name, str): |
| 149 | + raise ValueError( |
| 150 | + f"Dictionary/Dataclass keys must be strings, got {type_annotation!r}" |
| 151 | + ) # pragma: no cover |
| 152 | + if get_origin(field_annotation) is ClassVar: |
| 153 | + continue |
| 154 | + |
| 155 | + fields[field_name] = get_polars_type_from_annotation(field_annotation) |
| 156 | + |
| 157 | + if not fields: |
| 158 | + raise ValueError( |
| 159 | + f"No type annotations in dict/dataclass type (got {type_annotation!r})" |
| 160 | + ) |
| 161 | + |
| 162 | + return pl.Struct(fields) # type: ignore |
| 163 | + |
| 164 | + if type_annotation is list: |
| 165 | + raise ValueError( |
| 166 | + f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}" |
| 167 | + ) |
| 168 | + if type_annotation is dict or type_origin is dict: |
| 169 | + raise ValueError(f"Dict must be `typing.TypedDict` subclass, got {type_annotation!r}") |
| 170 | + |
| 171 | + for type_ in type_annotation.mro(): |
| 172 | + polars_type = PYTHON_TYPE_TO_POLARS_TYPE.get(type_) |
| 173 | + if polars_type: |
| 174 | + return polars_type |
| 175 | + raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}") |
0 commit comments