]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/sqlalchemy_types.py
make reverse API work with sqlite
[nominatim.git] / nominatim / db / sqlalchemy_types.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Custom types for SQLAlchemy.
9 """
10 from __future__ import annotations
11 from typing import Callable, Any, cast
12 import sys
13
14 import sqlalchemy as sa
15 from sqlalchemy.ext.compiler import compiles
16 from sqlalchemy import types
17
18 from nominatim.typing import SaColumn, SaBind
19
20 #pylint: disable=all
21
22 class Geometry_DistanceSpheroid(sa.sql.expression.FunctionElement[float]):
23     """ Function to compute the spherical distance in meters.
24     """
25     type = sa.Float()
26     name = 'Geometry_DistanceSpheroid'
27     inherit_cache = True
28
29
30 @compiles(Geometry_DistanceSpheroid) # type: ignore[no-untyped-call, misc]
31 def _default_distance_spheroid(element: SaColumn,
32                                compiler: 'sa.Compiled', **kw: Any) -> str:
33     return "ST_DistanceSpheroid(%s,"\
34            " 'SPHEROID[\"WGS 84\",6378137,298.257223563, AUTHORITY[\"EPSG\",\"7030\"]]')"\
35              % compiler.process(element.clauses, **kw)
36
37
38 @compiles(Geometry_DistanceSpheroid, 'sqlite') # type: ignore[no-untyped-call, misc]
39 def _spatialite_distance_spheroid(element: SaColumn,
40                                   compiler: 'sa.Compiled', **kw: Any) -> str:
41     return "Distance(%s, true)" % compiler.process(element.clauses, **kw)
42
43
44 class Geometry_IsLineLike(sa.sql.expression.FunctionElement[bool]):
45     """ Check if the geometry is a line or multiline.
46     """
47     type = sa.Boolean()
48     name = 'Geometry_IsLineLike'
49     inherit_cache = True
50
51
52 @compiles(Geometry_IsLineLike) # type: ignore[no-untyped-call, misc]
53 def _default_is_line_like(element: SaColumn,
54                           compiler: 'sa.Compiled', **kw: Any) -> str:
55     return "ST_GeometryType(%s) IN ('ST_LineString', 'ST_MultiLineString')" % \
56                compiler.process(element.clauses, **kw)
57
58
59 @compiles(Geometry_IsLineLike, 'sqlite') # type: ignore[no-untyped-call, misc]
60 def _sqlite_is_line_like(element: SaColumn,
61                          compiler: 'sa.Compiled', **kw: Any) -> str:
62     return "ST_GeometryType(%s) IN ('LINESTRING', 'MULTILINESTRING')" % \
63                compiler.process(element.clauses, **kw)
64
65
66 class Geometry_IsAreaLike(sa.sql.expression.FunctionElement[bool]):
67     """ Check if the geometry is a polygon or multipolygon.
68     """
69     type = sa.Boolean()
70     name = 'Geometry_IsLineLike'
71     inherit_cache = True
72
73
74 @compiles(Geometry_IsAreaLike) # type: ignore[no-untyped-call, misc]
75 def _default_is_area_like(element: SaColumn,
76                           compiler: 'sa.Compiled', **kw: Any) -> str:
77     return "ST_GeometryType(%s) IN ('ST_Polygon', 'ST_MultiPolygon')" % \
78                compiler.process(element.clauses, **kw)
79
80
81 @compiles(Geometry_IsAreaLike, 'sqlite') # type: ignore[no-untyped-call, misc]
82 def _sqlite_is_area_like(element: SaColumn,
83                          compiler: 'sa.Compiled', **kw: Any) -> str:
84     return "ST_GeometryType(%s) IN ('POLYGON', 'MULTIPOLYGON')" % \
85                compiler.process(element.clauses, **kw)
86
87
88 class Geometry_IntersectsBbox(sa.sql.expression.FunctionElement[bool]):
89     """ Check if the bounding boxes of the given geometries intersect.
90     """
91     type = sa.Boolean()
92     name = 'Geometry_IntersectsBbox'
93     inherit_cache = True
94
95
96 @compiles(Geometry_IntersectsBbox) # type: ignore[no-untyped-call, misc]
97 def _default_intersects(element: SaColumn,
98                         compiler: 'sa.Compiled', **kw: Any) -> str:
99     arg1, arg2 = list(element.clauses)
100     return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
101
102
103 @compiles(Geometry_IntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
104 def _sqlite_intersects(element: SaColumn,
105                        compiler: 'sa.Compiled', **kw: Any) -> str:
106     return "MbrIntersects(%s)" % compiler.process(element.clauses, **kw)
107
108
109 class Geometry(types.UserDefinedType): # type: ignore[type-arg]
110     """ Simplified type decorator for PostGIS geometry. This type
111         only supports geometries in 4326 projection.
112     """
113     cache_ok = True
114
115     def __init__(self, subtype: str = 'Geometry'):
116         self.subtype = subtype
117
118
119     def get_col_spec(self) -> str:
120         return f'GEOMETRY({self.subtype}, 4326)'
121
122
123     def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]:
124         def process(value: Any) -> str:
125             if isinstance(value, str):
126                 return value
127
128             return cast(str, value.to_wkt())
129         return process
130
131
132     def result_processor(self, dialect: 'sa.Dialect', coltype: object) -> Callable[[Any], str]:
133         def process(value: Any) -> str:
134             assert isinstance(value, str)
135             return value
136         return process
137
138
139     def column_expression(self, col: SaColumn) -> SaColumn:
140         return sa.func.ST_AsEWKB(col)
141
142
143     def bind_expression(self, bindvalue: SaBind) -> SaColumn:
144         return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self)
145
146
147     class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
148
149         def intersects(self, other: SaColumn) -> 'sa.Operators':
150             return Geometry_IntersectsBbox(self, other)
151
152
153         def is_line_like(self) -> SaColumn:
154             return Geometry_IsLineLike(self)
155
156
157         def is_area(self) -> SaColumn:
158             return Geometry_IsAreaLike(self)
159
160
161         def ST_DWithin(self, other: SaColumn, distance: SaColumn) -> SaColumn:
162             return sa.func.ST_DWithin(self, other, distance, type_=sa.Boolean)
163
164
165         def ST_DWithin_no_index(self, other: SaColumn, distance: SaColumn) -> SaColumn:
166             return sa.func.ST_DWithin(sa.func.coalesce(sa.null(), self),
167                                       other, distance, type_=sa.Boolean)
168
169
170         def ST_Intersects_no_index(self, other: SaColumn) -> 'sa.Operators':
171             return sa.func.coalesce(sa.null(), self).op('&&')(other)
172
173
174         def ST_Distance(self, other: SaColumn) -> SaColumn:
175             return sa.func.ST_Distance(self, other, type_=sa.Float)
176
177
178         def ST_Contains(self, other: SaColumn) -> SaColumn:
179             return sa.func.ST_Contains(self, other, type_=sa.Boolean)
180
181
182         def ST_CoveredBy(self, other: SaColumn) -> SaColumn:
183             return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean)
184
185
186         def ST_ClosestPoint(self, other: SaColumn) -> SaColumn:
187             return sa.func.coalesce(sa.func.ST_ClosestPoint(self, other, type_=Geometry),
188                                     other)
189
190
191         def ST_Buffer(self, other: SaColumn) -> SaColumn:
192             return sa.func.ST_Buffer(self, other, type_=Geometry)
193
194
195         def ST_Expand(self, other: SaColumn) -> SaColumn:
196             return sa.func.ST_Expand(self, other, type_=Geometry)
197
198
199         def ST_Collect(self) -> SaColumn:
200             return sa.func.ST_Collect(self, type_=Geometry)
201
202
203         def ST_Centroid(self) -> SaColumn:
204             return sa.func.ST_Centroid(self, type_=Geometry)
205
206
207         def ST_LineInterpolatePoint(self, other: SaColumn) -> SaColumn:
208             return sa.func.ST_LineInterpolatePoint(self, other, type_=Geometry)
209
210
211         def ST_LineLocatePoint(self, other: SaColumn) -> SaColumn:
212             return sa.func.ST_LineLocatePoint(self, other, type_=sa.Float)
213
214
215         def distance_spheroid(self, other: SaColumn) -> SaColumn:
216             return Geometry_DistanceSpheroid(self, other)
217
218
219 @compiles(Geometry, 'sqlite') # type: ignore[no-untyped-call]
220 def get_col_spec(self, *args, **kwargs): # type: ignore[no-untyped-def]
221     return 'GEOMETRY'
222
223
224 SQLITE_FUNCTION_ALIAS = (
225     ('ST_AsEWKB', sa.Text, 'AsEWKB'),
226     ('ST_GeomFromEWKT', Geometry, 'GeomFromEWKT'),
227     ('ST_AsGeoJSON', sa.Text, 'AsGeoJSON'),
228     ('ST_AsKML', sa.Text, 'AsKML'),
229     ('ST_AsSVG', sa.Text, 'AsSVG'),
230     ('ST_LineLocatePoint', sa.Float, 'ST_Line_Locate_Point'),
231     ('ST_LineInterpolatePoint', sa.Float, 'ST_Line_Interpolate_Point'),
232 )
233
234 def _add_function_alias(func: str, ftype: type, alias: str) -> None:
235     _FuncDef = type(func, (sa.sql.functions.GenericFunction, ), {
236         "type": ftype(),
237         "name": func,
238         "identifier": func,
239         "inherit_cache": True})
240
241     func_templ = f"{alias}(%s)"
242
243     def _sqlite_impl(element: Any, compiler: Any, **kw: Any) -> Any:
244         return func_templ % compiler.process(element.clauses, **kw)
245
246     compiles(_FuncDef, 'sqlite')(_sqlite_impl) # type: ignore[no-untyped-call]
247
248 for alias in SQLITE_FUNCTION_ALIAS:
249     _add_function_alias(*alias)
250
251
252 class ST_DWithin(sa.sql.functions.GenericFunction[bool]):
253     type = sa.Boolean()
254     name = 'ST_DWithin'
255     inherit_cache = True
256
257
258 @compiles(ST_DWithin, 'sqlite') # type: ignore[no-untyped-call, misc]
259 def default_json_array_each(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str:
260     geom1, geom2, dist = list(element.clauses)
261     return "(MbrIntersects(%s, ST_Expand(%s, %s)) = 1 AND ST_Distance(%s, %s) <= %s)" % (
262         compiler.process(geom1, **kw), compiler.process(geom2, **kw),
263         compiler.process(dist, **kw),
264         compiler.process(geom1, **kw), compiler.process(geom2, **kw),
265         compiler.process(dist, **kw))