]> git.openstreetmap.org Git - nominatim.git/commitdiff
enable all API tests for sqlite and port missing features
authorSarah Hoffmann <lonvia@denofr.de>
Wed, 6 Dec 2023 19:56:21 +0000 (20:56 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Thu, 7 Dec 2023 08:32:02 +0000 (09:32 +0100)
15 files changed:
nominatim/api/core.py
nominatim/api/search/db_search_lookups.py
nominatim/api/search/db_searches.py
nominatim/db/sqlalchemy_functions.py
nominatim/db/sqlalchemy_types/int_array.py
nominatim/db/sqlalchemy_types/key_value.py
nominatim/db/sqlite_functions.py [new file with mode: 0644]
nominatim/tools/convert_sqlite.py
test/python/api/conftest.py
test/python/api/search/test_search_country.py
test/python/api/search/test_search_near.py
test/python/api/search/test_search_places.py
test/python/api/search/test_search_poi.py
test/python/api/search/test_search_postcode.py
test/python/api/test_api_search.py

index b2624227586160c72924e80e25dc887f4150f8aa..f975f44aaec39b3e358329f67c0c0da70f702017 100644 (file)
@@ -19,6 +19,7 @@ import sqlalchemy.ext.asyncio as sa_asyncio
 from nominatim.errors import UsageError
 from nominatim.db.sqlalchemy_schema import SearchTables
 from nominatim.db.async_core_library import PGCORE_LIB, PGCORE_ERROR
 from nominatim.errors import UsageError
 from nominatim.db.sqlalchemy_schema import SearchTables
 from nominatim.db.async_core_library import PGCORE_LIB, PGCORE_ERROR
+import nominatim.db.sqlite_functions
 from nominatim.config import Configuration
 from nominatim.api.connection import SearchConnection
 from nominatim.api.status import get_status, StatusResult
 from nominatim.config import Configuration
 from nominatim.api.connection import SearchConnection
 from nominatim.api.status import get_status, StatusResult
@@ -122,6 +123,7 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
                 @sa.event.listens_for(engine.sync_engine, "connect")
                 def _on_sqlite_connect(dbapi_con: Any, _: Any) -> None:
                     dbapi_con.run_async(lambda conn: conn.enable_load_extension(True))
                 @sa.event.listens_for(engine.sync_engine, "connect")
                 def _on_sqlite_connect(dbapi_con: Any, _: Any) -> None:
                     dbapi_con.run_async(lambda conn: conn.enable_load_extension(True))
+                    nominatim.db.sqlite_functions.install_custom_functions(dbapi_con)
                     cursor = dbapi_con.cursor()
                     cursor.execute("SELECT load_extension('mod_spatialite')")
                     cursor.execute('SELECT SetDecimalPrecision(7)')
                     cursor = dbapi_con.cursor()
                     cursor.execute("SELECT load_extension('mod_spatialite')")
                     cursor.execute('SELECT SetDecimalPrecision(7)')
index 3e307235b850b954a24a5221273fa5c930785053..aa5cef5f47e491d68fa6b69961f303fcb3b8dcb0 100644 (file)
@@ -26,18 +26,38 @@ class LookupAll(LookupType):
     inherit_cache = True
 
     def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
     inherit_cache = True
 
     def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
-        super().__init__(getattr(table.c, column),
+        super().__init__(table.c.place_id, getattr(table.c, column), column,
                          sa.type_coerce(tokens, IntArray))
 
 
 @compiles(LookupAll) # type: ignore[no-untyped-call, misc]
 def _default_lookup_all(element: LookupAll,
                         compiler: 'sa.Compiled', **kw: Any) -> str:
                          sa.type_coerce(tokens, IntArray))
 
 
 @compiles(LookupAll) # type: ignore[no-untyped-call, misc]
 def _default_lookup_all(element: LookupAll,
                         compiler: 'sa.Compiled', **kw: Any) -> str:
-    col, tokens = list(element.clauses)
+    _, col, _, tokens = list(element.clauses)
     return "(%s @> %s)" % (compiler.process(col, **kw),
                            compiler.process(tokens, **kw))
 
 
     return "(%s @> %s)" % (compiler.process(col, **kw),
                            compiler.process(tokens, **kw))
 
 
+@compiles(LookupAll, 'sqlite') # type: ignore[no-untyped-call, misc]
+def _sqlite_lookup_all(element: LookupAll,
+                        compiler: 'sa.Compiled', **kw: Any) -> str:
+    place, col, colname, tokens = list(element.clauses)
+    return "(%s IN (SELECT CAST(value as bigint) FROM"\
+           " (SELECT array_intersect_fuzzy(places) as p FROM"\
+           "   (SELECT places FROM reverse_search_name"\
+           "   WHERE word IN (SELECT value FROM json_each('[' || %s || ']'))"\
+           "     AND column = %s"\
+           "   ORDER BY length(places)) as x) as u,"\
+           " json_each('[' || u.p || ']'))"\
+           " AND array_contains(%s, %s))"\
+             % (compiler.process(place, **kw),
+                compiler.process(tokens, **kw),
+                compiler.process(colname, **kw),
+                compiler.process(col, **kw),
+                compiler.process(tokens, **kw)
+                )
+
+
 
 class LookupAny(LookupType):
     """ Find all entries that contain at least one of the given tokens.
 
 class LookupAny(LookupType):
     """ Find all entries that contain at least one of the given tokens.
@@ -46,17 +66,28 @@ class LookupAny(LookupType):
     inherit_cache = True
 
     def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
     inherit_cache = True
 
     def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
-        super().__init__(getattr(table.c, column),
+        super().__init__(table.c.place_id, getattr(table.c, column), column,
                          sa.type_coerce(tokens, IntArray))
 
                          sa.type_coerce(tokens, IntArray))
 
-
 @compiles(LookupAny) # type: ignore[no-untyped-call, misc]
 def _default_lookup_any(element: LookupAny,
                         compiler: 'sa.Compiled', **kw: Any) -> str:
 @compiles(LookupAny) # type: ignore[no-untyped-call, misc]
 def _default_lookup_any(element: LookupAny,
                         compiler: 'sa.Compiled', **kw: Any) -> str:
-    col, tokens = list(element.clauses)
+    _, col, _, tokens = list(element.clauses)
     return "(%s && %s)" % (compiler.process(col, **kw),
                            compiler.process(tokens, **kw))
 
     return "(%s && %s)" % (compiler.process(col, **kw),
                            compiler.process(tokens, **kw))
 
+@compiles(LookupAny, 'sqlite') # type: ignore[no-untyped-call, misc]
+def _sqlite_lookup_any(element: LookupAny,
+                        compiler: 'sa.Compiled', **kw: Any) -> str:
+    place, _, colname, tokens = list(element.clauses)
+    return "%s IN (SELECT CAST(value as bigint) FROM"\
+           " (SELECT array_union(places) as p FROM reverse_search_name"\
+           "   WHERE word IN (SELECT value FROM json_each('[' || %s || ']'))"\
+           "     AND column = %s) as u,"\
+           " json_each('[' || u.p || ']'))" % (compiler.process(place, **kw),
+                                               compiler.process(tokens, **kw),
+                                               compiler.process(colname, **kw))
+
 
 
 class Restrict(LookupType):
 
 
 class Restrict(LookupType):
@@ -76,3 +107,8 @@ def _default_restrict(element: Restrict,
     arg1, arg2 = list(element.clauses)
     return "(coalesce(null, %s) @> %s)" % (compiler.process(arg1, **kw),
                                            compiler.process(arg2, **kw))
     arg1, arg2 = list(element.clauses)
     return "(coalesce(null, %s) @> %s)" % (compiler.process(arg1, **kw),
                                            compiler.process(arg2, **kw))
+
+@compiles(Restrict, 'sqlite') # type: ignore[no-untyped-call, misc]
+def _sqlite_restrict(element: Restrict,
+                        compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "array_contains(%s)" % compiler.process(element.clauses, **kw)
index c56554fdc0f3899fcad113540971cc5042aad027..ee98100c637fbe3365b99c85951e07bf92c4055f 100644 (file)
@@ -11,7 +11,6 @@ from typing import List, Tuple, AsyncIterator, Dict, Any, Callable
 import abc
 
 import sqlalchemy as sa
 import abc
 
 import sqlalchemy as sa
-from sqlalchemy.dialects.postgresql import array_agg
 
 from nominatim.typing import SaFromClause, SaScalarSelect, SaColumn, \
                              SaExpression, SaSelect, SaLambdaSelect, SaRow, SaBind
 
 from nominatim.typing import SaFromClause, SaScalarSelect, SaColumn, \
                              SaExpression, SaSelect, SaLambdaSelect, SaRow, SaBind
@@ -19,7 +18,7 @@ from nominatim.api.connection import SearchConnection
 from nominatim.api.types import SearchDetails, DataLayer, GeometryFormat, Bbox
 import nominatim.api.results as nres
 from nominatim.api.search.db_search_fields import SearchData, WeightedCategories
 from nominatim.api.types import SearchDetails, DataLayer, GeometryFormat, Bbox
 import nominatim.api.results as nres
 from nominatim.api.search.db_search_fields import SearchData, WeightedCategories
-from nominatim.db.sqlalchemy_types import Geometry
+from nominatim.db.sqlalchemy_types import Geometry, IntArray
 
 #pylint: disable=singleton-comparison,not-callable
 #pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements
 
 #pylint: disable=singleton-comparison,not-callable
 #pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements
@@ -110,7 +109,7 @@ def _add_geometry_columns(sql: SaLambdaSelect, col: SaColumn, details: SearchDet
 
 def _make_interpolation_subquery(table: SaFromClause, inner: SaFromClause,
                                  numerals: List[int], details: SearchDetails) -> SaScalarSelect:
 
 def _make_interpolation_subquery(table: SaFromClause, inner: SaFromClause,
                                  numerals: List[int], details: SearchDetails) -> SaScalarSelect:
-    all_ids = array_agg(table.c.place_id) # type: ignore[no-untyped-call]
+    all_ids = sa.func.ArrayAgg(table.c.place_id)
     sql = sa.select(all_ids).where(table.c.parent_place_id == inner.c.place_id)
 
     if len(numerals) == 1:
     sql = sa.select(all_ids).where(table.c.parent_place_id == inner.c.place_id)
 
     if len(numerals) == 1:
@@ -134,9 +133,7 @@ def _filter_by_layer(table: SaFromClause, layers: DataLayer) -> SaColumn:
         orexpr.append(no_index(table.c.rank_address).between(1, 30))
     elif layers & DataLayer.ADDRESS:
         orexpr.append(no_index(table.c.rank_address).between(1, 29))
         orexpr.append(no_index(table.c.rank_address).between(1, 30))
     elif layers & DataLayer.ADDRESS:
         orexpr.append(no_index(table.c.rank_address).between(1, 29))
-        orexpr.append(sa.and_(no_index(table.c.rank_address) == 30,
-                              sa.or_(table.c.housenumber != None,
-                                     table.c.address.has_key('addr:housename'))))
+        orexpr.append(sa.func.IsAddressPoint(table))
     elif layers & DataLayer.POI:
         orexpr.append(sa.and_(no_index(table.c.rank_address) == 30,
                               table.c.class_.not_in(('place', 'building'))))
     elif layers & DataLayer.POI:
         orexpr.append(sa.and_(no_index(table.c.rank_address) == 30,
                               table.c.class_.not_in(('place', 'building'))))
@@ -188,12 +185,21 @@ async def _get_placex_housenumbers(conn: SearchConnection,
         yield result
 
 
         yield result
 
 
+def _int_list_to_subquery(inp: List[int]) -> 'sa.Subquery':
+    """ Create a subselect that returns the given list of integers
+        as rows in the column 'nr'.
+    """
+    vtab = sa.func.JsonArrayEach(sa.type_coerce(inp, sa.JSON))\
+               .table_valued(sa.column('value', type_=sa.JSON)) # type: ignore[no-untyped-call]
+    return sa.select(sa.cast(sa.cast(vtab.c.value, sa.Text), sa.Integer).label('nr')).subquery()
+
+
 async def _get_osmline(conn: SearchConnection, place_ids: List[int],
                        numerals: List[int],
                        details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
     t = conn.t.osmline
 async def _get_osmline(conn: SearchConnection, place_ids: List[int],
                        numerals: List[int],
                        details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
     t = conn.t.osmline
-    values = sa.values(sa.Column('nr', sa.Integer()), name='housenumber')\
-               .data([(n,) for n in numerals])
+
+    values = _int_list_to_subquery(numerals)
     sql = sa.select(t.c.place_id, t.c.osm_id,
                     t.c.parent_place_id, t.c.address,
                     values.c.nr.label('housenumber'),
     sql = sa.select(t.c.place_id, t.c.osm_id,
                     t.c.parent_place_id, t.c.address,
                     values.c.nr.label('housenumber'),
@@ -216,8 +222,7 @@ async def _get_tiger(conn: SearchConnection, place_ids: List[int],
                      numerals: List[int], osm_id: int,
                      details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
     t = conn.t.tiger
                      numerals: List[int], osm_id: int,
                      details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
     t = conn.t.tiger
-    values = sa.values(sa.Column('nr', sa.Integer()), name='housenumber')\
-               .data([(n,) for n in numerals])
+    values = _int_list_to_subquery(numerals)
     sql = sa.select(t.c.place_id, t.c.parent_place_id,
                     sa.literal('W').label('osm_type'),
                     sa.literal(osm_id).label('osm_id'),
     sql = sa.select(t.c.place_id, t.c.parent_place_id,
                     sa.literal('W').label('osm_type'),
                     sa.literal(osm_id).label('osm_id'),
@@ -573,7 +578,8 @@ class PostcodeSearch(AbstractSearch):
             tsearch = conn.t.search_name
             sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\
                      .where((tsearch.c.name_vector + tsearch.c.nameaddress_vector)
             tsearch = conn.t.search_name
             sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\
                      .where((tsearch.c.name_vector + tsearch.c.nameaddress_vector)
-                                     .contains(self.lookups[0].tokens))
+                                     .contains(sa.type_coerce(self.lookups[0].tokens,
+                                                              IntArray)))
 
         for ranking in self.rankings:
             penalty += ranking.sql_penalty(conn.t.search_name)
 
         for ranking in self.rankings:
             penalty += ranking.sql_penalty(conn.t.search_name)
@@ -692,10 +698,10 @@ class PlaceSearch(AbstractSearch):
             sql = sql.order_by(sa.text('accuracy'))
 
         if self.housenumbers:
             sql = sql.order_by(sa.text('accuracy'))
 
         if self.housenumbers:
-            hnr_regexp = f"\\m({'|'.join(self.housenumbers.values)})\\M"
+            hnr_list = '|'.join(self.housenumbers.values)
             sql = sql.where(tsearch.c.address_rank.between(16, 30))\
                      .where(sa.or_(tsearch.c.address_rank < 30,
             sql = sql.where(tsearch.c.address_rank.between(16, 30))\
                      .where(sa.or_(tsearch.c.address_rank < 30,
-                                   t.c.housenumber.op('~*')(hnr_regexp)))
+                                   sa.func.RegexpWord(hnr_list, t.c.housenumber)))
 
             # Cross check for housenumbers, need to do that on a rather large
             # set. Worst case there are 40.000 main streets in OSM.
 
             # Cross check for housenumbers, need to do that on a rather large
             # set. Worst case there are 40.000 main streets in OSM.
@@ -703,10 +709,10 @@ class PlaceSearch(AbstractSearch):
 
             # Housenumbers from placex
             thnr = conn.t.placex.alias('hnr')
 
             # Housenumbers from placex
             thnr = conn.t.placex.alias('hnr')
-            pid_list = array_agg(thnr.c.place_id) # type: ignore[no-untyped-call]
+            pid_list = sa.func.ArrayAgg(thnr.c.place_id)
             place_sql = sa.select(pid_list)\
                           .where(thnr.c.parent_place_id == inner.c.place_id)\
             place_sql = sa.select(pid_list)\
                           .where(thnr.c.parent_place_id == inner.c.place_id)\
-                          .where(thnr.c.housenumber.op('~*')(hnr_regexp))\
+                          .where(sa.func.RegexpWord(hnr_list, thnr.c.housenumber))\
                           .where(thnr.c.linked_place_id == None)\
                           .where(thnr.c.indexed_status == 0)
 
                           .where(thnr.c.linked_place_id == None)\
                           .where(thnr.c.indexed_status == 0)
 
index 5872401cca449d47790cb25e9641699122767e7b..e2437dd2e34c4ad4b5080558f8b4dee28ceb4cb1 100644 (file)
@@ -188,6 +188,7 @@ def sqlite_json_array_each(element: JsonArrayEach, compiler: 'sa.Compiled', **kw
     return "json_each(%s)" % compiler.process(element.clauses, **kw)
 
 
     return "json_each(%s)" % compiler.process(element.clauses, **kw)
 
 
+
 class Greatest(sa.sql.functions.GenericFunction[Any]):
     """ Function to compute maximum of all its input parameters.
     """
 class Greatest(sa.sql.functions.GenericFunction[Any]):
     """ Function to compute maximum of all its input parameters.
     """
@@ -198,3 +199,23 @@ class Greatest(sa.sql.functions.GenericFunction[Any]):
 @compiles(Greatest, 'sqlite') # type: ignore[no-untyped-call, misc]
 def sqlite_greatest(element: Greatest, compiler: 'sa.Compiled', **kw: Any) -> str:
     return "max(%s)" % compiler.process(element.clauses, **kw)
 @compiles(Greatest, 'sqlite') # type: ignore[no-untyped-call, misc]
 def sqlite_greatest(element: Greatest, compiler: 'sa.Compiled', **kw: Any) -> str:
     return "max(%s)" % compiler.process(element.clauses, **kw)
+
+
+
+class RegexpWord(sa.sql.functions.GenericFunction[Any]):
+    """ Check if a full word is in a given string.
+    """
+    name = 'RegexpWord'
+    inherit_cache = True
+
+
+@compiles(RegexpWord, 'postgresql') # type: ignore[no-untyped-call, misc]
+def postgres_regexp_nocase(element: RegexpWord, compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "%s ~* ('\\m(' || %s  || ')\\M')::text" % (compiler.process(arg2, **kw), compiler.process(arg1, **kw))
+
+
+@compiles(RegexpWord, 'sqlite') # type: ignore[no-untyped-call, misc]
+def sqlite_regexp_nocase(element: RegexpWord, compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "regexp('\\b(' || %s  || ')\\b', %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
index 499376cb85ca59d44119f2bcb4b4e17eeedd2f3f..a31793f3f523a686d8698854487e6d779f331ab8 100644 (file)
@@ -57,22 +57,16 @@ class IntArray(sa.types.TypeDecorator[Any]):
             """ Concate the array with the given array. If one of the
                 operants is null, the value of the other will be returned.
             """
             """ Concate the array with the given array. If one of the
                 operants is null, the value of the other will be returned.
             """
-            return sa.func.array_cat(self, other, type_=IntArray)
+            return ArrayCat(self.expr, other)
 
 
         def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators':
             """ Return true if the array contains all the value of the argument
                 array.
             """
 
 
         def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators':
             """ Return true if the array contains all the value of the argument
                 array.
             """
-            return cast('sa.ColumnOperators', self.op('@>', is_comparison=True)(other))
+            return ArrayContains(self.expr, other)
 
 
 
 
-        def overlaps(self, other: SaColumn) -> 'sa.Operators':
-            """ Return true if at least one value of the argument is contained
-                in the array.
-            """
-            return self.op('&&', is_comparison=True)(other)
-
 
 class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
     """ Aggregate function to collect elements in an array.
 
 class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
     """ Aggregate function to collect elements in an array.
@@ -82,6 +76,48 @@ class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
     name = 'array_agg'
     inherit_cache = True
 
     name = 'array_agg'
     inherit_cache = True
 
+
 @compiles(ArrayAgg, 'sqlite') # type: ignore[no-untyped-call, misc]
 def sqlite_array_agg(element: ArrayAgg, compiler: 'sa.Compiled', **kw: Any) -> str:
     return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw)
 @compiles(ArrayAgg, 'sqlite') # type: ignore[no-untyped-call, misc]
 def sqlite_array_agg(element: ArrayAgg, compiler: 'sa.Compiled', **kw: Any) -> str:
     return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw)
+
+
+
+class ArrayContains(sa.sql.expression.FunctionElement[Any]):
+    """ Function to check if an array is fully contained in another.
+    """
+    name = 'ArrayContains'
+    inherit_cache = True
+
+
+@compiles(ArrayContains) # type: ignore[no-untyped-call, misc]
+def generic_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "(%s @> %s)" % (compiler.process(arg1, **kw),
+                           compiler.process(arg2, **kw))
+
+
+@compiles(ArrayContains, 'sqlite') # type: ignore[no-untyped-call, misc]
+def sqlite_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "array_contains(%s)" % compiler.process(element.clauses, **kw)
+
+
+
+class ArrayCat(sa.sql.expression.FunctionElement[Any]):
+    """ Function to check if an array is fully contained in another.
+    """
+    type = IntArray()
+    identifier = 'ArrayCat'
+    inherit_cache = True
+
+
+@compiles(ArrayCat) # type: ignore[no-untyped-call, misc]
+def generic_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "array_cat(%s)" % compiler.process(element.clauses, **kw)
+
+
+@compiles(ArrayCat, 'sqlite') # type: ignore[no-untyped-call, misc]
+def sqlite_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "(%s || ',' || %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
+
index 4f2d824aff8ed49ec72178df07fc4ea1116fe791..937caa021b1058b9f3757db587dc589e0a5ca69c 100644 (file)
@@ -10,6 +10,7 @@ A custom type that implements a simple key-value store of strings.
 from typing import Any
 
 import sqlalchemy as sa
 from typing import Any
 
 import sqlalchemy as sa
+from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.dialects.postgresql import HSTORE
 from sqlalchemy.dialects.sqlite import JSON as sqlite_json
 
 from sqlalchemy.dialects.postgresql import HSTORE
 from sqlalchemy.dialects.sqlite import JSON as sqlite_json
 
@@ -37,11 +38,25 @@ class KeyValueStore(sa.types.TypeDecorator[Any]):
                 one, overwriting values where necessary. When the argument
                 is null, nothing happens.
             """
                 one, overwriting values where necessary. When the argument
                 is null, nothing happens.
             """
-            return self.op('||')(sa.func.coalesce(other,
-                                                  sa.type_coerce('', KeyValueStore)))
+            return KeyValueConcat(self.expr, other)
+
+
+class KeyValueConcat(sa.sql.expression.FunctionElement[Any]):
+    """ Return the merged key-value store from the input parameters.
+    """
+    type = KeyValueStore()
+    name = 'JsonConcat'
+    inherit_cache = True
+
+@compiles(KeyValueConcat) # type: ignore[no-untyped-call, misc]
+def default_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "(%s || coalesce(%s, ''::hstore))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
+
+@compiles(KeyValueConcat, 'sqlite') # type: ignore[no-untyped-call, misc]
+def sqlite_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "json_patch(%s, coalesce(%s, '{}'))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
+
 
 
 
 
-        def has_key(self, key: SaColumn) -> 'sa.Operators':
-            """ Return true if the key is cotained in the store.
-            """
-            return self.op('?', is_comparison=True)(key)
diff --git a/nominatim/db/sqlite_functions.py b/nominatim/db/sqlite_functions.py
new file mode 100644 (file)
index 0000000..2134ae4
--- /dev/null
@@ -0,0 +1,122 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Custom functions for SQLite.
+"""
+from typing import cast, Optional, Set, Any
+import json
+
+# pylint: disable=protected-access
+
+def weigh_search(search_vector: Optional[str], rankings: str, default: float) -> float:
+    """ Custom weight function for search results.
+    """
+    if search_vector is not None:
+        svec = [int(x) for x in search_vector.split(',')]
+        for rank in json.loads(rankings):
+            if all(r in svec for r in rank[1]):
+                return cast(float, rank[0])
+
+    return default
+
+
+class ArrayIntersectFuzzy:
+    """ Compute the array of common elements of all input integer arrays.
+        Very large input paramenters may be ignored to speed up
+        computation. Therefore, the result is a superset of common elements.
+
+        Input and output arrays are given as comma-separated lists.
+    """
+    def __init__(self) -> None:
+        self.first = ''
+        self.values: Optional[Set[int]] = None
+
+    def step(self, value: Optional[str]) -> None:
+        """ Add the next array to the intersection.
+        """
+        if value is not None:
+            if not self.first:
+                self.first = value
+            elif len(value) < 10000000:
+                if self.values is None:
+                    self.values = {int(x) for x in self.first.split(',')}
+                self.values.intersection_update((int(x) for x in value.split(',')))
+
+    def finalize(self) -> str:
+        """ Return the final result.
+        """
+        if self.values is not None:
+            return ','.join(map(str, self.values))
+
+        return self.first
+
+
+class ArrayUnion:
+    """ Compute the set of all elements of the input integer arrays.
+
+        Input and output arrays are given as strings of comma-separated lists.
+    """
+    def __init__(self) -> None:
+        self.values: Optional[Set[str]] = None
+
+    def step(self, value: Optional[str]) -> None:
+        """ Add the next array to the union.
+        """
+        if value is not None:
+            if self.values is None:
+                self.values = set(value.split(','))
+            else:
+                self.values.update(value.split(','))
+
+    def finalize(self) -> str:
+        """ Return the final result.
+        """
+        return '' if self.values is None else ','.join(self.values)
+
+
+def array_contains(container: Optional[str], containee: Optional[str]) -> Optional[bool]:
+    """ Is the array 'containee' completely contained in array 'container'.
+    """
+    if container is None or containee is None:
+        return None
+
+    vset = container.split(',')
+    return all(v in vset for v in containee.split(','))
+
+
+def array_pair_contains(container1: Optional[str], container2: Optional[str],
+                        containee: Optional[str]) -> Optional[bool]:
+    """ Is the array 'containee' completely contained in the union of
+        array 'container1' and array 'container2'.
+    """
+    if container1 is None or container2 is None or containee is None:
+        return None
+
+    vset = container1.split(',') + container2.split(',')
+    return all(v in vset for v in containee.split(','))
+
+
+def install_custom_functions(conn: Any) -> None:
+    """ Install helper functions for Nominatim into the given SQLite
+        database connection.
+    """
+    conn.create_function('weigh_search', 3, weigh_search, deterministic=True)
+    conn.create_function('array_contains', 2, array_contains, deterministic=True)
+    conn.create_function('array_pair_contains', 3, array_pair_contains, deterministic=True)
+    _create_aggregate(conn, 'array_intersect_fuzzy', 1, ArrayIntersectFuzzy)
+    _create_aggregate(conn, 'array_union', 1, ArrayUnion)
+
+
+async def _make_aggregate(aioconn: Any, *args: Any) -> None:
+    await aioconn._execute(aioconn._conn.create_aggregate, *args)
+
+
+def _create_aggregate(conn: Any, name: str, nargs: int, aggregate: Any) -> None:
+    try:
+        conn.await_(_make_aggregate(conn._connection, name, nargs, aggregate))
+    except Exception as error: # pylint: disable=broad-exception-caught
+        conn._handle_exception(error)
index d9e39ba37402b7a9dcc7455fa1665ce978b82eb9..16139c5fbcf6d41a55dd97f06db1fd8912da4b2a 100644 (file)
@@ -205,15 +205,15 @@ class SqliteWriter:
     async def create_search_index(self) -> None:
         """ Create the tables and indexes needed for word lookup.
         """
     async def create_search_index(self) -> None:
         """ Create the tables and indexes needed for word lookup.
         """
+        LOG.warning("Creating reverse search table")
+        rsn = sa.Table('reverse_search_name', self.dest.t.meta,
+                       sa.Column('word', sa.Integer()),
+                       sa.Column('column', sa.Text()),
+                       sa.Column('places', IntArray))
+        await self.dest.connection.run_sync(rsn.create)
+
         tsrc = self.src.t.search_name
         for column in ('name_vector', 'nameaddress_vector'):
         tsrc = self.src.t.search_name
         for column in ('name_vector', 'nameaddress_vector'):
-            table_name = f'reverse_search_{column}'
-            LOG.warning("Creating reverse search %s", table_name)
-            rsn = sa.Table(table_name, self.dest.t.meta,
-                           sa.Column('word', sa.Integer()),
-                           sa.Column('places', IntArray))
-            await self.dest.connection.run_sync(rsn.create)
-
             sql = sa.select(sa.func.unnest(getattr(tsrc.c, column)).label('word'),
                             sa.func.ArrayAgg(tsrc.c.place_id).label('places'))\
                     .group_by('word')
             sql = sa.select(sa.func.unnest(getattr(tsrc.c, column)).label('word'),
                             sa.func.ArrayAgg(tsrc.c.place_id).label('places'))\
                     .group_by('word')
@@ -224,11 +224,12 @@ class SqliteWriter:
                 for row in partition:
                     row.places.sort()
                     data.append({'word': row.word,
                 for row in partition:
                     row.places.sort()
                     data.append({'word': row.word,
+                                 'column': column,
                                  'places': row.places})
                 await self.dest.execute(rsn.insert(), data)
 
                                  'places': row.places})
                 await self.dest.execute(rsn.insert(), data)
 
-            await self.dest.connection.run_sync(
-                sa.Index(f'idx_reverse_search_{column}_word', rsn.c.word).create)
+        await self.dest.connection.run_sync(
+            sa.Index('idx_reverse_search_name_word', rsn.c.word).create)
 
 
     def select_from(self, table: str) -> SaSelect:
 
 
     def select_from(self, table: str) -> SaSelect:
index 91a3107fbcc3cc6692d5a3a515c4cf908f951dff..05eaddf5fc0f182cfc501504e48b8865fdb9af95 100644 (file)
@@ -16,6 +16,7 @@ import sqlalchemy as sa
 
 import nominatim.api as napi
 from nominatim.db.sql_preprocessor import SQLPreprocessor
 
 import nominatim.api as napi
 from nominatim.db.sql_preprocessor import SQLPreprocessor
+from nominatim.api.search.query_analyzer_factory import make_query_analyzer
 from nominatim.tools import convert_sqlite
 import nominatim.api.logging as loglib
 
 from nominatim.tools import convert_sqlite
 import nominatim.api.logging as loglib
 
@@ -160,6 +161,22 @@ class APITester:
                                      """)))
 
 
                                      """)))
 
 
+    def add_word_table(self, content):
+        data = [dict(zip(['word_id', 'word_token', 'type', 'word', 'info'], c))
+                for c in content]
+
+        async def _do_sql():
+            async with self.api._async_api.begin() as conn:
+                if 'word' not in conn.t.meta.tables:
+                    await make_query_analyzer(conn)
+                    word_table = conn.t.meta.tables['word']
+                    await conn.connection.run_sync(word_table.create)
+                if data:
+                    await conn.execute(conn.t.meta.tables['word'].insert(), data)
+
+        self.async_to_sync(_do_sql())
+
+
     async def exec_async(self, sql, *args, **kwargs):
         async with self.api._async_api.begin() as conn:
             return await conn.execute(sql, *args, **kwargs)
     async def exec_async(self, sql, *args, **kwargs):
         async with self.api._async_api.begin() as conn:
             return await conn.execute(sql, *args, **kwargs)
@@ -195,6 +212,22 @@ def frontend(request, event_loop, tmp_path):
         db = str(tmp_path / 'test_nominatim_python_unittest.sqlite')
 
         def mkapi(apiobj, options={'reverse'}):
         db = str(tmp_path / 'test_nominatim_python_unittest.sqlite')
 
         def mkapi(apiobj, options={'reverse'}):
+            apiobj.add_data('properties',
+                        [{'property': 'tokenizer', 'value': 'icu'},
+                         {'property': 'tokenizer_import_normalisation', 'value': ':: lower();'},
+                         {'property': 'tokenizer_import_transliteration', 'value': "'1' > '/1/'; 'ä' > 'ä '"},
+                        ])
+
+            async def _do_sql():
+                async with apiobj.api._async_api.begin() as conn:
+                    if 'word' in conn.t.meta.tables:
+                        return
+                    await make_query_analyzer(conn)
+                    word_table = conn.t.meta.tables['word']
+                    await conn.connection.run_sync(word_table.create)
+
+            apiobj.async_to_sync(_do_sql())
+
             event_loop.run_until_complete(convert_sqlite.convert(Path('/invalid'),
                                                                  db, options))
             outapi = napi.NominatimAPI(Path('/invalid'),
             event_loop.run_until_complete(convert_sqlite.convert(Path('/invalid'),
                                                                  db, options))
             outapi = napi.NominatimAPI(Path('/invalid'),
index 82b1d37fe30ba52c4d1bd90b0415a10099815893..dc87d313a0ff98b1e6b776600bc21889bd0f7c27 100644 (file)
@@ -15,7 +15,7 @@ from nominatim.api.search.db_searches import CountrySearch
 from nominatim.api.search.db_search_fields import WeightedStrings
 
 
 from nominatim.api.search.db_search_fields import WeightedStrings
 
 
-def run_search(apiobj, global_penalty, ccodes,
+def run_search(apiobj, frontend, global_penalty, ccodes,
                country_penalties=None, details=SearchDetails()):
     if country_penalties is None:
         country_penalties = [0.0] * len(ccodes)
                country_penalties=None, details=SearchDetails()):
     if country_penalties is None:
         country_penalties = [0.0] * len(ccodes)
@@ -25,15 +25,16 @@ def run_search(apiobj, global_penalty, ccodes,
         countries = WeightedStrings(ccodes, country_penalties)
 
     search = CountrySearch(MySearchData())
         countries = WeightedStrings(ccodes, country_penalties)
 
     search = CountrySearch(MySearchData())
+    api = frontend(apiobj, options=['search'])
 
     async def run():
 
     async def run():
-        async with apiobj.api._async_api.begin() as conn:
+        async with api._async_api.begin() as conn:
             return await search.lookup(conn, details)
 
             return await search.lookup(conn, details)
 
-    return apiobj.async_to_sync(run())
+    return api._loop.run_until_complete(run())
 
 
 
 
-def test_find_from_placex(apiobj):
+def test_find_from_placex(apiobj, frontend):
     apiobj.add_placex(place_id=55, class_='boundary', type='administrative',
                       rank_search=4, rank_address=4,
                       name={'name': 'Lolaland'},
     apiobj.add_placex(place_id=55, class_='boundary', type='administrative',
                       rank_search=4, rank_address=4,
                       name={'name': 'Lolaland'},
@@ -41,32 +42,32 @@ def test_find_from_placex(apiobj):
                       centroid=(10, 10),
                       geometry='POLYGON((9.5 9.5, 9.5 10.5, 10.5 10.5, 10.5 9.5, 9.5 9.5))')
 
                       centroid=(10, 10),
                       geometry='POLYGON((9.5 9.5, 9.5 10.5, 10.5 10.5, 10.5 9.5, 9.5 9.5))')
 
-    results = run_search(apiobj, 0.5, ['de', 'yw'], [0.0, 0.3])
+    results = run_search(apiobj, frontend, 0.5, ['de', 'yw'], [0.0, 0.3])
 
     assert len(results) == 1
     assert results[0].place_id == 55
     assert results[0].accuracy == 0.8
 
 
     assert len(results) == 1
     assert results[0].place_id == 55
     assert results[0].accuracy == 0.8
 
-def test_find_from_fallback_countries(apiobj):
+def test_find_from_fallback_countries(apiobj, frontend):
     apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))')
     apiobj.add_country_name('ro', {'name': 'România'})
 
     apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))')
     apiobj.add_country_name('ro', {'name': 'România'})
 
-    results = run_search(apiobj, 0.0, ['ro'])
+    results = run_search(apiobj, frontend, 0.0, ['ro'])
 
     assert len(results) == 1
     assert results[0].names == {'name': 'România'}
 
 
 
     assert len(results) == 1
     assert results[0].names == {'name': 'România'}
 
 
-def test_find_none(apiobj):
-    assert len(run_search(apiobj, 0.0, ['xx'])) == 0
+def test_find_none(apiobj, frontend):
+    assert len(run_search(apiobj, frontend, 0.0, ['xx'])) == 0
 
 
 @pytest.mark.parametrize('coord,numres', [((0.5, 1), 1), ((10, 10), 0)])
 
 
 @pytest.mark.parametrize('coord,numres', [((0.5, 1), 1), ((10, 10), 0)])
-def test_find_near(apiobj, coord, numres):
+def test_find_near(apiobj, frontend, coord, numres):
     apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))')
     apiobj.add_country_name('ro', {'name': 'România'})
 
     apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))')
     apiobj.add_country_name('ro', {'name': 'România'})
 
-    results = run_search(apiobj, 0.0, ['ro'],
+    results = run_search(apiobj, frontend, 0.0, ['ro'],
                          details=SearchDetails(near=napi.Point(*coord),
                                                near_radius=0.1))
 
                          details=SearchDetails(near=napi.Point(*coord),
                                                near_radius=0.1))
 
@@ -92,8 +93,8 @@ class TestCountryParameters:
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
     @pytest.mark.parametrize('cc', ['yw', 'ro'])
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
     @pytest.mark.parametrize('cc', ['yw', 'ro'])
-    def test_return_geometries(self, apiobj, geom, cc):
-        results = run_search(apiobj, 0.5, [cc],
+    def test_return_geometries(self, apiobj, frontend, geom, cc):
+        results = run_search(apiobj, frontend, 0.5, [cc],
                              details=SearchDetails(geometry_output=geom))
 
         assert len(results) == 1
                              details=SearchDetails(geometry_output=geom))
 
         assert len(results) == 1
@@ -101,8 +102,8 @@ class TestCountryParameters:
 
 
     @pytest.mark.parametrize('pid,rids', [(76, [55]), (55, [])])
 
 
     @pytest.mark.parametrize('pid,rids', [(76, [55]), (55, [])])
-    def test_exclude_place_id(self, apiobj, pid, rids):
-        results = run_search(apiobj, 0.5, ['yw', 'ro'],
+    def test_exclude_place_id(self, apiobj, frontend, pid, rids):
+        results = run_search(apiobj, frontend, 0.5, ['yw', 'ro'],
                              details=SearchDetails(excluded=[pid]))
 
         assert [r.place_id for r in results] == rids
                              details=SearchDetails(excluded=[pid]))
 
         assert [r.place_id for r in results] == rids
@@ -110,8 +111,8 @@ class TestCountryParameters:
 
     @pytest.mark.parametrize('viewbox,rids', [((9, 9, 11, 11), [55]),
                                               ((-10, -10, -3, -3), [])])
 
     @pytest.mark.parametrize('viewbox,rids', [((9, 9, 11, 11), [55]),
                                               ((-10, -10, -3, -3), [])])
-    def test_bounded_viewbox_in_placex(self, apiobj, viewbox, rids):
-        results = run_search(apiobj, 0.5, ['yw'],
+    def test_bounded_viewbox_in_placex(self, apiobj, frontend, viewbox, rids):
+        results = run_search(apiobj, frontend, 0.5, ['yw'],
                              details=SearchDetails.from_kwargs({'viewbox': viewbox,
                                                                 'bounded_viewbox': True}))
 
                              details=SearchDetails.from_kwargs({'viewbox': viewbox,
                                                                 'bounded_viewbox': True}))
 
@@ -120,8 +121,8 @@ class TestCountryParameters:
 
     @pytest.mark.parametrize('viewbox,numres', [((0, 0, 1, 1), 1),
                                               ((-10, -10, -3, -3), 0)])
 
     @pytest.mark.parametrize('viewbox,numres', [((0, 0, 1, 1), 1),
                                               ((-10, -10, -3, -3), 0)])
-    def test_bounded_viewbox_in_fallback(self, apiobj, viewbox, numres):
-        results = run_search(apiobj, 0.5, ['ro'],
+    def test_bounded_viewbox_in_fallback(self, apiobj, frontend, viewbox, numres):
+        results = run_search(apiobj, frontend, 0.5, ['ro'],
                              details=SearchDetails.from_kwargs({'viewbox': viewbox,
                                                                 'bounded_viewbox': True}))
 
                              details=SearchDetails.from_kwargs({'viewbox': viewbox,
                                                                 'bounded_viewbox': True}))
 
index c0caa9ae6af336a2be15599ce02dbe489df98e6d..5b60dd51d59c9626906591591d1b326d73c3ddb1 100644 (file)
@@ -17,7 +17,7 @@ from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCateg
 from nominatim.api.search.db_search_lookups import LookupAll
 
 
 from nominatim.api.search.db_search_lookups import LookupAll
 
 
-def run_search(apiobj, global_penalty, cat, cat_penalty=None, ccodes=[],
+def run_search(apiobj, frontend, global_penalty, cat, cat_penalty=None, ccodes=[],
                details=SearchDetails()):
 
     class PlaceSearchData:
                details=SearchDetails()):
 
     class PlaceSearchData:
@@ -39,21 +39,23 @@ def run_search(apiobj, global_penalty, cat, cat_penalty=None, ccodes=[],
 
     near_search = NearSearch(0.1, WeightedCategories(cat, cat_penalty), place_search)
 
 
     near_search = NearSearch(0.1, WeightedCategories(cat, cat_penalty), place_search)
 
+    api = frontend(apiobj, options=['search'])
+
     async def run():
     async def run():
-        async with apiobj.api._async_api.begin() as conn:
+        async with api._async_api.begin() as conn:
             return await near_search.lookup(conn, details)
 
             return await near_search.lookup(conn, details)
 
-    results = apiobj.async_to_sync(run())
+    results = api._loop.run_until_complete(run())
     results.sort(key=lambda r: r.accuracy)
 
     return results
 
 
     results.sort(key=lambda r: r.accuracy)
 
     return results
 
 
-def test_no_results_inner_query(apiobj):
-    assert not run_search(apiobj, 0.4, [('this', 'that')])
+def test_no_results_inner_query(apiobj, frontend):
+    assert not run_search(apiobj, frontend, 0.4, [('this', 'that')])
 
 
 
 
-def test_no_appropriate_results_inner_query(apiobj):
+def test_no_appropriate_results_inner_query(apiobj, frontend):
     apiobj.add_placex(place_id=100, country_code='us',
                       centroid=(5.6, 4.3),
                       geometry='POLYGON((0.0 0.0, 10.0 0.0, 10.0 2.0, 0.0 2.0, 0.0 0.0))')
     apiobj.add_placex(place_id=100, country_code='us',
                       centroid=(5.6, 4.3),
                       geometry='POLYGON((0.0 0.0, 10.0 0.0, 10.0 2.0, 0.0 2.0, 0.0 0.0))')
@@ -62,7 +64,7 @@ def test_no_appropriate_results_inner_query(apiobj):
     apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                       centroid=(5.6001, 4.2994))
 
     apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                       centroid=(5.6001, 4.2994))
 
-    assert not run_search(apiobj, 0.4, [('amenity', 'bank')])
+    assert not run_search(apiobj, frontend, 0.4, [('amenity', 'bank')])
 
 
 class TestNearSearch:
 
 
 class TestNearSearch:
@@ -79,18 +81,18 @@ class TestNearSearch:
                                centroid=(-10.3, 56.9))
 
 
                                centroid=(-10.3, 56.9))
 
 
-    def test_near_in_placex(self, apiobj):
+    def test_near_in_placex(self, apiobj, frontend):
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           centroid=(5.6001, 4.2994))
         apiobj.add_placex(place_id=23, class_='amenity', type='bench',
                           centroid=(5.6001, 4.2994))
 
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           centroid=(5.6001, 4.2994))
         apiobj.add_placex(place_id=23, class_='amenity', type='bench',
                           centroid=(5.6001, 4.2994))
 
-        results = run_search(apiobj, 0.1, [('amenity', 'bank')])
+        results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')])
 
         assert [r.place_id for r in results] == [22]
 
 
 
         assert [r.place_id for r in results] == [22]
 
 
-    def test_multiple_types_near_in_placex(self, apiobj):
+    def test_multiple_types_near_in_placex(self, apiobj, frontend):
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           importance=0.002,
                           centroid=(5.6001, 4.2994))
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           importance=0.002,
                           centroid=(5.6001, 4.2994))
@@ -98,13 +100,13 @@ class TestNearSearch:
                           importance=0.001,
                           centroid=(5.6001, 4.2994))
 
                           importance=0.001,
                           centroid=(5.6001, 4.2994))
 
-        results = run_search(apiobj, 0.1, [('amenity', 'bank'),
-                                           ('amenity', 'bench')])
+        results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank'),
+                                                     ('amenity', 'bench')])
 
         assert [r.place_id for r in results] == [22, 23]
 
 
 
         assert [r.place_id for r in results] == [22, 23]
 
 
-    def test_near_in_classtype(self, apiobj):
+    def test_near_in_classtype(self, apiobj, frontend):
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           centroid=(5.6, 4.34))
         apiobj.add_placex(place_id=23, class_='amenity', type='bench',
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           centroid=(5.6, 4.34))
         apiobj.add_placex(place_id=23, class_='amenity', type='bench',
@@ -112,13 +114,13 @@ class TestNearSearch:
         apiobj.add_class_type_table('amenity', 'bank')
         apiobj.add_class_type_table('amenity', 'bench')
 
         apiobj.add_class_type_table('amenity', 'bank')
         apiobj.add_class_type_table('amenity', 'bench')
 
-        results = run_search(apiobj, 0.1, [('amenity', 'bank')])
+        results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')])
 
         assert [r.place_id for r in results] == [22]
 
 
     @pytest.mark.parametrize('cc,rid', [('us', 22), ('mx', 23)])
 
         assert [r.place_id for r in results] == [22]
 
 
     @pytest.mark.parametrize('cc,rid', [('us', 22), ('mx', 23)])
-    def test_restrict_by_country(self, apiobj, cc, rid):
+    def test_restrict_by_country(self, apiobj, frontend, cc, rid):
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           centroid=(5.6001, 4.2994),
                           country_code='us')
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           centroid=(5.6001, 4.2994),
                           country_code='us')
@@ -132,13 +134,13 @@ class TestNearSearch:
                           centroid=(-10.3001, 56.9),
                           country_code='us')
 
                           centroid=(-10.3001, 56.9),
                           country_code='us')
 
-        results = run_search(apiobj, 0.1, [('amenity', 'bank')], ccodes=[cc, 'fr'])
+        results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')], ccodes=[cc, 'fr'])
 
         assert [r.place_id for r in results] == [rid]
 
 
     @pytest.mark.parametrize('excluded,rid', [(22, 122), (122, 22)])
 
         assert [r.place_id for r in results] == [rid]
 
 
     @pytest.mark.parametrize('excluded,rid', [(22, 122), (122, 22)])
-    def test_exclude_place_by_id(self, apiobj, excluded, rid):
+    def test_exclude_place_by_id(self, apiobj, frontend, excluded, rid):
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           centroid=(5.6001, 4.2994),
                           country_code='us')
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           centroid=(5.6001, 4.2994),
                           country_code='us')
@@ -147,7 +149,7 @@ class TestNearSearch:
                           country_code='us')
 
 
                           country_code='us')
 
 
-        results = run_search(apiobj, 0.1, [('amenity', 'bank')],
+        results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')],
                              details=SearchDetails(excluded=[excluded]))
 
         assert [r.place_id for r in results] == [rid]
                              details=SearchDetails(excluded=[excluded]))
 
         assert [r.place_id for r in results] == [rid]
@@ -155,12 +157,12 @@ class TestNearSearch:
 
     @pytest.mark.parametrize('layer,rids', [(napi.DataLayer.POI, [22]),
                                             (napi.DataLayer.MANMADE, [])])
 
     @pytest.mark.parametrize('layer,rids', [(napi.DataLayer.POI, [22]),
                                             (napi.DataLayer.MANMADE, [])])
-    def test_with_layer(self, apiobj, layer, rids):
+    def test_with_layer(self, apiobj, frontend, layer, rids):
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           centroid=(5.6001, 4.2994),
                           country_code='us')
 
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           centroid=(5.6001, 4.2994),
                           country_code='us')
 
-        results = run_search(apiobj, 0.1, [('amenity', 'bank')],
+        results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')],
                              details=SearchDetails(layers=layer))
 
         assert [r.place_id for r in results] == rids
                              details=SearchDetails(layers=layer))
 
         assert [r.place_id for r in results] == rids
index 44e4098dada62713b87816246ea5929778030cce..c446a35f88c8ecb533053c682663689a1d9de689 100644 (file)
@@ -18,7 +18,9 @@ from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCateg
                                                   FieldLookup, FieldRanking, RankedTokens
 from nominatim.api.search.db_search_lookups import LookupAll, LookupAny, Restrict
 
                                                   FieldLookup, FieldRanking, RankedTokens
 from nominatim.api.search.db_search_lookups import LookupAll, LookupAny, Restrict
 
-def run_search(apiobj, global_penalty, lookup, ranking, count=2,
+APIOPTIONS = ['search']
+
+def run_search(apiobj, frontend, global_penalty, lookup, ranking, count=2,
                hnrs=[], pcs=[], ccodes=[], quals=[],
                details=SearchDetails()):
     class MySearchData:
                hnrs=[], pcs=[], ccodes=[], quals=[],
                details=SearchDetails()):
     class MySearchData:
@@ -32,11 +34,16 @@ def run_search(apiobj, global_penalty, lookup, ranking, count=2,
 
     search = PlaceSearch(0.0, MySearchData(), count)
 
 
     search = PlaceSearch(0.0, MySearchData(), count)
 
+    if frontend is None:
+        api = apiobj
+    else:
+        api = frontend(apiobj, options=APIOPTIONS)
+
     async def run():
     async def run():
-        async with apiobj.api._async_api.begin() as conn:
+        async with api._async_api.begin() as conn:
             return await search.lookup(conn, details)
 
             return await search.lookup(conn, details)
 
-    results = apiobj.async_to_sync(run())
+    results = api._loop.run_until_complete(run())
     results.sort(key=lambda r: r.accuracy)
 
     return results
     results.sort(key=lambda r: r.accuracy)
 
     return results
@@ -59,61 +66,61 @@ class TestNameOnlySearches:
     @pytest.mark.parametrize('lookup_type', [LookupAll, Restrict])
     @pytest.mark.parametrize('rank,res', [([10], [100, 101]),
                                           ([20], [101, 100])])
     @pytest.mark.parametrize('lookup_type', [LookupAll, Restrict])
     @pytest.mark.parametrize('rank,res', [([10], [100, 101]),
                                           ([20], [101, 100])])
-    def test_lookup_all_match(self, apiobj, lookup_type, rank, res):
+    def test_lookup_all_match(self, apiobj, frontend, lookup_type, rank, res):
         lookup = FieldLookup('name_vector', [1,2], lookup_type)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)])
 
         lookup = FieldLookup('name_vector', [1,2], lookup_type)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking])
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking])
 
         assert [r.place_id for r in results] == res
 
 
     @pytest.mark.parametrize('lookup_type', [LookupAll, Restrict])
 
         assert [r.place_id for r in results] == res
 
 
     @pytest.mark.parametrize('lookup_type', [LookupAll, Restrict])
-    def test_lookup_all_partial_match(self, apiobj, lookup_type):
+    def test_lookup_all_partial_match(self, apiobj, frontend, lookup_type):
         lookup = FieldLookup('name_vector', [1,20], lookup_type)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
 
         lookup = FieldLookup('name_vector', [1,20], lookup_type)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking])
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking])
 
         assert len(results) == 1
         assert results[0].place_id == 101
 
     @pytest.mark.parametrize('rank,res', [([10], [100, 101]),
                                           ([20], [101, 100])])
 
         assert len(results) == 1
         assert results[0].place_id == 101
 
     @pytest.mark.parametrize('rank,res', [([10], [100, 101]),
                                           ([20], [101, 100])])
-    def test_lookup_any_match(self, apiobj, rank, res):
+    def test_lookup_any_match(self, apiobj, frontend, rank, res):
         lookup = FieldLookup('name_vector', [11,21], LookupAny)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)])
 
         lookup = FieldLookup('name_vector', [11,21], LookupAny)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking])
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking])
 
         assert [r.place_id for r in results] == res
 
 
 
         assert [r.place_id for r in results] == res
 
 
-    def test_lookup_any_partial_match(self, apiobj):
+    def test_lookup_any_partial_match(self, apiobj, frontend):
         lookup = FieldLookup('name_vector', [20], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
 
         lookup = FieldLookup('name_vector', [20], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking])
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking])
 
         assert len(results) == 1
         assert results[0].place_id == 101
 
 
     @pytest.mark.parametrize('cc,res', [('us', 100), ('mx', 101)])
 
         assert len(results) == 1
         assert results[0].place_id == 101
 
 
     @pytest.mark.parametrize('cc,res', [('us', 100), ('mx', 101)])
-    def test_lookup_restrict_country(self, apiobj, cc, res):
+    def test_lookup_restrict_country(self, apiobj, frontend, cc, res):
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])])
 
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], ccodes=[cc])
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], ccodes=[cc])
 
         assert [r.place_id for r in results] == [res]
 
 
 
         assert [r.place_id for r in results] == [res]
 
 
-    def test_lookup_restrict_placeid(self, apiobj):
+    def test_lookup_restrict_placeid(self, apiobj, frontend):
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])])
 
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking],
                              details=SearchDetails(excluded=[101]))
 
         assert [r.place_id for r in results] == [100]
                              details=SearchDetails(excluded=[101]))
 
         assert [r.place_id for r in results] == [100]
@@ -123,18 +130,18 @@ class TestNameOnlySearches:
                                       napi.GeometryFormat.KML,
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
                                       napi.GeometryFormat.KML,
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
-    def test_return_geometries(self, apiobj, geom):
+    def test_return_geometries(self, apiobj, frontend, geom):
         lookup = FieldLookup('name_vector', [20], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
 
         lookup = FieldLookup('name_vector', [20], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking],
                              details=SearchDetails(geometry_output=geom))
 
         assert geom.name.lower() in results[0].geometry
 
 
     @pytest.mark.parametrize('factor,npoints', [(0.0, 3), (1.0, 2)])
                              details=SearchDetails(geometry_output=geom))
 
         assert geom.name.lower() in results[0].geometry
 
 
     @pytest.mark.parametrize('factor,npoints', [(0.0, 3), (1.0, 2)])
-    def test_return_simplified_geometry(self, apiobj, factor, npoints):
+    def test_return_simplified_geometry(self, apiobj, frontend, factor, npoints):
         apiobj.add_placex(place_id=333, country_code='us',
                           centroid=(9.0, 9.0),
                           geometry='LINESTRING(8.9 9.0, 9.0 9.0, 9.1 9.0)')
         apiobj.add_placex(place_id=333, country_code='us',
                           centroid=(9.0, 9.0),
                           geometry='LINESTRING(8.9 9.0, 9.0 9.0, 9.1 9.0)')
@@ -144,7 +151,7 @@ class TestNameOnlySearches:
         lookup = FieldLookup('name_vector', [55], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
 
         lookup = FieldLookup('name_vector', [55], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking],
                              details=SearchDetails(geometry_output=napi.GeometryFormat.GEOJSON,
                                                    geometry_simplification=factor))
 
                              details=SearchDetails(geometry_output=napi.GeometryFormat.GEOJSON,
                                                    geometry_simplification=factor))
 
@@ -158,50 +165,52 @@ class TestNameOnlySearches:
 
     @pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.7,4.0,6.0,5.0'])
     @pytest.mark.parametrize('wcount,rids', [(2, [100, 101]), (20000, [100])])
 
     @pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.7,4.0,6.0,5.0'])
     @pytest.mark.parametrize('wcount,rids', [(2, [100, 101]), (20000, [100])])
-    def test_prefer_viewbox(self, apiobj, viewbox, wcount, rids):
+    def test_prefer_viewbox(self, apiobj, frontend, viewbox, wcount, rids):
         lookup = FieldLookup('name_vector', [1, 2], LookupAll)
         ranking = FieldRanking('name_vector', 0.2, [RankedTokens(0.0, [21])])
 
         lookup = FieldLookup('name_vector', [1, 2], LookupAll)
         ranking = FieldRanking('name_vector', 0.2, [RankedTokens(0.0, [21])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking])
+        api = frontend(apiobj, options=APIOPTIONS)
+        results = run_search(api, None, 0.1, [lookup], [ranking])
         assert [r.place_id for r in results] == [101, 100]
 
         assert [r.place_id for r in results] == [101, 100]
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], count=wcount,
+        results = run_search(api, None, 0.1, [lookup], [ranking], count=wcount,
                              details=SearchDetails.from_kwargs({'viewbox': viewbox}))
         assert [r.place_id for r in results] == rids
 
 
     @pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.55,4.27,5.62,4.31'])
                              details=SearchDetails.from_kwargs({'viewbox': viewbox}))
         assert [r.place_id for r in results] == rids
 
 
     @pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.55,4.27,5.62,4.31'])
-    def test_force_viewbox(self, apiobj, viewbox):
+    def test_force_viewbox(self, apiobj, frontend, viewbox):
         lookup = FieldLookup('name_vector', [1, 2], LookupAll)
 
         details=SearchDetails.from_kwargs({'viewbox': viewbox,
                                            'bounded_viewbox': True})
 
         lookup = FieldLookup('name_vector', [1, 2], LookupAll)
 
         details=SearchDetails.from_kwargs({'viewbox': viewbox,
                                            'bounded_viewbox': True})
 
-        results = run_search(apiobj, 0.1, [lookup], [], details=details)
+        results = run_search(apiobj, frontend, 0.1, [lookup], [], details=details)
         assert [r.place_id for r in results] == [100]
 
 
         assert [r.place_id for r in results] == [100]
 
 
-    def test_prefer_near(self, apiobj):
+    def test_prefer_near(self, apiobj, frontend):
         lookup = FieldLookup('name_vector', [1, 2], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
 
         lookup = FieldLookup('name_vector', [1, 2], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking])
+        api = frontend(apiobj, options=APIOPTIONS)
+        results = run_search(api, None, 0.1, [lookup], [ranking])
         assert [r.place_id for r in results] == [101, 100]
 
         assert [r.place_id for r in results] == [101, 100]
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking],
+        results = run_search(api, None, 0.1, [lookup], [ranking],
                              details=SearchDetails.from_kwargs({'near': '5.6,4.3'}))
         results.sort(key=lambda r: -r.importance)
         assert [r.place_id for r in results] == [100, 101]
 
 
     @pytest.mark.parametrize('radius', [0.09, 0.11])
                              details=SearchDetails.from_kwargs({'near': '5.6,4.3'}))
         results.sort(key=lambda r: -r.importance)
         assert [r.place_id for r in results] == [100, 101]
 
 
     @pytest.mark.parametrize('radius', [0.09, 0.11])
-    def test_force_near(self, apiobj, radius):
+    def test_force_near(self, apiobj, frontend, radius):
         lookup = FieldLookup('name_vector', [1, 2], LookupAll)
 
         details=SearchDetails.from_kwargs({'near': '5.6,4.3',
                                            'near_radius': radius})
 
         lookup = FieldLookup('name_vector', [1, 2], LookupAll)
 
         details=SearchDetails.from_kwargs({'near': '5.6,4.3',
                                            'near_radius': radius})
 
-        results = run_search(apiobj, 0.1, [lookup], [], details=details)
+        results = run_search(apiobj, frontend, 0.1, [lookup], [], details=details)
 
         assert [r.place_id for r in results] == [100]
 
 
         assert [r.place_id for r in results] == [100]
 
@@ -242,72 +251,72 @@ class TestStreetWithHousenumber:
     @pytest.mark.parametrize('hnr,res', [('20', [91, 1]), ('20 a', [1]),
                                          ('21', [2]), ('22', [2, 92]),
                                          ('24', [93]), ('25', [])])
     @pytest.mark.parametrize('hnr,res', [('20', [91, 1]), ('20 a', [1]),
                                          ('21', [2]), ('22', [2, 92]),
                                          ('24', [93]), ('25', [])])
-    def test_lookup_by_single_housenumber(self, apiobj, hnr, res):
+    def test_lookup_by_single_housenumber(self, apiobj, frontend, hnr, res):
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=[hnr])
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=[hnr])
 
         assert [r.place_id for r in results] == res + [1000, 2000]
 
 
     @pytest.mark.parametrize('cc,res', [('es', [2, 1000]), ('pt', [92, 2000])])
 
         assert [r.place_id for r in results] == res + [1000, 2000]
 
 
     @pytest.mark.parametrize('cc,res', [('es', [2, 1000]), ('pt', [92, 2000])])
-    def test_lookup_with_country_restriction(self, apiobj, cc, res):
+    def test_lookup_with_country_restriction(self, apiobj, frontend, cc, res):
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
                              ccodes=[cc])
 
         assert [r.place_id for r in results] == res
 
 
                              ccodes=[cc])
 
         assert [r.place_id for r in results] == res
 
 
-    def test_lookup_exclude_housenumber_placeid(self, apiobj):
+    def test_lookup_exclude_housenumber_placeid(self, apiobj, frontend):
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
                              details=SearchDetails(excluded=[92]))
 
         assert [r.place_id for r in results] == [2, 1000, 2000]
 
 
                              details=SearchDetails(excluded=[92]))
 
         assert [r.place_id for r in results] == [2, 1000, 2000]
 
 
-    def test_lookup_exclude_street_placeid(self, apiobj):
+    def test_lookup_exclude_street_placeid(self, apiobj, frontend):
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
                              details=SearchDetails(excluded=[1000]))
 
         assert [r.place_id for r in results] == [2, 92, 2000]
 
 
                              details=SearchDetails(excluded=[1000]))
 
         assert [r.place_id for r in results] == [2, 92, 2000]
 
 
-    def test_lookup_only_house_qualifier(self, apiobj):
+    def test_lookup_only_house_qualifier(self, apiobj, frontend):
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
                              quals=[('place', 'house')])
 
         assert [r.place_id for r in results] == [2, 92]
 
 
                              quals=[('place', 'house')])
 
         assert [r.place_id for r in results] == [2, 92]
 
 
-    def test_lookup_only_street_qualifier(self, apiobj):
+    def test_lookup_only_street_qualifier(self, apiobj, frontend):
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
                              quals=[('highway', 'residential')])
 
         assert [r.place_id for r in results] == [1000, 2000]
 
 
     @pytest.mark.parametrize('rank,found', [(26, True), (27, False), (30, False)])
                              quals=[('highway', 'residential')])
 
         assert [r.place_id for r in results] == [1000, 2000]
 
 
     @pytest.mark.parametrize('rank,found', [(26, True), (27, False), (30, False)])
-    def test_lookup_min_rank(self, apiobj, rank, found):
+    def test_lookup_min_rank(self, apiobj, frontend, rank, found):
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
         lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
                              details=SearchDetails(min_rank=rank))
 
         assert [r.place_id for r in results] == ([2, 92, 1000, 2000] if found else [2, 92])
                              details=SearchDetails(min_rank=rank))
 
         assert [r.place_id for r in results] == ([2, 92, 1000, 2000] if found else [2, 92])
@@ -317,17 +326,17 @@ class TestStreetWithHousenumber:
                                       napi.GeometryFormat.KML,
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
                                       napi.GeometryFormat.KML,
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
-    def test_return_geometries(self, apiobj, geom):
+    def test_return_geometries(self, apiobj, frontend, geom):
         lookup = FieldLookup('name_vector', [1, 2], LookupAll)
 
         lookup = FieldLookup('name_vector', [1, 2], LookupAll)
 
-        results = run_search(apiobj, 0.1, [lookup], [], hnrs=['20', '21', '22'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=['20', '21', '22'],
                              details=SearchDetails(geometry_output=geom))
 
         assert results
         assert all(geom.name.lower() in r.geometry for r in results)
 
 
                              details=SearchDetails(geometry_output=geom))
 
         assert results
         assert all(geom.name.lower() in r.geometry for r in results)
 
 
-def test_very_large_housenumber(apiobj):
+def test_very_large_housenumber(apiobj, frontend):
     apiobj.add_placex(place_id=93, class_='place', type='house',
                       parent_place_id=2000,
                       housenumber='2467463524544', country_code='pt')
     apiobj.add_placex(place_id=93, class_='place', type='house',
                       parent_place_id=2000,
                       housenumber='2467463524544', country_code='pt')
@@ -340,7 +349,7 @@ def test_very_large_housenumber(apiobj):
 
     lookup = FieldLookup('name_vector', [1, 2], LookupAll)
 
 
     lookup = FieldLookup('name_vector', [1, 2], LookupAll)
 
-    results = run_search(apiobj, 0.1, [lookup], [], hnrs=['2467463524544'],
+    results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=['2467463524544'],
                          details=SearchDetails())
 
     assert results
                          details=SearchDetails())
 
     assert results
@@ -348,7 +357,7 @@ def test_very_large_housenumber(apiobj):
 
 
 @pytest.mark.parametrize('wcount,rids', [(2, [990, 991]), (30000, [990])])
 
 
 @pytest.mark.parametrize('wcount,rids', [(2, [990, 991]), (30000, [990])])
-def test_name_and_postcode(apiobj, wcount, rids):
+def test_name_and_postcode(apiobj, frontend, wcount, rids):
     apiobj.add_placex(place_id=990, class_='highway', type='service',
                       rank_search=27, rank_address=27,
                       postcode='11225',
     apiobj.add_placex(place_id=990, class_='highway', type='service',
                       rank_search=27, rank_address=27,
                       postcode='11225',
@@ -368,7 +377,7 @@ def test_name_and_postcode(apiobj, wcount, rids):
 
     lookup = FieldLookup('name_vector', [111], LookupAll)
 
 
     lookup = FieldLookup('name_vector', [111], LookupAll)
 
-    results = run_search(apiobj, 0.1, [lookup], [], pcs=['11225'], count=wcount,
+    results = run_search(apiobj, frontend, 0.1, [lookup], [], pcs=['11225'], count=wcount,
                          details=SearchDetails())
 
     assert results
                          details=SearchDetails())
 
     assert results
@@ -398,10 +407,10 @@ class TestInterpolations:
 
 
     @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])])
 
 
     @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])])
-    def test_lookup_housenumber(self, apiobj, hnr, res):
+    def test_lookup_housenumber(self, apiobj, frontend, hnr, res):
         lookup = FieldLookup('name_vector', [111], LookupAll)
 
         lookup = FieldLookup('name_vector', [111], LookupAll)
 
-        results = run_search(apiobj, 0.1, [lookup], [], hnrs=[hnr])
+        results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=[hnr])
 
         assert [r.place_id for r in results] == res + [990]
 
 
         assert [r.place_id for r in results] == res + [990]
 
@@ -410,10 +419,10 @@ class TestInterpolations:
                                       napi.GeometryFormat.KML,
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
                                       napi.GeometryFormat.KML,
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
-    def test_osmline_with_geometries(self, apiobj, geom):
+    def test_osmline_with_geometries(self, apiobj, frontend, geom):
         lookup = FieldLookup('name_vector', [111], LookupAll)
 
         lookup = FieldLookup('name_vector', [111], LookupAll)
 
-        results = run_search(apiobj, 0.1, [lookup], [], hnrs=['21'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=['21'],
                              details=SearchDetails(geometry_output=geom))
 
         assert results[0].place_id == 992
                              details=SearchDetails(geometry_output=geom))
 
         assert results[0].place_id == 992
@@ -446,10 +455,10 @@ class TestTiger:
 
 
     @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])])
 
 
     @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])])
-    def test_lookup_housenumber(self, apiobj, hnr, res):
+    def test_lookup_housenumber(self, apiobj, frontend, hnr, res):
         lookup = FieldLookup('name_vector', [111], LookupAll)
 
         lookup = FieldLookup('name_vector', [111], LookupAll)
 
-        results = run_search(apiobj, 0.1, [lookup], [], hnrs=[hnr])
+        results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=[hnr])
 
         assert [r.place_id for r in results] == res + [990]
 
 
         assert [r.place_id for r in results] == res + [990]
 
@@ -458,10 +467,10 @@ class TestTiger:
                                       napi.GeometryFormat.KML,
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
                                       napi.GeometryFormat.KML,
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
-    def test_tiger_with_geometries(self, apiobj, geom):
+    def test_tiger_with_geometries(self, apiobj, frontend, geom):
         lookup = FieldLookup('name_vector', [111], LookupAll)
 
         lookup = FieldLookup('name_vector', [111], LookupAll)
 
-        results = run_search(apiobj, 0.1, [lookup], [], hnrs=['21'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=['21'],
                              details=SearchDetails(geometry_output=geom))
 
         assert results[0].place_id == 992
                              details=SearchDetails(geometry_output=geom))
 
         assert results[0].place_id == 992
@@ -513,10 +522,10 @@ class TestLayersRank30:
                                            (napi.DataLayer.NATURAL, [227]),
                                            (napi.DataLayer.MANMADE | napi.DataLayer.NATURAL, [225, 227]),
                                            (napi.DataLayer.MANMADE | napi.DataLayer.RAILWAY, [225, 226])])
                                            (napi.DataLayer.NATURAL, [227]),
                                            (napi.DataLayer.MANMADE | napi.DataLayer.NATURAL, [225, 227]),
                                            (napi.DataLayer.MANMADE | napi.DataLayer.RAILWAY, [225, 226])])
-    def test_layers_rank30(self, apiobj, layer, res):
+    def test_layers_rank30(self, apiobj, frontend, layer, res):
         lookup = FieldLookup('name_vector', [34], LookupAny)
 
         lookup = FieldLookup('name_vector', [34], LookupAny)
 
-        results = run_search(apiobj, 0.1, [lookup], [],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [],
                              details=SearchDetails(layers=layer))
 
         assert [r.place_id for r in results] == res
                              details=SearchDetails(layers=layer))
 
         assert [r.place_id for r in results] == res
index b80c075200f4aa3e2a08cb3b2a9b84d569bf0d38..a0b578baffbc07c2b03b7b06dafc4ca5c7e3b007 100644 (file)
@@ -15,7 +15,7 @@ from nominatim.api.search.db_searches import PoiSearch
 from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories
 
 
 from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories
 
 
-def run_search(apiobj, global_penalty, poitypes, poi_penalties=None,
+def run_search(apiobj, frontend, global_penalty, poitypes, poi_penalties=None,
                ccodes=[], details=SearchDetails()):
     if poi_penalties is None:
         poi_penalties = [0.0] * len(poitypes)
                ccodes=[], details=SearchDetails()):
     if poi_penalties is None:
         poi_penalties = [0.0] * len(poitypes)
@@ -27,16 +27,18 @@ def run_search(apiobj, global_penalty, poitypes, poi_penalties=None,
 
     search = PoiSearch(MySearchData())
 
 
     search = PoiSearch(MySearchData())
 
+    api = frontend(apiobj, options=['search'])
+
     async def run():
     async def run():
-        async with apiobj.api._async_api.begin() as conn:
+        async with api._async_api.begin() as conn:
             return await search.lookup(conn, details)
 
             return await search.lookup(conn, details)
 
-    return apiobj.async_to_sync(run())
+    return api._loop.run_until_complete(run())
 
 
 @pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2),
                                        ('5.0, 4.59933', 1)])
 
 
 @pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2),
                                        ('5.0, 4.59933', 1)])
-def test_simple_near_search_in_placex(apiobj, coord, pid):
+def test_simple_near_search_in_placex(apiobj, frontend, coord, pid):
     apiobj.add_placex(place_id=1, class_='highway', type='bus_stop',
                       centroid=(5.0, 4.6))
     apiobj.add_placex(place_id=2, class_='highway', type='bus_stop',
     apiobj.add_placex(place_id=1, class_='highway', type='bus_stop',
                       centroid=(5.0, 4.6))
     apiobj.add_placex(place_id=2, class_='highway', type='bus_stop',
@@ -44,7 +46,7 @@ def test_simple_near_search_in_placex(apiobj, coord, pid):
 
     details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.001})
 
 
     details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.001})
 
-    results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], details=details)
+    results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5], details=details)
 
     assert [r.place_id for r in results] == [pid]
 
 
     assert [r.place_id for r in results] == [pid]
 
@@ -52,7 +54,7 @@ def test_simple_near_search_in_placex(apiobj, coord, pid):
 @pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2),
                                        ('34.3, 56.4', 2),
                                        ('5.0, 4.59933', 1)])
 @pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2),
                                        ('34.3, 56.4', 2),
                                        ('5.0, 4.59933', 1)])
-def test_simple_near_search_in_classtype(apiobj, coord, pid):
+def test_simple_near_search_in_classtype(apiobj, frontend, coord, pid):
     apiobj.add_placex(place_id=1, class_='highway', type='bus_stop',
                       centroid=(5.0, 4.6))
     apiobj.add_placex(place_id=2, class_='highway', type='bus_stop',
     apiobj.add_placex(place_id=1, class_='highway', type='bus_stop',
                       centroid=(5.0, 4.6))
     apiobj.add_placex(place_id=2, class_='highway', type='bus_stop',
@@ -61,7 +63,7 @@ def test_simple_near_search_in_classtype(apiobj, coord, pid):
 
     details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.5})
 
 
     details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.5})
 
-    results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], details=details)
+    results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5], details=details)
 
     assert [r.place_id for r in results] == [pid]
 
 
     assert [r.place_id for r in results] == [pid]
 
@@ -83,25 +85,25 @@ class TestPoiSearchWithRestrictions:
             self.args = {'near': '34.3, 56.100021', 'near_radius': 0.001}
 
 
             self.args = {'near': '34.3, 56.100021', 'near_radius': 0.001}
 
 
-    def test_unrestricted(self, apiobj):
-        results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5],
+    def test_unrestricted(self, apiobj, frontend):
+        results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5],
                              details=SearchDetails.from_kwargs(self.args))
 
         assert [r.place_id for r in results] == [1, 2]
 
 
                              details=SearchDetails.from_kwargs(self.args))
 
         assert [r.place_id for r in results] == [1, 2]
 
 
-    def test_restict_country(self, apiobj):
-        results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5],
+    def test_restict_country(self, apiobj, frontend):
+        results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5],
                              ccodes=['de', 'nz'],
                              details=SearchDetails.from_kwargs(self.args))
 
         assert [r.place_id for r in results] == [2]
 
 
                              ccodes=['de', 'nz'],
                              details=SearchDetails.from_kwargs(self.args))
 
         assert [r.place_id for r in results] == [2]
 
 
-    def test_restrict_by_viewbox(self, apiobj):
+    def test_restrict_by_viewbox(self, apiobj, frontend):
         args = {'bounded_viewbox': True, 'viewbox': '34.299,56.0,34.3001,56.10001'}
         args.update(self.args)
         args = {'bounded_viewbox': True, 'viewbox': '34.299,56.0,34.3001,56.10001'}
         args.update(self.args)
-        results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5],
+        results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5],
                              ccodes=['de', 'nz'],
                              details=SearchDetails.from_kwargs(args))
 
                              ccodes=['de', 'nz'],
                              details=SearchDetails.from_kwargs(args))
 
index e7153f38bf8b6147a268d5b051422d0d8d415680..6976b6a592ac5921fbb283c84c411065500793bb 100644 (file)
@@ -15,7 +15,7 @@ from nominatim.api.search.db_searches import PostcodeSearch
 from nominatim.api.search.db_search_fields import WeightedStrings, FieldLookup, \
                                                   FieldRanking, RankedTokens
 
 from nominatim.api.search.db_search_fields import WeightedStrings, FieldLookup, \
                                                   FieldRanking, RankedTokens
 
-def run_search(apiobj, global_penalty, pcs, pc_penalties=None,
+def run_search(apiobj, frontend, global_penalty, pcs, pc_penalties=None,
                ccodes=[], lookup=[], ranking=[], details=SearchDetails()):
     if pc_penalties is None:
         pc_penalties = [0.0] * len(pcs)
                ccodes=[], lookup=[], ranking=[], details=SearchDetails()):
     if pc_penalties is None:
         pc_penalties = [0.0] * len(pcs)
@@ -29,28 +29,30 @@ def run_search(apiobj, global_penalty, pcs, pc_penalties=None,
 
     search = PostcodeSearch(0.0, MySearchData())
 
 
     search = PostcodeSearch(0.0, MySearchData())
 
+    api = frontend(apiobj, options=['search'])
+
     async def run():
     async def run():
-        async with apiobj.api._async_api.begin() as conn:
+        async with api._async_api.begin() as conn:
             return await search.lookup(conn, details)
 
             return await search.lookup(conn, details)
 
-    return apiobj.async_to_sync(run())
+    return api._loop.run_until_complete(run())
 
 
 
 
-def test_postcode_only_search(apiobj):
+def test_postcode_only_search(apiobj, frontend):
     apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345')
     apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345')
 
     apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345')
     apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345')
 
-    results = run_search(apiobj, 0.3, ['12345', '12 345'], [0.0, 0.1])
+    results = run_search(apiobj, frontend, 0.3, ['12345', '12 345'], [0.0, 0.1])
 
     assert len(results) == 2
     assert [r.place_id for r in results] == [100, 101]
 
 
 
     assert len(results) == 2
     assert [r.place_id for r in results] == [100, 101]
 
 
-def test_postcode_with_country(apiobj):
+def test_postcode_with_country(apiobj, frontend):
     apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345')
     apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345')
 
     apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345')
     apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345')
 
-    results = run_search(apiobj, 0.3, ['12345', '12 345'], [0.0, 0.1],
+    results = run_search(apiobj, frontend, 0.3, ['12345', '12 345'], [0.0, 0.1],
                          ccodes=['de', 'pl'])
 
     assert len(results) == 1
                          ccodes=['de', 'pl'])
 
     assert len(results) == 1
@@ -81,30 +83,30 @@ class TestPostcodeSearchWithAddress:
                                country_code='pl')
 
 
                                country_code='pl')
 
 
-    def test_lookup_both(self, apiobj):
+    def test_lookup_both(self, apiobj, frontend):
         lookup = FieldLookup('name_vector', [1,2], 'restrict')
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
         lookup = FieldLookup('name_vector', [1,2], 'restrict')
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, ['12345'], lookup=[lookup], ranking=[ranking])
+        results = run_search(apiobj, frontend, 0.1, ['12345'], lookup=[lookup], ranking=[ranking])
 
         assert [r.place_id for r in results] == [100, 101]
 
 
 
         assert [r.place_id for r in results] == [100, 101]
 
 
-    def test_restrict_by_name(self, apiobj):
+    def test_restrict_by_name(self, apiobj, frontend):
         lookup = FieldLookup('name_vector', [10], 'restrict')
 
         lookup = FieldLookup('name_vector', [10], 'restrict')
 
-        results = run_search(apiobj, 0.1, ['12345'], lookup=[lookup])
+        results = run_search(apiobj, frontend, 0.1, ['12345'], lookup=[lookup])
 
         assert [r.place_id for r in results] == [100]
 
 
     @pytest.mark.parametrize('coord,place_id', [((16.5, 5), 100),
                                                 ((-45.1, 7.004), 101)])
 
         assert [r.place_id for r in results] == [100]
 
 
     @pytest.mark.parametrize('coord,place_id', [((16.5, 5), 100),
                                                 ((-45.1, 7.004), 101)])
-    def test_lookup_near(self, apiobj, coord, place_id):
+    def test_lookup_near(self, apiobj, frontend, coord, place_id):
         lookup = FieldLookup('name_vector', [1,2], 'restrict')
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
         lookup = FieldLookup('name_vector', [1,2], 'restrict')
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, ['12345'],
+        results = run_search(apiobj, frontend, 0.1, ['12345'],
                              lookup=[lookup], ranking=[ranking],
                              details=SearchDetails(near=napi.Point(*coord),
                                                    near_radius=0.6))
                              lookup=[lookup], ranking=[ranking],
                              details=SearchDetails(near=napi.Point(*coord),
                                                    near_radius=0.6))
@@ -116,8 +118,8 @@ class TestPostcodeSearchWithAddress:
                                       napi.GeometryFormat.KML,
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
                                       napi.GeometryFormat.KML,
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
-    def test_return_geometries(self, apiobj, geom):
-        results = run_search(apiobj, 0.1, ['12345'],
+    def test_return_geometries(self, apiobj, frontend, geom):
+        results = run_search(apiobj, frontend, 0.1, ['12345'],
                              details=SearchDetails(geometry_output=geom))
 
         assert results
                              details=SearchDetails(geometry_output=geom))
 
         assert results
@@ -126,8 +128,8 @@ class TestPostcodeSearchWithAddress:
 
     @pytest.mark.parametrize('viewbox, rids', [('-46,6,-44,8', [101,100]),
                                                ('16,4,18,6', [100,101])])
 
     @pytest.mark.parametrize('viewbox, rids', [('-46,6,-44,8', [101,100]),
                                                ('16,4,18,6', [100,101])])
-    def test_prefer_viewbox(self, apiobj, viewbox, rids):
-        results = run_search(apiobj, 0.1, ['12345'],
+    def test_prefer_viewbox(self, apiobj, frontend, viewbox, rids):
+        results = run_search(apiobj, frontend, 0.1, ['12345'],
                              details=SearchDetails.from_kwargs({'viewbox': viewbox}))
 
         assert [r.place_id for r in results] == rids
                              details=SearchDetails.from_kwargs({'viewbox': viewbox}))
 
         assert [r.place_id for r in results] == rids
@@ -135,8 +137,8 @@ class TestPostcodeSearchWithAddress:
 
     @pytest.mark.parametrize('viewbox, rid', [('-46,6,-44,8', 101),
                                                ('16,4,18,6', 100)])
 
     @pytest.mark.parametrize('viewbox, rid', [('-46,6,-44,8', 101),
                                                ('16,4,18,6', 100)])
-    def test_restrict_to_viewbox(self, apiobj, viewbox, rid):
-        results = run_search(apiobj, 0.1, ['12345'],
+    def test_restrict_to_viewbox(self, apiobj, frontend, viewbox, rid):
+        results = run_search(apiobj, frontend, 0.1, ['12345'],
                              details=SearchDetails.from_kwargs({'viewbox': viewbox,
                                                                 'bounded_viewbox': True}))
 
                              details=SearchDetails.from_kwargs({'viewbox': viewbox,
                                                                 'bounded_viewbox': True}))
 
@@ -145,16 +147,16 @@ class TestPostcodeSearchWithAddress:
 
     @pytest.mark.parametrize('coord,rids', [((17.05, 5), [100, 101]),
                                             ((-45, 7.1), [101, 100])])
 
     @pytest.mark.parametrize('coord,rids', [((17.05, 5), [100, 101]),
                                             ((-45, 7.1), [101, 100])])
-    def test_prefer_near(self, apiobj, coord, rids):
-        results = run_search(apiobj, 0.1, ['12345'],
+    def test_prefer_near(self, apiobj, frontend, coord, rids):
+        results = run_search(apiobj, frontend, 0.1, ['12345'],
                              details=SearchDetails(near=napi.Point(*coord)))
 
         assert [r.place_id for r in results] == rids
 
 
     @pytest.mark.parametrize('pid,rid', [(100, 101), (101, 100)])
                              details=SearchDetails(near=napi.Point(*coord)))
 
         assert [r.place_id for r in results] == rids
 
 
     @pytest.mark.parametrize('pid,rid', [(100, 101), (101, 100)])
-    def test_exclude(self, apiobj, pid, rid):
-        results = run_search(apiobj, 0.1, ['12345'],
+    def test_exclude(self, apiobj, frontend, pid, rid):
+        results = run_search(apiobj, frontend, 0.1, ['12345'],
                              details=SearchDetails(excluded=[pid]))
 
         assert [r.place_id for r in results] == [rid]
                              details=SearchDetails(excluded=[pid]))
 
         assert [r.place_id for r in results] == [rid]
index aa263d24dd67a8a8974004870a36123061c68d93..22dbaa2642d63a99ad227246c2c135139c087bcc 100644 (file)
@@ -19,6 +19,8 @@ import sqlalchemy as sa
 import nominatim.api as napi
 import nominatim.api.logging as loglib
 
 import nominatim.api as napi
 import nominatim.api.logging as loglib
 
+API_OPTIONS = {'search'}
+
 @pytest.fixture(autouse=True)
 def setup_icu_tokenizer(apiobj):
     """ Setup the propoerties needed for using the ICU tokenizer.
 @pytest.fixture(autouse=True)
 def setup_icu_tokenizer(apiobj):
     """ Setup the propoerties needed for using the ICU tokenizer.
@@ -30,66 +32,62 @@ def setup_icu_tokenizer(apiobj):
                     ])
 
 
                     ])
 
 
-def test_search_no_content(apiobj, table_factory):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
+def test_search_no_content(apiobj, frontend):
+    apiobj.add_word_table([])
 
 
-    assert apiobj.api.search('foo') == []
+    api = frontend(apiobj, options=API_OPTIONS)
+    assert api.search('foo') == []
 
 
 
 
-def test_search_simple_word(apiobj, table_factory):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
-                  content=[(55, 'test', 'W', 'test', None),
+def test_search_simple_word(apiobj, frontend):
+    apiobj.add_word_table([(55, 'test', 'W', 'test', None),
                            (2, 'test', 'w', 'test', None)])
 
     apiobj.add_placex(place_id=444, class_='place', type='village',
                       centroid=(1.3, 0.7))
     apiobj.add_search_name(444, names=[2, 55])
 
                            (2, 'test', 'w', 'test', None)])
 
     apiobj.add_placex(place_id=444, class_='place', type='village',
                       centroid=(1.3, 0.7))
     apiobj.add_search_name(444, names=[2, 55])
 
-    results = apiobj.api.search('TEST')
+    api = frontend(apiobj, options=API_OPTIONS)
+    results = api.search('TEST')
 
     assert [r.place_id for r in results] == [444]
 
 
 @pytest.mark.parametrize('logtype', ['text', 'html'])
 
     assert [r.place_id for r in results] == [444]
 
 
 @pytest.mark.parametrize('logtype', ['text', 'html'])
-def test_search_with_debug(apiobj, table_factory, logtype):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
-                  content=[(55, 'test', 'W', 'test', None),
+def test_search_with_debug(apiobj, frontend, logtype):
+    apiobj.add_word_table([(55, 'test', 'W', 'test', None),
                            (2, 'test', 'w', 'test', None)])
 
     apiobj.add_placex(place_id=444, class_='place', type='village',
                       centroid=(1.3, 0.7))
     apiobj.add_search_name(444, names=[2, 55])
 
                            (2, 'test', 'w', 'test', None)])
 
     apiobj.add_placex(place_id=444, class_='place', type='village',
                       centroid=(1.3, 0.7))
     apiobj.add_search_name(444, names=[2, 55])
 
+    api = frontend(apiobj, options=API_OPTIONS)
     loglib.set_log_output(logtype)
     loglib.set_log_output(logtype)
-    results = apiobj.api.search('TEST')
+    results = api.search('TEST')
 
     assert loglib.get_and_disable()
 
 
 
     assert loglib.get_and_disable()
 
 
-def test_address_no_content(apiobj, table_factory):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
+def test_address_no_content(apiobj, frontend):
+    apiobj.add_word_table([])
 
 
-    assert apiobj.api.search_address(amenity='hotel',
-                                     street='Main St 34',
-                                     city='Happyville',
-                                     county='Wideland',
-                                     state='Praerie',
-                                     postalcode='55648',
-                                     country='xx') == []
+    api = frontend(apiobj, options=API_OPTIONS)
+    assert api.search_address(amenity='hotel',
+                              street='Main St 34',
+                              city='Happyville',
+                              county='Wideland',
+                              state='Praerie',
+                              postalcode='55648',
+                              country='xx') == []
 
 
 @pytest.mark.parametrize('atype,address,search', [('street', 26, 26),
                                                   ('city', 16, 18),
                                                   ('county', 12, 12),
                                                   ('state', 8, 8)])
 
 
 @pytest.mark.parametrize('atype,address,search', [('street', 26, 26),
                                                   ('city', 16, 18),
                                                   ('county', 12, 12),
                                                   ('state', 8, 8)])
-def test_address_simple_places(apiobj, table_factory, atype, address, search):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
-                  content=[(55, 'test', 'W', 'test', None),
+def test_address_simple_places(apiobj, frontend, atype, address, search):
+    apiobj.add_word_table([(55, 'test', 'W', 'test', None),
                            (2, 'test', 'w', 'test', None)])
 
     apiobj.add_placex(place_id=444,
                            (2, 'test', 'w', 'test', None)])
 
     apiobj.add_placex(place_id=444,
@@ -97,53 +95,51 @@ def test_address_simple_places(apiobj, table_factory, atype, address, search):
                       centroid=(1.3, 0.7))
     apiobj.add_search_name(444, names=[2, 55], address_rank=address, search_rank=search)
 
                       centroid=(1.3, 0.7))
     apiobj.add_search_name(444, names=[2, 55], address_rank=address, search_rank=search)
 
-    results = apiobj.api.search_address(**{atype: 'TEST'})
+    api = frontend(apiobj, options=API_OPTIONS)
+    results = api.search_address(**{atype: 'TEST'})
 
     assert [r.place_id for r in results] == [444]
 
 
 
     assert [r.place_id for r in results] == [444]
 
 
-def test_address_country(apiobj, table_factory):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
-                  content=[(None, 'ro', 'C', 'ro', None)])
+def test_address_country(apiobj, frontend):
+    apiobj.add_word_table([(None, 'ro', 'C', 'ro', None)])
     apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))')
     apiobj.add_country_name('ro', {'name': 'România'})
 
     apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))')
     apiobj.add_country_name('ro', {'name': 'România'})
 
-    assert len(apiobj.api.search_address(country='ro')) == 1
+    api = frontend(apiobj, options=API_OPTIONS)
+    assert len(api.search_address(country='ro')) == 1
 
 
 
 
-def test_category_no_categories(apiobj, table_factory):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
+def test_category_no_categories(apiobj, frontend):
+    apiobj.add_word_table([])
 
 
-    assert apiobj.api.search_category([], near_query='Berlin') == []
+    api = frontend(apiobj, options=API_OPTIONS)
+    assert api.search_category([], near_query='Berlin') == []
 
 
 
 
-def test_category_no_content(apiobj, table_factory):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
+def test_category_no_content(apiobj, frontend):
+    apiobj.add_word_table([])
 
 
-    assert apiobj.api.search_category([('amenity', 'restaurant')]) == []
+    api = frontend(apiobj, options=API_OPTIONS)
+    assert api.search_category([('amenity', 'restaurant')]) == []
 
 
 
 
-def test_category_simple_restaurant(apiobj, table_factory):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
+def test_category_simple_restaurant(apiobj, frontend):
+    apiobj.add_word_table([])
 
     apiobj.add_placex(place_id=444, class_='amenity', type='restaurant',
                       centroid=(1.3, 0.7))
     apiobj.add_search_name(444, names=[2, 55], address_rank=16, search_rank=18)
 
 
     apiobj.add_placex(place_id=444, class_='amenity', type='restaurant',
                       centroid=(1.3, 0.7))
     apiobj.add_search_name(444, names=[2, 55], address_rank=16, search_rank=18)
 
-    results = apiobj.api.search_category([('amenity', 'restaurant')],
-                                         near=(1.3, 0.701), near_radius=0.015)
+    api = frontend(apiobj, options=API_OPTIONS)
+    results = api.search_category([('amenity', 'restaurant')],
+                                  near=(1.3, 0.701), near_radius=0.015)
 
     assert [r.place_id for r in results] == [444]
 
 
 
     assert [r.place_id for r in results] == [444]
 
 
-def test_category_with_search_phrase(apiobj, table_factory):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
-                  content=[(55, 'test', 'W', 'test', None),
+def test_category_with_search_phrase(apiobj, frontend):
+    apiobj.add_word_table([(55, 'test', 'W', 'test', None),
                            (2, 'test', 'w', 'test', None)])
 
     apiobj.add_placex(place_id=444, class_='place', type='village',
                            (2, 'test', 'w', 'test', None)])
 
     apiobj.add_placex(place_id=444, class_='place', type='village',
@@ -153,7 +149,7 @@ def test_category_with_search_phrase(apiobj, table_factory):
     apiobj.add_placex(place_id=95, class_='amenity', type='restaurant',
                       centroid=(1.3, 0.7003))
 
     apiobj.add_placex(place_id=95, class_='amenity', type='restaurant',
                       centroid=(1.3, 0.7003))
 
-    results = apiobj.api.search_category([('amenity', 'restaurant')],
-                                         near_query='TEST')
+    api = frontend(apiobj, options=API_OPTIONS)
+    results = api.search_category([('amenity', 'restaurant')], near_query='TEST')
 
     assert [r.place_id for r in results] == [95]
 
     assert [r.place_id for r in results] == [95]