1 # SPDX-License-Identifier: GPL-3.0-or-later
 
   3 # This file is part of Nominatim. (https://nominatim.org)
 
   5 # Copyright (C) 2023 by the Nominatim developer community.
 
   6 # For a full list of authors see the git log.
 
   8 Custom types for SQLAlchemy.
 
  10 from __future__ import annotations
 
  11 from typing import Callable, Any, cast
 
  14 import sqlalchemy as sa
 
  15 from sqlalchemy.ext.compiler import compiles
 
  16 from sqlalchemy import types
 
  18 from nominatim.typing import SaColumn, SaBind
 
  22 class Geometry_DistanceSpheroid(sa.sql.expression.FunctionElement[float]):
 
  23     """ Function to compute the spherical distance in meters.
 
  26     name = 'Geometry_DistanceSpheroid'
 
  30 @compiles(Geometry_DistanceSpheroid) # type: ignore[no-untyped-call, misc]
 
  31 def _default_distance_spheroid(element: SaColumn,
 
  32                                compiler: 'sa.Compiled', **kw: Any) -> str:
 
  33     return "ST_DistanceSpheroid(%s,"\
 
  34            " 'SPHEROID[\"WGS 84\",6378137,298.257223563, AUTHORITY[\"EPSG\",\"7030\"]]')"\
 
  35              % compiler.process(element.clauses, **kw)
 
  38 @compiles(Geometry_DistanceSpheroid, 'sqlite') # type: ignore[no-untyped-call, misc]
 
  39 def _spatialite_distance_spheroid(element: SaColumn,
 
  40                                   compiler: 'sa.Compiled', **kw: Any) -> str:
 
  41     return "COALESCE(Distance(%s, true), 0.0)" % compiler.process(element.clauses, **kw)
 
  44 class Geometry_IsLineLike(sa.sql.expression.FunctionElement[Any]):
 
  45     """ Check if the geometry is a line or multiline.
 
  47     name = 'Geometry_IsLineLike'
 
  51 @compiles(Geometry_IsLineLike) # type: ignore[no-untyped-call, misc]
 
  52 def _default_is_line_like(element: SaColumn,
 
  53                           compiler: 'sa.Compiled', **kw: Any) -> str:
 
  54     return "ST_GeometryType(%s) IN ('ST_LineString', 'ST_MultiLineString')" % \
 
  55                compiler.process(element.clauses, **kw)
 
  58 @compiles(Geometry_IsLineLike, 'sqlite') # type: ignore[no-untyped-call, misc]
 
  59 def _sqlite_is_line_like(element: SaColumn,
 
  60                          compiler: 'sa.Compiled', **kw: Any) -> str:
 
  61     return "ST_GeometryType(%s) IN ('LINESTRING', 'MULTILINESTRING')" % \
 
  62                compiler.process(element.clauses, **kw)
 
  65 class Geometry_IsAreaLike(sa.sql.expression.FunctionElement[Any]):
 
  66     """ Check if the geometry is a polygon or multipolygon.
 
  68     name = 'Geometry_IsLineLike'
 
  72 @compiles(Geometry_IsAreaLike) # type: ignore[no-untyped-call, misc]
 
  73 def _default_is_area_like(element: SaColumn,
 
  74                           compiler: 'sa.Compiled', **kw: Any) -> str:
 
  75     return "ST_GeometryType(%s) IN ('ST_Polygon', 'ST_MultiPolygon')" % \
 
  76                compiler.process(element.clauses, **kw)
 
  79 @compiles(Geometry_IsAreaLike, 'sqlite') # type: ignore[no-untyped-call, misc]
 
  80 def _sqlite_is_area_like(element: SaColumn,
 
  81                          compiler: 'sa.Compiled', **kw: Any) -> str:
 
  82     return "ST_GeometryType(%s) IN ('POLYGON', 'MULTIPOLYGON')" % \
 
  83                compiler.process(element.clauses, **kw)
 
  86 class Geometry_IntersectsBbox(sa.sql.expression.FunctionElement[Any]):
 
  87     """ Check if the bounding boxes of the given geometries intersect.
 
  89     name = 'Geometry_IntersectsBbox'
 
  93 @compiles(Geometry_IntersectsBbox) # type: ignore[no-untyped-call, misc]
 
  94 def _default_intersects(element: SaColumn,
 
  95                         compiler: 'sa.Compiled', **kw: Any) -> str:
 
  96     arg1, arg2 = list(element.clauses)
 
  97     return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
 
 100 @compiles(Geometry_IntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
 
 101 def _sqlite_intersects(element: SaColumn,
 
 102                        compiler: 'sa.Compiled', **kw: Any) -> str:
 
 103     return "MbrIntersects(%s) = 1" % compiler.process(element.clauses, **kw)
 
 106 class Geometry_ColumnIntersectsBbox(sa.sql.expression.FunctionElement[Any]):
 
 107     """ Check if the bounding box of the geometry intersects with the
 
 108         given table column, using the spatial index for the column.
 
 110         The index must exist or the query may return nothing.
 
 112     name = 'Geometry_ColumnIntersectsBbox'
 
 116 @compiles(Geometry_ColumnIntersectsBbox) # type: ignore[no-untyped-call, misc]
 
 117 def default_intersects_column(element: SaColumn,
 
 118                               compiler: 'sa.Compiled', **kw: Any) -> str:
 
 119     arg1, arg2 = list(element.clauses)
 
 120     return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
 
 123 @compiles(Geometry_ColumnIntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
 
 124 def spatialite_intersects_column(element: SaColumn,
 
 125                                  compiler: 'sa.Compiled', **kw: Any) -> str:
 
 126     arg1, arg2 = list(element.clauses)
 
 127     return "MbrIntersects(%s, %s) = 1 and "\
 
 128            "%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
 
 129                         "WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
 
 130                         "AND search_frame = %s)" %(
 
 131               compiler.process(arg1, **kw),
 
 132               compiler.process(arg2, **kw),
 
 133               arg1.table.name, arg1.table.name, arg1.name,
 
 134               compiler.process(arg2, **kw))
 
 137 class Geometry_ColumnDWithin(sa.sql.expression.FunctionElement[Any]):
 
 138     """ Check if the geometry is within the distance of the
 
 139         given table column, using the spatial index for the column.
 
 141         The index must exist or the query may return nothing.
 
 143     name = 'Geometry_ColumnDWithin'
 
 147 @compiles(Geometry_ColumnDWithin) # type: ignore[no-untyped-call, misc]
 
 148 def default_dwithin_column(element: SaColumn,
 
 149                            compiler: 'sa.Compiled', **kw: Any) -> str:
 
 150     return "ST_DWithin(%s)" % compiler.process(element.clauses, **kw)
 
 152 @compiles(Geometry_ColumnDWithin, 'sqlite') # type: ignore[no-untyped-call, misc]
 
 153 def spatialite_dwithin_column(element: SaColumn,
 
 154                               compiler: 'sa.Compiled', **kw: Any) -> str:
 
 155     geom1, geom2, dist = list(element.clauses)
 
 156     return "ST_Distance(%s, %s) < %s and "\
 
 157            "%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
 
 158                         "WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
 
 159                         "AND search_frame = ST_Expand(%s, %s))" %(
 
 160               compiler.process(geom1, **kw),
 
 161               compiler.process(geom2, **kw),
 
 162               compiler.process(dist, **kw),
 
 163               geom1.table.name, geom1.table.name, geom1.name,
 
 164               compiler.process(geom2, **kw),
 
 165               compiler.process(dist, **kw))
 
 169 class Geometry(types.UserDefinedType): # type: ignore[type-arg]
 
 170     """ Simplified type decorator for PostGIS geometry. This type
 
 171         only supports geometries in 4326 projection.
 
 175     def __init__(self, subtype: str = 'Geometry'):
 
 176         self.subtype = subtype
 
 179     def get_col_spec(self) -> str:
 
 180         return f'GEOMETRY({self.subtype}, 4326)'
 
 183     def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]:
 
 184         def process(value: Any) -> str:
 
 185             if isinstance(value, str):
 
 188             return cast(str, value.to_wkt())
 
 192     def result_processor(self, dialect: 'sa.Dialect', coltype: object) -> Callable[[Any], str]:
 
 193         def process(value: Any) -> str:
 
 194             assert isinstance(value, str)
 
 199     def column_expression(self, col: SaColumn) -> SaColumn:
 
 200         return sa.func.ST_AsEWKB(col)
 
 203     def bind_expression(self, bindvalue: SaBind) -> SaColumn:
 
 204         return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self)
 
 207     class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
 
 209         def intersects(self, other: SaColumn) -> 'sa.Operators':
 
 210             if isinstance(self.expr, sa.Column):
 
 211                 return Geometry_ColumnIntersectsBbox(self.expr, other)
 
 213             return Geometry_IntersectsBbox(self.expr, other)
 
 216         def is_line_like(self) -> SaColumn:
 
 217             return Geometry_IsLineLike(self)
 
 220         def is_area(self) -> SaColumn:
 
 221             return Geometry_IsAreaLike(self)
 
 224         def ST_DWithin(self, other: SaColumn, distance: SaColumn) -> SaColumn:
 
 225             if isinstance(self.expr, sa.Column):
 
 226                 return Geometry_ColumnDWithin(self.expr, other, distance)
 
 228             return sa.func.ST_DWithin(self.expr, other, distance)
 
 231         def ST_DWithin_no_index(self, other: SaColumn, distance: SaColumn) -> SaColumn:
 
 232             return sa.func.ST_DWithin(sa.func.coalesce(sa.null(), self),
 
 236         def ST_Intersects_no_index(self, other: SaColumn) -> 'sa.Operators':
 
 237             return Geometry_IntersectsBbox(sa.func.coalesce(sa.null(), self), other)
 
 240         def ST_Distance(self, other: SaColumn) -> SaColumn:
 
 241             return sa.func.ST_Distance(self, other, type_=sa.Float)
 
 244         def ST_Contains(self, other: SaColumn) -> SaColumn:
 
 245             return sa.func.ST_Contains(self, other, type_=sa.Boolean)
 
 248         def ST_CoveredBy(self, other: SaColumn) -> SaColumn:
 
 249             return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean)
 
 252         def ST_ClosestPoint(self, other: SaColumn) -> SaColumn:
 
 253             return sa.func.coalesce(sa.func.ST_ClosestPoint(self, other, type_=Geometry),
 
 257         def ST_Buffer(self, other: SaColumn) -> SaColumn:
 
 258             return sa.func.ST_Buffer(self, other, type_=Geometry)
 
 261         def ST_Expand(self, other: SaColumn) -> SaColumn:
 
 262             return sa.func.ST_Expand(self, other, type_=Geometry)
 
 265         def ST_Collect(self) -> SaColumn:
 
 266             return sa.func.ST_Collect(self, type_=Geometry)
 
 269         def ST_Centroid(self) -> SaColumn:
 
 270             return sa.func.ST_Centroid(self, type_=Geometry)
 
 273         def ST_LineInterpolatePoint(self, other: SaColumn) -> SaColumn:
 
 274             return sa.func.ST_LineInterpolatePoint(self, other, type_=Geometry)
 
 277         def ST_LineLocatePoint(self, other: SaColumn) -> SaColumn:
 
 278             return sa.func.ST_LineLocatePoint(self, other, type_=sa.Float)
 
 281         def distance_spheroid(self, other: SaColumn) -> SaColumn:
 
 282             return Geometry_DistanceSpheroid(self, other)
 
 285 @compiles(Geometry, 'sqlite') # type: ignore[no-untyped-call]
 
 286 def get_col_spec(self, *args, **kwargs): # type: ignore[no-untyped-def]
 
 290 SQLITE_FUNCTION_ALIAS = (
 
 291     ('ST_AsEWKB', sa.Text, 'AsEWKB'),
 
 292     ('ST_GeomFromEWKT', Geometry, 'GeomFromEWKT'),
 
 293     ('ST_AsGeoJSON', sa.Text, 'AsGeoJSON'),
 
 294     ('ST_AsKML', sa.Text, 'AsKML'),
 
 295     ('ST_AsSVG', sa.Text, 'AsSVG'),
 
 296     ('ST_LineLocatePoint', sa.Float, 'ST_Line_Locate_Point'),
 
 297     ('ST_LineInterpolatePoint', sa.Float, 'ST_Line_Interpolate_Point'),
 
 300 def _add_function_alias(func: str, ftype: type, alias: str) -> None:
 
 301     _FuncDef = type(func, (sa.sql.functions.GenericFunction, ), {
 
 305         "inherit_cache": True})
 
 307     func_templ = f"{alias}(%s)"
 
 309     def _sqlite_impl(element: Any, compiler: Any, **kw: Any) -> Any:
 
 310         return func_templ % compiler.process(element.clauses, **kw)
 
 312     compiles(_FuncDef, 'sqlite')(_sqlite_impl) # type: ignore[no-untyped-call]
 
 314 for alias in SQLITE_FUNCTION_ALIAS:
 
 315     _add_function_alias(*alias)
 
 318 class ST_DWithin(sa.sql.functions.GenericFunction[Any]):
 
 323 @compiles(ST_DWithin, 'sqlite') # type: ignore[no-untyped-call, misc]
 
 324 def default_json_array_each(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str:
 
 325     geom1, geom2, dist = list(element.clauses)
 
 326     return "(MbrIntersects(%s, ST_Expand(%s, %s)) = 1 AND ST_Distance(%s, %s) <= %s)" % (
 
 327         compiler.process(geom1, **kw), compiler.process(geom2, **kw),
 
 328         compiler.process(dist, **kw),
 
 329         compiler.process(geom1, **kw), compiler.process(geom2, **kw),
 
 330         compiler.process(dist, **kw))