From 4a9253a0a98808fecabd269002ea98dc0c43eb24 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Wed, 9 Jul 2025 15:36:11 +0200 Subject: [PATCH] simplify QueryNode penalty and initial assignment --- src/nominatim_api/search/icu_tokenizer.py | 40 ++++++++++++-------- src/nominatim_api/search/query.py | 30 +++++---------- src/nominatim_api/search/token_assignment.py | 10 ----- 3 files changed, 33 insertions(+), 47 deletions(-) diff --git a/src/nominatim_api/search/icu_tokenizer.py b/src/nominatim_api/search/icu_tokenizer.py index 35171344..15a5e2ab 100644 --- a/src/nominatim_api/search/icu_tokenizer.py +++ b/src/nominatim_api/search/icu_tokenizer.py @@ -37,17 +37,16 @@ DB_TO_TOKEN_TYPE = { 'C': qmod.TOKEN_COUNTRY } -PENALTY_IN_TOKEN_BREAK = { - qmod.BREAK_START: 0.5, - qmod.BREAK_END: 0.5, - qmod.BREAK_PHRASE: 0.5, - qmod.BREAK_SOFT_PHRASE: 0.5, - qmod.BREAK_WORD: 0.1, - qmod.BREAK_PART: 0.0, - qmod.BREAK_TOKEN: 0.0 +PENALTY_BREAK = { + qmod.BREAK_START: -0.5, + qmod.BREAK_END: -0.5, + qmod.BREAK_PHRASE: -0.5, + qmod.BREAK_SOFT_PHRASE: -0.5, + qmod.BREAK_WORD: 0.0, + qmod.BREAK_PART: 0.2, + qmod.BREAK_TOKEN: 0.4 } - @dataclasses.dataclass class ICUToken(qmod.Token): """ Specialised token for ICU tokenizer. @@ -78,13 +77,13 @@ class ICUToken(qmod.Token): self.penalty += (distance/len(self.lookup_word)) @staticmethod - def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken': + def from_db_row(row: SaRow) -> 'ICUToken': """ Create a ICUToken from the row of the word table. """ count = 1 if row.info is None else row.info.get('count', 1) addr_count = 1 if row.info is None else row.info.get('addr_count', 1) - penalty = base_penalty + penalty = 0.0 if row.type == 'w': penalty += 0.3 elif row.type == 'W': @@ -174,11 +173,14 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): self.split_query(query) log().var_dump('Transliterated query', lambda: query.get_transliterated_query()) - words = query.extract_words(base_penalty=PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]) + words = query.extract_words() for row in await self.lookup_in_db(list(words.keys())): for trange in words[row.word_token]: - token = ICUToken.from_db_row(row, trange.penalty or 0.0) + # Create a new token for each position because the token + # penalty can vary depending on the position in the query. + # (See rerank_tokens() below.) + token = ICUToken.from_db_row(row) if row.type == 'S': if row.info['op'] in ('in', 'near'): if trange.start == 0: @@ -200,6 +202,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): lookup_word=pc, word_token=term, info=None)) self.rerank_tokens(query) + self.compute_break_penalties(query) log().table_dump('Word tokens', _dump_word_tokens(query)) @@ -232,10 +235,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): query.add_node(qmod.BREAK_TOKEN, phrase.ptype, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN], term, word) - query.nodes[-1].adjust_break(breakchar, - PENALTY_IN_TOKEN_BREAK[breakchar]) + query.nodes[-1].btype = breakchar - query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END]) + query.nodes[-1].btype = qmod.BREAK_END async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]': """ Return the token information from the database for the @@ -300,6 +302,12 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): for token in tokens: cast(ICUToken, token).rematch(norm) + def compute_break_penalties(self, query: qmod.QueryStruct) -> None: + """ Set the break penalties for the nodes in the query. + """ + for node in query.nodes: + node.penalty = PENALTY_BREAK[node.btype] + def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]: yield ['type', 'from', 'to', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info'] diff --git a/src/nominatim_api/search/query.py b/src/nominatim_api/search/query.py index 092bd586..e89ed2dc 100644 --- a/src/nominatim_api/search/query.py +++ b/src/nominatim_api/search/query.py @@ -191,7 +191,9 @@ class QueryNode: ptype: PhraseType penalty: float - """ Penalty for the break at this node. + """ Penalty for having a word break at this position. The penalty + may be negative, when a word break is more likely than continuing + the word after the node. """ term_lookup: str """ Transliterated term ending at this node. @@ -221,12 +223,6 @@ class QueryNode: return self.partial.count / (self.partial.count + self.partial.addr_count) - def adjust_break(self, btype: BreakType, penalty: float) -> None: - """ Change the break type and penalty for this node. - """ - self.btype = btype - self.penalty = penalty - def has_tokens(self, end: int, *ttypes: TokenType) -> bool: """ Check if there are tokens of the given types ending at the given node. @@ -277,8 +273,7 @@ class QueryStruct: self.source = source self.dir_penalty = 0.0 self.nodes: List[QueryNode] = \ - [QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY, - 0.0, '', '')] + [QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY)] def num_token_slots(self) -> int: """ Return the length of the query in vertice steps. @@ -286,13 +281,12 @@ class QueryStruct: return len(self.nodes) - 1 def add_node(self, btype: BreakType, ptype: PhraseType, - break_penalty: float = 0.0, term_lookup: str = '', term_normalized: str = '') -> None: """ Append a new break node with the given break type. The phrase type denotes the type for any tokens starting at the node. """ - self.nodes.append(QueryNode(btype, ptype, break_penalty, term_lookup, term_normalized)) + self.nodes.append(QueryNode(btype, ptype, 0.0, term_lookup, term_normalized)) def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None: """ Add a token to the query. 'start' and 'end' are the indexes of the @@ -386,17 +380,14 @@ class QueryStruct: """ return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes) - def extract_words(self, base_penalty: float = 0.0, - start: int = 0, + def extract_words(self, start: int = 0, endpos: Optional[int] = None) -> Dict[str, List[TokenRange]]: """ Add all combinations of words that can be formed from the terms between the given start and endnode. The terms are joined with spaces for each break. Words can never go across a BREAK_PHRASE. The functions returns a dictionary of possible words with their - position within the query and a penalty. The penalty is computed - from the base_penalty plus the penalty for each node the word - crosses. + position within the query. """ if endpos is None: endpos = len(self.nodes) @@ -405,16 +396,13 @@ class QueryStruct: for first, first_node in enumerate(self.nodes[start + 1:endpos], start): word = first_node.term_lookup - penalty = base_penalty - words[word].append(TokenRange(first, first + 1, penalty=penalty)) + words[word].append(TokenRange(first, first + 1)) if first_node.btype != BREAK_PHRASE: - penalty += first_node.penalty max_last = min(first + 20, endpos) for last, last_node in enumerate(self.nodes[first + 2:max_last], first + 2): word = ' '.join((word, last_node.term_lookup)) - words[word].append(TokenRange(first, last, penalty=penalty)) + words[word].append(TokenRange(first, last)) if last_node.btype == BREAK_PHRASE: break - penalty += last_node.penalty return words diff --git a/src/nominatim_api/search/token_assignment.py b/src/nominatim_api/search/token_assignment.py index 4247158c..85c411b9 100644 --- a/src/nominatim_api/search/token_assignment.py +++ b/src/nominatim_api/search/token_assignment.py @@ -23,16 +23,6 @@ class TypedRange: trange: qmod.TokenRange -PENALTY_TOKENCHANGE = { - qmod.BREAK_START: 0.0, - qmod.BREAK_END: 0.0, - qmod.BREAK_PHRASE: 0.0, - qmod.BREAK_SOFT_PHRASE: 0.0, - qmod.BREAK_WORD: 0.1, - qmod.BREAK_PART: 0.2, - qmod.BREAK_TOKEN: 0.4 -} - TypedRangeSeq = List[TypedRange] -- 2.39.5