]> git.openstreetmap.org Git - nominatim.git/commitdiff
add address counts to tokens
authorSarah Hoffmann <lonvia@denofr.de>
Fri, 15 Mar 2024 09:54:13 +0000 (10:54 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Mon, 18 Mar 2024 10:25:48 +0000 (11:25 +0100)
nominatim/api/search/icu_tokenizer.py
nominatim/api/search/legacy_tokenizer.py
nominatim/api/search/query.py
nominatim/tokenizer/base.py
nominatim/tokenizer/legacy_tokenizer.py
test/python/api/search/test_api_search_query.py
test/python/api/search/test_db_search_builder.py
test/python/api/search/test_token_assignment.py
test/python/cli/conftest.py
test/python/tokenizer/test_icu.py

index 1c2565d1ad60c80df1f1ecb78b216439b8d98224..05ec7690c8ac0a34d8436fd08e641f4cb19bd680 100644 (file)
@@ -97,6 +97,7 @@ class ICUToken(qmod.Token):
         """ Create a ICUToken from the row of the word table.
         """
         count = 1 if row.info is None else row.info.get('count', 1)
+        addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
 
         penalty = 0.0
         if row.type == 'w':
@@ -123,7 +124,8 @@ class ICUToken(qmod.Token):
 
         return ICUToken(penalty=penalty, token=row.word_id, count=count,
                         lookup_word=lookup_word, is_indexed=True,
-                        word_token=row.word_token, info=row.info)
+                        word_token=row.word_token, info=row.info,
+                        addr_count=addr_count)
 
 
 
@@ -257,7 +259,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
             if len(part.token) <= 4 and part[0].isdigit()\
                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
-                                ICUToken(0.5, 0, 1, part.token, True, part.token, None))
+                                ICUToken(0.5, 0, 1, 1, part.token, True, part.token, None))
 
 
     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
index 86d42a543d20ce8429770d16451e61f8b7ea1e4b..bd17706e5dff7c3fc5fd5f1d37eafe234fd809f1 100644 (file)
@@ -210,6 +210,7 @@ class LegacyQueryAnalyzer(AbstractQueryAnalyzer):
 
         return LegacyToken(penalty=penalty, token=row.word_id,
                            count=row.search_name_count or 1,
+                           addr_count=1, # not supported
                            lookup_word=lookup_word,
                            word_token=row.word_token.strip(),
                            category=(rowclass, row.type) if rowclass is not None else None,
@@ -226,7 +227,7 @@ class LegacyQueryAnalyzer(AbstractQueryAnalyzer):
             if len(part) <= 4 and part.isdigit()\
                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
-                                LegacyToken(penalty=0.5, token=0, count=1,
+                                LegacyToken(penalty=0.5, token=0, count=1, addr_count=1,
                                             lookup_word=part, word_token=part,
                                             category=None, country=None,
                                             operator=None, is_indexed=True))
index bd91c2ece72eac23dbda429e88979f826888b0c1..a0d7add1b70118e32d628b4894a893386d09d996 100644 (file)
@@ -99,10 +99,10 @@ class Token(ABC):
     penalty: float
     token: int
     count: int
+    addr_count: int
     lookup_word: str
     is_indexed: bool
 
-    addr_count: int = 1
 
     @abstractmethod
     def get_category(self) -> Tuple[str, str]:
index 29bcc8e196cf29cf4ef110252fbad0a5f26da2b7..12c826eb21b19da1b1ad989d4cc5ede9f0c699cf 100644 (file)
@@ -201,7 +201,7 @@ class AbstractTokenizer(ABC):
 
 
     @abstractmethod
-    def update_statistics(self, config: Configuration) -> None:
+    def update_statistics(self, config: Configuration, threads: int = 1) -> None:
         """ Recompute any tokenizer statistics necessary for efficient lookup.
             This function is meant to be called from time to time by the user
             to improve performance. However, the tokenizer must not depend on
index f3a00839aa2f0302ba611f0f0454f7fdf818c02e..93808cc39f3407458bb2d570d2a8740128f2c168 100644 (file)
@@ -210,7 +210,7 @@ class LegacyTokenizer(AbstractTokenizer):
             self._save_config(conn, config)
 
 
-    def update_statistics(self, _: Configuration) -> None:
+    def update_statistics(self, config: Configuration, threads: int = 1) -> None:
         """ Recompute the frequency of full words.
         """
         with connect(self.dsn) as conn:
index fe850ce902930a817981bd42c6c549fc5bd91ec3..bfdceb4165fc984451e6ca8266a15554cc0cb2b8 100644 (file)
@@ -18,7 +18,8 @@ class MyToken(query.Token):
 
 
 def mktoken(tid: int):
-    return MyToken(3.0, tid, 1, 'foo', True)
+    return MyToken(penalty=3.0, token=tid, count=1, addr_count=1,
+                   lookup_word='foo', is_indexed=True)
 
 
 @pytest.mark.parametrize('ptype,ttype', [('NONE', 'WORD'),
index d3aea90002740d7660e12a4b210bf4cb41344c60..68f71298c6b64f10a846796562bd658fdfdf7cc3 100644 (file)
@@ -31,7 +31,9 @@ def make_query(*args):
         for end, ttype, tinfo in tlist:
             for tid, word in tinfo:
                 q.add_token(TokenRange(start, end), ttype,
-                            MyToken(0.5 if ttype == TokenType.PARTIAL else 0.0, tid, 1, word, True))
+                            MyToken(penalty=0.5 if ttype == TokenType.PARTIAL else 0.0,
+                                    token=tid, count=1, addr_count=1,
+                                    lookup_word=word, is_indexed=True))
 
 
     return q
@@ -395,14 +397,14 @@ def make_counted_searches(name_part, name_full, address_part, address_full,
     q.add_node(BreakType.END, PhraseType.NONE)
 
     q.add_token(TokenRange(0, 1), TokenType.PARTIAL,
-                MyToken(0.5, 1, name_part, 'name_part', True))
+                MyToken(0.5, 1, name_part, 1, 'name_part', True))
     q.add_token(TokenRange(0, 1), TokenType.WORD,
-                MyToken(0, 101, name_full, 'name_full', True))
+                MyToken(0, 101, name_full, 1, 'name_full', True))
     for i in range(num_address_parts):
         q.add_token(TokenRange(i + 1, i + 2), TokenType.PARTIAL,
-                    MyToken(0.5, 2, address_part, 'address_part', True))
+                    MyToken(0.5, 2, address_part, 1, 'address_part', True))
         q.add_token(TokenRange(i + 1, i + 2), TokenType.WORD,
-                    MyToken(0, 102, address_full, 'address_full', True))
+                    MyToken(0, 102, address_full, 1, 'address_full', True))
 
     builder = SearchBuilder(q, SearchDetails())
 
index 54e8af14cc27fb466e58a99a5d3d7ef28657e1f6..cde8495d0bb2ce557cc9d6ecd2de24721d454f3b 100644 (file)
@@ -19,7 +19,8 @@ class MyToken(Token):
 
 def make_query(*args):
     q = QueryStruct([Phrase(args[0][1], '')])
-    dummy = MyToken(3.0, 45, 1, 'foo', True)
+    dummy = MyToken(penalty=3.0, token=45, count=1, addr_count=1,
+                    lookup_word='foo', is_indexed=True)
 
     for btype, ptype, _ in args[1:]:
         q.add_node(btype, ptype)
index 1bb393fb240613d1ed85f04d36269733c61469c8..28aba597e7de38d324ebadaa1e6ef67e62b84b82 100644 (file)
@@ -32,16 +32,16 @@ class DummyTokenizer:
         self.update_statistics_called = False
         self.update_word_tokens_called = False
 
-    def update_sql_functions(self, *args):
+    def update_sql_functions(self, *args, **kwargs):
         self.update_sql_functions_called = True
 
-    def finalize_import(self, *args):
+    def finalize_import(self, *args, **kwargs):
         self.finalize_import_called = True
 
-    def update_statistics(self, *args):
+    def update_statistics(self, *args, **kwargs):
         self.update_statistics_called = True
 
-    def update_word_tokens(self, *args):
+    def update_word_tokens(self, *args, **kwargs):
         self.update_word_tokens_called = True
 
 
index aa1afe160ca9010630b7b36502d14b173a453003..9f6eae62e3467900e11829a421a3bbdef623e211 100644 (file)
@@ -227,16 +227,20 @@ def test_update_statistics_reverse_only(word_table, tokenizer_factory, test_conf
 def test_update_statistics(word_table, table_factory, temp_db_cursor,
                            tokenizer_factory, test_config):
     word_table.add_full_word(1000, 'hello')
+    word_table.add_full_word(1001, 'bye')
     table_factory('search_name',
-                  'place_id BIGINT, name_vector INT[]',
-                  [(12, [1000])])
+                  'place_id BIGINT, name_vector INT[], nameaddress_vector INT[]',
+                  [(12, [1000], [1001])])
     tok = tokenizer_factory()
 
     tok.update_statistics(test_config)
 
     assert temp_db_cursor.scalar("""SELECT count(*) FROM word
-                                    WHERE type = 'W' and
-                                          (info->>'count')::int > 0""") > 0
+                                    WHERE type = 'W' and word_id = 1000 and
+                                          (info->>'count')::int > 0""") == 1
+    assert temp_db_cursor.scalar("""SELECT count(*) FROM word
+                                    WHERE type = 'W' and word_id = 1001 and
+                                          (info->>'addr_count')::int > 0""") == 1
 
 
 def test_normalize_postcode(analyzer):