1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Custom type for an array of integers.
10 from typing import Any, List, cast, Optional
12 import sqlalchemy as sa
13 from sqlalchemy.ext.compiler import compiles
14 from sqlalchemy.dialects.postgresql import ARRAY
16 from nominatim.typing import SaDialect, SaColumn
20 class IntList(sa.types.TypeDecorator[Any]):
21 """ A list of integers saved as a text of comma-separated numbers.
23 impl = sa.types.Unicode
26 def process_bind_param(self, value: Optional[Any], dialect: 'sa.Dialect') -> Optional[str]:
30 assert isinstance(value, list)
31 return ','.join(map(str, value))
33 def process_result_value(self, value: Optional[Any],
34 dialect: SaDialect) -> Optional[List[int]]:
35 return [int(v) for v in value.split(',')] if value is not None else None
37 def copy(self, **kw: Any) -> 'IntList':
38 return IntList(self.impl.length)
41 class IntArray(sa.types.TypeDecorator[Any]):
42 """ Dialect-independent list of integers.
47 def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
48 if dialect.name == 'postgresql':
49 return ARRAY(sa.Integer()) #pylint: disable=invalid-name
54 class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
56 def __add__(self, other: SaColumn) -> 'sa.ColumnOperators':
57 """ Concate the array with the given array. If one of the
58 operants is null, the value of the other will be returned.
60 return sa.func.array_cat(self, other, type_=IntArray)
63 def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators':
64 """ Return true if the array contains all the value of the argument
67 return cast('sa.ColumnOperators', self.op('@>', is_comparison=True)(other))
70 def overlaps(self, other: SaColumn) -> 'sa.Operators':
71 """ Return true if at least one value of the argument is contained
74 return self.op('&&', is_comparison=True)(other)
77 class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
78 """ Aggregate function to collect elements in an array.
81 identifier = 'ArrayAgg'
85 @compiles(ArrayAgg, 'sqlite') # type: ignore[no-untyped-call, misc]
86 def sqlite_array_agg(element: ArrayAgg, compiler: 'sa.Compiled', **kw: Any) -> str:
87 return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw)