From 4634ad0720ce97973b48adbe21b55ce1e6b2c8a7 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Wed, 9 Jul 2025 20:35:15 +0200 Subject: [PATCH] rebalance word transition penalties --- src/nominatim_api/search/db_search_builder.py | 43 ++++++++----------- src/nominatim_api/search/icu_tokenizer.py | 5 +-- src/nominatim_api/search/query.py | 23 +++++++++- src/nominatim_api/search/token_assignment.py | 17 +++++--- 4 files changed, 52 insertions(+), 36 deletions(-) diff --git a/src/nominatim_api/search/db_search_builder.py b/src/nominatim_api/search/db_search_builder.py index 9cb263fd..34f6b6c2 100644 --- a/src/nominatim_api/search/db_search_builder.py +++ b/src/nominatim_api/search/db_search_builder.py @@ -282,10 +282,14 @@ class SearchBuilder: """ Create a ranking expression for a name term in the given range. """ name_fulls = self.query.get_tokens(trange, qmod.TOKEN_WORD) - ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls] + full_word_penalty = self.query.get_in_word_penalty(trange) + ranks = [dbf.RankedTokens(t.penalty + full_word_penalty, [t.token]) + for t in name_fulls] ranks.sort(key=lambda r: r.penalty) # Fallback, sum of penalty for partials - default = sum(t.penalty for t in self.query.iter_partials(trange)) + 0.2 + default = sum(t.penalty for t in self.query.iter_partials(trange)) + default += sum(n.word_break_penalty + for n in self.query.nodes[trange.start + 1:trange.end]) return dbf.FieldRanking(db_field, default, ranks) def get_addr_ranking(self, trange: qmod.TokenRange) -> dbf.FieldRanking: @@ -303,7 +307,7 @@ class SearchBuilder: if partial is not None: if pos + 1 < trange.end: penalty = rank.penalty + partial.penalty \ - + PENALTY_WORDCHANGE[self.query.nodes[pos + 1].btype] + + self.query.nodes[pos + 1].word_break_penalty heapq.heappush(todo, (neglen - 1, pos + 1, dbf.RankedTokens(penalty, rank.tokens))) else: @@ -313,7 +317,9 @@ class SearchBuilder: for tlist in self.query.nodes[pos].starting: if tlist.ttype == qmod.TOKEN_WORD: if tlist.end < trange.end: - chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype] + chgpenalty = self.query.nodes[tlist.end].word_break_penalty \ + + self.query.get_in_word_penalty( + qmod.TokenRange(pos, tlist.end)) for t in tlist.tokens: heapq.heappush(todo, (neglen - 1, tlist.end, rank.with_token(t, chgpenalty))) @@ -323,7 +329,9 @@ class SearchBuilder: if len(ranks) >= 10: # Too many variants, bail out and only add # Worst-case Fallback: sum of penalty of partials - default = sum(t.penalty for t in self.query.iter_partials(trange)) + 0.2 + default = sum(t.penalty for t in self.query.iter_partials(trange)) + default += sum(n.word_break_penalty + for n in self.query.nodes[trange.start + 1:trange.end]) ranks.append(dbf.RankedTokens(rank.penalty + default, [])) # Bail out of outer loop break @@ -346,6 +354,7 @@ class SearchBuilder: if not tokens: return None sdata.set_strings('countries', tokens) + sdata.penalty += self.query.get_in_word_penalty(assignment.country) elif self.details.countries: sdata.countries = dbf.WeightedStrings(self.details.countries, [0.0] * len(self.details.countries)) @@ -353,29 +362,24 @@ class SearchBuilder: sdata.set_strings('housenumbers', self.query.get_tokens(assignment.housenumber, qmod.TOKEN_HOUSENUMBER)) + sdata.penalty += self.query.get_in_word_penalty(assignment.housenumber) if assignment.postcode: sdata.set_strings('postcodes', self.query.get_tokens(assignment.postcode, qmod.TOKEN_POSTCODE)) + sdata.penalty += self.query.get_in_word_penalty(assignment.postcode) if assignment.qualifier: tokens = self.get_qualifier_tokens(assignment.qualifier) if not tokens: return None sdata.set_qualifiers(tokens) + sdata.penalty += self.query.get_in_word_penalty(assignment.qualifier) elif self.details.categories: sdata.qualifiers = dbf.WeightedCategories(self.details.categories, [0.0] * len(self.details.categories)) if assignment.address: - if not assignment.name and assignment.housenumber: - # housenumber search: the first item needs to be handled like - # a name in ranking or penalties are not comparable with - # normal searches. - sdata.set_ranking([self.get_name_ranking(assignment.address[0], - db_field='nameaddress_vector')] - + [self.get_addr_ranking(r) for r in assignment.address[1:]]) - else: - sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address]) + sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address]) else: sdata.rankings = [] @@ -421,14 +425,3 @@ class SearchBuilder: return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values())) return None - - -PENALTY_WORDCHANGE = { - qmod.BREAK_START: 0.0, - qmod.BREAK_END: 0.0, - qmod.BREAK_PHRASE: 0.0, - qmod.BREAK_SOFT_PHRASE: 0.0, - qmod.BREAK_WORD: 0.1, - qmod.BREAK_PART: 0.2, - qmod.BREAK_TOKEN: 0.4 -} diff --git a/src/nominatim_api/search/icu_tokenizer.py b/src/nominatim_api/search/icu_tokenizer.py index 15a5e2ab..2bb9ce93 100644 --- a/src/nominatim_api/search/icu_tokenizer.py +++ b/src/nominatim_api/search/icu_tokenizer.py @@ -47,6 +47,7 @@ PENALTY_BREAK = { qmod.BREAK_TOKEN: 0.4 } + @dataclasses.dataclass class ICUToken(qmod.Token): """ Specialised token for ICU tokenizer. @@ -232,9 +233,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): if trans: for term in trans.split(' '): if term: - query.add_node(qmod.BREAK_TOKEN, phrase.ptype, - PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN], - term, word) + query.add_node(qmod.BREAK_TOKEN, phrase.ptype, term, word) query.nodes[-1].btype = breakchar query.nodes[-1].btype = qmod.BREAK_END diff --git a/src/nominatim_api/search/query.py b/src/nominatim_api/search/query.py index e89ed2dc..3ce9db21 100644 --- a/src/nominatim_api/search/query.py +++ b/src/nominatim_api/search/query.py @@ -214,6 +214,19 @@ class QueryNode: types of tokens spanning over the gap. """ + @property + def word_break_penalty(self) -> float: + """ Penalty to apply when a words ends at this node. + """ + return max(0, self.penalty) + + @property + def word_continuation_penalty(self) -> float: + """ Penalty to apply when a word continues over this node + (i.e. is a multi-term word). + """ + return max(0, -self.penalty) + def name_address_ratio(self) -> float: """ Return the propability that the partial token belonging to this node forms part of a name (as opposed of part of the address). @@ -273,7 +286,8 @@ class QueryStruct: self.source = source self.dir_penalty = 0.0 self.nodes: List[QueryNode] = \ - [QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY)] + [QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY, + 0.0, '', '')] def num_token_slots(self) -> int: """ Return the length of the query in vertice steps. @@ -338,6 +352,13 @@ class QueryStruct: assert ttype != TOKEN_PARTIAL return self.nodes[trange.start].get_tokens(trange.end, ttype) or [] + def get_in_word_penalty(self, trange: TokenRange) -> float: + """ Gets the sum of penalties for all token transitions + within the given range. + """ + return sum(n.word_continuation_penalty + for n in self.nodes[trange.start + 1:trange.end]) + def iter_partials(self, trange: TokenRange) -> Iterator[Token]: """ Iterate over the partial tokens between the given nodes. Missing partials are ignored. diff --git a/src/nominatim_api/search/token_assignment.py b/src/nominatim_api/search/token_assignment.py index 85c411b9..798ee546 100644 --- a/src/nominatim_api/search/token_assignment.py +++ b/src/nominatim_api/search/token_assignment.py @@ -182,7 +182,7 @@ class _TokenSequence: return None def advance(self, ttype: qmod.TokenType, end_pos: int, - btype: qmod.BreakType) -> Optional['_TokenSequence']: + force_break: bool, break_penalty: float) -> Optional['_TokenSequence']: """ Return a new token sequence state with the given token type extended. """ @@ -195,7 +195,7 @@ class _TokenSequence: new_penalty = 0.0 else: last = self.seq[-1] - if btype != qmod.BREAK_PHRASE and last.ttype == ttype: + if not force_break and last.ttype == ttype: # extend the existing range newseq = self.seq[:-1] + [TypedRange(ttype, last.trange.replace_end(end_pos))] new_penalty = 0.0 @@ -203,7 +203,7 @@ class _TokenSequence: # start a new range newseq = list(self.seq) + [TypedRange(ttype, qmod.TokenRange(last.trange.end, end_pos))] - new_penalty = PENALTY_TOKENCHANGE[btype] + new_penalty = break_penalty return _TokenSequence(newseq, newdir, self.penalty + new_penalty) @@ -307,7 +307,7 @@ class _TokenSequence: name, addr = first.split(i) log().comment(f'split first word = name ({i - first.start})') yield dataclasses.replace(base, name=name, address=[addr] + base.address[1:], - penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype]) + penalty=penalty + query.nodes[i].word_break_penalty) def _get_assignments_address_backward(self, base: TokenAssignment, query: qmod.QueryStruct) -> Iterator[TokenAssignment]: @@ -352,7 +352,7 @@ class _TokenSequence: addr, name = last.split(i) log().comment(f'split last word = name ({i - last.start})') yield dataclasses.replace(base, name=name, address=base.address[:-1] + [addr], - penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype]) + penalty=penalty + query.nodes[i].word_break_penalty) def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]: """ Yield possible assignments for the current sequence. @@ -412,12 +412,15 @@ def yield_token_assignments(query: qmod.QueryStruct) -> Iterator[TokenAssignment for tlist in node.starting: yield from _append_state_to_todo( query, todo, - state.advance(tlist.ttype, tlist.end, node.btype)) + state.advance(tlist.ttype, tlist.end, + True, node.word_break_penalty)) if node.partial is not None: yield from _append_state_to_todo( query, todo, - state.advance(qmod.TOKEN_PARTIAL, state.end_pos + 1, node.btype)) + state.advance(qmod.TOKEN_PARTIAL, state.end_pos + 1, + node.btype == qmod.BREAK_PHRASE, + node.word_break_penalty)) def _append_state_to_todo(query: qmod.QueryStruct, todo: List[_TokenSequence], -- 2.39.5