]> git.openstreetmap.org Git - nominatim.git/commitdiff
add type annotations for legacy tokenizer
authorSarah Hoffmann <lonvia@denofr.de>
Fri, 15 Jul 2022 20:52:26 +0000 (22:52 +0200)
committerSarah Hoffmann <lonvia@denofr.de>
Mon, 18 Jul 2022 07:47:57 +0000 (09:47 +0200)
nominatim/db/connection.py
nominatim/tokenizer/base.py
nominatim/tokenizer/legacy_tokenizer.py

index 25ddcba4c887d46039ba765ff57952f9288852ba..685ac6cbfe112e2f6814a8e074573a4909259f37 100644 (file)
@@ -7,7 +7,7 @@
 """
 Specialised connection and cursor functions.
 """
-from typing import List, Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple
+from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
 import contextlib
 import logging
 import os
@@ -36,7 +36,7 @@ class Cursor(psycopg2.extras.DictCursor):
         super().execute(query, args)
 
 
-    def execute_values(self, sql: Query, argslist: List[Any],
+    def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
                        template: Optional[str] = None) -> None:
         """ Wrapper for the psycopg2 convenience function to execute
             SQL for a list of values.
index 6484ff6a06a798b206c9c8e54bf947cf2da4b1d4..afcb0864214eb6a2c7e2c6378b0dca0c7843c4e3 100644 (file)
@@ -9,7 +9,7 @@ Abstract class defintions for tokenizers. These base classes are here
 mainly for documentation purposes.
 """
 from abc import ABC, abstractmethod
-from typing import List, Tuple, Dict, Any
+from typing import List, Tuple, Dict, Any, Optional
 from pathlib import Path
 
 from typing_extensions import Protocol
@@ -187,7 +187,7 @@ class AbstractTokenizer(ABC):
 
 
     @abstractmethod
-    def check_database(self, config: Configuration) -> str:
+    def check_database(self, config: Configuration) -> Optional[str]:
         """ Check that the database is set up correctly and ready for being
             queried.
 
index 36fd5722441a12e92d24434d5dc1317497dd27bc..848d619116abde7aeca34d78cf3d330db69eab64 100644 (file)
@@ -7,8 +7,10 @@
 """
 Tokenizer implementing normalisation as used before Nominatim 4.
 """
+from typing import Optional, Sequence, List, Tuple, Mapping, Any, Callable, cast, Dict, Set
 from collections import OrderedDict
 import logging
+from pathlib import Path
 import re
 import shutil
 from textwrap import dedent
@@ -17,10 +19,12 @@ from icu import Transliterator
 import psycopg2
 import psycopg2.extras
 
-from nominatim.db.connection import connect
+from nominatim.db.connection import connect, Connection
+from nominatim.config import Configuration
 from nominatim.db import properties
 from nominatim.db import utils as db_utils
 from nominatim.db.sql_preprocessor import SQLPreprocessor
+from nominatim.data.place_info import PlaceInfo
 from nominatim.errors import UsageError
 from nominatim.tokenizer.base import AbstractAnalyzer, AbstractTokenizer
 
@@ -29,13 +33,13 @@ DBCFG_MAXWORDFREQ = "tokenizer_maxwordfreq"
 
 LOG = logging.getLogger()
 
-def create(dsn, data_dir):
+def create(dsn: str, data_dir: Path) -> 'LegacyTokenizer':
     """ Create a new instance of the tokenizer provided by this module.
     """
     return LegacyTokenizer(dsn, data_dir)
 
 
-def _install_module(config_module_path, src_dir, module_dir):
+def _install_module(config_module_path: str, src_dir: Path, module_dir: Path) -> str:
     """ Copies the PostgreSQL normalisation module into the project
         directory if necessary. For historical reasons the module is
         saved in the '/module' subdirectory and not with the other tokenizer
@@ -52,7 +56,7 @@ def _install_module(config_module_path, src_dir, module_dir):
     # Compatibility mode for builddir installations.
     if module_dir.exists() and src_dir.samefile(module_dir):
         LOG.info('Running from build directory. Leaving database module as is.')
-        return module_dir
+        return str(module_dir)
 
     # In any other case install the module in the project directory.
     if not module_dir.exists():
@@ -64,10 +68,10 @@ def _install_module(config_module_path, src_dir, module_dir):
 
     LOG.info('Database module installed at %s', str(destfile))
 
-    return module_dir
+    return str(module_dir)
 
 
-def _check_module(module_dir, conn):
+def _check_module(module_dir: str, conn: Connection) -> None:
     """ Try to use the PostgreSQL module to confirm that it is correctly
         installed and accessible from PostgreSQL.
     """
@@ -89,13 +93,13 @@ class LegacyTokenizer(AbstractTokenizer):
         calls to the database.
     """
 
-    def __init__(self, dsn, data_dir):
+    def __init__(self, dsn: str, data_dir: Path) -> None:
         self.dsn = dsn
         self.data_dir = data_dir
-        self.normalization = None
+        self.normalization: Optional[str] = None
 
 
-    def init_new_db(self, config, init_db=True):
+    def init_new_db(self, config: Configuration, init_db: bool = True) -> None:
         """ Set up a new tokenizer for the database.
 
             This copies all necessary data in the project directory to make
@@ -119,7 +123,7 @@ class LegacyTokenizer(AbstractTokenizer):
             self._init_db_tables(config)
 
 
-    def init_from_project(self, config):
+    def init_from_project(self, config: Configuration) -> None:
         """ Initialise the tokenizer from the project directory.
         """
         with connect(self.dsn) as conn:
@@ -132,7 +136,7 @@ class LegacyTokenizer(AbstractTokenizer):
 
         self._install_php(config, overwrite=False)
 
-    def finalize_import(self, config):
+    def finalize_import(self, config: Configuration) -> None:
         """ Do any required postprocessing to make the tokenizer data ready
             for use.
         """
@@ -141,7 +145,7 @@ class LegacyTokenizer(AbstractTokenizer):
             sqlp.run_sql_file(conn, 'tokenizer/legacy_tokenizer_indices.sql')
 
 
-    def update_sql_functions(self, config):
+    def update_sql_functions(self, config: Configuration) -> None:
         """ Reimport the SQL functions for this tokenizer.
         """
         with connect(self.dsn) as conn:
@@ -154,7 +158,7 @@ class LegacyTokenizer(AbstractTokenizer):
                               modulepath=modulepath)
 
 
-    def check_database(self, _):
+    def check_database(self, _: Configuration) -> Optional[str]:
         """ Check that the tokenizer is set up correctly.
         """
         hint = """\
@@ -181,7 +185,7 @@ class LegacyTokenizer(AbstractTokenizer):
         return None
 
 
-    def migrate_database(self, config):
+    def migrate_database(self, config: Configuration) -> None:
         """ Initialise the project directory of an existing database for
             use with this tokenizer.
 
@@ -198,7 +202,7 @@ class LegacyTokenizer(AbstractTokenizer):
             self._save_config(conn, config)
 
 
-    def update_statistics(self):
+    def update_statistics(self) -> None:
         """ Recompute the frequency of full words.
         """
         with connect(self.dsn) as conn:
@@ -218,13 +222,13 @@ class LegacyTokenizer(AbstractTokenizer):
             conn.commit()
 
 
-    def update_word_tokens(self):
+    def update_word_tokens(self) -> None:
         """ No house-keeping implemented for the legacy tokenizer.
         """
         LOG.info("No tokenizer clean-up available.")
 
 
-    def name_analyzer(self):
+    def name_analyzer(self) -> 'LegacyNameAnalyzer':
         """ Create a new analyzer for tokenizing names and queries
             using this tokinzer. Analyzers are context managers and should
             be used accordingly:
@@ -244,7 +248,7 @@ class LegacyTokenizer(AbstractTokenizer):
         return LegacyNameAnalyzer(self.dsn, normalizer)
 
 
-    def _install_php(self, config, overwrite=True):
+    def _install_php(self, config: Configuration, overwrite: bool = True) -> None:
         """ Install the php script for the tokenizer.
         """
         php_file = self.data_dir / "tokenizer.php"
@@ -258,7 +262,7 @@ class LegacyTokenizer(AbstractTokenizer):
                 """), encoding='utf-8')
 
 
-    def _init_db_tables(self, config):
+    def _init_db_tables(self, config: Configuration) -> None:
         """ Set up the word table and fill it with pre-computed word
             frequencies.
         """
@@ -271,10 +275,12 @@ class LegacyTokenizer(AbstractTokenizer):
         db_utils.execute_file(self.dsn, config.lib_dir.data / 'words.sql')
 
 
-    def _save_config(self, conn, config):
+    def _save_config(self, conn: Connection, config: Configuration) -> None:
         """ Save the configuration that needs to remain stable for the given
             database as database properties.
         """
+        assert self.normalization is not None
+
         properties.set_property(conn, DBCFG_NORMALIZATION, self.normalization)
         properties.set_property(conn, DBCFG_MAXWORDFREQ, config.MAX_WORD_FREQUENCY)
 
