]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/sqlalchemy_types/int_array.py
499376cb85ca59d44119f2bcb4b4e17eeedd2f3f
[nominatim.git] / nominatim / db / sqlalchemy_types / int_array.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Custom type for an array of integers.
9 """
10 from typing import Any, List, cast, Optional
11
12 import sqlalchemy as sa
13 from sqlalchemy.ext.compiler import compiles
14 from sqlalchemy.dialects.postgresql import ARRAY
15
16 from nominatim.typing import SaDialect, SaColumn
17
18 # pylint: disable=all
19
20 class IntList(sa.types.TypeDecorator[Any]):
21     """ A list of integers saved as a text of comma-separated numbers.
22     """
23     impl = sa.types.Unicode
24     cache_ok = True
25
26     def process_bind_param(self, value: Optional[Any], dialect: 'sa.Dialect') -> Optional[str]:
27         if value is None:
28             return None
29
30         assert isinstance(value, list)
31         return ','.join(map(str, value))
32
33     def process_result_value(self, value: Optional[Any],
34                              dialect: SaDialect) -> Optional[List[int]]:
35         return [int(v) for v in value.split(',')] if value is not None else None
36
37     def copy(self, **kw: Any) -> 'IntList':
38         return IntList(self.impl.length)
39
40
41 class IntArray(sa.types.TypeDecorator[Any]):
42     """ Dialect-independent list of integers.
43     """
44     impl = IntList
45     cache_ok = True
46
47     def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
48         if dialect.name == 'postgresql':
49             return ARRAY(sa.Integer()) #pylint: disable=invalid-name
50
51         return IntList()
52
53
54     class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
55
56         def __add__(self, other: SaColumn) -> 'sa.ColumnOperators':
57             """ Concate the array with the given array. If one of the
58                 operants is null, the value of the other will be returned.
59             """
60             return sa.func.array_cat(self, other, type_=IntArray)
61
62
63         def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators':
64             """ Return true if the array contains all the value of the argument
65                 array.
66             """
67             return cast('sa.ColumnOperators', self.op('@>', is_comparison=True)(other))
68
69
70         def overlaps(self, other: SaColumn) -> 'sa.Operators':
71             """ Return true if at least one value of the argument is contained
72                 in the array.
73             """
74             return self.op('&&', is_comparison=True)(other)
75
76
77 class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
78     """ Aggregate function to collect elements in an array.
79     """
80     type = IntArray()
81     identifier = 'ArrayAgg'
82     name = 'array_agg'
83     inherit_cache = True
84
85 @compiles(ArrayAgg, 'sqlite') # type: ignore[no-untyped-call, misc]
86 def sqlite_array_agg(element: ArrayAgg, compiler: 'sa.Compiled', **kw: Any) -> str:
87     return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw)