]> git.openstreetmap.org Git - nominatim.git/commitdiff
Merge remote-tracking branch 'upstream/master'
authorSarah Hoffmann <lonvia@denofr.de>
Mon, 18 Dec 2023 15:00:08 +0000 (16:00 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Mon, 18 Dec 2023 15:00:08 +0000 (16:00 +0100)
45 files changed:
.github/workflows/ci-tests.yml
docs/customize/SQLite.md [new file with mode: 0644]
docs/mkdocs.yml
lib-sql/functions/importance.sql
lib-sql/functions/ranking.sql
nominatim/api/core.py
nominatim/api/logging.py
nominatim/api/reverse.py
nominatim/api/search/db_search_builder.py
nominatim/api/search/db_search_fields.py
nominatim/api/search/db_search_lookups.py [new file with mode: 0644]
nominatim/api/search/db_searches.py
nominatim/api/search/icu_tokenizer.py
nominatim/clicmd/args.py
nominatim/clicmd/convert.py
nominatim/clicmd/refresh.py
nominatim/clicmd/setup.py
nominatim/db/connection.py
nominatim/db/sqlalchemy_functions.py
nominatim/db/sqlalchemy_schema.py
nominatim/db/sqlalchemy_types/__init__.py [new file with mode: 0644]
nominatim/db/sqlalchemy_types/geometry.py [moved from nominatim/db/sqlalchemy_types.py with 85% similarity]
nominatim/db/sqlalchemy_types/int_array.py [new file with mode: 0644]
nominatim/db/sqlalchemy_types/json.py [new file with mode: 0644]
nominatim/db/sqlalchemy_types/key_value.py [new file with mode: 0644]
nominatim/db/sqlite_functions.py [new file with mode: 0644]
nominatim/tools/convert_sqlite.py
nominatim/tools/database_import.py
nominatim/typing.py
nominatim/utils/json_writer.py
test/bdd/api/search/geocodejson.feature
test/bdd/api/search/language.feature
test/bdd/api/search/params.feature
test/bdd/api/search/postcode.feature
test/bdd/api/search/queries.feature
test/bdd/api/search/simple.feature
test/bdd/api/search/structured.feature
test/python/api/conftest.py
test/python/api/search/test_db_search_builder.py
test/python/api/search/test_search_country.py
test/python/api/search/test_search_near.py
test/python/api/search/test_search_places.py
test/python/api/search/test_search_poi.py
test/python/api/search/test_search_postcode.py
test/python/api/test_api_search.py

index 1dade3bcfa28c3b4a43e2b73aa5882ba66238d33..42c03edc17d9e76fd20463b6294c25ce1fa730bc 100644 (file)
@@ -105,7 +105,7 @@ jobs:
               if: matrix.flavour != 'oldstuff'
 
             - name: Install newer pytest-asyncio
-              run: pip3 install -U pytest-asyncio
+              run: pip3 install -U pytest-asyncio==0.21.1
               if: matrix.flavour == 'ubuntu-20'
 
             - name: Install test prerequsites (from pip for Ubuntu 18)
@@ -349,3 +349,95 @@ jobs:
             - name: Clean up database (reverse-only import)
               run: nominatim refresh --postcodes --word-tokens
               working-directory: /home/nominatim/nominatim-project
+
+    install-no-superuser:
+      runs-on: ubuntu-latest
+      needs: create-archive
+
+      strategy:
+          matrix:
+              name: [Ubuntu-22]
+              include:
+                  - name: Ubuntu-22
+                    image: "ubuntu:22.04"
+                    ubuntu: 22
+                    install_mode: install-apache
+
+      container:
+          image: ${{ matrix.image }}
+          env:
+              LANG: en_US.UTF-8
+
+      defaults:
+          run:
+              shell: sudo -Hu nominatim bash --noprofile --norc -eo pipefail {0}
+
+      steps:
+          - name: Prepare container (Ubuntu)
+            run: |
+                export APT_LISTCHANGES_FRONTEND=none
+                export DEBIAN_FRONTEND=noninteractive
+                apt-get update -qq
+                apt-get install -y git sudo wget
+                ln -snf /usr/share/zoneinfo/$CONTAINER_TIMEZONE /etc/localtime && echo $CONTAINER_TIMEZONE > /etc/timezone
+            shell: bash
+
+          - name: Setup import user
+            run: |
+                useradd -m nominatim
+                echo 'nominatim   ALL=(ALL:ALL) NOPASSWD: ALL' > /etc/sudoers.d/nominiatim
+                echo "/home/nominatim/Nominatim/vagrant/Install-on-${OS}.sh no $INSTALL_MODE" > /home/nominatim/vagrant.sh
+            shell: bash
+            env:
+              OS: ${{ matrix.name }}
+              INSTALL_MODE: ${{ matrix.install_mode }}
+
+          - uses: actions/download-artifact@v3
+            with:
+                name: full-source
+                path: /home/nominatim
+
+          - name: Install Nominatim
+            run: |
+              export USERNAME=nominatim
+              export USERHOME=/home/nominatim
+              export NOSYSTEMD=yes
+              export HAVE_SELINUX=no
+              tar xf nominatim-src.tar.bz2
+              . vagrant.sh
+            working-directory: /home/nominatim
+
+          - name: Prepare import environment
+            run: |
+                mv Nominatim/test/testdb/apidb-test-data.pbf test.pbf
+                mv Nominatim/settings/flex-base.lua flex-base.lua
+                mv Nominatim/settings/import-extratags.lua import-extratags.lua
+                mv Nominatim/settings/taginfo.lua taginfo.lua
+                rm -rf Nominatim
+                mkdir data-env-reverse
+            working-directory: /home/nominatim
+
+          - name: Prepare Database
+            run: |
+                nominatim import --prepare-database
+            working-directory: /home/nominatim/nominatim-project
+
+          - name: Create import user
+            run: |
+                sudo -u postgres createuser -S osm-import
+                sudo -u postgres psql -c "ALTER USER \"osm-import\" WITH PASSWORD 'osm-import';"
+            working-directory: /home/nominatim/nominatim-project
+
+          - name: Grant import user rights
+            run: |
+                sudo -u postgres psql -c "GRANT INSERT, UPDATE ON ALL TABLES IN SCHEMA public TO \"osm-import\";"
+            working-directory: /home/nominatim/nominatim-project
+
+          - name: Run import
+            run: |
+                NOMINATIM_DATABASE_DSN="pgsql:host=127.0.0.1;dbname=nominatim;user=osm-import;password=osm-import" nominatim import --continue import-from-file --osm-file ../test.pbf
+            working-directory: /home/nominatim/nominatim-project
+
+          - name: Check full import
+            run: nominatim admin --check-database
+            working-directory: /home/nominatim/nominatim-project
\ No newline at end of file
diff --git a/docs/customize/SQLite.md b/docs/customize/SQLite.md
new file mode 100644 (file)
index 0000000..9614fea
--- /dev/null
@@ -0,0 +1,55 @@
+A Nominatim database can be converted into an SQLite database and used as
+a read-only source for geocoding queries. This sections describes how to
+create and use an SQLite database.
+
+!!! danger
+    This feature is in an experimental state at the moment. Use at your own
+    risk.
+
+## Installing prerequisites
+
+To use a SQLite database, you need to install:
+
+* SQLite (>= 3.30)
+* Spatialite (> 5.0.0)
+
+On Ubuntu/Debian, you can run:
+
+    sudo apt install sqlite3 libsqlite3-mod-spatialite libspatialite7
+
+## Creating a new SQLite database
+
+Nominatim cannot import directly into SQLite database. Instead you have to
+first create a geocoding database in PostgreSQL by running a
+[regular Nominatim import](../admin/Import.md).
+
+Once this is done, the database can be converted to SQLite with
+
+    nominatim convert -o mydb.sqlite
+
+This will create a database where all geocoding functions are available.
+Depending on what functions you need, the database can be made smaller:
+
+* `--without-reverse` omits indexes only needed for reverse geocoding
+* `--without-search` omit tables and indexes used for forward search
+* `--without-details` leaves out extra information only available in the
+  details API
+
+## Using an SQLite database
+
+Once you have created the database, you can use it by simply pointing the
+database DSN to the SQLite file:
+
+    NOMINATIM_DATABASE_DSN=sqlite:dbname=mydb.sqlite
+
+Please note that SQLite support is only available for the Python frontend. To
+use the test server with an SQLite database, you therefore need to switch
+the frontend engine:
+
+    nominatim serve --engine falcon
+
+You need to install falcon or starlette for this, depending on which engine
+you choose.
+
+The CLI query commands and the library interface already use the new Python
+frontend and therefore work right out of the box.
index 3301356d71577f08f16bec798d3539549c4137e8..f332640ff98f404080df41124fd0bf8b444d5c59 100644 (file)
@@ -40,6 +40,7 @@ nav:
         - 'Special Phrases': 'customize/Special-Phrases.md'
         - 'External data: US housenumbers from TIGER': 'customize/Tiger.md'
         - 'External data: Postcodes': 'customize/Postcodes.md'
+        - 'Conversion to SQLite': 'customize/SQLite.md'
     - 'Library Guide':
         - 'Getting Started': 'library/Getting-Started.md'
         - 'Nominatim API class': 'library/NominatimAPI.md'
index 44e8bc8b8e0b31e8d0ff837d059a119deb5c3af1..6c089d824b99565a72ad9de6202c5fd1f6ae5b2b 100644 (file)
@@ -62,10 +62,6 @@ BEGIN
   WHILE langs[i] IS NOT NULL LOOP
     wiki_article := extratags->(case when langs[i] in ('english','country') THEN 'wikipedia' ELSE 'wikipedia:'||langs[i] END);
     IF wiki_article is not null THEN
-      wiki_article := regexp_replace(wiki_article,E'^(.*?)([a-z]{2,3}).wikipedia.org/wiki/',E'\\2:');
-      wiki_article := regexp_replace(wiki_article,E'^(.*?)([a-z]{2,3}).wikipedia.org/w/index.php\\?title=',E'\\2:');
-      wiki_article := regexp_replace(wiki_article,E'^(.*?)/([a-z]{2,3})/wiki/',E'\\2:');
-      --wiki_article := regexp_replace(wiki_article,E'^(.*?)([a-z]{2,3})[=:]',E'\\2:');
       wiki_article := replace(wiki_article,' ','_');
       IF strpos(wiki_article, ':') IN (3,4) THEN
         wiki_article_language := lower(trim(split_part(wiki_article, ':', 1)));
index 0b18954cedb985ab71430b20762958f7571dd6da..97a0cde38e2b6aa6ef8aebf63af2611894502fbd 100644 (file)
@@ -287,21 +287,19 @@ LANGUAGE plpgsql IMMUTABLE;
 
 
 CREATE OR REPLACE FUNCTION weigh_search(search_vector INT[],
-                                        term_vectors TEXT[],
-                                        weight_vectors FLOAT[],
+                                        rankings TEXT,
                                         def_weight FLOAT)
   RETURNS FLOAT
   AS $$
 DECLARE
-  pos INT := 1;
-  terms TEXT;
+  rank JSON;
 BEGIN
-  FOREACH terms IN ARRAY term_vectors
+  FOR rank IN
+    SELECT * FROM json_array_elements(rankings::JSON)
   LOOP
-    IF search_vector @> terms::INTEGER[] THEN
-      RETURN weight_vectors[pos];
+    IF true = ALL(SELECT x::int = ANY(search_vector) FROM json_array_elements_text(rank->1) as x) THEN
+      RETURN (rank->>0)::float;
     END IF;
-    pos := pos + 1;
   END LOOP;
   RETURN def_weight;
 END;
index 44ac91606fef90a746bb26d06b2a9fc6da0e61e4..1c0c4423fcae0c6a82d07c39c690bd1b14ffbb47 100644 (file)
@@ -19,6 +19,7 @@ import sqlalchemy.ext.asyncio as sa_asyncio
 from nominatim.errors import UsageError
 from nominatim.db.sqlalchemy_schema import SearchTables
 from nominatim.db.async_core_library import PGCORE_LIB, PGCORE_ERROR
+import nominatim.db.sqlite_functions
 from nominatim.config import Configuration
 from nominatim.api.connection import SearchConnection
 from nominatim.api.status import get_status, StatusResult
@@ -84,6 +85,14 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
             extra_args: Dict[str, Any] = {'future': True,
                                           'echo': self.config.get_bool('DEBUG_SQL')}
 
+            if self.config.get_int('API_POOL_SIZE') == 0:
+                extra_args['poolclass'] = sa.pool.NullPool
+            else:
+                extra_args['poolclass'] = sa.pool.QueuePool
+                extra_args['max_overflow'] = 0
+                extra_args['pool_size'] = self.config.get_int('API_POOL_SIZE')
+
+
             is_sqlite = self.config.DATABASE_DSN.startswith('sqlite:')
 
             if is_sqlite:
@@ -92,6 +101,10 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
                 dburl = sa.engine.URL.create('sqlite+aiosqlite',
                                              database=params.get('dbname'))
 
+                if not ('NOMINATIM_DATABASE_RW' in self.config.environ
+                        and self.config.get_bool('DATABASE_RW')) \
+                   and not Path(params.get('dbname', '')).is_file():
+                    raise UsageError(f"SQlite database '{params.get('dbname')}' does not exist.")
             else:
                 dsn = self.config.get_database_params()
                 query = {k: v for k, v in dsn.items()
@@ -105,39 +118,40 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
                            host=dsn.get('host'),
                            port=int(dsn['port']) if 'port' in dsn else None,
                            query=query)
-                extra_args['max_overflow'] = 0
-                extra_args['pool_size'] = self.config.get_int('API_POOL_SIZE')
 
             engine = sa_asyncio.create_async_engine(dburl, **extra_args)
 
-            try:
-                async with engine.begin() as conn:
-                    result = await conn.scalar(sa.text('SHOW server_version_num'))
-                    server_version = int(result)
-            except (PGCORE_ERROR, sa.exc.OperationalError):
+            if is_sqlite:
                 server_version = 0
 
-            if server_version >= 110000 and not is_sqlite:
-                @sa.event.listens_for(engine.sync_engine, "connect")
-                def _on_connect(dbapi_con: Any, _: Any) -> None:
-                    cursor = dbapi_con.cursor()
-                    cursor.execute("SET jit_above_cost TO '-1'")
-                    cursor.execute("SET max_parallel_workers_per_gather TO '0'")
-                # Make sure that all connections get the new settings
-                await self.close()
-
-            if is_sqlite:
                 @sa.event.listens_for(engine.sync_engine, "connect")
                 def _on_sqlite_connect(dbapi_con: Any, _: Any) -> None:
                     dbapi_con.run_async(lambda conn: conn.enable_load_extension(True))
+                    nominatim.db.sqlite_functions.install_custom_functions(dbapi_con)
                     cursor = dbapi_con.cursor()
                     cursor.execute("SELECT load_extension('mod_spatialite')")
                     cursor.execute('SELECT SetDecimalPrecision(7)')
                     dbapi_con.run_async(lambda conn: conn.enable_load_extension(False))
+            else:
+                try:
+                    async with engine.begin() as conn:
+                        result = await conn.scalar(sa.text('SHOW server_version_num'))
+                        server_version = int(result)
+                except (PGCORE_ERROR, sa.exc.OperationalError):
+                    server_version = 0
+
+                if server_version >= 110000:
+                    @sa.event.listens_for(engine.sync_engine, "connect")
+                    def _on_connect(dbapi_con: Any, _: Any) -> None:
+                        cursor = dbapi_con.cursor()
+                        cursor.execute("SET jit_above_cost TO '-1'")
+                        cursor.execute("SET max_parallel_workers_per_gather TO '0'")
+                    # Make sure that all connections get the new settings
+                    await engine.dispose()
 
             self._property_cache['DB:server_version'] = server_version
 
-            self._tables = SearchTables(sa.MetaData(), engine.name) # pylint: disable=no-member
+            self._tables = SearchTables(sa.MetaData()) # pylint: disable=no-member
             self._engine = engine
 
 
index 37ae7f5f04464241ad0e81062b56d125555cadff..e16e0bd2d3bdbcab64b7f8c074ddbbe72cc4843e 100644 (file)
@@ -90,26 +90,42 @@ class BaseLogger:
         params = dict(compiled.params)
         if isinstance(extra_params, Mapping):
             for k, v in extra_params.items():
-                params[k] = str(v)
+                if hasattr(v, 'to_wkt'):
+                    params[k] = v.to_wkt()
+                elif isinstance(v, (int, float)):
+                    params[k] = v
+                else:
+                    params[k] = str(v)
         elif isinstance(extra_params, Sequence) and extra_params:
             for k in extra_params[0]:
                 params[k] = f':{k}'
 
         sqlstr = str(compiled)
 
-        if sa.__version__.startswith('1'):
-            try:
-                sqlstr = re.sub(r'__\[POSTCOMPILE_[^]]*\]', '%s', sqlstr)
-                return sqlstr % tuple((repr(params.get(name, None))
-                                      for name in compiled.positiontup)) # type: ignore
-            except TypeError:
-                return sqlstr
-
-        # Fixes an odd issue with Python 3.7 where percentages are not
-        # quoted correctly.
-        sqlstr = re.sub(r'%(?!\()', '%%', sqlstr)
-        sqlstr = re.sub(r'__\[POSTCOMPILE_([^]]*)\]', r'%(\1)s', sqlstr)
-        return sqlstr % params
+        if conn.dialect.name == 'postgresql':
+            if sa.__version__.startswith('1'):
+                try:
+                    sqlstr = re.sub(r'__\[POSTCOMPILE_[^]]*\]', '%s', sqlstr)
+                    return sqlstr % tuple((repr(params.get(name, None))
+                                          for name in compiled.positiontup)) # type: ignore
+                except TypeError:
+                    return sqlstr
+
+            # Fixes an odd issue with Python 3.7 where percentages are not
+            # quoted correctly.
+            sqlstr = re.sub(r'%(?!\()', '%%', sqlstr)
+            sqlstr = re.sub(r'__\[POSTCOMPILE_([^]]*)\]', r'%(\1)s', sqlstr)
+            return sqlstr % params
+
+        assert conn.dialect.name == 'sqlite'
+
+        # params in positional order
+        pparams = (repr(params.get(name, None)) for name in compiled.positiontup) # type: ignore
+
+        sqlstr = re.sub(r'__\[POSTCOMPILE_([^]]*)\]', '?', sqlstr)
+        sqlstr = re.sub(r"\?", lambda m: next(pparams), sqlstr)
+
+        return sqlstr
 
 class HTMLLogger(BaseLogger):
     """ Logger that formats messages in HTML.
index fb4c0b23d0f2fd4790d942b31508126f39a2d379..df5c10f2669951d0f0bbcf778815724c23a719ba 100644 (file)
@@ -180,7 +180,7 @@ class ReverseGeocoder:
         diststr = sa.text(f"{distance}")
 
         sql: SaLambdaSelect = sa.lambda_stmt(lambda: _select_from_placex(t)
-                .where(t.c.geometry.ST_DWithin(WKT_PARAM, diststr))
+                .where(t.c.geometry.within_distance(WKT_PARAM, diststr))
                 .where(t.c.indexed_status == 0)
                 .where(t.c.linked_place_id == None)
                 .where(sa.or_(sa.not_(t.c.geometry.is_area()),
@@ -219,7 +219,7 @@ class ReverseGeocoder:
         t = self.conn.t.placex
 
         sql: SaLambdaSelect = sa.lambda_stmt(lambda: _select_from_placex(t)
-                .where(t.c.geometry.ST_DWithin(WKT_PARAM, 0.001))
+                .where(t.c.geometry.within_distance(WKT_PARAM, 0.001))
                 .where(t.c.parent_place_id == parent_place_id)
                 .where(sa.func.IsAddressPoint(t))
                 .where(t.c.indexed_status == 0)
@@ -241,7 +241,7 @@ class ReverseGeocoder:
                    sa.select(t,
                              t.c.linegeo.ST_Distance(WKT_PARAM).label('distance'),
                              _locate_interpolation(t))
-                     .where(t.c.linegeo.ST_DWithin(WKT_PARAM, distance))
+                     .where(t.c.linegeo.within_distance(WKT_PARAM, distance))
                      .where(t.c.startnumber != None)
                      .order_by('distance')
                      .limit(1))
@@ -275,7 +275,7 @@ class ReverseGeocoder:
             inner = sa.select(t,
                               t.c.linegeo.ST_Distance(WKT_PARAM).label('distance'),
                               _locate_interpolation(t))\
-                      .where(t.c.linegeo.ST_DWithin(WKT_PARAM, 0.001))\
+                      .where(t.c.linegeo.within_distance(WKT_PARAM, 0.001))\
                       .where(t.c.parent_place_id == parent_place_id)\
                       .order_by('distance')\
                       .limit(1)\
index c755f2a74f8a16e2d53ca30503549040685d0046..fd8cc7af90ffb3aa71581aac842e602d82cc0d39 100644 (file)
@@ -15,6 +15,7 @@ from nominatim.api.search.query import QueryStruct, Token, TokenType, TokenRange
 from nominatim.api.search.token_assignment import TokenAssignment
 import nominatim.api.search.db_search_fields as dbf
 import nominatim.api.search.db_searches as dbs
+import nominatim.api.search.db_search_lookups as lookups
 
 
 def wrap_near_search(categories: List[Tuple[str, str]],
@@ -152,7 +153,7 @@ class SearchBuilder:
                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
                                                  [t.token for r in address
                                                   for t in self.query.get_partials_list(r)],
-                                                 'restrict')]
+                                                 lookups.Restrict)]
                 penalty += 0.2
             yield dbs.PostcodeSearch(penalty, sdata)
 
@@ -162,7 +163,7 @@ class SearchBuilder:
         """ Build a simple address search for special entries where the
             housenumber is the main name token.
         """
-        sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], 'lookup_any')]
+        sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], lookups.LookupAny)]
         expected_count = sum(t.count for t in hnrs)
 
         partials = [t for trange in address
@@ -170,16 +171,16 @@ class SearchBuilder:
 
         if expected_count < 8000:
             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
-                                                 [t.token for t in partials], 'restrict'))
+                                                 [t.token for t in partials], lookups.Restrict))
         elif len(partials) != 1 or partials[0].count < 10000:
             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
-                                                 [t.token for t in partials], 'lookup_all'))
+                                                 [t.token for t in partials], lookups.LookupAll))
         else:
             sdata.lookups.append(
                 dbf.FieldLookup('nameaddress_vector',
                                 [t.token for t
                                  in self.query.get_tokens(address[0], TokenType.WORD)],
-                                'lookup_any'))
+                                lookups.LookupAny))
 
         sdata.housenumbers = dbf.WeightedStrings([], [])
         yield dbs.PlaceSearch(0.05, sdata, expected_count)
@@ -232,16 +233,16 @@ class SearchBuilder:
                 penalty += 1.2 * sum(t.penalty for t in addr_partials if not t.is_indexed)
             # Any of the full names applies with all of the partials from the address
             yield penalty, fulls_count / (2**len(addr_partials)),\
-                  dbf.lookup_by_any_name([t.token for t in name_fulls], addr_tokens,
-                                         'restrict' if fulls_count < 10000 else 'lookup_all')
+                  dbf.lookup_by_any_name([t.token for t in name_fulls],
+                                         addr_tokens, fulls_count > 10000)
 
         # To catch remaining results, lookup by name and address
         # We only do this if there is a reasonable number of results expected.
         exp_count = exp_count / (2**len(addr_partials)) if addr_partials else exp_count
         if exp_count < 10000 and all(t.is_indexed for t in name_partials):
-            lookup = [dbf.FieldLookup('name_vector', name_tokens, 'lookup_all')]
+            lookup = [dbf.FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
             if addr_tokens:
-                lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all'))
+                lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, lookups.LookupAll))
             penalty += 0.35 * max(0, 5 - len(name_partials) - len(addr_tokens))
             yield penalty, exp_count, lookup
 
index 59af826086db86027f2c808dee51824fb17e72ff..6947a565f80dad421dcd9398975284988121a254 100644 (file)
@@ -7,14 +7,16 @@
 """
 Data structures for more complex fields in abstract search descriptions.
 """
-from typing import List, Tuple, Iterator, cast, Dict
+from typing import List, Tuple, Iterator, Dict, Type
 import dataclasses
 
 import sqlalchemy as sa
-from sqlalchemy.dialects.postgresql import ARRAY
 
 from nominatim.typing import SaFromClause, SaColumn, SaExpression
 from nominatim.api.search.query import Token
+import nominatim.api.search.db_search_lookups as lookups
+from nominatim.utils.json_writer import JsonWriter
+
 
 @dataclasses.dataclass
 class WeightedStrings:
@@ -129,11 +131,17 @@ class FieldRanking:
         """
         assert self.rankings
 
-        return sa.func.weigh_search(table.c[self.column],
-                                    [f"{{{','.join((str(s) for s in r.tokens))}}}"
-                                     for r in self.rankings],
-                                    [r.penalty for r in self.rankings],
-                                    self.default)
+        rout = JsonWriter().start_array()
+        for rank in self.rankings:
+            rout.start_array().value(rank.penalty).next()
+            rout.start_array()
+            for token in rank.tokens:
+                rout.value(token).next()
+            rout.end_array()
+            rout.end_array().next()
+        rout.end_array()
+
+        return sa.func.weigh_search(table.c[self.column], rout(), self.default)
 
 
 @dataclasses.dataclass
@@ -146,19 +154,12 @@ class FieldLookup:
     """
     column: str
     tokens: List[int]
-    lookup_type: str
+    lookup_type: Type[lookups.LookupType]
 
     def sql_condition(self, table: SaFromClause) -> SaColumn:
         """ Create an SQL expression for the given match condition.
         """
-        col = table.c[self.column]
-        if self.lookup_type == 'lookup_all':
-            return col.contains(self.tokens)
-        if self.lookup_type == 'lookup_any':
-            return cast(SaColumn, col.overlap(self.tokens))
-
-        return sa.func.array_cat(col, sa.text('ARRAY[]::integer[]'),
-                                 type_=ARRAY(sa.Integer())).contains(self.tokens)
+        return self.lookup_type(table, self.column, self.tokens)
 
 
 class SearchData:
@@ -224,22 +225,23 @@ def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[Fiel
     """ Create a lookup list where name tokens are looked up via index
         and potential address tokens are used to restrict the search further.
     """
-    lookup = [FieldLookup('name_vector', name_tokens, 'lookup_all')]
+    lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
     if addr_tokens:
-        lookup.append(FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
+        lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookups.Restrict))
 
     return lookup
 
 
 def lookup_by_any_name(name_tokens: List[int], addr_tokens: List[int],
-                       lookup_type: str) -> List[FieldLookup]:
+                       use_index_for_addr: bool) -> List[FieldLookup]:
     """ Create a lookup list where name tokens are looked up via index
         and only one of the name tokens must be present.
         Potential address tokens are used to restrict the search further.
     """
-    lookup = [FieldLookup('name_vector', name_tokens, 'lookup_any')]
+    lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAny)]
     if addr_tokens:
-        lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookup_type))
+        lookup.append(FieldLookup('nameaddress_vector', addr_tokens,
+                                  lookups.LookupAll if use_index_for_addr else lookups.Restrict))
 
     return lookup
 
@@ -248,5 +250,5 @@ def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[Field
     """ Create a lookup list where address tokens are looked up via index
         and the name tokens are only used to restrict the search further.
     """
-    return [FieldLookup('name_vector', name_tokens, 'restrict'),
-            FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')]
+    return [FieldLookup('name_vector', name_tokens, lookups.Restrict),
+            FieldLookup('nameaddress_vector', addr_tokens, lookups.LookupAll)]
diff --git a/nominatim/api/search/db_search_lookups.py b/nominatim/api/search/db_search_lookups.py
new file mode 100644 (file)
index 0000000..aa5cef5
--- /dev/null
@@ -0,0 +1,114 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Implementation of lookup functions for the search_name table.
+"""
+from typing import List, Any
+
+import sqlalchemy as sa
+from sqlalchemy.ext.compiler import compiles
+
+from nominatim.typing import SaFromClause
+from nominatim.db.sqlalchemy_types import IntArray
+
+# pylint: disable=consider-using-f-string
+
+LookupType = sa.sql.expression.FunctionElement[Any]
+
+class LookupAll(LookupType):
+    """ Find all entries in search_name table that contain all of
+        a given list of tokens using an index for the search.
+    """
+    inherit_cache = True
+
+    def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
+        super().__init__(table.c.place_id, getattr(table.c, column), column,
+                         sa.type_coerce(tokens, IntArray))
+
+
+@compiles(LookupAll) # type: ignore[no-untyped-call, misc]
+def _default_lookup_all(element: LookupAll,
+                        compiler: 'sa.Compiled', **kw: Any) -> str:
+    _, col, _, tokens = list(element.clauses)
+    return "(%s @> %s)" % (compiler.process(col, **kw),
+                           compiler.process(tokens, **kw))
+
+
+@compiles(LookupAll, 'sqlite') # type: ignore[no-untyped-call, misc]
+def _sqlite_lookup_all(element: LookupAll,
+                        compiler: 'sa.Compiled', **kw: Any) -> str:
+    place, col, colname, tokens = list(element.clauses)
+    return "(%s IN (SELECT CAST(value as bigint) FROM"\
+           " (SELECT array_intersect_fuzzy(places) as p FROM"\
+           "   (SELECT places FROM reverse_search_name"\
+           "   WHERE word IN (SELECT value FROM json_each('[' || %s || ']'))"\
+           "     AND column = %s"\
+           "   ORDER BY length(places)) as x) as u,"\
+           " json_each('[' || u.p || ']'))"\
+           " AND array_contains(%s, %s))"\
+             % (compiler.process(place, **kw),
+                compiler.process(tokens, **kw),
+                compiler.process(colname, **kw),
+                compiler.process(col, **kw),
+                compiler.process(tokens, **kw)
+                )
+
+
+
+class LookupAny(LookupType):
+    """ Find all entries that contain at least one of the given tokens.
+        Use an index for the search.
+    """
+    inherit_cache = True
+
+    def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
+        super().__init__(table.c.place_id, getattr(table.c, column), column,
+                         sa.type_coerce(tokens, IntArray))
+
+@compiles(LookupAny) # type: ignore[no-untyped-call, misc]
+def _default_lookup_any(element: LookupAny,
+                        compiler: 'sa.Compiled', **kw: Any) -> str:
+    _, col, _, tokens = list(element.clauses)
+    return "(%s && %s)" % (compiler.process(col, **kw),
+                           compiler.process(tokens, **kw))
+
+@compiles(LookupAny, 'sqlite') # type: ignore[no-untyped-call, misc]
+def _sqlite_lookup_any(element: LookupAny,
+                        compiler: 'sa.Compiled', **kw: Any) -> str:
+    place, _, colname, tokens = list(element.clauses)
+    return "%s IN (SELECT CAST(value as bigint) FROM"\
+           " (SELECT array_union(places) as p FROM reverse_search_name"\
+           "   WHERE word IN (SELECT value FROM json_each('[' || %s || ']'))"\
+           "     AND column = %s) as u,"\
+           " json_each('[' || u.p || ']'))" % (compiler.process(place, **kw),
+                                               compiler.process(tokens, **kw),
+                                               compiler.process(colname, **kw))
+
+
+
+class Restrict(LookupType):
+    """ Find all entries that contain all of the given tokens.
+        Do not use an index for the search.
+    """
+    inherit_cache = True
+
+    def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
+        super().__init__(getattr(table.c, column),
+                         sa.type_coerce(tokens, IntArray))
+
+
+@compiles(Restrict) # type: ignore[no-untyped-call, misc]
+def _default_restrict(element: Restrict,
+                        compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "(coalesce(null, %s) @> %s)" % (compiler.process(arg1, **kw),
+                                           compiler.process(arg2, **kw))
+
+@compiles(Restrict, 'sqlite') # type: ignore[no-untyped-call, misc]
+def _sqlite_restrict(element: Restrict,
+                        compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "array_contains(%s)" % compiler.process(element.clauses, **kw)
index 232f816ef89609f050ea15e79f3651410222ef86..ee98100c637fbe3365b99c85951e07bf92c4055f 100644 (file)
@@ -11,7 +11,6 @@ from typing import List, Tuple, AsyncIterator, Dict, Any, Callable
 import abc
 
 import sqlalchemy as sa
-from sqlalchemy.dialects.postgresql import ARRAY, array_agg
 
 from nominatim.typing import SaFromClause, SaScalarSelect, SaColumn, \
                              SaExpression, SaSelect, SaLambdaSelect, SaRow, SaBind
@@ -19,7 +18,7 @@ from nominatim.api.connection import SearchConnection
 from nominatim.api.types import SearchDetails, DataLayer, GeometryFormat, Bbox
 import nominatim.api.results as nres
 from nominatim.api.search.db_search_fields import SearchData, WeightedCategories
-from nominatim.db.sqlalchemy_types import Geometry
+from nominatim.db.sqlalchemy_types import Geometry, IntArray
 
 #pylint: disable=singleton-comparison,not-callable
 #pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements
@@ -55,12 +54,29 @@ NEAR_PARAM: SaBind = sa.bindparam('near', type_=Geometry)
 NEAR_RADIUS_PARAM: SaBind = sa.bindparam('near_radius')
 COUNTRIES_PARAM: SaBind = sa.bindparam('countries')
 
-def _within_near(t: SaFromClause) -> Callable[[], SaExpression]:
-    return lambda: t.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM)
+
+def filter_by_area(sql: SaSelect, t: SaFromClause,
+                   details: SearchDetails, avoid_index: bool = False) -> SaSelect:
+    """ Apply SQL statements for filtering by viewbox and near point,
+        if applicable.
+    """
+    if details.near is not None and details.near_radius is not None:
+        if details.near_radius < 0.1 and not avoid_index:
+            sql = sql.where(t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM))
+        else:
+            sql = sql.where(t.c.geometry.ST_Distance(NEAR_PARAM) <= NEAR_RADIUS_PARAM)
+    if details.viewbox is not None and details.bounded_viewbox:
+        sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM,
+                                                use_index=not avoid_index and
+                                                          details.viewbox.area < 0.2))
+
+    return sql
+
 
 def _exclude_places(t: SaFromClause) -> Callable[[], SaExpression]:
     return lambda: t.c.place_id.not_in(sa.bindparam('excluded'))
 
+
 def _select_placex(t: SaFromClause) -> SaSelect:
     return sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
                      t.c.class_, t.c.type,
@@ -93,7 +109,7 @@ def _add_geometry_columns(sql: SaLambdaSelect, col: SaColumn, details: SearchDet
 
 def _make_interpolation_subquery(table: SaFromClause, inner: SaFromClause,
                                  numerals: List[int], details: SearchDetails) -> SaScalarSelect:
-    all_ids = array_agg(table.c.place_id) # type: ignore[no-untyped-call]
+    all_ids = sa.func.ArrayAgg(table.c.place_id)
     sql = sa.select(all_ids).where(table.c.parent_place_id == inner.c.place_id)
 
     if len(numerals) == 1:
@@ -117,9 +133,7 @@ def _filter_by_layer(table: SaFromClause, layers: DataLayer) -> SaColumn:
         orexpr.append(no_index(table.c.rank_address).between(1, 30))
     elif layers & DataLayer.ADDRESS:
         orexpr.append(no_index(table.c.rank_address).between(1, 29))
-        orexpr.append(sa.and_(no_index(table.c.rank_address) == 30,
-                              sa.or_(table.c.housenumber != None,
-                                     table.c.address.has_key('addr:housename'))))
+        orexpr.append(sa.func.IsAddressPoint(table))
     elif layers & DataLayer.POI:
         orexpr.append(sa.and_(no_index(table.c.rank_address) == 30,
                               table.c.class_.not_in(('place', 'building'))))
@@ -171,12 +185,21 @@ async def _get_placex_housenumbers(conn: SearchConnection,
         yield result
 
 
+def _int_list_to_subquery(inp: List[int]) -> 'sa.Subquery':
+    """ Create a subselect that returns the given list of integers
+        as rows in the column 'nr'.
+    """
+    vtab = sa.func.JsonArrayEach(sa.type_coerce(inp, sa.JSON))\
+               .table_valued(sa.column('value', type_=sa.JSON)) # type: ignore[no-untyped-call]
+    return sa.select(sa.cast(sa.cast(vtab.c.value, sa.Text), sa.Integer).label('nr')).subquery()
+
+
 async def _get_osmline(conn: SearchConnection, place_ids: List[int],
                        numerals: List[int],
                        details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
     t = conn.t.osmline
-    values = sa.values(sa.Column('nr', sa.Integer()), name='housenumber')\
-               .data([(n,) for n in numerals])
+
+    values = _int_list_to_subquery(numerals)
     sql = sa.select(t.c.place_id, t.c.osm_id,
                     t.c.parent_place_id, t.c.address,
                     values.c.nr.label('housenumber'),
@@ -199,8 +222,7 @@ async def _get_tiger(conn: SearchConnection, place_ids: List[int],
                      numerals: List[int], osm_id: int,
                      details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
     t = conn.t.tiger
-    values = sa.values(sa.Column('nr', sa.Integer()), name='housenumber')\
-               .data([(n,) for n in numerals])
+    values = _int_list_to_subquery(numerals)
     sql = sa.select(t.c.place_id, t.c.parent_place_id,
                     sa.literal('W').label('osm_type'),
                     sa.literal(osm_id).label('osm_id'),
@@ -295,7 +317,7 @@ class NearSearch(AbstractSearch):
 
         if table is None:
             # No classtype table available, do a simplified lookup in placex.
-            table = conn.t.placex.alias('inner')
+            table = conn.t.placex
             sql = sa.select(table.c.place_id,
                             sa.func.min(tgeom.c.centroid.ST_Distance(table.c.centroid))
                               .label('dist'))\
@@ -366,7 +388,7 @@ class PoiSearch(AbstractSearch):
                            .add_columns((-t.c.centroid.ST_Distance(NEAR_PARAM))
                                          .label('importance'))\
                            .where(t.c.linked_place_id == None) \
-                           .where(t.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM)) \
+                           .where(t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM)) \
                            .order_by(t.c.centroid.ST_Distance(NEAR_PARAM)) \
                            .limit(LIMIT_PARAM)
 
@@ -403,8 +425,8 @@ class PoiSearch(AbstractSearch):
 
                     if details.near and details.near_radius is not None:
                         sql = sql.order_by(table.c.centroid.ST_Distance(NEAR_PARAM))\
-                                 .where(table.c.centroid.ST_DWithin(NEAR_PARAM,
-                                                                    NEAR_RADIUS_PARAM))
+                                 .where(table.c.centroid.within_distance(NEAR_PARAM,
+                                                                         NEAR_RADIUS_PARAM))
 
                     if self.countries:
                         sql = sql.where(t.c.country_code.in_(self.countries.values))
@@ -449,11 +471,7 @@ class CountrySearch(AbstractSearch):
         if details.excluded:
             sql = sql.where(_exclude_places(t))
 
-        if details.viewbox is not None and details.bounded_viewbox:
-            sql = sql.where(lambda: t.c.geometry.intersects(VIEWBOX_PARAM))
-
-        if details.near is not None and details.near_radius is not None:
-            sql = sql.where(_within_near(t))
+        sql = filter_by_area(sql, t, details)
 
         results = nres.SearchResults()
         for row in await conn.execute(sql, _details_to_bind_params(details)):
@@ -486,18 +504,12 @@ class CountrySearch(AbstractSearch):
                 .where(tgrid.c.country_code.in_(self.countries.values))\
                 .group_by(tgrid.c.country_code)
 
-        if details.viewbox is not None and details.bounded_viewbox:
-            sql = sql.where(tgrid.c.geometry.intersects(VIEWBOX_PARAM))
-        if details.near is not None and details.near_radius is not None:
-            sql = sql.where(_within_near(tgrid))
+        sql = filter_by_area(sql, tgrid, details, avoid_index=True)
 
         sub = sql.subquery('grid')
 
         sql = sa.select(t.c.country_code,
-                        (t.c.name
-                         + sa.func.coalesce(t.c.derived_name,
-                                            sa.cast('', type_=conn.t.types.Composite))
-                        ).label('name'),
+                        t.c.name.merge(t.c.derived_name).label('name'),
                         sub.c.centroid, sub.c.bbox)\
                 .join(sub, t.c.country_code == sub.c.country_code)
 
@@ -545,19 +557,16 @@ class PostcodeSearch(AbstractSearch):
 
         penalty: SaExpression = sa.literal(self.penalty)
 
-        if details.viewbox is not None:
-            if details.bounded_viewbox:
-                sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM))
-            else:
-                penalty += sa.case((t.c.geometry.intersects(VIEWBOX_PARAM), 0.0),
-                                   (t.c.geometry.intersects(VIEWBOX2_PARAM), 0.5),
-                                   else_=1.0)
+        if details.viewbox is not None and not details.bounded_viewbox:
+            penalty += sa.case((t.c.geometry.intersects(VIEWBOX_PARAM), 0.0),
+                               (t.c.geometry.intersects(VIEWBOX2_PARAM), 0.5),
+                               else_=1.0)
 
         if details.near is not None:
-            if details.near_radius is not None:
-                sql = sql.where(_within_near(t))
             sql = sql.order_by(t.c.geometry.ST_Distance(NEAR_PARAM))
 
+        sql = filter_by_area(sql, t, details)
+
         if self.countries:
             sql = sql.where(t.c.country_code.in_(self.countries.values))
 
@@ -566,13 +575,11 @@ class PostcodeSearch(AbstractSearch):
 
         if self.lookups:
             assert len(self.lookups) == 1
-            assert self.lookups[0].lookup_type == 'restrict'
             tsearch = conn.t.search_name
             sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\
-                     .where(sa.func.array_cat(tsearch.c.name_vector,
-                                              tsearch.c.nameaddress_vector,
-                                              type_=ARRAY(sa.Integer))
-                                    .contains(self.lookups[0].tokens))
+                     .where((tsearch.c.name_vector + tsearch.c.nameaddress_vector)
+                                     .contains(sa.type_coerce(self.lookups[0].tokens,
+                                                              IntArray)))
 
         for ranking in self.rankings:
             penalty += ranking.sql_penalty(conn.t.search_name)
@@ -637,11 +644,11 @@ class PlaceSearch(AbstractSearch):
             sql = sql.where(tsearch.c.address_rank > 9)
             tpc = conn.t.postcode
             pcs = self.postcodes.values
-            if self.expected_count > 1000:
+            if self.expected_count > 5000:
                 # Many results expected. Restrict by postcode.
                 sql = sql.where(sa.select(tpc.c.postcode)
                                   .where(tpc.c.postcode.in_(pcs))
-                                  .where(tsearch.c.centroid.ST_DWithin(tpc.c.geometry, 0.12))
+                                  .where(tsearch.c.centroid.within_distance(tpc.c.geometry, 0.12))
                                   .exists())
 
             # Less results, only have a preference for close postcodes
@@ -653,27 +660,26 @@ class PlaceSearch(AbstractSearch):
 
         if details.viewbox is not None:
             if details.bounded_viewbox:
-                if details.viewbox.area < 0.2:
-                    sql = sql.where(tsearch.c.centroid.intersects(VIEWBOX_PARAM))
-                else:
-                    sql = sql.where(tsearch.c.centroid.ST_Intersects_no_index(VIEWBOX_PARAM))
+                sql = sql.where(tsearch.c.centroid
+                                         .intersects(VIEWBOX_PARAM,
+                                                     use_index=details.viewbox.area < 0.2))
             elif self.expected_count >= 10000:
-                if details.viewbox.area < 0.5:
-                    sql = sql.where(tsearch.c.centroid.intersects(VIEWBOX2_PARAM))
-                else:
-                    sql = sql.where(tsearch.c.centroid.ST_Intersects_no_index(VIEWBOX2_PARAM))
+                sql = sql.where(tsearch.c.centroid
+                                         .intersects(VIEWBOX2_PARAM,
+                                                     use_index=details.viewbox.area < 0.5))
             else:
-                penalty += sa.case((t.c.geometry.intersects(VIEWBOX_PARAM), 0.0),
-                                   (t.c.geometry.intersects(VIEWBOX2_PARAM), 0.5),
+                penalty += sa.case((t.c.geometry.intersects(VIEWBOX_PARAM, use_index=False), 0.0),
+                                   (t.c.geometry.intersects(VIEWBOX2_PARAM, use_index=False), 0.5),
                                    else_=1.0)
 
         if details.near is not None:
             if details.near_radius is not None:
                 if details.near_radius < 0.1:
-                    sql = sql.where(tsearch.c.centroid.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM))
+                    sql = sql.where(tsearch.c.centroid.within_distance(NEAR_PARAM,
+                                                                       NEAR_RADIUS_PARAM))
                 else:
-                    sql = sql.where(tsearch.c.centroid.ST_DWithin_no_index(NEAR_PARAM,
-                                                                           NEAR_RADIUS_PARAM))
+                    sql = sql.where(tsearch.c.centroid
+                                             .ST_Distance(NEAR_PARAM) <  NEAR_RADIUS_PARAM)
             sql = sql.add_columns((-tsearch.c.centroid.ST_Distance(NEAR_PARAM))
                                       .label('importance'))
             sql = sql.order_by(sa.desc(sa.text('importance')))
@@ -692,10 +698,10 @@ class PlaceSearch(AbstractSearch):
             sql = sql.order_by(sa.text('accuracy'))
 
         if self.housenumbers:
-            hnr_regexp = f"\\m({'|'.join(self.housenumbers.values)})\\M"
+            hnr_list = '|'.join(self.housenumbers.values)
             sql = sql.where(tsearch.c.address_rank.between(16, 30))\
                      .where(sa.or_(tsearch.c.address_rank < 30,
-                                   t.c.housenumber.op('~*')(hnr_regexp)))
+                                   sa.func.RegexpWord(hnr_list, t.c.housenumber)))
 
             # Cross check for housenumbers, need to do that on a rather large
             # set. Worst case there are 40.000 main streets in OSM.
@@ -703,10 +709,10 @@ class PlaceSearch(AbstractSearch):
 
             # Housenumbers from placex
             thnr = conn.t.placex.alias('hnr')
-            pid_list = array_agg(thnr.c.place_id) # type: ignore[no-untyped-call]
+            pid_list = sa.func.ArrayAgg(thnr.c.place_id)
             place_sql = sa.select(pid_list)\
                           .where(thnr.c.parent_place_id == inner.c.place_id)\
-                          .where(thnr.c.housenumber.op('~*')(hnr_regexp))\
+                          .where(sa.func.RegexpWord(hnr_list, thnr.c.housenumber))\
                           .where(thnr.c.linked_place_id == None)\
                           .where(thnr.c.indexed_status == 0)
 
index 06a06f34984e32c033837bf68c362cf34ea272e6..ff1c3feed40069328d4fdc01aec77745356a70bf 100644 (file)
@@ -22,6 +22,7 @@ from nominatim.api.connection import SearchConnection
 from nominatim.api.logging import log
 from nominatim.api.search import query as qmod
 from nominatim.api.search.query_analyzer_factory import AbstractQueryAnalyzer
+from nominatim.db.sqlalchemy_types import Json
 
 
 DB_TO_TOKEN_TYPE = {
@@ -159,7 +160,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
                      sa.Column('word_token', sa.Text, nullable=False),
                      sa.Column('type', sa.Text, nullable=False),
                      sa.Column('word', sa.Text),
-                     sa.Column('info', self.conn.t.types.Json))
+                     sa.Column('info', Json))
 
 
     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
index eb3a3b6145f3c76dc569ebbcdbd6cbbc4fa48b37..433435bc6f37dd1735078f807e599a42d05fdaa7 100644 (file)
@@ -87,6 +87,7 @@ class NominatimArgs:
     offline: bool
     ignore_errors: bool
     index_noanalyse: bool
+    prepare_database: bool
 
     # Arguments to 'index'
     boundaries_only: bool
index 26b3fb1ffedc4dbf470a5e35b0fbf42a8e20c608..7ba77172bdf7853031fb29b8154ccae9870cf827 100644 (file)
@@ -76,7 +76,7 @@ class ConvertDB:
         group.add_argument('--reverse', action=WithAction, dest_set=self.options, default=True,
                            help='Enable/disable support for reverse and lookup API'
                                 ' (default: enabled)')
-        group.add_argument('--search', action=WithAction, dest_set=self.options, default=False,
+        group.add_argument('--search', action=WithAction, dest_set=self.options, default=True,
                            help='Enable/disable support for search API (default: disabled)')
         group.add_argument('--details', action=WithAction, dest_set=self.options, default=True,
                            help='Enable/disable support for details API (default: enabled)')
index ea605ea09e5d970996f676a51a6ebd6f44cd989b..5e1b044e734336bf305f1c69c649bf18fda0fddc 100644 (file)
@@ -128,7 +128,7 @@ class UpdateRefresh:
             LOG.warning('Import secondary importance raster data from %s', args.project_dir)
             if refresh.import_secondary_importance(args.config.get_libpq_dsn(),
                                                 args.project_dir) > 0:
-                LOG.fatal('FATAL: Cannot update sendary importance raster data')
+                LOG.fatal('FATAL: Cannot update secondary importance raster data')
                 return 1
 
         if args.functions:
@@ -141,10 +141,10 @@ class UpdateRefresh:
         if args.wiki_data:
             data_path = Path(args.config.WIKIPEDIA_DATA_PATH
                              or args.project_dir)
-            LOG.warning('Import wikipdia article importance from %s', data_path)
+            LOG.warning('Import wikipedia article importance from %s', data_path)
             if refresh.import_wikipedia_articles(args.config.get_libpq_dsn(),
                                                  data_path) > 0:
-                LOG.fatal('FATAL: Wikipedia importance dump file not found')
+                LOG.fatal('FATAL: Wikipedia importance file not found in %s', data_path)
                 return 1
 
         # Attention: importance MUST come after wiki data import.
index 8464e151f4f1534034c442446a44bc7c27bce22c..3d212ff980994d7b6511bc234215e1d3201cdd5e 100644 (file)
@@ -40,13 +40,15 @@ class SetupAll:
 
     def add_args(self, parser: argparse.ArgumentParser) -> None:
         group_name = parser.add_argument_group('Required arguments')
-        group1 = group_name.add_mutually_exclusive_group(required=True)
+        group1 = group_name.add_argument_group()
         group1.add_argument('--osm-file', metavar='FILE', action='append',
                            help='OSM file to be imported'
-                                ' (repeat for importing multiple files)')
+                                ' (repeat for importing multiple files)',
+                                default=None)
         group1.add_argument('--continue', dest='continue_at',
-                           choices=['load-data', 'indexing', 'db-postprocess'],
-                           help='Continue an import that was interrupted')
+                           choices=['import-from-file', 'load-data', 'indexing', 'db-postprocess'],
+                           help='Continue an import that was interrupted',
+                           default=None)
         group2 = parser.add_argument_group('Optional arguments')
         group2.add_argument('--osm2pgsql-cache', metavar='SIZE', type=int,
                            help='Size of cache to be used by osm2pgsql (in MB)')
@@ -65,9 +67,11 @@ class SetupAll:
                            help='Continue import even when errors in SQL are present')
         group3.add_argument('--index-noanalyse', action='store_true',
                            help='Do not perform analyse operations during index (expert only)')
+        group3.add_argument('--prepare-database', action='store_true',
+                            help='Create the database but do not import any data')
 
 
-    def run(self, args: NominatimArgs) -> int: # pylint: disable=too-many-statements
+    def run(self, args: NominatimArgs) -> int: # pylint: disable=too-many-statements, too-many-branches
         from ..data import country_info
         from ..tools import database_import, refresh, postcodes, freeze
         from ..indexer.indexer import Indexer
@@ -76,43 +80,61 @@ class SetupAll:
 
         country_info.setup_country_config(args.config)
 
-        if args.continue_at is None:
+        if args.osm_file is None and args.continue_at is None and not args.prepare_database:
+            raise UsageError("No input files (use --osm-file).")
+
+        if args.osm_file is not None and args.continue_at not in ('import-from-file', None):
+            raise UsageError(f"Cannot use --continue {args.continue_at} and --osm-file together.")
+
+        if args.continue_at is not None and args.prepare_database:
+            raise UsageError(
+                "Cannot use --continue and --prepare-database together."
+            )
+
+
+        if args.prepare_database or args.continue_at is None:
+            LOG.warning('Creating database')
+            database_import.setup_database_skeleton(args.config.get_libpq_dsn(),
+                                                        rouser=args.config.DATABASE_WEBUSER)
+            if args.prepare_database:
+                return 0
+
+        if args.continue_at in (None, 'import-from-file'):
             files = args.get_osm_file_list()
             if not files:
                 raise UsageError("No input files (use --osm-file).")
 
-            LOG.warning('Creating database')
-            database_import.setup_database_skeleton(args.config.get_libpq_dsn(),
-                                                    rouser=args.config.DATABASE_WEBUSER)
-
-            LOG.warning('Setting up country tables')
-            country_info.setup_country_tables(args.config.get_libpq_dsn(),
-                                              args.config.lib_dir.data,
-                                              args.no_partitions)
-
-            LOG.warning('Importing OSM data file')
-            database_import.import_osm_data(files,
-                                            args.osm2pgsql_options(0, 1),
-                                            drop=args.no_updates,
-                                            ignore_errors=args.ignore_errors)
-
-            LOG.warning('Importing wikipedia importance data')
-            data_path = Path(args.config.WIKIPEDIA_DATA_PATH or args.project_dir)
-            if refresh.import_wikipedia_articles(args.config.get_libpq_dsn(),
-                                                 data_path) > 0:
-                LOG.error('Wikipedia importance dump file not found. '
-                          'Calculating importance values of locations will not '
-                          'use Wikipedia importance data.')
-
-            LOG.warning('Importing secondary importance raster data')
-            if refresh.import_secondary_importance(args.config.get_libpq_dsn(),
-                                                   args.project_dir) != 0:
-                LOG.error('Secondary importance file not imported. '
-                          'Falling back to default ranking.')
-
-            self._setup_tables(args.config, args.reverse_only)
-
-        if args.continue_at is None or args.continue_at == 'load-data':
+            if args.continue_at in ('import-from-file', None):
+                # Check if the correct plugins are installed
+                database_import.check_existing_database_plugins(args.config.get_libpq_dsn())
+                LOG.warning('Setting up country tables')
+                country_info.setup_country_tables(args.config.get_libpq_dsn(),
+                                                args.config.lib_dir.data,
+                                                args.no_partitions)
+
+                LOG.warning('Importing OSM data file')
+                database_import.import_osm_data(files,
+                                                args.osm2pgsql_options(0, 1),
+                                                drop=args.no_updates,
+                                                ignore_errors=args.ignore_errors)
+
+                LOG.warning('Importing wikipedia importance data')
+                data_path = Path(args.config.WIKIPEDIA_DATA_PATH or args.project_dir)
+                if refresh.import_wikipedia_articles(args.config.get_libpq_dsn(),
+                                                    data_path) > 0:
+                    LOG.error('Wikipedia importance dump file not found. '
+                            'Calculating importance values of locations will not '
+                            'use Wikipedia importance data.')
+
+                LOG.warning('Importing secondary importance raster data')
+                if refresh.import_secondary_importance(args.config.get_libpq_dsn(),
+                                                    args.project_dir) != 0:
+                    LOG.error('Secondary importance file not imported. '
+                            'Falling back to default ranking.')
+
+                self._setup_tables(args.config, args.reverse_only)
+
+        if args.continue_at in ('import-from-file', 'load-data', None):
             LOG.warning('Initialise tables')
             with connect(args.config.get_libpq_dsn()) as conn:
                 database_import.truncate_data_tables(conn)
@@ -123,12 +145,13 @@ class SetupAll:
         LOG.warning("Setting up tokenizer")
         tokenizer = self._get_tokenizer(args.continue_at, args.config)
 
-        if args.continue_at is None or args.continue_at == 'load-data':
+        if args.continue_at in ('import-from-file', 'load-data', None):
             LOG.warning('Calculate postcodes')
             postcodes.update_postcodes(args.config.get_libpq_dsn(),
                                        args.project_dir, tokenizer)
 
-        if args.continue_at is None or args.continue_at in ('load-data', 'indexing'):
+        if args.continue_at in \
+            ('import-from-file', 'load-data', 'indexing', None):
             LOG.warning('Indexing places')
             indexer = Indexer(args.config.get_libpq_dsn(), tokenizer, num_threads)
             indexer.index_full(analyse=not args.index_noanalyse)
@@ -185,7 +208,7 @@ class SetupAll:
         """
         from ..tokenizer import factory as tokenizer_factory
 
-        if continue_at is None or continue_at == 'load-data':
+        if continue_at in ('import-from-file', 'load-data', None):
             # (re)initialise the tokenizer data
             return tokenizer_factory.create_tokenizer(config)
 
index fce897bc7250c814d6679188a5c6281a8ac3790c..82801ae7995c9d1e5527baec0d9dd89c85e70e4d 100644 (file)
@@ -174,6 +174,15 @@ class Connection(psycopg2.extensions.connection):
 
         return (int(version_parts[0]), int(version_parts[1]))
 
+
+    def extension_loaded(self, extension_name: str) -> bool:
+        """ Return True if the hstore extension is loaded in the database.
+        """
+        with self.cursor() as cur:
+            cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, ))
+            return cur.rowcount > 0
+
+
 class ConnectionContext(ContextManager[Connection]):
     """ Context manager of the connection that also provides direct access
         to the underlying connection.
index cb04f7626f08b97f2ee602900849e132f65f6272..e2437dd2e34c4ad4b5080558f8b4dee28ceb4cb1 100644 (file)
@@ -29,7 +29,7 @@ class PlacexGeometryReverseLookuppolygon(sa.sql.functions.GenericFunction[Any]):
 
 
 @compiles(PlacexGeometryReverseLookuppolygon) # type: ignore[no-untyped-call, misc]
-def _default_intersects(element: SaColumn,
+def _default_intersects(element: PlacexGeometryReverseLookuppolygon,
                         compiler: 'sa.Compiled', **kw: Any) -> str:
     return ("(ST_GeometryType(placex.geometry) in ('ST_Polygon', 'ST_MultiPolygon')"
             " AND placex.rank_address between 4 and 25"
@@ -40,7 +40,7 @@ def _default_intersects(element: SaColumn,
 
 
 @compiles(PlacexGeometryReverseLookuppolygon, 'sqlite') # type: ignore[no-untyped-call, misc]
-def _sqlite_intersects(element: SaColumn,
+def _sqlite_intersects(element: PlacexGeometryReverseLookuppolygon,
                        compiler: 'sa.Compiled', **kw: Any) -> str:
     return ("(ST_GeometryType(placex.geometry) in ('POLYGON', 'MULTIPOLYGON')"
             " AND placex.rank_address between 4 and 25"
@@ -61,7 +61,7 @@ class IntersectsReverseDistance(sa.sql.functions.GenericFunction[Any]):
 
 
 @compiles(IntersectsReverseDistance) # type: ignore[no-untyped-call, misc]
-def default_reverse_place_diameter(element: SaColumn,
+def default_reverse_place_diameter(element: IntersectsReverseDistance,
                                    compiler: 'sa.Compiled', **kw: Any) -> str:
     table = element.tablename
     return f"({table}.rank_address between 4 and 25"\
@@ -74,7 +74,7 @@ def default_reverse_place_diameter(element: SaColumn,
 
 
 @compiles(IntersectsReverseDistance, 'sqlite') # type: ignore[no-untyped-call, misc]
-def sqlite_reverse_place_diameter(element: SaColumn,
+def sqlite_reverse_place_diameter(element: IntersectsReverseDistance,
                                   compiler: 'sa.Compiled', **kw: Any) -> str:
     geom1, rank, geom2 = list(element.clauses)
     table = element.tablename
@@ -102,7 +102,7 @@ class IsBelowReverseDistance(sa.sql.functions.GenericFunction[Any]):
 
 
 @compiles(IsBelowReverseDistance) # type: ignore[no-untyped-call, misc]
-def default_is_below_reverse_distance(element: SaColumn,
+def default_is_below_reverse_distance(element: IsBelowReverseDistance,
                                       compiler: 'sa.Compiled', **kw: Any) -> str:
     dist, rank = list(element.clauses)
     return "%s < reverse_place_diameter(%s)" % (compiler.process(dist, **kw),
@@ -110,25 +110,13 @@ def default_is_below_reverse_distance(element: SaColumn,
 
 
 @compiles(IsBelowReverseDistance, 'sqlite') # type: ignore[no-untyped-call, misc]
-def sqlite_is_below_reverse_distance(element: SaColumn,
+def sqlite_is_below_reverse_distance(element: IsBelowReverseDistance,
                                      compiler: 'sa.Compiled', **kw: Any) -> str:
     dist, rank = list(element.clauses)
     return "%s < 14.0 * exp(-0.2 * %s) - 0.03" % (compiler.process(dist, **kw),
                                                   compiler.process(rank, **kw))
 
 
-def select_index_placex_geometry_reverse_lookupplacenode(table: str) -> 'sa.TextClause':
-    """ Create an expression with the necessary conditions over a placex
-        table that the index 'idx_placex_geometry_reverse_lookupPlaceNode'
-        can be used.
-    """
-    return sa.text(f"{table}.rank_address between 4 and 25"
-                   f" AND {table}.type != 'postcode'"
-                   f" AND {table}.name is not null"
-                   f" AND {table}.linked_place_id is null"
-                   f" AND {table}.osm_type = 'N'")
-
-
 class IsAddressPoint(sa.sql.functions.GenericFunction[Any]):
     name = 'IsAddressPoint'
     inherit_cache = True
@@ -139,7 +127,7 @@ class IsAddressPoint(sa.sql.functions.GenericFunction[Any]):
 
 
 @compiles(IsAddressPoint) # type: ignore[no-untyped-call, misc]
-def default_is_address_point(element: SaColumn,
+def default_is_address_point(element: IsAddressPoint,
                              compiler: 'sa.Compiled', **kw: Any) -> str:
     rank, hnr, name = list(element.clauses)
     return "(%s = 30 AND (%s IS NOT NULL OR %s ? 'addr:housename'))" % (
@@ -149,7 +137,7 @@ def default_is_address_point(element: SaColumn,
 
 
 @compiles(IsAddressPoint, 'sqlite') # type: ignore[no-untyped-call, misc]
-def sqlite_is_address_point(element: SaColumn,
+def sqlite_is_address_point(element: IsAddressPoint,
                             compiler: 'sa.Compiled', **kw: Any) -> str:
     rank, hnr, name = list(element.clauses)
     return "(%s = 30 AND coalesce(%s, json_extract(%s, '$.addr:housename')) IS NOT NULL)" % (
@@ -166,7 +154,7 @@ class CrosscheckNames(sa.sql.functions.GenericFunction[Any]):
     inherit_cache = True
 
 @compiles(CrosscheckNames) # type: ignore[no-untyped-call, misc]
-def compile_crosscheck_names(element: SaColumn,
+def compile_crosscheck_names(element: CrosscheckNames,
                              compiler: 'sa.Compiled', **kw: Any) -> str:
     arg1, arg2 = list(element.clauses)
     return "coalesce(avals(%s) && ARRAY(SELECT * FROM json_array_elements_text(%s)), false)" % (
@@ -174,7 +162,7 @@ def compile_crosscheck_names(element: SaColumn,
 
 
 @compiles(CrosscheckNames, 'sqlite') # type: ignore[no-untyped-call, misc]
-def compile_sqlite_crosscheck_names(element: SaColumn,
+def compile_sqlite_crosscheck_names(element: CrosscheckNames,
                                     compiler: 'sa.Compiled', **kw: Any) -> str:
     arg1, arg2 = list(element.clauses)
     return "EXISTS(SELECT *"\
@@ -191,15 +179,16 @@ class JsonArrayEach(sa.sql.functions.GenericFunction[Any]):
 
 
 @compiles(JsonArrayEach) # type: ignore[no-untyped-call, misc]
-def default_json_array_each(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str:
+def default_json_array_each(element: JsonArrayEach, compiler: 'sa.Compiled', **kw: Any) -> str:
     return "json_array_elements(%s)" % compiler.process(element.clauses, **kw)
 
 
 @compiles(JsonArrayEach, 'sqlite') # type: ignore[no-untyped-call, misc]
-def sqlite_json_array_each(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str:
+def sqlite_json_array_each(element: JsonArrayEach, compiler: 'sa.Compiled', **kw: Any) -> str:
     return "json_each(%s)" % compiler.process(element.clauses, **kw)
 
 
+
 class Greatest(sa.sql.functions.GenericFunction[Any]):
     """ Function to compute maximum of all its input parameters.
     """
@@ -208,5 +197,25 @@ class Greatest(sa.sql.functions.GenericFunction[Any]):
 
 
 @compiles(Greatest, 'sqlite') # type: ignore[no-untyped-call, misc]
-def sqlite_greatest(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str:
+def sqlite_greatest(element: Greatest, compiler: 'sa.Compiled', **kw: Any) -> str:
     return "max(%s)" % compiler.process(element.clauses, **kw)
+
+
+
+class RegexpWord(sa.sql.functions.GenericFunction[Any]):
+    """ Check if a full word is in a given string.
+    """
+    name = 'RegexpWord'
+    inherit_cache = True
+
+
+@compiles(RegexpWord, 'postgresql') # type: ignore[no-untyped-call, misc]
+def postgres_regexp_nocase(element: RegexpWord, compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "%s ~* ('\\m(' || %s  || ')\\M')::text" % (compiler.process(arg2, **kw), compiler.process(arg1, **kw))
+
+
+@compiles(RegexpWord, 'sqlite') # type: ignore[no-untyped-call, misc]
+def sqlite_regexp_nocase(element: RegexpWord, compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "regexp('\\b(' || %s  || ')\\b', %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
index 7dd1e0ce0b046182b6224eab7b5ec16769719b96..0ec22b7e1fa322469a2ea75d38642c3b75f02aa8 100644 (file)
@@ -7,37 +7,10 @@
 """
 SQLAlchemy definitions for all tables used by the frontend.
 """
-from typing import Any
-
 import sqlalchemy as sa
-from sqlalchemy.dialects.postgresql import HSTORE, ARRAY, JSONB, array
-from sqlalchemy.dialects.sqlite import JSON as sqlite_json
 
 import nominatim.db.sqlalchemy_functions #pylint: disable=unused-import
-from nominatim.db.sqlalchemy_types import Geometry
-
-class PostgresTypes:
-    """ Type definitions for complex types as used in Postgres variants.
-    """
-    Composite = HSTORE
-    Json = JSONB
-    IntArray = ARRAY(sa.Integer()) #pylint: disable=invalid-name
-    to_array = array
-
-
-class SqliteTypes:
-    """ Type definitions for complex types as used in Postgres variants.
-    """
-    Composite = sqlite_json
-    Json = sqlite_json
-    IntArray = sqlite_json
-
-    @staticmethod
-    def to_array(arr: Any) -> Any:
-        """ Sqlite has no special conversion for arrays.
-        """
-        return arr
-
+from nominatim.db.sqlalchemy_types import Geometry, KeyValueStore, IntArray
 
 #pylint: disable=too-many-instance-attributes
 class SearchTables:
@@ -47,14 +20,7 @@ class SearchTables:
         Any data used for updates only will not be visible.
     """
 
-    def __init__(self, meta: sa.MetaData, engine_name: str) -> None:
-        if engine_name == 'postgresql':
-            self.types: Any = PostgresTypes
-        elif engine_name == 'sqlite':
-            self.types = SqliteTypes
-        else:
-            raise ValueError("Only 'postgresql' and 'sqlite' engines are supported.")
-
+    def __init__(self, meta: sa.MetaData) -> None:
         self.meta = meta
 
         self.import_status = sa.Table('import_status', meta,
@@ -80,9 +46,9 @@ class SearchTables:
             sa.Column('class', sa.Text, nullable=False, key='class_'),
             sa.Column('type', sa.Text, nullable=False),
             sa.Column('admin_level', sa.SmallInteger),
-            sa.Column('name', self.types.Composite),
-            sa.Column('address', self.types.Composite),
-            sa.Column('extratags', self.types.Composite),
+            sa.Column('name', KeyValueStore),
+            sa.Column('address', KeyValueStore),
+            sa.Column('extratags', KeyValueStore),
             sa.Column('geometry', Geometry, nullable=False),
             sa.Column('wikipedia', sa.Text),
             sa.Column('country_code', sa.String(2)),
@@ -118,14 +84,14 @@ class SearchTables:
             sa.Column('step', sa.SmallInteger),
             sa.Column('indexed_status', sa.SmallInteger),
             sa.Column('linegeo', Geometry),
-            sa.Column('address', self.types.Composite),
+            sa.Column('address', KeyValueStore),
             sa.Column('postcode', sa.Text),
             sa.Column('country_code', sa.String(2)))
 
         self.country_name = sa.Table('country_name', meta,
             sa.Column('country_code', sa.String(2)),
-            sa.Column('name', self.types.Composite),
-            sa.Column('derived_name', self.types.Composite),
+            sa.Column('name', KeyValueStore),
+            sa.Column('derived_name', KeyValueStore),
             sa.Column('partition', sa.Integer))
 
         self.country_grid = sa.Table('country_osm_grid', meta,
@@ -139,8 +105,8 @@ class SearchTables:
             sa.Column('importance', sa.Float),
             sa.Column('search_rank', sa.SmallInteger),
             sa.Column('address_rank', sa.SmallInteger),
-            sa.Column('name_vector', self.types.IntArray),
-            sa.Column('nameaddress_vector', self.types.IntArray),
+            sa.Column('name_vector', IntArray),
+            sa.Column('nameaddress_vector', IntArray),
             sa.Column('country_code', sa.String(2)),
             sa.Column('centroid', Geometry))
 
diff --git a/nominatim/db/sqlalchemy_types/__init__.py b/nominatim/db/sqlalchemy_types/__init__.py
new file mode 100644 (file)
index 0000000..dc41799
--- /dev/null
@@ -0,0 +1,17 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Module with custom types for SQLAlchemy
+"""
+
+# See also https://github.com/PyCQA/pylint/issues/6006
+# pylint: disable=useless-import-alias
+
+from .geometry import (Geometry as Geometry)
+from .int_array import (IntArray as IntArray)
+from .key_value import (KeyValueStore as KeyValueStore)
+from .json import (Json as Json)
similarity index 85%
rename from nominatim/db/sqlalchemy_types.py
rename to nominatim/db/sqlalchemy_types/geometry.py
index a36e8c462acfce3b4cc5e730b2eb5c008f1dfa14..0731b0b796337869c495b10cb7a112df4a4dcc8c 100644 (file)
@@ -28,7 +28,7 @@ class Geometry_DistanceSpheroid(sa.sql.expression.FunctionElement[float]):
 
 
 @compiles(Geometry_DistanceSpheroid) # type: ignore[no-untyped-call, misc]
-def _default_distance_spheroid(element: SaColumn,
+def _default_distance_spheroid(element: Geometry_DistanceSpheroid,
                                compiler: 'sa.Compiled', **kw: Any) -> str:
     return "ST_DistanceSpheroid(%s,"\
            " 'SPHEROID[\"WGS 84\",6378137,298.257223563, AUTHORITY[\"EPSG\",\"7030\"]]')"\
@@ -36,7 +36,7 @@ def _default_distance_spheroid(element: SaColumn,
 
 
 @compiles(Geometry_DistanceSpheroid, 'sqlite') # type: ignore[no-untyped-call, misc]
-def _spatialite_distance_spheroid(element: SaColumn,
+def _spatialite_distance_spheroid(element: Geometry_DistanceSpheroid,
                                   compiler: 'sa.Compiled', **kw: Any) -> str:
     return "COALESCE(Distance(%s, true), 0.0)" % compiler.process(element.clauses, **kw)
 
@@ -49,14 +49,14 @@ class Geometry_IsLineLike(sa.sql.expression.FunctionElement[Any]):
 
 
 @compiles(Geometry_IsLineLike) # type: ignore[no-untyped-call, misc]
-def _default_is_line_like(element: SaColumn,
+def _default_is_line_like(element: Geometry_IsLineLike,
                           compiler: 'sa.Compiled', **kw: Any) -> str:
     return "ST_GeometryType(%s) IN ('ST_LineString', 'ST_MultiLineString')" % \
                compiler.process(element.clauses, **kw)
 
 
 @compiles(Geometry_IsLineLike, 'sqlite') # type: ignore[no-untyped-call, misc]
-def _sqlite_is_line_like(element: SaColumn,
+def _sqlite_is_line_like(element: Geometry_IsLineLike,
                          compiler: 'sa.Compiled', **kw: Any) -> str:
     return "ST_GeometryType(%s) IN ('LINESTRING', 'MULTILINESTRING')" % \
                compiler.process(element.clauses, **kw)
@@ -70,14 +70,14 @@ class Geometry_IsAreaLike(sa.sql.expression.FunctionElement[Any]):
 
 
 @compiles(Geometry_IsAreaLike) # type: ignore[no-untyped-call, misc]
-def _default_is_area_like(element: SaColumn,
+def _default_is_area_like(element: Geometry_IsAreaLike,
                           compiler: 'sa.Compiled', **kw: Any) -> str:
     return "ST_GeometryType(%s) IN ('ST_Polygon', 'ST_MultiPolygon')" % \
                compiler.process(element.clauses, **kw)
 
 
 @compiles(Geometry_IsAreaLike, 'sqlite') # type: ignore[no-untyped-call, misc]
-def _sqlite_is_area_like(element: SaColumn,
+def _sqlite_is_area_like(element: Geometry_IsAreaLike,
                          compiler: 'sa.Compiled', **kw: Any) -> str:
     return "ST_GeometryType(%s) IN ('POLYGON', 'MULTIPOLYGON')" % \
                compiler.process(element.clauses, **kw)
@@ -91,14 +91,14 @@ class Geometry_IntersectsBbox(sa.sql.expression.FunctionElement[Any]):
 
 
 @compiles(Geometry_IntersectsBbox) # type: ignore[no-untyped-call, misc]
-def _default_intersects(element: SaColumn,
+def _default_intersects(element: Geometry_IntersectsBbox,
                         compiler: 'sa.Compiled', **kw: Any) -> str:
     arg1, arg2 = list(element.clauses)
     return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
 
 
 @compiles(Geometry_IntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
-def _sqlite_intersects(element: SaColumn,
+def _sqlite_intersects(element: Geometry_IntersectsBbox,
                        compiler: 'sa.Compiled', **kw: Any) -> str:
     return "MbrIntersects(%s) = 1" % compiler.process(element.clauses, **kw)
 
@@ -114,14 +114,14 @@ class Geometry_ColumnIntersectsBbox(sa.sql.expression.FunctionElement[Any]):
 
 
 @compiles(Geometry_ColumnIntersectsBbox) # type: ignore[no-untyped-call, misc]
-def default_intersects_column(element: SaColumn,
+def default_intersects_column(element: Geometry_ColumnIntersectsBbox,
                               compiler: 'sa.Compiled', **kw: Any) -> str:
     arg1, arg2 = list(element.clauses)
     return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
 
 
 @compiles(Geometry_ColumnIntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
-def spatialite_intersects_column(element: SaColumn,
+def spatialite_intersects_column(element: Geometry_ColumnIntersectsBbox,
                                  compiler: 'sa.Compiled', **kw: Any) -> str:
     arg1, arg2 = list(element.clauses)
     return "MbrIntersects(%s, %s) = 1 and "\
@@ -145,12 +145,12 @@ class Geometry_ColumnDWithin(sa.sql.expression.FunctionElement[Any]):
 
 
 @compiles(Geometry_ColumnDWithin) # type: ignore[no-untyped-call, misc]
-def default_dwithin_column(element: SaColumn,
+def default_dwithin_column(element: Geometry_ColumnDWithin,
                            compiler: 'sa.Compiled', **kw: Any) -> str:
     return "ST_DWithin(%s)" % compiler.process(element.clauses, **kw)
 
 @compiles(Geometry_ColumnDWithin, 'sqlite') # type: ignore[no-untyped-call, misc]
-def spatialite_dwithin_column(element: SaColumn,
+def spatialite_dwithin_column(element: Geometry_ColumnDWithin,
                               compiler: 'sa.Compiled', **kw: Any) -> str:
     geom1, geom2, dist = list(element.clauses)
     return "ST_Distance(%s, %s) < %s and "\
@@ -165,7 +165,6 @@ def spatialite_dwithin_column(element: SaColumn,
               compiler.process(dist, **kw))
 
 
-
 class Geometry(types.UserDefinedType): # type: ignore[type-arg]
     """ Simplified type decorator for PostGIS geometry. This type
         only supports geometries in 4326 projection.
@@ -206,7 +205,10 @@ class Geometry(types.UserDefinedType): # type: ignore[type-arg]
 
     class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
 
-        def intersects(self, other: SaColumn) -> 'sa.Operators':
+        def intersects(self, other: SaColumn, use_index: bool = True) -> 'sa.Operators':
+            if not use_index:
+                return Geometry_IntersectsBbox(sa.func.coalesce(sa.null(), self.expr), other)
+
             if isinstance(self.expr, sa.Column):
                 return Geometry_ColumnIntersectsBbox(self.expr, other)
 
@@ -221,20 +223,11 @@ class Geometry(types.UserDefinedType): # type: ignore[type-arg]
             return Geometry_IsAreaLike(self)
 
 
-        def ST_DWithin(self, other: SaColumn, distance: SaColumn) -> SaColumn:
+        def within_distance(self, other: SaColumn, distance: SaColumn) -> SaColumn:
             if isinstance(self.expr, sa.Column):
                 return Geometry_ColumnDWithin(self.expr, other, distance)
 
-            return sa.func.ST_DWithin(self.expr, other, distance)
-
-
-        def ST_DWithin_no_index(self, other: SaColumn, distance: SaColumn) -> SaColumn:
-            return sa.func.ST_DWithin(sa.func.coalesce(sa.null(), self),
-                                      other, distance)
-
-
-        def ST_Intersects_no_index(self, other: SaColumn) -> 'sa.Operators':
-            return Geometry_IntersectsBbox(sa.func.coalesce(sa.null(), self), other)
+            return self.ST_Distance(other) < distance
 
 
         def ST_Distance(self, other: SaColumn) -> SaColumn:
@@ -313,18 +306,3 @@ def _add_function_alias(func: str, ftype: type, alias: str) -> None:
 
 for alias in SQLITE_FUNCTION_ALIAS:
     _add_function_alias(*alias)
-
-
-class ST_DWithin(sa.sql.functions.GenericFunction[Any]):
-    name = 'ST_DWithin'
-    inherit_cache = True
-
-
-@compiles(ST_DWithin, 'sqlite') # type: ignore[no-untyped-call, misc]
-def default_json_array_each(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str:
-    geom1, geom2, dist = list(element.clauses)
-    return "(MbrIntersects(%s, ST_Expand(%s, %s)) = 1 AND ST_Distance(%s, %s) <= %s)" % (
-        compiler.process(geom1, **kw), compiler.process(geom2, **kw),
-        compiler.process(dist, **kw),
-        compiler.process(geom1, **kw), compiler.process(geom2, **kw),
-        compiler.process(dist, **kw))
diff --git a/nominatim/db/sqlalchemy_types/int_array.py b/nominatim/db/sqlalchemy_types/int_array.py
new file mode 100644 (file)
index 0000000..a31793f
--- /dev/null
@@ -0,0 +1,123 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Custom type for an array of integers.
+"""
+from typing import Any, List, cast, Optional
+
+import sqlalchemy as sa
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy.dialects.postgresql import ARRAY
+
+from nominatim.typing import SaDialect, SaColumn
+
+# pylint: disable=all
+
+class IntList(sa.types.TypeDecorator[Any]):
+    """ A list of integers saved as a text of comma-separated numbers.
+    """
+    impl = sa.types.Unicode
+    cache_ok = True
+
+    def process_bind_param(self, value: Optional[Any], dialect: 'sa.Dialect') -> Optional[str]:
+        if value is None:
+            return None
+
+        assert isinstance(value, list)
+        return ','.join(map(str, value))
+
+    def process_result_value(self, value: Optional[Any],
+                             dialect: SaDialect) -> Optional[List[int]]:
+        return [int(v) for v in value.split(',')] if value is not None else None
+
+    def copy(self, **kw: Any) -> 'IntList':
+        return IntList(self.impl.length)
+
+
+class IntArray(sa.types.TypeDecorator[Any]):
+    """ Dialect-independent list of integers.
+    """
+    impl = IntList
+    cache_ok = True
+
+    def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
+        if dialect.name == 'postgresql':
+            return ARRAY(sa.Integer()) #pylint: disable=invalid-name
+
+        return IntList()
+
+
+    class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
+
+        def __add__(self, other: SaColumn) -> 'sa.ColumnOperators':
+            """ Concate the array with the given array. If one of the
+                operants is null, the value of the other will be returned.
+            """
+            return ArrayCat(self.expr, other)
+
+
+        def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators':
+            """ Return true if the array contains all the value of the argument
+                array.
+            """
+            return ArrayContains(self.expr, other)
+
+
+
+class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
+    """ Aggregate function to collect elements in an array.
+    """
+    type = IntArray()
+    identifier = 'ArrayAgg'
+    name = 'array_agg'
+    inherit_cache = True
+
+
+@compiles(ArrayAgg, 'sqlite') # type: ignore[no-untyped-call, misc]
+def sqlite_array_agg(element: ArrayAgg, compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw)
+
+
+
+class ArrayContains(sa.sql.expression.FunctionElement[Any]):
+    """ Function to check if an array is fully contained in another.
+    """
+    name = 'ArrayContains'
+    inherit_cache = True
+
+
+@compiles(ArrayContains) # type: ignore[no-untyped-call, misc]
+def generic_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "(%s @> %s)" % (compiler.process(arg1, **kw),
+                           compiler.process(arg2, **kw))
+
+
+@compiles(ArrayContains, 'sqlite') # type: ignore[no-untyped-call, misc]
+def sqlite_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "array_contains(%s)" % compiler.process(element.clauses, **kw)
+
+
+
+class ArrayCat(sa.sql.expression.FunctionElement[Any]):
+    """ Function to check if an array is fully contained in another.
+    """
+    type = IntArray()
+    identifier = 'ArrayCat'
+    inherit_cache = True
+
+
+@compiles(ArrayCat) # type: ignore[no-untyped-call, misc]
+def generic_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
+    return "array_cat(%s)" % compiler.process(element.clauses, **kw)
+
+
+@compiles(ArrayCat, 'sqlite') # type: ignore[no-untyped-call, misc]
+def sqlite_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "(%s || ',' || %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
+
diff --git a/nominatim/db/sqlalchemy_types/json.py b/nominatim/db/sqlalchemy_types/json.py
new file mode 100644 (file)
index 0000000..31635fd
--- /dev/null
@@ -0,0 +1,30 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Common json type for different dialects.
+"""
+from typing import Any
+
+import sqlalchemy as sa
+from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.dialects.sqlite import JSON as sqlite_json
+
+from nominatim.typing import SaDialect
+
+# pylint: disable=all
+
+class Json(sa.types.TypeDecorator[Any]):
+    """ Dialect-independent type for JSON.
+    """
+    impl = sa.types.JSON
+    cache_ok = True
+
+    def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
+        if dialect.name == 'postgresql':
+            return JSONB(none_as_null=True) # type: ignore[no-untyped-call]
+
+        return sqlite_json(none_as_null=True)
diff --git a/nominatim/db/sqlalchemy_types/key_value.py b/nominatim/db/sqlalchemy_types/key_value.py
new file mode 100644 (file)
index 0000000..937caa0
--- /dev/null
@@ -0,0 +1,62 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+A custom type that implements a simple key-value store of strings.
+"""
+from typing import Any
+
+import sqlalchemy as sa
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy.dialects.postgresql import HSTORE
+from sqlalchemy.dialects.sqlite import JSON as sqlite_json
+
+from nominatim.typing import SaDialect, SaColumn
+
+# pylint: disable=all
+
+class KeyValueStore(sa.types.TypeDecorator[Any]):
+    """ Dialect-independent type of a simple key-value store of strings.
+    """
+    impl = HSTORE
+    cache_ok = True
+
+    def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
+        if dialect.name == 'postgresql':
+            return HSTORE() # type: ignore[no-untyped-call]
+
+        return sqlite_json(none_as_null=True)
+
+
+    class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
+
+        def merge(self, other: SaColumn) -> 'sa.Operators':
+            """ Merge the values from the given KeyValueStore into this
+                one, overwriting values where necessary. When the argument
+                is null, nothing happens.
+            """
+            return KeyValueConcat(self.expr, other)
+
+
+class KeyValueConcat(sa.sql.expression.FunctionElement[Any]):
+    """ Return the merged key-value store from the input parameters.
+    """
+    type = KeyValueStore()
+    name = 'JsonConcat'
+    inherit_cache = True
+
+@compiles(KeyValueConcat) # type: ignore[no-untyped-call, misc]
+def default_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "(%s || coalesce(%s, ''::hstore))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
+
+@compiles(KeyValueConcat, 'sqlite') # type: ignore[no-untyped-call, misc]
+def sqlite_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
+    arg1, arg2 = list(element.clauses)
+    return "json_patch(%s, coalesce(%s, '{}'))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
+
+
+
diff --git a/nominatim/db/sqlite_functions.py b/nominatim/db/sqlite_functions.py
new file mode 100644 (file)
index 0000000..2134ae4
--- /dev/null
@@ -0,0 +1,122 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2023 by the Nominatim developer community.
+# For a full list of authors see the git log.
+"""
+Custom functions for SQLite.
+"""
+from typing import cast, Optional, Set, Any
+import json
+
+# pylint: disable=protected-access
+
+def weigh_search(search_vector: Optional[str], rankings: str, default: float) -> float:
+    """ Custom weight function for search results.
+    """
+    if search_vector is not None:
+        svec = [int(x) for x in search_vector.split(',')]
+        for rank in json.loads(rankings):
+            if all(r in svec for r in rank[1]):
+                return cast(float, rank[0])
+
+    return default
+
+
+class ArrayIntersectFuzzy:
+    """ Compute the array of common elements of all input integer arrays.
+        Very large input paramenters may be ignored to speed up
+        computation. Therefore, the result is a superset of common elements.
+
+        Input and output arrays are given as comma-separated lists.
+    """
+    def __init__(self) -> None:
+        self.first = ''
+        self.values: Optional[Set[int]] = None
+
+    def step(self, value: Optional[str]) -> None:
+        """ Add the next array to the intersection.
+        """
+        if value is not None:
+            if not self.first:
+                self.first = value
+            elif len(value) < 10000000:
+                if self.values is None:
+                    self.values = {int(x) for x in self.first.split(',')}
+                self.values.intersection_update((int(x) for x in value.split(',')))
+
+    def finalize(self) -> str:
+        """ Return the final result.
+        """
+        if self.values is not None:
+            return ','.join(map(str, self.values))
+
+        return self.first
+
+
+class ArrayUnion:
+    """ Compute the set of all elements of the input integer arrays.
+
+        Input and output arrays are given as strings of comma-separated lists.
+    """
+    def __init__(self) -> None:
+        self.values: Optional[Set[str]] = None
+
+    def step(self, value: Optional[str]) -> None:
+        """ Add the next array to the union.
+        """
+        if value is not None:
+            if self.values is None:
+                self.values = set(value.split(','))
+            else:
+                self.values.update(value.split(','))
+
+    def finalize(self) -> str:
+        """ Return the final result.
+        """
+        return '' if self.values is None else ','.join(self.values)
+
+
+def array_contains(container: Optional[str], containee: Optional[str]) -> Optional[bool]:
+    """ Is the array 'containee' completely contained in array 'container'.
+    """
+    if container is None or containee is None:
+        return None
+
+    vset = container.split(',')
+    return all(v in vset for v in containee.split(','))
+
+
+def array_pair_contains(container1: Optional[str], container2: Optional[str],
+                        containee: Optional[str]) -> Optional[bool]:
+    """ Is the array 'containee' completely contained in the union of
+        array 'container1' and array 'container2'.
+    """
+    if container1 is None or container2 is None or containee is None:
+        return None
+
+    vset = container1.split(',') + container2.split(',')
+    return all(v in vset for v in containee.split(','))
+
+
+def install_custom_functions(conn: Any) -> None:
+    """ Install helper functions for Nominatim into the given SQLite
+        database connection.
+    """
+    conn.create_function('weigh_search', 3, weigh_search, deterministic=True)
+    conn.create_function('array_contains', 2, array_contains, deterministic=True)
+    conn.create_function('array_pair_contains', 3, array_pair_contains, deterministic=True)
+    _create_aggregate(conn, 'array_intersect_fuzzy', 1, ArrayIntersectFuzzy)
+    _create_aggregate(conn, 'array_union', 1, ArrayUnion)
+
+
+async def _make_aggregate(aioconn: Any, *args: Any) -> None:
+    await aioconn._execute(aioconn._conn.create_aggregate, *args)
+
+
+def _create_aggregate(conn: Any, name: str, nargs: int, aggregate: Any) -> None:
+    try:
+        conn.await_(_make_aggregate(conn._connection, name, nargs, aggregate))
+    except Exception as error: # pylint: disable=broad-exception-caught
+        conn._handle_exception(error)
index 0702e5d8c045185cbcfea0a0fd18e328de2cd8b2..3e5847107efbd5c10016e03316560581feff165b 100644 (file)
@@ -14,7 +14,8 @@ from pathlib import Path
 import sqlalchemy as sa
 
 from nominatim.typing import SaSelect
-from nominatim.db.sqlalchemy_types import Geometry
+from nominatim.db.sqlalchemy_types import Geometry, IntArray
+from nominatim.api.search.query_analyzer_factory import make_query_analyzer
 import nominatim.api as napi
 
 LOG = logging.getLogger()
@@ -27,11 +28,15 @@ async def convert(project_dir: Path, outfile: Path, options: Set[str]) -> None:
 
     try:
         outapi = napi.NominatimAPIAsync(project_dir,
-                                        {'NOMINATIM_DATABASE_DSN': f"sqlite:dbname={outfile}"})
+                                        {'NOMINATIM_DATABASE_DSN': f"sqlite:dbname={outfile}",
+                                         'NOMINATIM_DATABASE_RW': '1'})
 
-        async with api.begin() as src, outapi.begin() as dest:
-            writer = SqliteWriter(src, dest, options)
-            await writer.write()
+        try:
+            async with api.begin() as src, outapi.begin() as dest:
+                writer = SqliteWriter(src, dest, options)
+                await writer.write()
+        finally:
+            await outapi.close()
     finally:
         await api.close()
 
@@ -51,18 +56,24 @@ class SqliteWriter:
         """ Create the database structure and copy the data from
             the source database to the destination.
         """
+        LOG.warning('Setting up spatialite')
         await self.dest.execute(sa.select(sa.func.InitSpatialMetaData(True, 'WGS84')))
 
         await self.create_tables()
         await self.copy_data()
+        if 'search' in self.options:
+            await self.create_word_table()
         await self.create_indexes()
 
 
     async def create_tables(self) -> None:
         """ Set up the database tables.
         """
+        LOG.warning('Setting up tables')
         if 'search' not in self.options:
             self.dest.t.meta.remove(self.dest.t.search_name)
+        else:
+            await self.create_class_tables()
 
         await self.dest.connection.run_sync(self.dest.t.meta.create_all)
 
@@ -75,6 +86,41 @@ class SqliteWriter:
                                                       col.type.subtype.upper(), 'XY')))
 
 
+    async def create_class_tables(self) -> None:
+        """ Set up the table that serve class/type-specific geometries.
+        """
+        sql = sa.text("""SELECT tablename FROM pg_tables
+                         WHERE tablename LIKE 'place_classtype_%'""")
+        for res in await self.src.execute(sql):
+            for db in (self.src, self.dest):
+                sa.Table(res[0], db.t.meta,
+                         sa.Column('place_id', sa.BigInteger),
+                         sa.Column('centroid', Geometry))
+
+
+    async def create_word_table(self) -> None:
+        """ Create the word table.
+            This table needs the property information to determine the
+            correct format. Therefore needs to be done after all other
+            data has been copied.
+        """
+        await make_query_analyzer(self.src)
+        await make_query_analyzer(self.dest)
+        src = self.src.t.meta.tables['word']
+        dest = self.dest.t.meta.tables['word']
+
+        await self.dest.connection.run_sync(dest.create)
+
+        LOG.warning("Copying word table")
+        async_result = await self.src.connection.stream(sa.select(src))
+
+        async for partition in async_result.partitions(10000):
+            data = [{k: getattr(r, k) for k in r._fields} for r in partition]
+            await self.dest.execute(dest.insert(), data)
+
+        await self.dest.connection.run_sync(sa.Index('idx_word_woken', dest.c.word_token).create)
+
+
     async def copy_data(self) -> None:
         """ Copy data for all registered tables.
         """
@@ -87,6 +133,14 @@ class SqliteWriter:
                         for r in partition]
                 await self.dest.execute(table.insert(), data)
 
+        # Set up a minimal copy of pg_tables used to look up the class tables later.
+        pg_tables = sa.Table('pg_tables', self.dest.t.meta,
+                             sa.Column('schemaname', sa.Text, default='public'),
+                             sa.Column('tablename', sa.Text))
+        await self.dest.connection.run_sync(pg_tables.create)
+        data = [{'tablename': t} for t in self.dest.t.meta.tables]
+        await self.dest.execute(pg_tables.insert().values(data))
+
 
     async def create_indexes(self) -> None:
         """ Add indexes necessary for the frontend.
@@ -116,6 +170,22 @@ class SqliteWriter:
         await self.create_index('placex', 'parent_place_id')
         await self.create_index('placex', 'rank_address')
         await self.create_index('addressline', 'place_id')
+        await self.create_index('postcode', 'place_id')
+        await self.create_index('osmline', 'place_id')
+        await self.create_index('tiger', 'place_id')
+
+        if 'search' in self.options:
+            await self.create_spatial_index('postcode', 'geometry')
+            await self.create_spatial_index('search_name', 'centroid')
+            await self.create_index('search_name', 'place_id')
+            await self.create_index('osmline', 'parent_place_id')
+            await self.create_index('tiger', 'parent_place_id')
+            await self.create_search_index()
+
+            for t in self.dest.t.meta.tables:
+                if t.startswith('place_classtype_'):
+                    await self.dest.execute(sa.select(
+                      sa.func.CreateSpatialIndex(t, 'centroid')))
 
 
     async def create_spatial_index(self, table: str, column: str) -> None:
@@ -133,6 +203,36 @@ class SqliteWriter:
             sa.Index(f"idx_{table}_{column}", getattr(table.c, column)).create)
 
 
+    async def create_search_index(self) -> None:
+        """ Create the tables and indexes needed for word lookup.
+        """
+        LOG.warning("Creating reverse search table")
+        rsn = sa.Table('reverse_search_name', self.dest.t.meta,
+                       sa.Column('word', sa.Integer()),
+                       sa.Column('column', sa.Text()),
+                       sa.Column('places', IntArray))
+        await self.dest.connection.run_sync(rsn.create)
+
+        tsrc = self.src.t.search_name
+        for column in ('name_vector', 'nameaddress_vector'):
+            sql = sa.select(sa.func.unnest(getattr(tsrc.c, column)).label('word'),
+                            sa.func.ArrayAgg(tsrc.c.place_id).label('places'))\
+                    .group_by('word')
+
+            async_result = await self.src.connection.stream(sql)
+            async for partition in async_result.partitions(100):
+                data = []
+                for row in partition:
+                    row.places.sort()
+                    data.append({'word': row.word,
+                                 'column': column,
+                                 'places': row.places})
+                await self.dest.execute(rsn.insert(), data)
+
+        await self.dest.connection.run_sync(
+            sa.Index('idx_reverse_search_name_word', rsn.c.word).create)
+
+
     def select_from(self, table: str) -> SaSelect:
         """ Create the SQL statement to select the source columns and rows.
         """
index cb620d41fb8f31126fe69a622bf14130e38494d1..de7e6a4aa2018c06e7284b4120973351b8a04ea5 100644 (file)
@@ -23,7 +23,8 @@ from nominatim.db.async_connection import DBConnection
 from nominatim.db.sql_preprocessor import SQLPreprocessor
 from nominatim.tools.exec_utils import run_osm2pgsql
 from nominatim.errors import UsageError
-from nominatim.version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
+from nominatim.version import POSTGRESQL_REQUIRED_VERSION, \
+                              POSTGIS_REQUIRED_VERSION
 
 LOG = logging.getLogger()
 
@@ -38,6 +39,25 @@ def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int,
         raise UsageError(f'{module} is too old.')
 
 
+def _require_loaded(extension_name: str, conn: Connection) -> None:
+    """ Check that the given extension is loaded. """
+    if not conn.extension_loaded(extension_name):
+        LOG.fatal('Required module %s is not loaded.', extension_name)
+        raise UsageError(f'{extension_name} is not loaded.')
+
+
+def check_existing_database_plugins(dsn: str) -> None:
+    """ Check that the database has the required plugins installed."""
+    with connect(dsn) as conn:
+        _require_version('PostgreSQL server',
+                         conn.server_version_tuple(),
+                         POSTGRESQL_REQUIRED_VERSION)
+        _require_version('PostGIS',
+                         conn.postgis_version_tuple(),
+                         POSTGIS_REQUIRED_VERSION)
+        _require_loaded('hstore', conn)
+
+
 def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None:
     """ Create a new database for Nominatim and populate it with the
         essential extensions.
index 7274f1d396f8159b714c80fff14fd25b3455b345..62ecd8c3e169ce7340dca7c6eb6a83a7881cd3d5 100644 (file)
@@ -72,3 +72,4 @@ SaLabel: TypeAlias = 'sa.Label[Any]'
 SaFromClause: TypeAlias = 'sa.FromClause'
 SaSelectable: TypeAlias = 'sa.Selectable'
 SaBind: TypeAlias = 'sa.BindParameter[Any]'
+SaDialect: TypeAlias = 'sa.Dialect'
index bb642233e78d8c4234afb7e3bc54c4cd0f69cd8b..fcc355d5eee9fa2331ad47e47eb2a4ac18dac078 100644 (file)
@@ -76,8 +76,8 @@ class JsonWriter:
     def end_array(self) -> 'JsonWriter':
         """ Write the closing bracket of a JSON array.
         """
-        assert self.pending in (',', '[', '')
-        if self.pending == '[':
+        assert self.pending in (',', '[', ']', ')', '')
+        if self.pending not in (',', ''):
             self.data.write(self.pending)
         self.pending = ']'
         return self
index b0ef92dacf6df1a1dd921ac7c7e3369e9dc35886..271ec10c16c81b991c7daf6d041f8ae68ad7077f 100644 (file)
@@ -1,3 +1,4 @@
+@SQLITE
 @APIDB
 Feature: Parameters for Search API
     Testing correctness of geocodejson output.
index b76adbef5bae66b9dc40ac8d5d0de1e254da90ee..fe14cdbe6c81e9967f95020bef74440f98ce4e81 100644 (file)
@@ -1,3 +1,4 @@
+@SQLITE
 @APIDB
 Feature: Localization of search results
 
index d5512f5b6640163a311bfad430f5b209ad604664..e667b690b0d8f2572498c6ceefbf776f86b8422e 100644 (file)
@@ -1,3 +1,4 @@
+@SQLITE
 @APIDB
 Feature: Search queries
     Testing different queries and parameters
index 81836efb57535221e7645ae041b3700cafe6df99..e372f449a95a882053d6c0ddaf92a5dbae95e751 100644 (file)
@@ -1,3 +1,4 @@
+@SQLITE
 @APIDB
 Feature: Searches with postcodes
     Various searches involving postcodes
index 847f1dbf02823aff211fdfa073b65be4a042380a..eba903ea3058ecee883778cd96f609f99766b6cc 100644 (file)
@@ -1,3 +1,4 @@
+@SQLITE
 @APIDB
 Feature: Search queries
     Generic search result correctness
index 11cd4801beb9b140b0535ef548dba9a1890db281..121271cdf1abd3eb04be9b9715c70de7715a15bb 100644 (file)
@@ -1,3 +1,4 @@
+@SQLITE
 @APIDB
 Feature: Simple Tests
     Simple tests for internal server errors and response format.
index 517c0eddd229c16e2a2d33057783d9f280a3dad7..a1dd5b83d4621b07dd670f3311dbf8d165cc0cce 100644 (file)
@@ -1,3 +1,4 @@
+@SQLITE
 @APIDB
 Feature: Structured search queries
     Testing correctness of results with
index cb7f324a393fa24e2ddb097e710b92662fc96bd2..05eaddf5fc0f182cfc501504e48b8865fdb9af95 100644 (file)
@@ -16,6 +16,7 @@ import sqlalchemy as sa
 
 import nominatim.api as napi
 from nominatim.db.sql_preprocessor import SQLPreprocessor
+from nominatim.api.search.query_analyzer_factory import make_query_analyzer
 from nominatim.tools import convert_sqlite
 import nominatim.api.logging as loglib
 
@@ -160,6 +161,22 @@ class APITester:
                                      """)))
 
 
+    def add_word_table(self, content):
+        data = [dict(zip(['word_id', 'word_token', 'type', 'word', 'info'], c))
+                for c in content]
+
+        async def _do_sql():
+            async with self.api._async_api.begin() as conn:
+                if 'word' not in conn.t.meta.tables:
+                    await make_query_analyzer(conn)
+                    word_table = conn.t.meta.tables['word']
+                    await conn.connection.run_sync(word_table.create)
+                if data:
+                    await conn.execute(conn.t.meta.tables['word'].insert(), data)
+
+        self.async_to_sync(_do_sql())
+
+
     async def exec_async(self, sql, *args, **kwargs):
         async with self.api._async_api.begin() as conn:
             return await conn.execute(sql, *args, **kwargs)
@@ -190,17 +207,40 @@ def apiobj(temp_db_with_extensions, temp_db_conn, monkeypatch):
 
 @pytest.fixture(params=['postgres_db', 'sqlite_db'])
 def frontend(request, event_loop, tmp_path):
+    testapis = []
     if request.param == 'sqlite_db':
         db = str(tmp_path / 'test_nominatim_python_unittest.sqlite')
 
         def mkapi(apiobj, options={'reverse'}):
+            apiobj.add_data('properties',
+                        [{'property': 'tokenizer', 'value': 'icu'},
+                         {'property': 'tokenizer_import_normalisation', 'value': ':: lower();'},
+                         {'property': 'tokenizer_import_transliteration', 'value': "'1' > '/1/'; 'ä' > 'ä '"},
+                        ])
+
+            async def _do_sql():
+                async with apiobj.api._async_api.begin() as conn:
+                    if 'word' in conn.t.meta.tables:
+                        return
+                    await make_query_analyzer(conn)
+                    word_table = conn.t.meta.tables['word']
+                    await conn.connection.run_sync(word_table.create)
+
+            apiobj.async_to_sync(_do_sql())
+
             event_loop.run_until_complete(convert_sqlite.convert(Path('/invalid'),
                                                                  db, options))
-            return napi.NominatimAPI(Path('/invalid'),
-                                     {'NOMINATIM_DATABASE_DSN': f"sqlite:dbname={db}",
-                                      'NOMINATIM_USE_US_TIGER_DATA': 'yes'})
+            outapi = napi.NominatimAPI(Path('/invalid'),
+                                       {'NOMINATIM_DATABASE_DSN': f"sqlite:dbname={db}",
+                                        'NOMINATIM_USE_US_TIGER_DATA': 'yes'})
+            testapis.append(outapi)
+
+            return outapi
     elif request.param == 'postgres_db':
         def mkapi(apiobj, options=None):
             return apiobj.api
 
-    return mkapi
+    yield mkapi
+
+    for api in testapis:
+        api.close()
index 87d75261528283574aae5d6a83b09d5645ac406e..d3aea90002740d7660e12a4b210bf4cb41344c60 100644 (file)
@@ -420,8 +420,8 @@ def test_infrequent_partials_in_name():
     assert len(search.lookups) == 2
     assert len(search.rankings) == 2
 
-    assert set((l.column, l.lookup_type) for l in search.lookups) == \
-            {('name_vector', 'lookup_all'), ('nameaddress_vector', 'restrict')}
+    assert set((l.column, l.lookup_type.__name__) for l in search.lookups) == \
+            {('name_vector', 'LookupAll'), ('nameaddress_vector', 'Restrict')}
 
 
 def test_frequent_partials_in_name_and_address():
@@ -432,10 +432,10 @@ def test_frequent_partials_in_name_and_address():
     assert all(isinstance(s, dbs.PlaceSearch) for s in searches)
     searches.sort(key=lambda s: s.penalty)
 
-    assert set((l.column, l.lookup_type) for l in searches[0].lookups) == \
-            {('name_vector', 'lookup_any'), ('nameaddress_vector', 'restrict')}
-    assert set((l.column, l.lookup_type) for l in searches[1].lookups) == \
-            {('nameaddress_vector', 'lookup_all'), ('name_vector', 'lookup_all')}
+    assert set((l.column, l.lookup_type.__name__) for l in searches[0].lookups) == \
+            {('name_vector', 'LookupAny'), ('nameaddress_vector', 'Restrict')}
+    assert set((l.column, l.lookup_type.__name__) for l in searches[1].lookups) == \
+            {('nameaddress_vector', 'LookupAll'), ('name_vector', 'LookupAll')}
 
 
 def test_too_frequent_partials_in_name_and_address():
@@ -446,5 +446,5 @@ def test_too_frequent_partials_in_name_and_address():
     assert all(isinstance(s, dbs.PlaceSearch) for s in searches)
     searches.sort(key=lambda s: s.penalty)
 
-    assert set((l.column, l.lookup_type) for l in searches[0].lookups) == \
-            {('name_vector', 'lookup_any'), ('nameaddress_vector', 'restrict')}
+    assert set((l.column, l.lookup_type.__name__) for l in searches[0].lookups) == \
+            {('name_vector', 'LookupAny'), ('nameaddress_vector', 'Restrict')}
index 82b1d37fe30ba52c4d1bd90b0415a10099815893..dc87d313a0ff98b1e6b776600bc21889bd0f7c27 100644 (file)
@@ -15,7 +15,7 @@ from nominatim.api.search.db_searches import CountrySearch
 from nominatim.api.search.db_search_fields import WeightedStrings
 
 
-def run_search(apiobj, global_penalty, ccodes,
+def run_search(apiobj, frontend, global_penalty, ccodes,
                country_penalties=None, details=SearchDetails()):
     if country_penalties is None:
         country_penalties = [0.0] * len(ccodes)
@@ -25,15 +25,16 @@ def run_search(apiobj, global_penalty, ccodes,
         countries = WeightedStrings(ccodes, country_penalties)
 
     search = CountrySearch(MySearchData())
+    api = frontend(apiobj, options=['search'])
 
     async def run():
-        async with apiobj.api._async_api.begin() as conn:
+        async with api._async_api.begin() as conn:
             return await search.lookup(conn, details)
 
-    return apiobj.async_to_sync(run())
+    return api._loop.run_until_complete(run())
 
 
-def test_find_from_placex(apiobj):
+def test_find_from_placex(apiobj, frontend):
     apiobj.add_placex(place_id=55, class_='boundary', type='administrative',
                       rank_search=4, rank_address=4,
                       name={'name': 'Lolaland'},
@@ -41,32 +42,32 @@ def test_find_from_placex(apiobj):
                       centroid=(10, 10),
                       geometry='POLYGON((9.5 9.5, 9.5 10.5, 10.5 10.5, 10.5 9.5, 9.5 9.5))')
 
-    results = run_search(apiobj, 0.5, ['de', 'yw'], [0.0, 0.3])
+    results = run_search(apiobj, frontend, 0.5, ['de', 'yw'], [0.0, 0.3])
 
     assert len(results) == 1
     assert results[0].place_id == 55
     assert results[0].accuracy == 0.8
 
-def test_find_from_fallback_countries(apiobj):
+def test_find_from_fallback_countries(apiobj, frontend):
     apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))')
     apiobj.add_country_name('ro', {'name': 'România'})
 
-    results = run_search(apiobj, 0.0, ['ro'])
+    results = run_search(apiobj, frontend, 0.0, ['ro'])
 
     assert len(results) == 1
     assert results[0].names == {'name': 'România'}
 
 
-def test_find_none(apiobj):
-    assert len(run_search(apiobj, 0.0, ['xx'])) == 0
+def test_find_none(apiobj, frontend):
+    assert len(run_search(apiobj, frontend, 0.0, ['xx'])) == 0
 
 
 @pytest.mark.parametrize('coord,numres', [((0.5, 1), 1), ((10, 10), 0)])
-def test_find_near(apiobj, coord, numres):
+def test_find_near(apiobj, frontend, coord, numres):
     apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))')
     apiobj.add_country_name('ro', {'name': 'România'})
 
-    results = run_search(apiobj, 0.0, ['ro'],
+    results = run_search(apiobj, frontend, 0.0, ['ro'],
                          details=SearchDetails(near=napi.Point(*coord),
                                                near_radius=0.1))
 
@@ -92,8 +93,8 @@ class TestCountryParameters:
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
     @pytest.mark.parametrize('cc', ['yw', 'ro'])
-    def test_return_geometries(self, apiobj, geom, cc):
-        results = run_search(apiobj, 0.5, [cc],
+    def test_return_geometries(self, apiobj, frontend, geom, cc):
+        results = run_search(apiobj, frontend, 0.5, [cc],
                              details=SearchDetails(geometry_output=geom))
 
         assert len(results) == 1
@@ -101,8 +102,8 @@ class TestCountryParameters:
 
 
     @pytest.mark.parametrize('pid,rids', [(76, [55]), (55, [])])
-    def test_exclude_place_id(self, apiobj, pid, rids):
-        results = run_search(apiobj, 0.5, ['yw', 'ro'],
+    def test_exclude_place_id(self, apiobj, frontend, pid, rids):
+        results = run_search(apiobj, frontend, 0.5, ['yw', 'ro'],
                              details=SearchDetails(excluded=[pid]))
 
         assert [r.place_id for r in results] == rids
@@ -110,8 +111,8 @@ class TestCountryParameters:
 
     @pytest.mark.parametrize('viewbox,rids', [((9, 9, 11, 11), [55]),
                                               ((-10, -10, -3, -3), [])])
-    def test_bounded_viewbox_in_placex(self, apiobj, viewbox, rids):
-        results = run_search(apiobj, 0.5, ['yw'],
+    def test_bounded_viewbox_in_placex(self, apiobj, frontend, viewbox, rids):
+        results = run_search(apiobj, frontend, 0.5, ['yw'],
                              details=SearchDetails.from_kwargs({'viewbox': viewbox,
                                                                 'bounded_viewbox': True}))
 
@@ -120,8 +121,8 @@ class TestCountryParameters:
 
     @pytest.mark.parametrize('viewbox,numres', [((0, 0, 1, 1), 1),
                                               ((-10, -10, -3, -3), 0)])
-    def test_bounded_viewbox_in_fallback(self, apiobj, viewbox, numres):
-        results = run_search(apiobj, 0.5, ['ro'],
+    def test_bounded_viewbox_in_fallback(self, apiobj, frontend, viewbox, numres):
+        results = run_search(apiobj, frontend, 0.5, ['ro'],
                              details=SearchDetails.from_kwargs({'viewbox': viewbox,
                                                                 'bounded_viewbox': True}))
 
index 2a0acb745969a777a75856f8cc002ea7e33da91f..5b60dd51d59c9626906591591d1b326d73c3ddb1 100644 (file)
@@ -14,9 +14,10 @@ from nominatim.api.types import SearchDetails
 from nominatim.api.search.db_searches import NearSearch, PlaceSearch
 from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories,\
                                                   FieldLookup, FieldRanking, RankedTokens
+from nominatim.api.search.db_search_lookups import LookupAll
 
 
-def run_search(apiobj, global_penalty, cat, cat_penalty=None, ccodes=[],
+def run_search(apiobj, frontend, global_penalty, cat, cat_penalty=None, ccodes=[],
                details=SearchDetails()):
 
     class PlaceSearchData:
@@ -25,7 +26,7 @@ def run_search(apiobj, global_penalty, cat, cat_penalty=None, ccodes=[],
         countries = WeightedStrings(ccodes, [0.0] * len(ccodes))
         housenumbers = WeightedStrings([], [])
         qualifiers = WeightedStrings([], [])
-        lookups = [FieldLookup('name_vector', [56], 'lookup_all')]
+        lookups = [FieldLookup('name_vector', [56], LookupAll)]
         rankings = []
 
     if ccodes is not None:
@@ -38,21 +39,23 @@ def run_search(apiobj, global_penalty, cat, cat_penalty=None, ccodes=[],
 
     near_search = NearSearch(0.1, WeightedCategories(cat, cat_penalty), place_search)
 
+    api = frontend(apiobj, options=['search'])
+
     async def run():
-        async with apiobj.api._async_api.begin() as conn:
+        async with api._async_api.begin() as conn:
             return await near_search.lookup(conn, details)
 
-    results = apiobj.async_to_sync(run())
+    results = api._loop.run_until_complete(run())
     results.sort(key=lambda r: r.accuracy)
 
     return results
 
 
-def test_no_results_inner_query(apiobj):
-    assert not run_search(apiobj, 0.4, [('this', 'that')])
+def test_no_results_inner_query(apiobj, frontend):
+    assert not run_search(apiobj, frontend, 0.4, [('this', 'that')])
 
 
-def test_no_appropriate_results_inner_query(apiobj):
+def test_no_appropriate_results_inner_query(apiobj, frontend):
     apiobj.add_placex(place_id=100, country_code='us',
                       centroid=(5.6, 4.3),
                       geometry='POLYGON((0.0 0.0, 10.0 0.0, 10.0 2.0, 0.0 2.0, 0.0 0.0))')
@@ -61,7 +64,7 @@ def test_no_appropriate_results_inner_query(apiobj):
     apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                       centroid=(5.6001, 4.2994))
 
-    assert not run_search(apiobj, 0.4, [('amenity', 'bank')])
+    assert not run_search(apiobj, frontend, 0.4, [('amenity', 'bank')])
 
 
 class TestNearSearch:
@@ -78,18 +81,18 @@ class TestNearSearch:
                                centroid=(-10.3, 56.9))
 
 
-    def test_near_in_placex(self, apiobj):
+    def test_near_in_placex(self, apiobj, frontend):
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           centroid=(5.6001, 4.2994))
         apiobj.add_placex(place_id=23, class_='amenity', type='bench',
                           centroid=(5.6001, 4.2994))
 
-        results = run_search(apiobj, 0.1, [('amenity', 'bank')])
+        results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')])
 
         assert [r.place_id for r in results] == [22]
 
 
-    def test_multiple_types_near_in_placex(self, apiobj):
+    def test_multiple_types_near_in_placex(self, apiobj, frontend):
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           importance=0.002,
                           centroid=(5.6001, 4.2994))
@@ -97,13 +100,13 @@ class TestNearSearch:
                           importance=0.001,
                           centroid=(5.6001, 4.2994))
 
-        results = run_search(apiobj, 0.1, [('amenity', 'bank'),
-                                           ('amenity', 'bench')])
+        results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank'),
+                                                     ('amenity', 'bench')])
 
         assert [r.place_id for r in results] == [22, 23]
 
 
-    def test_near_in_classtype(self, apiobj):
+    def test_near_in_classtype(self, apiobj, frontend):
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           centroid=(5.6, 4.34))
         apiobj.add_placex(place_id=23, class_='amenity', type='bench',
@@ -111,13 +114,13 @@ class TestNearSearch:
         apiobj.add_class_type_table('amenity', 'bank')
         apiobj.add_class_type_table('amenity', 'bench')
 
-        results = run_search(apiobj, 0.1, [('amenity', 'bank')])
+        results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')])
 
         assert [r.place_id for r in results] == [22]
 
 
     @pytest.mark.parametrize('cc,rid', [('us', 22), ('mx', 23)])
-    def test_restrict_by_country(self, apiobj, cc, rid):
+    def test_restrict_by_country(self, apiobj, frontend, cc, rid):
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           centroid=(5.6001, 4.2994),
                           country_code='us')
@@ -131,13 +134,13 @@ class TestNearSearch:
                           centroid=(-10.3001, 56.9),
                           country_code='us')
 
-        results = run_search(apiobj, 0.1, [('amenity', 'bank')], ccodes=[cc, 'fr'])
+        results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')], ccodes=[cc, 'fr'])
 
         assert [r.place_id for r in results] == [rid]
 
 
     @pytest.mark.parametrize('excluded,rid', [(22, 122), (122, 22)])
-    def test_exclude_place_by_id(self, apiobj, excluded, rid):
+    def test_exclude_place_by_id(self, apiobj, frontend, excluded, rid):
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           centroid=(5.6001, 4.2994),
                           country_code='us')
@@ -146,7 +149,7 @@ class TestNearSearch:
                           country_code='us')
 
 
-        results = run_search(apiobj, 0.1, [('amenity', 'bank')],
+        results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')],
                              details=SearchDetails(excluded=[excluded]))
 
         assert [r.place_id for r in results] == [rid]
@@ -154,12 +157,12 @@ class TestNearSearch:
 
     @pytest.mark.parametrize('layer,rids', [(napi.DataLayer.POI, [22]),
                                             (napi.DataLayer.MANMADE, [])])
-    def test_with_layer(self, apiobj, layer, rids):
+    def test_with_layer(self, apiobj, frontend, layer, rids):
         apiobj.add_placex(place_id=22, class_='amenity', type='bank',
                           centroid=(5.6001, 4.2994),
                           country_code='us')
 
-        results = run_search(apiobj, 0.1, [('amenity', 'bank')],
+        results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')],
                              details=SearchDetails(layers=layer))
 
         assert [r.place_id for r in results] == rids
index 8a363e97735b585aee1372ea6d87d05a3a12a17e..c446a35f88c8ecb533053c682663689a1d9de689 100644 (file)
@@ -16,8 +16,11 @@ from nominatim.api.types import SearchDetails
 from nominatim.api.search.db_searches import PlaceSearch
 from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories,\
                                                   FieldLookup, FieldRanking, RankedTokens
+from nominatim.api.search.db_search_lookups import LookupAll, LookupAny, Restrict
 
-def run_search(apiobj, global_penalty, lookup, ranking, count=2,
+APIOPTIONS = ['search']
+
+def run_search(apiobj, frontend, global_penalty, lookup, ranking, count=2,
                hnrs=[], pcs=[], ccodes=[], quals=[],
                details=SearchDetails()):
     class MySearchData:
@@ -31,11 +34,16 @@ def run_search(apiobj, global_penalty, lookup, ranking, count=2,
 
     search = PlaceSearch(0.0, MySearchData(), count)
 
+    if frontend is None:
+        api = apiobj
+    else:
+        api = frontend(apiobj, options=APIOPTIONS)
+
     async def run():
-        async with apiobj.api._async_api.begin() as conn:
+        async with api._async_api.begin() as conn:
             return await search.lookup(conn, details)
 
-    results = apiobj.async_to_sync(run())
+    results = api._loop.run_until_complete(run())
     results.sort(key=lambda r: r.accuracy)
 
     return results
@@ -55,64 +63,64 @@ class TestNameOnlySearches:
                                centroid=(-10.3, 56.9))
 
 
-    @pytest.mark.parametrize('lookup_type', ['lookup_all', 'restrict'])
+    @pytest.mark.parametrize('lookup_type', [LookupAll, Restrict])
     @pytest.mark.parametrize('rank,res', [([10], [100, 101]),
                                           ([20], [101, 100])])
-    def test_lookup_all_match(self, apiobj, lookup_type, rank, res):
+    def test_lookup_all_match(self, apiobj, frontend, lookup_type, rank, res):
         lookup = FieldLookup('name_vector', [1,2], lookup_type)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking])
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking])
 
         assert [r.place_id for r in results] == res
 
 
-    @pytest.mark.parametrize('lookup_type', ['lookup_all', 'restrict'])
-    def test_lookup_all_partial_match(self, apiobj, lookup_type):
+    @pytest.mark.parametrize('lookup_type', [LookupAll, Restrict])
+    def test_lookup_all_partial_match(self, apiobj, frontend, lookup_type):
         lookup = FieldLookup('name_vector', [1,20], lookup_type)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking])
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking])
 
         assert len(results) == 1
         assert results[0].place_id == 101
 
     @pytest.mark.parametrize('rank,res', [([10], [100, 101]),
                                           ([20], [101, 100])])
-    def test_lookup_any_match(self, apiobj, rank, res):
-        lookup = FieldLookup('name_vector', [11,21], 'lookup_any')
+    def test_lookup_any_match(self, apiobj, frontend, rank, res):
+        lookup = FieldLookup('name_vector', [11,21], LookupAny)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking])
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking])
 
         assert [r.place_id for r in results] == res
 
 
-    def test_lookup_any_partial_match(self, apiobj):
-        lookup = FieldLookup('name_vector', [20], 'lookup_all')
+    def test_lookup_any_partial_match(self, apiobj, frontend):
+        lookup = FieldLookup('name_vector', [20], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking])
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking])
 
         assert len(results) == 1
         assert results[0].place_id == 101
 
 
     @pytest.mark.parametrize('cc,res', [('us', 100), ('mx', 101)])
-    def test_lookup_restrict_country(self, apiobj, cc, res):
-        lookup = FieldLookup('name_vector', [1,2], 'lookup_all')
+    def test_lookup_restrict_country(self, apiobj, frontend, cc, res):
+        lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], ccodes=[cc])
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], ccodes=[cc])
 
         assert [r.place_id for r in results] == [res]
 
 
-    def test_lookup_restrict_placeid(self, apiobj):
-        lookup = FieldLookup('name_vector', [1,2], 'lookup_all')
+    def test_lookup_restrict_placeid(self, apiobj, frontend):
+        lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking],
                              details=SearchDetails(excluded=[101]))
 
         assert [r.place_id for r in results] == [100]
@@ -122,28 +130,28 @@ class TestNameOnlySearches:
                                       napi.GeometryFormat.KML,
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
-    def test_return_geometries(self, apiobj, geom):
-        lookup = FieldLookup('name_vector', [20], 'lookup_all')
+    def test_return_geometries(self, apiobj, frontend, geom):
+        lookup = FieldLookup('name_vector', [20], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking],
                              details=SearchDetails(geometry_output=geom))
 
         assert geom.name.lower() in results[0].geometry
 
 
     @pytest.mark.parametrize('factor,npoints', [(0.0, 3), (1.0, 2)])
-    def test_return_simplified_geometry(self, apiobj, factor, npoints):
+    def test_return_simplified_geometry(self, apiobj, frontend, factor, npoints):
         apiobj.add_placex(place_id=333, country_code='us',
                           centroid=(9.0, 9.0),
                           geometry='LINESTRING(8.9 9.0, 9.0 9.0, 9.1 9.0)')
         apiobj.add_search_name(333, names=[55], country_code='us',
                                centroid=(5.6, 4.3))
 
-        lookup = FieldLookup('name_vector', [55], 'lookup_all')
+        lookup = FieldLookup('name_vector', [55], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking],
                              details=SearchDetails(geometry_output=napi.GeometryFormat.GEOJSON,
                                                    geometry_simplification=factor))
 
@@ -157,50 +165,52 @@ class TestNameOnlySearches:
 
     @pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.7,4.0,6.0,5.0'])
     @pytest.mark.parametrize('wcount,rids', [(2, [100, 101]), (20000, [100])])
-    def test_prefer_viewbox(self, apiobj, viewbox, wcount, rids):
-        lookup = FieldLookup('name_vector', [1, 2], 'lookup_all')
+    def test_prefer_viewbox(self, apiobj, frontend, viewbox, wcount, rids):
+        lookup = FieldLookup('name_vector', [1, 2], LookupAll)
         ranking = FieldRanking('name_vector', 0.2, [RankedTokens(0.0, [21])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking])
+        api = frontend(apiobj, options=APIOPTIONS)
+        results = run_search(api, None, 0.1, [lookup], [ranking])
         assert [r.place_id for r in results] == [101, 100]
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], count=wcount,
+        results = run_search(api, None, 0.1, [lookup], [ranking], count=wcount,
                              details=SearchDetails.from_kwargs({'viewbox': viewbox}))
         assert [r.place_id for r in results] == rids
 
 
     @pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.55,4.27,5.62,4.31'])
-    def test_force_viewbox(self, apiobj, viewbox):
-        lookup = FieldLookup('name_vector', [1, 2], 'lookup_all')
+    def test_force_viewbox(self, apiobj, frontend, viewbox):
+        lookup = FieldLookup('name_vector', [1, 2], LookupAll)
 
         details=SearchDetails.from_kwargs({'viewbox': viewbox,
                                            'bounded_viewbox': True})
 
-        results = run_search(apiobj, 0.1, [lookup], [], details=details)
+        results = run_search(apiobj, frontend, 0.1, [lookup], [], details=details)
         assert [r.place_id for r in results] == [100]
 
 
-    def test_prefer_near(self, apiobj):
-        lookup = FieldLookup('name_vector', [1, 2], 'lookup_all')
+    def test_prefer_near(self, apiobj, frontend):
+        lookup = FieldLookup('name_vector', [1, 2], LookupAll)
         ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking])
+        api = frontend(apiobj, options=APIOPTIONS)
+        results = run_search(api, None, 0.1, [lookup], [ranking])
         assert [r.place_id for r in results] == [101, 100]
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking],
+        results = run_search(api, None, 0.1, [lookup], [ranking],
                              details=SearchDetails.from_kwargs({'near': '5.6,4.3'}))
         results.sort(key=lambda r: -r.importance)
         assert [r.place_id for r in results] == [100, 101]
 
 
     @pytest.mark.parametrize('radius', [0.09, 0.11])
-    def test_force_near(self, apiobj, radius):
-        lookup = FieldLookup('name_vector', [1, 2], 'lookup_all')
+    def test_force_near(self, apiobj, frontend, radius):
+        lookup = FieldLookup('name_vector', [1, 2], LookupAll)
 
         details=SearchDetails.from_kwargs({'near': '5.6,4.3',
                                            'near_radius': radius})
 
-        results = run_search(apiobj, 0.1, [lookup], [], details=details)
+        results = run_search(apiobj, frontend, 0.1, [lookup], [], details=details)
 
         assert [r.place_id for r in results] == [100]
 
@@ -241,72 +251,72 @@ class TestStreetWithHousenumber:
     @pytest.mark.parametrize('hnr,res', [('20', [91, 1]), ('20 a', [1]),
                                          ('21', [2]), ('22', [2, 92]),
                                          ('24', [93]), ('25', [])])
-    def test_lookup_by_single_housenumber(self, apiobj, hnr, res):
-        lookup = FieldLookup('name_vector', [1,2], 'lookup_all')
+    def test_lookup_by_single_housenumber(self, apiobj, frontend, hnr, res):
+        lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=[hnr])
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=[hnr])
 
         assert [r.place_id for r in results] == res + [1000, 2000]
 
 
     @pytest.mark.parametrize('cc,res', [('es', [2, 1000]), ('pt', [92, 2000])])
-    def test_lookup_with_country_restriction(self, apiobj, cc, res):
-        lookup = FieldLookup('name_vector', [1,2], 'lookup_all')
+    def test_lookup_with_country_restriction(self, apiobj, frontend, cc, res):
+        lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
                              ccodes=[cc])
 
         assert [r.place_id for r in results] == res
 
 
-    def test_lookup_exclude_housenumber_placeid(self, apiobj):
-        lookup = FieldLookup('name_vector', [1,2], 'lookup_all')
+    def test_lookup_exclude_housenumber_placeid(self, apiobj, frontend):
+        lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
                              details=SearchDetails(excluded=[92]))
 
         assert [r.place_id for r in results] == [2, 1000, 2000]
 
 
-    def test_lookup_exclude_street_placeid(self, apiobj):
-        lookup = FieldLookup('name_vector', [1,2], 'lookup_all')
+    def test_lookup_exclude_street_placeid(self, apiobj, frontend):
+        lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
                              details=SearchDetails(excluded=[1000]))
 
         assert [r.place_id for r in results] == [2, 92, 2000]
 
 
-    def test_lookup_only_house_qualifier(self, apiobj):
-        lookup = FieldLookup('name_vector', [1,2], 'lookup_all')
+    def test_lookup_only_house_qualifier(self, apiobj, frontend):
+        lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
                              quals=[('place', 'house')])
 
         assert [r.place_id for r in results] == [2, 92]
 
 
-    def test_lookup_only_street_qualifier(self, apiobj):
-        lookup = FieldLookup('name_vector', [1,2], 'lookup_all')
+    def test_lookup_only_street_qualifier(self, apiobj, frontend):
+        lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
                              quals=[('highway', 'residential')])
 
         assert [r.place_id for r in results] == [1000, 2000]
 
 
     @pytest.mark.parametrize('rank,found', [(26, True), (27, False), (30, False)])
-    def test_lookup_min_rank(self, apiobj, rank, found):
-        lookup = FieldLookup('name_vector', [1,2], 'lookup_all')
+    def test_lookup_min_rank(self, apiobj, frontend, rank, found):
+        lookup = FieldLookup('name_vector', [1,2], LookupAll)
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
                              details=SearchDetails(min_rank=rank))
 
         assert [r.place_id for r in results] == ([2, 92, 1000, 2000] if found else [2, 92])
@@ -316,17 +326,17 @@ class TestStreetWithHousenumber:
                                       napi.GeometryFormat.KML,
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
-    def test_return_geometries(self, apiobj, geom):
-        lookup = FieldLookup('name_vector', [1, 2], 'lookup_all')
+    def test_return_geometries(self, apiobj, frontend, geom):
+        lookup = FieldLookup('name_vector', [1, 2], LookupAll)
 
-        results = run_search(apiobj, 0.1, [lookup], [], hnrs=['20', '21', '22'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=['20', '21', '22'],
                              details=SearchDetails(geometry_output=geom))
 
         assert results
         assert all(geom.name.lower() in r.geometry for r in results)
 
 
-def test_very_large_housenumber(apiobj):
+def test_very_large_housenumber(apiobj, frontend):
     apiobj.add_placex(place_id=93, class_='place', type='house',
                       parent_place_id=2000,
                       housenumber='2467463524544', country_code='pt')
@@ -337,9 +347,9 @@ def test_very_large_housenumber(apiobj):
                            search_rank=26, address_rank=26,
                            country_code='pt')
 
-    lookup = FieldLookup('name_vector', [1, 2], 'lookup_all')
+    lookup = FieldLookup('name_vector', [1, 2], LookupAll)
 
-    results = run_search(apiobj, 0.1, [lookup], [], hnrs=['2467463524544'],
+    results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=['2467463524544'],
                          details=SearchDetails())
 
     assert results
@@ -347,7 +357,7 @@ def test_very_large_housenumber(apiobj):
 
 
 @pytest.mark.parametrize('wcount,rids', [(2, [990, 991]), (30000, [990])])
-def test_name_and_postcode(apiobj, wcount, rids):
+def test_name_and_postcode(apiobj, frontend, wcount, rids):
     apiobj.add_placex(place_id=990, class_='highway', type='service',
                       rank_search=27, rank_address=27,
                       postcode='11225',
@@ -365,9 +375,9 @@ def test_name_and_postcode(apiobj, wcount, rids):
     apiobj.add_postcode(place_id=100, country_code='ch', postcode='11225',
                         geometry='POINT(10 10)')
 
-    lookup = FieldLookup('name_vector', [111], 'lookup_all')
+    lookup = FieldLookup('name_vector', [111], LookupAll)
 
-    results = run_search(apiobj, 0.1, [lookup], [], pcs=['11225'], count=wcount,
+    results = run_search(apiobj, frontend, 0.1, [lookup], [], pcs=['11225'], count=wcount,
                          details=SearchDetails())
 
     assert results
@@ -397,10 +407,10 @@ class TestInterpolations:
 
 
     @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])])
-    def test_lookup_housenumber(self, apiobj, hnr, res):
-        lookup = FieldLookup('name_vector', [111], 'lookup_all')
+    def test_lookup_housenumber(self, apiobj, frontend, hnr, res):
+        lookup = FieldLookup('name_vector', [111], LookupAll)
 
-        results = run_search(apiobj, 0.1, [lookup], [], hnrs=[hnr])
+        results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=[hnr])
 
         assert [r.place_id for r in results] == res + [990]
 
@@ -409,10 +419,10 @@ class TestInterpolations:
                                       napi.GeometryFormat.KML,
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
-    def test_osmline_with_geometries(self, apiobj, geom):
-        lookup = FieldLookup('name_vector', [111], 'lookup_all')
+    def test_osmline_with_geometries(self, apiobj, frontend, geom):
+        lookup = FieldLookup('name_vector', [111], LookupAll)
 
-        results = run_search(apiobj, 0.1, [lookup], [], hnrs=['21'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=['21'],
                              details=SearchDetails(geometry_output=geom))
 
         assert results[0].place_id == 992
@@ -445,10 +455,10 @@ class TestTiger:
 
 
     @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])])
-    def test_lookup_housenumber(self, apiobj, hnr, res):
-        lookup = FieldLookup('name_vector', [111], 'lookup_all')
+    def test_lookup_housenumber(self, apiobj, frontend, hnr, res):
+        lookup = FieldLookup('name_vector', [111], LookupAll)
 
-        results = run_search(apiobj, 0.1, [lookup], [], hnrs=[hnr])
+        results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=[hnr])
 
         assert [r.place_id for r in results] == res + [990]
 
@@ -457,10 +467,10 @@ class TestTiger:
                                       napi.GeometryFormat.KML,
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
-    def test_tiger_with_geometries(self, apiobj, geom):
-        lookup = FieldLookup('name_vector', [111], 'lookup_all')
+    def test_tiger_with_geometries(self, apiobj, frontend, geom):
+        lookup = FieldLookup('name_vector', [111], LookupAll)
 
-        results = run_search(apiobj, 0.1, [lookup], [], hnrs=['21'],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=['21'],
                              details=SearchDetails(geometry_output=geom))
 
         assert results[0].place_id == 992
@@ -512,10 +522,10 @@ class TestLayersRank30:
                                            (napi.DataLayer.NATURAL, [227]),
                                            (napi.DataLayer.MANMADE | napi.DataLayer.NATURAL, [225, 227]),
                                            (napi.DataLayer.MANMADE | napi.DataLayer.RAILWAY, [225, 226])])
-    def test_layers_rank30(self, apiobj, layer, res):
-        lookup = FieldLookup('name_vector', [34], 'lookup_any')
+    def test_layers_rank30(self, apiobj, frontend, layer, res):
+        lookup = FieldLookup('name_vector', [34], LookupAny)
 
-        results = run_search(apiobj, 0.1, [lookup], [],
+        results = run_search(apiobj, frontend, 0.1, [lookup], [],
                              details=SearchDetails(layers=layer))
 
         assert [r.place_id for r in results] == res
index b80c075200f4aa3e2a08cb3b2a9b84d569bf0d38..a0b578baffbc07c2b03b7b06dafc4ca5c7e3b007 100644 (file)
@@ -15,7 +15,7 @@ from nominatim.api.search.db_searches import PoiSearch
 from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories
 
 
-def run_search(apiobj, global_penalty, poitypes, poi_penalties=None,
+def run_search(apiobj, frontend, global_penalty, poitypes, poi_penalties=None,
                ccodes=[], details=SearchDetails()):
     if poi_penalties is None:
         poi_penalties = [0.0] * len(poitypes)
@@ -27,16 +27,18 @@ def run_search(apiobj, global_penalty, poitypes, poi_penalties=None,
 
     search = PoiSearch(MySearchData())
 
+    api = frontend(apiobj, options=['search'])
+
     async def run():
-        async with apiobj.api._async_api.begin() as conn:
+        async with api._async_api.begin() as conn:
             return await search.lookup(conn, details)
 
-    return apiobj.async_to_sync(run())
+    return api._loop.run_until_complete(run())
 
 
 @pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2),
                                        ('5.0, 4.59933', 1)])
-def test_simple_near_search_in_placex(apiobj, coord, pid):
+def test_simple_near_search_in_placex(apiobj, frontend, coord, pid):
     apiobj.add_placex(place_id=1, class_='highway', type='bus_stop',
                       centroid=(5.0, 4.6))
     apiobj.add_placex(place_id=2, class_='highway', type='bus_stop',
@@ -44,7 +46,7 @@ def test_simple_near_search_in_placex(apiobj, coord, pid):
 
     details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.001})
 
-    results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], details=details)
+    results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5], details=details)
 
     assert [r.place_id for r in results] == [pid]
 
@@ -52,7 +54,7 @@ def test_simple_near_search_in_placex(apiobj, coord, pid):
 @pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2),
                                        ('34.3, 56.4', 2),
                                        ('5.0, 4.59933', 1)])
-def test_simple_near_search_in_classtype(apiobj, coord, pid):
+def test_simple_near_search_in_classtype(apiobj, frontend, coord, pid):
     apiobj.add_placex(place_id=1, class_='highway', type='bus_stop',
                       centroid=(5.0, 4.6))
     apiobj.add_placex(place_id=2, class_='highway', type='bus_stop',
@@ -61,7 +63,7 @@ def test_simple_near_search_in_classtype(apiobj, coord, pid):
 
     details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.5})
 
-    results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], details=details)
+    results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5], details=details)
 
     assert [r.place_id for r in results] == [pid]
 
@@ -83,25 +85,25 @@ class TestPoiSearchWithRestrictions:
             self.args = {'near': '34.3, 56.100021', 'near_radius': 0.001}
 
 
-    def test_unrestricted(self, apiobj):
-        results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5],
+    def test_unrestricted(self, apiobj, frontend):
+        results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5],
                              details=SearchDetails.from_kwargs(self.args))
 
         assert [r.place_id for r in results] == [1, 2]
 
 
-    def test_restict_country(self, apiobj):
-        results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5],
+    def test_restict_country(self, apiobj, frontend):
+        results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5],
                              ccodes=['de', 'nz'],
                              details=SearchDetails.from_kwargs(self.args))
 
         assert [r.place_id for r in results] == [2]
 
 
-    def test_restrict_by_viewbox(self, apiobj):
+    def test_restrict_by_viewbox(self, apiobj, frontend):
         args = {'bounded_viewbox': True, 'viewbox': '34.299,56.0,34.3001,56.10001'}
         args.update(self.args)
-        results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5],
+        results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5],
                              ccodes=['de', 'nz'],
                              details=SearchDetails.from_kwargs(args))
 
index e7153f38bf8b6147a268d5b051422d0d8d415680..6976b6a592ac5921fbb283c84c411065500793bb 100644 (file)
@@ -15,7 +15,7 @@ from nominatim.api.search.db_searches import PostcodeSearch
 from nominatim.api.search.db_search_fields import WeightedStrings, FieldLookup, \
                                                   FieldRanking, RankedTokens
 
-def run_search(apiobj, global_penalty, pcs, pc_penalties=None,
+def run_search(apiobj, frontend, global_penalty, pcs, pc_penalties=None,
                ccodes=[], lookup=[], ranking=[], details=SearchDetails()):
     if pc_penalties is None:
         pc_penalties = [0.0] * len(pcs)
@@ -29,28 +29,30 @@ def run_search(apiobj, global_penalty, pcs, pc_penalties=None,
 
     search = PostcodeSearch(0.0, MySearchData())
 
+    api = frontend(apiobj, options=['search'])
+
     async def run():
-        async with apiobj.api._async_api.begin() as conn:
+        async with api._async_api.begin() as conn:
             return await search.lookup(conn, details)
 
-    return apiobj.async_to_sync(run())
+    return api._loop.run_until_complete(run())
 
 
-def test_postcode_only_search(apiobj):
+def test_postcode_only_search(apiobj, frontend):
     apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345')
     apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345')
 
-    results = run_search(apiobj, 0.3, ['12345', '12 345'], [0.0, 0.1])
+    results = run_search(apiobj, frontend, 0.3, ['12345', '12 345'], [0.0, 0.1])
 
     assert len(results) == 2
     assert [r.place_id for r in results] == [100, 101]
 
 
-def test_postcode_with_country(apiobj):
+def test_postcode_with_country(apiobj, frontend):
     apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345')
     apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345')
 
-    results = run_search(apiobj, 0.3, ['12345', '12 345'], [0.0, 0.1],
+    results = run_search(apiobj, frontend, 0.3, ['12345', '12 345'], [0.0, 0.1],
                          ccodes=['de', 'pl'])
 
     assert len(results) == 1
@@ -81,30 +83,30 @@ class TestPostcodeSearchWithAddress:
                                country_code='pl')
 
 
-    def test_lookup_both(self, apiobj):
+    def test_lookup_both(self, apiobj, frontend):
         lookup = FieldLookup('name_vector', [1,2], 'restrict')
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, ['12345'], lookup=[lookup], ranking=[ranking])
+        results = run_search(apiobj, frontend, 0.1, ['12345'], lookup=[lookup], ranking=[ranking])
 
         assert [r.place_id for r in results] == [100, 101]
 
 
-    def test_restrict_by_name(self, apiobj):
+    def test_restrict_by_name(self, apiobj, frontend):
         lookup = FieldLookup('name_vector', [10], 'restrict')
 
-        results = run_search(apiobj, 0.1, ['12345'], lookup=[lookup])
+        results = run_search(apiobj, frontend, 0.1, ['12345'], lookup=[lookup])
 
         assert [r.place_id for r in results] == [100]
 
 
     @pytest.mark.parametrize('coord,place_id', [((16.5, 5), 100),
                                                 ((-45.1, 7.004), 101)])
-    def test_lookup_near(self, apiobj, coord, place_id):
+    def test_lookup_near(self, apiobj, frontend, coord, place_id):
         lookup = FieldLookup('name_vector', [1,2], 'restrict')
         ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
 
-        results = run_search(apiobj, 0.1, ['12345'],
+        results = run_search(apiobj, frontend, 0.1, ['12345'],
                              lookup=[lookup], ranking=[ranking],
                              details=SearchDetails(near=napi.Point(*coord),
                                                    near_radius=0.6))
@@ -116,8 +118,8 @@ class TestPostcodeSearchWithAddress:
                                       napi.GeometryFormat.KML,
                                       napi.GeometryFormat.SVG,
                                       napi.GeometryFormat.TEXT])
-    def test_return_geometries(self, apiobj, geom):
-        results = run_search(apiobj, 0.1, ['12345'],
+    def test_return_geometries(self, apiobj, frontend, geom):
+        results = run_search(apiobj, frontend, 0.1, ['12345'],
                              details=SearchDetails(geometry_output=geom))
 
         assert results
@@ -126,8 +128,8 @@ class TestPostcodeSearchWithAddress:
 
     @pytest.mark.parametrize('viewbox, rids', [('-46,6,-44,8', [101,100]),
                                                ('16,4,18,6', [100,101])])
-    def test_prefer_viewbox(self, apiobj, viewbox, rids):
-        results = run_search(apiobj, 0.1, ['12345'],
+    def test_prefer_viewbox(self, apiobj, frontend, viewbox, rids):
+        results = run_search(apiobj, frontend, 0.1, ['12345'],
                              details=SearchDetails.from_kwargs({'viewbox': viewbox}))
 
         assert [r.place_id for r in results] == rids
@@ -135,8 +137,8 @@ class TestPostcodeSearchWithAddress:
 
     @pytest.mark.parametrize('viewbox, rid', [('-46,6,-44,8', 101),
                                                ('16,4,18,6', 100)])
-    def test_restrict_to_viewbox(self, apiobj, viewbox, rid):
-        results = run_search(apiobj, 0.1, ['12345'],
+    def test_restrict_to_viewbox(self, apiobj, frontend, viewbox, rid):
+        results = run_search(apiobj, frontend, 0.1, ['12345'],
                              details=SearchDetails.from_kwargs({'viewbox': viewbox,
                                                                 'bounded_viewbox': True}))
 
@@ -145,16 +147,16 @@ class TestPostcodeSearchWithAddress:
 
     @pytest.mark.parametrize('coord,rids', [((17.05, 5), [100, 101]),
                                             ((-45, 7.1), [101, 100])])
-    def test_prefer_near(self, apiobj, coord, rids):
-        results = run_search(apiobj, 0.1, ['12345'],
+    def test_prefer_near(self, apiobj, frontend, coord, rids):
+        results = run_search(apiobj, frontend, 0.1, ['12345'],
                              details=SearchDetails(near=napi.Point(*coord)))
 
         assert [r.place_id for r in results] == rids
 
 
     @pytest.mark.parametrize('pid,rid', [(100, 101), (101, 100)])
-    def test_exclude(self, apiobj, pid, rid):
-        results = run_search(apiobj, 0.1, ['12345'],
+    def test_exclude(self, apiobj, frontend, pid, rid):
+        results = run_search(apiobj, frontend, 0.1, ['12345'],
                              details=SearchDetails(excluded=[pid]))
 
         assert [r.place_id for r in results] == [rid]
index aa263d24dd67a8a8974004870a36123061c68d93..22dbaa2642d63a99ad227246c2c135139c087bcc 100644 (file)
@@ -19,6 +19,8 @@ import sqlalchemy as sa
 import nominatim.api as napi
 import nominatim.api.logging as loglib
 
+API_OPTIONS = {'search'}
+
 @pytest.fixture(autouse=True)
 def setup_icu_tokenizer(apiobj):
     """ Setup the propoerties needed for using the ICU tokenizer.
@@ -30,66 +32,62 @@ def setup_icu_tokenizer(apiobj):
                     ])
 
 
-def test_search_no_content(apiobj, table_factory):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
+def test_search_no_content(apiobj, frontend):
+    apiobj.add_word_table([])
 
-    assert apiobj.api.search('foo') == []
+    api = frontend(apiobj, options=API_OPTIONS)
+    assert api.search('foo') == []
 
 
-def test_search_simple_word(apiobj, table_factory):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
-                  content=[(55, 'test', 'W', 'test', None),
+def test_search_simple_word(apiobj, frontend):
+    apiobj.add_word_table([(55, 'test', 'W', 'test', None),
                            (2, 'test', 'w', 'test', None)])
 
     apiobj.add_placex(place_id=444, class_='place', type='village',
                       centroid=(1.3, 0.7))
     apiobj.add_search_name(444, names=[2, 55])
 
-    results = apiobj.api.search('TEST')
+    api = frontend(apiobj, options=API_OPTIONS)
+    results = api.search('TEST')
 
     assert [r.place_id for r in results] == [444]
 
 
 @pytest.mark.parametrize('logtype', ['text', 'html'])
-def test_search_with_debug(apiobj, table_factory, logtype):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
-                  content=[(55, 'test', 'W', 'test', None),
+def test_search_with_debug(apiobj, frontend, logtype):
+    apiobj.add_word_table([(55, 'test', 'W', 'test', None),
                            (2, 'test', 'w', 'test', None)])
 
     apiobj.add_placex(place_id=444, class_='place', type='village',
                       centroid=(1.3, 0.7))
     apiobj.add_search_name(444, names=[2, 55])
 
+    api = frontend(apiobj, options=API_OPTIONS)
     loglib.set_log_output(logtype)
-    results = apiobj.api.search('TEST')
+    results = api.search('TEST')
 
     assert loglib.get_and_disable()
 
 
-def test_address_no_content(apiobj, table_factory):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
+def test_address_no_content(apiobj, frontend):
+    apiobj.add_word_table([])
 
-    assert apiobj.api.search_address(amenity='hotel',
-                                     street='Main St 34',
-                                     city='Happyville',
-                                     county='Wideland',
-                                     state='Praerie',
-                                     postalcode='55648',
-                                     country='xx') == []
+    api = frontend(apiobj, options=API_OPTIONS)
+    assert api.search_address(amenity='hotel',
+                              street='Main St 34',
+                              city='Happyville',
+                              county='Wideland',
+                              state='Praerie',
+                              postalcode='55648',
+                              country='xx') == []
 
 
 @pytest.mark.parametrize('atype,address,search', [('street', 26, 26),
                                                   ('city', 16, 18),
                                                   ('county', 12, 12),
                                                   ('state', 8, 8)])
-def test_address_simple_places(apiobj, table_factory, atype, address, search):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
-                  content=[(55, 'test', 'W', 'test', None),
+def test_address_simple_places(apiobj, frontend, atype, address, search):
+    apiobj.add_word_table([(55, 'test', 'W', 'test', None),
                            (2, 'test', 'w', 'test', None)])
 
     apiobj.add_placex(place_id=444,
@@ -97,53 +95,51 @@ def test_address_simple_places(apiobj, table_factory, atype, address, search):
                       centroid=(1.3, 0.7))
     apiobj.add_search_name(444, names=[2, 55], address_rank=address, search_rank=search)
 
-    results = apiobj.api.search_address(**{atype: 'TEST'})
+    api = frontend(apiobj, options=API_OPTIONS)
+    results = api.search_address(**{atype: 'TEST'})
 
     assert [r.place_id for r in results] == [444]
 
 
-def test_address_country(apiobj, table_factory):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
-                  content=[(None, 'ro', 'C', 'ro', None)])
+def test_address_country(apiobj, frontend):
+    apiobj.add_word_table([(None, 'ro', 'C', 'ro', None)])
     apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))')
     apiobj.add_country_name('ro', {'name': 'România'})
 
-    assert len(apiobj.api.search_address(country='ro')) == 1
+    api = frontend(apiobj, options=API_OPTIONS)
+    assert len(api.search_address(country='ro')) == 1
 
 
-def test_category_no_categories(apiobj, table_factory):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
+def test_category_no_categories(apiobj, frontend):
+    apiobj.add_word_table([])
 
-    assert apiobj.api.search_category([], near_query='Berlin') == []
+    api = frontend(apiobj, options=API_OPTIONS)
+    assert api.search_category([], near_query='Berlin') == []
 
 
-def test_category_no_content(apiobj, table_factory):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
+def test_category_no_content(apiobj, frontend):
+    apiobj.add_word_table([])
 
-    assert apiobj.api.search_category([('amenity', 'restaurant')]) == []
+    api = frontend(apiobj, options=API_OPTIONS)
+    assert api.search_category([('amenity', 'restaurant')]) == []
 
 
-def test_category_simple_restaurant(apiobj, table_factory):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
+def test_category_simple_restaurant(apiobj, frontend):
+    apiobj.add_word_table([])
 
     apiobj.add_placex(place_id=444, class_='amenity', type='restaurant',
                       centroid=(1.3, 0.7))
     apiobj.add_search_name(444, names=[2, 55], address_rank=16, search_rank=18)
 
-    results = apiobj.api.search_category([('amenity', 'restaurant')],
-                                         near=(1.3, 0.701), near_radius=0.015)
+    api = frontend(apiobj, options=API_OPTIONS)
+    results = api.search_category([('amenity', 'restaurant')],
+                                  near=(1.3, 0.701), near_radius=0.015)
 
     assert [r.place_id for r in results] == [444]
 
 
-def test_category_with_search_phrase(apiobj, table_factory):
-    table_factory('word',
-                  definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
-                  content=[(55, 'test', 'W', 'test', None),
+def test_category_with_search_phrase(apiobj, frontend):
+    apiobj.add_word_table([(55, 'test', 'W', 'test', None),
                            (2, 'test', 'w', 'test', None)])
 
     apiobj.add_placex(place_id=444, class_='place', type='village',
@@ -153,7 +149,7 @@ def test_category_with_search_phrase(apiobj, table_factory):
     apiobj.add_placex(place_id=95, class_='amenity', type='restaurant',
                       centroid=(1.3, 0.7003))
 
-    results = apiobj.api.search_category([('amenity', 'restaurant')],
-                                         near_query='TEST')
+    api = frontend(apiobj, options=API_OPTIONS)
+    results = api.search_category([('amenity', 'restaurant')], near_query='TEST')
 
     assert [r.place_id for r in results] == [95]