]> git.openstreetmap.org Git - nominatim.git/commitdiff
factor out SQL for filtering by location
authorSarah Hoffmann <lonvia@denofr.de>
Wed, 6 Dec 2023 09:55:21 +0000 (10:55 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Thu, 7 Dec 2023 08:31:00 +0000 (09:31 +0100)
Also improves on the decision if an indexed is used or not.

nominatim/api/search/db_searches.py

index 35c12746597b1726a0d6531fab45a357b65e5c4c..a5461d6ad105fd4845955ffdeb6947dddcc12845 100644 (file)
@@ -55,12 +55,29 @@ NEAR_PARAM: SaBind = sa.bindparam('near', type_=Geometry)
 NEAR_RADIUS_PARAM: SaBind = sa.bindparam('near_radius')
 COUNTRIES_PARAM: SaBind = sa.bindparam('countries')
 
-def _within_near(t: SaFromClause) -> Callable[[], SaExpression]:
-    return lambda: t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM)
+
+def filter_by_area(sql: SaSelect, t: SaFromClause,
+                   details: SearchDetails, avoid_index: bool = False) -> SaSelect:
+    """ Apply SQL statements for filtering by viewbox and near point,
+        if applicable.
+    """
+    if details.near is not None and details.near_radius is not None:
+        if details.near_radius < 0.1 and not avoid_index:
+            sql = sql.where(t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM))
+        else:
+            sql = sql.where(t.c.geometry.ST_Distance(NEAR_PARAM) <= NEAR_RADIUS_PARAM)
+    if details.viewbox is not None and details.bounded_viewbox:
+        sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM,
+                                                use_index=not avoid_index and
+                                                          details.viewbox.area < 0.2))
+
+    return sql
+
 
 def _exclude_places(t: SaFromClause) -> Callable[[], SaExpression]:
     return lambda: t.c.place_id.not_in(sa.bindparam('excluded'))
 
+
 def _select_placex(t: SaFromClause) -> SaSelect:
     return sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
                      t.c.class_, t.c.type,
@@ -449,11 +466,7 @@ class CountrySearch(AbstractSearch):
         if details.excluded:
             sql = sql.where(_exclude_places(t))
 
-        if details.viewbox is not None and details.bounded_viewbox:
-            sql = sql.where(lambda: t.c.geometry.intersects(VIEWBOX_PARAM))
-
-        if details.near is not None and details.near_radius is not None:
-            sql = sql.where(_within_near(t))
+        sql = filter_by_area(sql, t, details)
 
         results = nres.SearchResults()
         for row in await conn.execute(sql, _details_to_bind_params(details)):
@@ -486,10 +499,7 @@ class CountrySearch(AbstractSearch):
                 .where(tgrid.c.country_code.in_(self.countries.values))\
                 .group_by(tgrid.c.country_code)
 
-        if details.viewbox is not None and details.bounded_viewbox:
-            sql = sql.where(tgrid.c.geometry.intersects(VIEWBOX_PARAM))
-        if details.near is not None and details.near_radius is not None:
-            sql = sql.where(_within_near(tgrid))
+        sql = filter_by_area(sql, tgrid, details, avoid_index=True)
 
         sub = sql.subquery('grid')
 
@@ -542,19 +552,16 @@ class PostcodeSearch(AbstractSearch):
 
         penalty: SaExpression = sa.literal(self.penalty)
 
-        if details.viewbox is not None:
-            if details.bounded_viewbox:
-                sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM))
-            else:
-                penalty += sa.case((t.c.geometry.intersects(VIEWBOX_PARAM), 0.0),
-                                   (t.c.geometry.intersects(VIEWBOX2_PARAM), 0.5),
-                                   else_=1.0)
+        if details.viewbox is not None and not details.bounded_viewbox:
+            penalty += sa.case((t.c.geometry.intersects(VIEWBOX_PARAM), 0.0),
+                               (t.c.geometry.intersects(VIEWBOX2_PARAM), 0.5),
+                               else_=1.0)
 
         if details.near is not None:
-            if details.near_radius is not None:
-                sql = sql.where(_within_near(t))
             sql = sql.order_by(t.c.geometry.ST_Distance(NEAR_PARAM))
 
+        sql = filter_by_area(sql, t, details)
+
         if self.countries:
             sql = sql.where(t.c.country_code.in_(self.countries.values))