@@ -287,8 +293,8 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
         normalization.
     """
 
-    def __init__(self, dsn, normalizer):
-        self.conn = connect(dsn).connection
+    def __init__(self, dsn: str, normalizer: Any):
+        self.conn: Optional[Connection] = connect(dsn).connection
         self.conn.autocommit = True
         self.normalizer = normalizer
         psycopg2.extras.register_hstore(self.conn)
@@ -296,7 +302,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
         self._cache = _TokenCache(self.conn)
 
 
-    def close(self):
+    def close(self) -> None:
         """ Free all resources used by the analyzer.
         """
         if self.conn:
@@ -304,7 +310,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
             self.conn = None
 
 
-    def get_word_token_info(self, words):
+    def get_word_token_info(self, words: Sequence[str]) -> List[Tuple[str, str, int]]:
         """ Return token information for the given list of words.
             If a word starts with # it is assumed to be a full name
             otherwise is a partial name.
@@ -315,6 +321,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
             The function is used for testing and debugging only
             and not necessarily efficient.
         """
+        assert self.conn is not None
         with self.conn.cursor() as cur:
             cur.execute("""SELECT t.term, word_token, word_id
                            FROM word, (SELECT unnest(%s::TEXT[]) as term) t
@@ -330,14 +337,14 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
             return [(r[0], r[1], r[2]) for r in cur]
 
 
-    def normalize(self, phrase):
+    def normalize(self, phrase: str) -> str:
         """ Normalize the given phrase, i.e. remove all properties that
             are irrelevant for search.
         """
-        return self.normalizer.transliterate(phrase)
+        return cast(str, self.normalizer.transliterate(phrase))
 
 
-    def normalize_postcode(self, postcode):
+    def normalize_postcode(self, postcode: str) -> str:
         """ Convert the postcode to a standardized form.
 
             This function must yield exactly the same result as the SQL function
@@ -346,10 +353,12 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
         return postcode.strip().upper()
 
 
-    def update_postcodes_from_db(self):
+    def update_postcodes_from_db(self) -> None:
         """ Update postcode tokens in the word table from the location_postcode
             table.
         """
+        assert self.conn is not None
+
         with self.conn.cursor() as cur:
             # This finds us the rows in location_postcode and word that are
             # missing in the other table.
@@ -383,9 +392,12 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
 
 
 
-    def update_special_phrases(self, phrases, should_replace):
+    def update_special_phrases(self, phrases: Sequence[Tuple[str, str, str, str]],
+                               should_replace: bool) -> None:
         """ Replace the search index for special phrases with the new phrases.
         """
+        assert self.conn is not None
+
         norm_phrases = set(((self.normalize(p[0]), p[1], p[2], p[3])
                             for p in phrases))
 
@@ -422,9 +434,11 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
                  len(norm_phrases), len(to_add), len(to_delete))
 
 
-    def add_country_names(self, country_code, names):
+    def add_country_names(self, country_code: str, names: Mapping[str, str]) -> None:
         """ Add names for the given country to the search index.
         """
+        assert self.conn is not None
+
         with self.conn.cursor() as cur:
             cur.execute(
                 """INSERT INTO word (word_id, word_token, country_code)
