From 2ef0e20a3f90e92d2aa926623fe012c8a708271a Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Fri, 11 Apr 2025 13:38:34 +0200 Subject: [PATCH] reorganise token reranking As the reranking is about changing penalties in presence of other tokens, change the datastructure to have the other tokens readily avilable. --- src/nominatim_api/search/icu_tokenizer.py | 53 ++++++++++++++--------- src/nominatim_api/search/query.py | 16 ++++--- 2 files changed, 43 insertions(+), 26 deletions(-) diff --git a/src/nominatim_api/search/icu_tokenizer.py b/src/nominatim_api/search/icu_tokenizer.py index 85b0a9f0..35171344 100644 --- a/src/nominatim_api/search/icu_tokenizer.py +++ b/src/nominatim_api/search/icu_tokenizer.py @@ -267,27 +267,38 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): def rerank_tokens(self, query: qmod.QueryStruct) -> None: """ Add penalties to tokens that depend on presence of other token. """ - for i, node, tlist in query.iter_token_lists(): - if tlist.ttype == qmod.TOKEN_POSTCODE: - tlen = len(cast(ICUToken, tlist.tokens[0]).word_token) - for repl in node.starting: - if repl.end == tlist.end and repl.ttype != qmod.TOKEN_POSTCODE \ - and (repl.ttype != qmod.TOKEN_HOUSENUMBER or tlen > 4): - repl.add_penalty(0.39) - elif (tlist.ttype == qmod.TOKEN_HOUSENUMBER - and len(tlist.tokens[0].lookup_word) <= 3): - if any(c.isdigit() for c in tlist.tokens[0].lookup_word): - for repl in node.starting: - if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER: - repl.add_penalty(0.5 - tlist.tokens[0].penalty) - elif tlist.ttype != qmod.TOKEN_COUNTRY: - norm = ' '.join(n.term_normalized for n in query.nodes[i + 1:tlist.end + 1] - if n.btype != qmod.BREAK_TOKEN) - if not norm: - # Can happen when the token only covers a partial term - norm = query.nodes[i + 1].term_normalized - for token in tlist.tokens: - cast(ICUToken, token).rematch(norm) + for start, end, tlist in query.iter_tokens_by_edge(): + if len(tlist) > 1: + # If it looks like a Postcode, give preference. + if qmod.TOKEN_POSTCODE in tlist: + for ttype, tokens in tlist.items(): + if ttype != qmod.TOKEN_POSTCODE and \ + (ttype != qmod.TOKEN_HOUSENUMBER or + start + 1 > end or + len(query.nodes[end].term_lookup) > 4): + for token in tokens: + token.penalty += 0.39 + + # If it looks like a simple housenumber, prefer that. + if qmod.TOKEN_HOUSENUMBER in tlist: + hnr_lookup = tlist[qmod.TOKEN_HOUSENUMBER][0].lookup_word + if len(hnr_lookup) <= 3 and any(c.isdigit() for c in hnr_lookup): + penalty = 0.5 - tlist[qmod.TOKEN_HOUSENUMBER][0].penalty + for ttype, tokens in tlist.items(): + if ttype != qmod.TOKEN_HOUSENUMBER: + for token in tokens: + token.penalty += penalty + + # rerank tokens against the normalized form + norm = ' '.join(n.term_normalized for n in query.nodes[start + 1:end + 1] + if n.btype != qmod.BREAK_TOKEN) + if not norm: + # Can happen when the token only covers a partial term + norm = query.nodes[start + 1].term_normalized + for ttype, tokens in tlist.items(): + if ttype != qmod.TOKEN_COUNTRY: + for token in tokens: + cast(ICUToken, token).rematch(norm) def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]: diff --git a/src/nominatim_api/search/query.py b/src/nominatim_api/search/query.py index 45bab88e..652c3986 100644 --- a/src/nominatim_api/search/query.py +++ b/src/nominatim_api/search/query.py @@ -183,10 +183,10 @@ class QueryNode: """ Penalty for the break at this node. """ term_lookup: str - """ Transliterated term following this node. + """ Transliterated term ending at this node. """ term_normalized: str - """ Normalised form of term following this node. + """ Normalised form of term ending at this node. When the token resulted from a split during transliteration, then this string contains the complete source term. """ @@ -307,12 +307,18 @@ class QueryStruct: """ return (n.partial for n in self.nodes[trange.start:trange.end] if n.partial is not None) - def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]: - """ Iterator over all token lists except partial tokens in the query. + def iter_tokens_by_edge(self) -> Iterator[Tuple[int, int, Dict[TokenType, List[Token]]]]: + """ Iterator over all tokens except partial ones grouped by edge. + + Returns the start and end node indexes and a dictionary + of list of tokens by token type. """ for i, node in enumerate(self.nodes): + by_end: Dict[int, Dict[TokenType, List[Token]]] = defaultdict(dict) for tlist in node.starting: - yield i, node, tlist + by_end[tlist.end][tlist.ttype] = tlist.tokens + for end, endlist in by_end.items(): + yield i, end, endlist def find_lookup_word_by_id(self, token: int) -> str: """ Find the first token with the given token ID and return -- 2.39.5