1 # SPDX-License-Identifier: GPL-3.0-or-later
 
   3 # This file is part of Nominatim. (https://nominatim.org)
 
   5 # Copyright (C) 2024 by the Nominatim developer community.
 
   6 # For a full list of authors see the git log.
 
   8 Implementation of query analysis for the ICU tokenizer.
 
  10 from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
 
  11 from collections import defaultdict
 
  15 from icu import Transliterator
 
  17 import sqlalchemy as sa
 
  19 from ..errors import UsageError
 
  20 from ..typing import SaRow
 
  21 from ..sql.sqlalchemy_types import Json
 
  22 from ..connection import SearchConnection
 
  23 from ..logging import log
 
  24 from . import query as qmod
 
  25 from ..query_preprocessing.config import QueryConfig
 
  26 from .query_analyzer_factory import AbstractQueryAnalyzer
 
  30     'W': qmod.TokenType.WORD,
 
  31     'w': qmod.TokenType.PARTIAL,
 
  32     'H': qmod.TokenType.HOUSENUMBER,
 
  33     'P': qmod.TokenType.POSTCODE,
 
  34     'C': qmod.TokenType.COUNTRY
 
  38 class QueryPart(NamedTuple):
 
  39     """ Normalized and transliterated form of a single term in the query.
 
  40         When the term came out of a split during the transliteration,
 
  41         the normalized string is the full word before transliteration.
 
  42         The word number keeps track of the word before transliteration
 
  43         and can be used to identify partial transliterated terms.
 
  50 QueryParts = List[QueryPart]
 
  51 WordDict = Dict[str, List[qmod.TokenRange]]
 
  54 def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
 
  55     """ Return all combinations of words in the terms list after the
 
  59     for first in range(start, total):
 
  60         word = terms[first].token
 
  61         yield word, qmod.TokenRange(first, first + 1)
 
  62         for last in range(first + 1, min(first + 20, total)):
 
  63             word = ' '.join((word, terms[last].token))
 
  64             yield word, qmod.TokenRange(first, last + 1)
 
  67 @dataclasses.dataclass
 
  68 class ICUToken(qmod.Token):
 
  69     """ Specialised token for ICU tokenizer.
 
  72     info: Optional[Dict[str, Any]]
 
  74     def get_category(self) -> Tuple[str, str]:
 
  76         return self.info.get('class', ''), self.info.get('type', '')
 
  78     def rematch(self, norm: str) -> None:
 
  79         """ Check how well the token matches the given normalized string
 
  80             and add a penalty, if necessary.
 
  82         if not self.lookup_word:
 
  85         seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
 
  87         for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
 
  88             if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
 
  90             elif tag == 'replace':
 
  91                 distance += max((ato-afrom), (bto-bfrom))
 
  93                 distance += abs((ato-afrom) - (bto-bfrom))
 
  94         self.penalty += (distance/len(self.lookup_word))
 
  97     def from_db_row(row: SaRow) -> 'ICUToken':
 
  98         """ Create a ICUToken from the row of the word table.
 
 100         count = 1 if row.info is None else row.info.get('count', 1)
 
 101         addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
 
 106         elif row.type == 'W':
 
 107             if len(row.word_token) == 1 and row.word_token == row.word:
 
 108                 penalty = 0.2 if row.word.isdigit() else 0.3
 
 109         elif row.type == 'H':
 
 110             penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
 
 111             if all(not c.isdigit() for c in row.word_token):
 
 112                 penalty += 0.2 * (len(row.word_token) - 1)
 
 113         elif row.type == 'C':
 
 114             if len(row.word_token) == 1:
 
 118             lookup_word = row.word
 
 120             lookup_word = row.info.get('lookup', row.word)
 
 122             lookup_word = lookup_word.split('@', 1)[0]
 
 124             lookup_word = row.word_token
 
 126         return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
 
 127                         lookup_word=lookup_word,
 
 128                         word_token=row.word_token, info=row.info,
 
 129                         addr_count=max(1, addr_count))
 
 132 class ICUQueryAnalyzer(AbstractQueryAnalyzer):
 
 133     """ Converter for query strings into a tokenized query
 
 134         using the tokens created by a ICU tokenizer.
 
 136     def __init__(self, conn: SearchConnection) -> None:
 
 139     async def setup(self) -> None:
 
 140         """ Set up static data structures needed for the analysis.
 
 142         async def _make_normalizer() -> Any:
 
 143             rules = await self.conn.get_property('tokenizer_import_normalisation')
 
 144             return Transliterator.createFromRules("normalization", rules)
 
 146         self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
 
 149         async def _make_transliterator() -> Any:
 
 150             rules = await self.conn.get_property('tokenizer_import_transliteration')
 
 151             return Transliterator.createFromRules("transliteration", rules)
 
 153         self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
 
 154                                                                _make_transliterator)
 
 156         await self._setup_preprocessing()
 
 158         if 'word' not in self.conn.t.meta.tables:
 
 159             sa.Table('word', self.conn.t.meta,
 
 160                      sa.Column('word_id', sa.Integer),
 
 161                      sa.Column('word_token', sa.Text, nullable=False),
 
 162                      sa.Column('type', sa.Text, nullable=False),
 
 163                      sa.Column('word', sa.Text),
 
 164                      sa.Column('info', Json))
 
 166     async def _setup_preprocessing(self) -> None:
 
 167         """ Load the rules for preprocessing and set up the handlers.
 
 170         rules = self.conn.config.load_sub_configuration('icu_tokenizer.yaml',
 
 171                                                         config='TOKENIZER_CONFIG')
 
 172         preprocessing_rules = rules.get('query-preprocessing', [])
 
 174         self.preprocessors = []
 
 176         for func in preprocessing_rules:
 
 177             if 'step' not in func:
 
 178                 raise UsageError("Preprocessing rule is missing the 'step' attribute.")
 
 179             if not isinstance(func['step'], str):
 
 180                 raise UsageError("'step' attribute must be a simple string.")
 
 182             module = self.conn.config.load_plugin_module(
 
 183                         func['step'], 'nominatim_api.query_preprocessing')
 
 184             self.preprocessors.append(
 
 185                 module.create(QueryConfig(func).set_normalizer(self.normalizer)))
 
 187     async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
 
 188         """ Analyze the given list of phrases and return the
 
 191         log().section('Analyze query (using ICU tokenizer)')
 
 192         for func in self.preprocessors:
 
 193             phrases = func(phrases)
 
 194         query = qmod.QueryStruct(phrases)
 
 196         log().var_dump('Normalized query', query.source)
 
 200         parts, words = self.split_query(query)
 
 201         log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
 
 203         for row in await self.lookup_in_db(list(words.keys())):
 
 204             for trange in words[row.word_token]:
 
 205                 token = ICUToken.from_db_row(row)
 
 207                     if row.info['op'] in ('in', 'near'):
 
 208                         if trange.start == 0:
 
 209                             query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
 
 211                         if trange.start == 0 and trange.end == query.num_token_slots():
 
 212                             query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
 
 214                             query.add_token(trange, qmod.TokenType.QUALIFIER, token)
 
 216                     query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
 
 218         self.add_extra_tokens(query, parts)
 
 219         self.rerank_tokens(query, parts)
 
 221         log().table_dump('Word tokens', _dump_word_tokens(query))
 
 225     def normalize_text(self, text: str) -> str:
 
 226         """ Bring the given text into a normalized form. That is the
 
 227             standardized form search will work with. All information removed
 
 228             at this stage is inevitably lost.
 
 230         return cast(str, self.normalizer.transliterate(text))
 
 232     def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
 
 233         """ Transliterate the phrases and split them into tokens.
 
 235             Returns the list of transliterated tokens together with their
 
 236             normalized form and a dictionary of words for lookup together
 
 239         parts: QueryParts = []
 
 241         words = defaultdict(list)
 
 243         for phrase in query.source:
 
 244             query.nodes[-1].ptype = phrase.ptype
 
 245             for word in phrase.text.split(' '):
 
 246                 trans = self.transliterator.transliterate(word)
 
 248                     for term in trans.split(' '):
 
 250                             parts.append(QueryPart(term, word, wordnr))
 
 251                             query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
 
 252                     query.nodes[-1].btype = qmod.BreakType.WORD
 
 254             query.nodes[-1].btype = qmod.BreakType.PHRASE
 
 256             for word, wrange in yield_words(parts, phrase_start):
 
 257                 words[word].append(wrange)
 
 259             phrase_start = len(parts)
 
 260         query.nodes[-1].btype = qmod.BreakType.END
 
 264     async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
 
 265         """ Return the token information from the database for the
 
 268         t = self.conn.t.meta.tables['word']
 
 269         return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
 
 271     def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
 
 272         """ Add tokens to query that are not saved in the database.
 
 274         for part, node, i in zip(parts, query.nodes, range(1000)):
 
 275             if len(part.token) <= 4 and part[0].isdigit()\
 
 276                and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
 
 277                 query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
 
 278                                 ICUToken(penalty=0.5, token=0,
 
 279                                          count=1, addr_count=1, lookup_word=part.token,
 
 280                                          word_token=part.token, info=None))
 
 282     def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
 
 283         """ Add penalties to tokens that depend on presence of other token.
 
 285         for i, node, tlist in query.iter_token_lists():
 
 286             if tlist.ttype == qmod.TokenType.POSTCODE:
 
 287                 for repl in node.starting:
 
 288                     if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
 
 289                        and (repl.ttype != qmod.TokenType.HOUSENUMBER
 
 290                             or len(tlist.tokens[0].lookup_word) > 4):
 
 291                         repl.add_penalty(0.39)
 
 292             elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
 
 293                   and len(tlist.tokens[0].lookup_word) <= 3):
 
 294                 if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
 
 295                     for repl in node.starting:
 
 296                         if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
 
 297                             repl.add_penalty(0.5 - tlist.tokens[0].penalty)
 
 298             elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
 
 299                 norm = parts[i].normalized
 
 300                 for j in range(i + 1, tlist.end):
 
 301                     if parts[j - 1].word_number != parts[j].word_number:
 
 302                         norm += '  ' + parts[j].normalized
 
 303                 for token in tlist.tokens:
 
 304                     cast(ICUToken, token).rematch(norm)
 
 307 def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
 
 308     out = query.nodes[0].btype.value
 
 309     for node, part in zip(query.nodes[1:], parts):
 
 310         out += part.token + node.btype.value
 
 314 def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
 
 315     yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
 
 316     for node in query.nodes:
 
 317         for tlist in node.starting:
 
 318             for token in tlist.tokens:
 
 319                 t = cast(ICUToken, token)
 
 320                 yield [tlist.ttype.name, t.token, t.word_token or '',
 
 321                        t.lookup_word or '', t.penalty, t.count, t.info]
 
 324 async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
 
 325     """ Create and set up a new query analyzer for a database based
 
 326         on the ICU tokenizer.
 
 328     out = ICUQueryAnalyzer(conn)