@@ -436,12 +450,14 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
                 """, (country_code, list(names.values()), country_code))
 
 
-    def process_place(self, place):
+    def process_place(self, place: PlaceInfo) -> Mapping[str, Any]:
         """ Determine tokenizer information about the given place.
 
             Returns a JSON-serialisable structure that will be handed into
             the database via the token_info field.
         """
+        assert self.conn is not None
+
         token_info = _TokenInfo(self._cache)
 
         names = place.name
@@ -450,6 +466,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
             token_info.add_names(self.conn, names)
 
             if place.is_country():
+                assert place.country_code is not None
                 self.add_country_names(place.country_code, names)
 
         address = place.address
@@ -459,7 +476,8 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
         return token_info.data
 
 
-    def _process_place_address(self, token_info, address):
+    def _process_place_address(self, token_info: '_TokenInfo', address: Mapping[str, str]) -> None:
+        assert self.conn is not None
         hnrs = []
         addr_terms = []
 
@@ -491,12 +509,12 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
 class _TokenInfo:
     """ Collect token information to be sent back to the database.
     """
-    def __init__(self, cache):
+    def __init__(self, cache: '_TokenCache') -> None:
         self.cache = cache
-        self.data = {}
+        self.data: Dict[str, Any] = {}
 
 
-    def add_names(self, conn, names):
+    def add_names(self, conn: Connection, names: Mapping[str, str]) -> None:
         """ Add token information for the names of the place.
         """
         with conn.cursor() as cur:
@@ -505,7 +523,7 @@ class _TokenInfo:
                                             (names, ))
 
 
-    def add_housenumbers(self, conn, hnrs):
+    def add_housenumbers(self, conn: Connection, hnrs: Sequence[str]) -> None:
         """ Extract housenumber information from the address.
         """
         if len(hnrs) == 1:
@@ -516,7 +534,7 @@ class _TokenInfo:
                 return
 
         # split numbers if necessary
-        simple_list = []
+        simple_list: List[str] = []
         for hnr in hnrs:
             simple_list.extend((x.strip() for x in re.split(r'[;,]', hnr)))
 
@@ -525,49 +543,53 @@ class _TokenInfo:
 
         with conn.cursor() as cur:
             cur.execute("SELECT * FROM create_housenumbers(%s)", (simple_list, ))
-            self.data['hnr_tokens'], self.data['hnr'] = cur.fetchone()
+            self.data['hnr_tokens'], self.data['hnr'] = \
+                cur.fetchone() # type: ignore[no-untyped-call]
 
 
-    def set_postcode(self, postcode):
+    def set_postcode(self, postcode: str) -> None:
         """ Set or replace the postcode token with the given value.
         """
         self.data['postcode'] = postcode
 
-    def add_street(self, conn, street):
+    def add_street(self, conn: Connection, street: str) -> None:
         """ Add addr:street match terms.
         """
-        def _get_street(name):
+        def _get_street(name: str) -> List[int]:
             with conn.cursor() as cur:
-                return cur.scalar("SELECT word_ids_from_name(%s)::text", (name, ))
+                return cast(List[int],
+                            cur.scalar("SELECT word_ids_from_name(%s)::text", (name, )))
 
         tokens = self.cache.streets.get(street, _get_street)
         if tokens:
             self.data['street'] = tokens
 
 
-    def add_place(self, conn, place):
+    def add_place(self, conn: Connection, place: str) -> None:
         """ Add addr:place search and match terms.
         """
-        def _get_place(name):
+        def _get_place(name: str) -> Tuple[List[int], List[int]]:
             with conn.cursor() as cur:
                 cur.execute("""SELECT make_keywords(hstore('name' , %s))::text,
                                       word_ids_from_name(%s)::text""",
                             (name, name))
-                return cur.fetchone()
+                return cast(Tuple[List[int], List[int]],
+                            cur.fetchone()) # type: ignore[no-untyped-call]
 
         self.data['place_search'], self.data['place_match'] = \
             self.cache.places.get(place, _get_place)
 
 
-    def add_address_terms(self, conn, terms):
+    def add_address_terms(self, conn: Connection, terms: Sequence[Tuple[str, str]]) -> None:
         """ Add additional address terms.
         """
-        def _get_address_term(name):
+        def _get_address_term(name: str) -> Tuple[List[int], List[int]]:
             with conn.cursor() as cur:
                 cur.execute("""SELECT addr_ids_from_name(%s)::text,
                                       word_ids_from_name(%s)::text""",
                             (name, name))
-                return cur.fetchone()
+                return cast(Tuple[List[int], List[int]],
+                            cur.fetchone()) # type: ignore[no-untyped-call]
 
         tokens = {}
         for key, value in terms:
@@ -584,13 +606,12 @@ class _LRU:
         produce the item when there is a cache miss.
     """
 
-    def __init__(self, maxsize=128, init_data=None):
-        self.data = init_data or OrderedDict()
+    def __init__(self, maxsize: int = 128):
+        self.data: 'OrderedDict[str, Any]' = OrderedDict()
         self.maxsize = maxsize
-        if init_data is not None and len(init_data) > maxsize:
-            self.maxsize = len(init_data)
 
-    def get(self, key, generator):
+
+    def get(self, key: str, generator: Callable[[str], Any]) -> Any:
         """ Get the item with the given key from the cache. If nothing
             is found in the cache, generate the value through the
             generator function and store it in the cache.
@@ -613,7 +634,7 @@ class _TokenCache:
         This cache is not thread-safe and needs to be instantiated per
         analyzer.
     """
-    def __init__(self, conn):
+    def __init__(self, conn: Connection):
         # various LRU caches
         self.streets = _LRU(maxsize=256)
         self.places = _LRU(maxsize=128)
@@ -623,18 +644,18 @@ class _TokenCache:
         with conn.cursor() as cur:
             cur.execute("""SELECT i, ARRAY[getorcreate_housenumber_id(i::text)]::text
                            FROM generate_series(1, 100) as i""")
-            self._cached_housenumbers = {str(r[0]): r[1] for r in cur}
+            self._cached_housenumbers: Dict[str, str] = {str(r[0]): r[1] for r in cur}
 
         # For postcodes remember the ones that have already been added
-        self.postcodes = set()
+        self.postcodes: Set[str] = set()
 
-    def get_housenumber(self, number):
+    def get_housenumber(self, number: str) -> Optional[str]:
         """ Get a housenumber token from the cache.
         """
         return self._cached_housenumbers.get(number)
 
 
-    def add_postcode(self, conn, postcode):
+    def add_postcode(self, conn: Connection, postcode: str) -> None:
         """ Make sure the given postcode is in the database.
         """
         if postcode not in self.postcodes: