]> git.openstreetmap.org Git - nominatim.git/commitdiff
add typing information for place_info and country_info
authorSarah Hoffmann <lonvia@denofr.de>
Thu, 7 Jul 2022 15:31:20 +0000 (17:31 +0200)
committerSarah Hoffmann <lonvia@denofr.de>
Mon, 18 Jul 2022 07:47:57 +0000 (09:47 +0200)
nominatim/data/country_info.py
nominatim/data/place_info.py
nominatim/tokenizer/base.py

index d754b4ddb029365b22d2cc7a77ccaeefc49a2719..ada763257847b6961edc3ee1813da36a4757c0ee 100644 (file)
@@ -7,13 +7,17 @@
 """
 Functions for importing and managing static country information.
 """
+from typing import Dict, Any, Iterable, Tuple, Optional, Container
+from pathlib import Path
 import psycopg2.extras
 
 from nominatim.db import utils as db_utils
-from nominatim.db.connection import connect
+from nominatim.db.connection import connect, Connection
 from nominatim.errors import UsageError
+from nominatim.config import Configuration
+from nominatim.tokenizer.base import AbstractTokenizer
 
-def _flatten_name_list(names):
+def _flatten_name_list(names: Any) -> Dict[str, str]:
     if names is None:
         return {}
 
@@ -41,11 +45,11 @@ class _CountryInfo:
     """ Caches country-specific properties from the configuration file.
     """
 
-    def __init__(self):
-        self._info = {}
+    def __init__(self) -> None:
+        self._info: Dict[str, Dict[str, Any]] = {}
 
 
-    def load(self, config):
+    def load(self, config: Configuration) -> None:
         """ Load the country properties from the configuration files,
             if they are not loaded yet.
         """
@@ -61,12 +65,12 @@ class _CountryInfo:
                 prop['names'] = _flatten_name_list(prop.get('names'))
 
 
-    def items(self):
+    def items(self) -> Iterable[Tuple[str, Dict[str, Any]]]:
         """ Return tuples of (country_code, property dict) as iterable.
         """
         return self._info.items()
 
-    def get(self, country_code):
+    def get(self, country_code: str) -> Dict[str, Any]:
         """ Get country information for the country with the given country code.
         """
         return self._info.get(country_code, {})
@@ -76,7 +80,7 @@ class _CountryInfo:
 _COUNTRY_INFO = _CountryInfo()
 
 
-def setup_country_config(config):
+def setup_country_config(config: Configuration) -> None:
     """ Load country properties from the configuration file.
         Needs to be called before using any other functions in this
         file.
@@ -84,7 +88,7 @@ def setup_country_config(config):
     _COUNTRY_INFO.load(config)
 
 
-def iterate(prop=None):
+def iterate(prop: Optional[str] = None) -> Iterable[Tuple[str, Dict[str, Any]]]:
     """ Iterate over country code and properties.
 
         When `prop` is None, all countries are returned with their complete
@@ -100,7 +104,7 @@ def iterate(prop=None):
     return ((c, p[prop]) for c, p in _COUNTRY_INFO.items() if prop in p)
 
 
-def setup_country_tables(dsn, sql_dir, ignore_partitions=False):
+def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = False) -> None:
     """ Create and populate the tables with basic static data that provides
         the background for geocoding. Data is assumed to not yet exist.
     """
@@ -112,7 +116,7 @@ def setup_country_tables(dsn, sql_dir, ignore_partitions=False):
             if ignore_partitions:
                 partition = 0
             else:
-                partition = props.get('partition')
+                partition = props.get('partition', 0)
             lang = props['languages'][0] if len(
                 props['languages']) == 1 else None
 
@@ -135,13 +139,14 @@ def setup_country_tables(dsn, sql_dir, ignore_partitions=False):
         conn.commit()
 
 
-def create_country_names(conn, tokenizer, languages=None):
+def create_country_names(conn: Connection, tokenizer: AbstractTokenizer,
+                         languages: Optional[Container[str]] = None) -> None:
     """ Add default country names to search index. `languages` is a comma-
         separated list of language codes as used in OSM. If `languages` is not
         empty then only name translations for the given languages are added
         to the index.
     """
-    def _include_key(key):
+    def _include_key(key: str) -> bool:
         return ':' not in key or not languages or \
                key[key.index(':') + 1:] in languages
 
index d2ba3979260fcfa2fd411771b844c3d67807c440..96912a61e36176900f7b557fd0e70a838af58784 100644 (file)
@@ -8,18 +8,19 @@
 Wrapper around place information the indexer gets from the database and hands to
 the tokenizer.
 """
+from typing import Optional, Mapping, Any
 
 class PlaceInfo:
     """ Data class containing all information the tokenizer gets about a
         place it should process the names for.
     """
 
-    def __init__(self, info):
+    def __init__(self, info: Mapping[str, Any]) -> None:
         self._info = info
 
 
     @property
-    def name(self):
+    def name(self) -> Optional[Mapping[str, str]]:
         """ A dictionary with the names of the place or None if the place
             has no names.
         """
@@ -27,7 +28,7 @@ class PlaceInfo:
 
 
     @property
-    def address(self):
+    def address(self) -> Optional[Mapping[str, str]]:
         """ A dictionary with the address elements of the place
             or None if no address information is available.
         """
@@ -35,7 +36,7 @@ class PlaceInfo:
 
 
     @property
-    def country_code(self):
+    def country_code(self) -> Optional[str]:
         """ The country code of the country the place is in. Guaranteed
             to be a two-letter lower-case string or None, if no country
             could be found.
@@ -44,20 +45,20 @@ class PlaceInfo:
 
 
     @property
-    def rank_address(self):
+    def rank_address(self) -> int:
         """ The computed rank address before rank correction.
         """
-        return self._info.get('rank_address')
+        return self._info.get('rank_address', 0)
 
 
-    def is_a(self, key, value):
+    def is_a(self, key: str, value: str) -> bool:
         """ Check if the place's primary tag corresponds to the given
             key and value.
         """
         return self._info.get('class') == key and self._info.get('type') == value
 
 
-    def is_country(self):
+    def is_country(self) -> bool:
         """ Check if the place is a valid country boundary.
         """
         return self.rank_address == 4 \
index 70a54bfdc28141e62e0e5e63e5e52342044b5a14..5a3d3b1276aa7c6bf3bb36787967d519fe70aa79 100644 (file)
@@ -28,7 +28,7 @@ class AbstractAnalyzer(ABC):
         return self
 
 
-    def __exit__(self, exc_type, exc_value, traceback) -> None:
+    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
         self.close()
 
 
@@ -95,7 +95,7 @@ class AbstractAnalyzer(ABC):
 
 
     @abstractmethod
-    def add_country_names(self, country_code: str, names: Dict[str, str]):
+    def add_country_names(self, country_code: str, names: Dict[str, str]) -> None:
         """ Add the given names to the tokenizer's list of country tokens.
 
             Arguments: