]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/db_search_builder.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / api / search / db_search_builder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Convertion from token assignment to an abstract DB search.
9 """
10 from typing import Optional, List, Tuple, Iterator, Dict
11 import heapq
12
13 from nominatim.api.types import SearchDetails, DataLayer
14 from nominatim.api.search.query import QueryStruct, Token, TokenType, TokenRange, BreakType
15 from nominatim.api.search.token_assignment import TokenAssignment
16 import nominatim.api.search.db_search_fields as dbf
17 import nominatim.api.search.db_searches as dbs
18
19
20 def wrap_near_search(categories: List[Tuple[str, str]],
21                      search: dbs.AbstractSearch) -> dbs.NearSearch:
22     """ Create a new search that wraps the given search in a search
23         for near places of the given category.
24     """
25     return dbs.NearSearch(penalty=search.penalty,
26                           categories=dbf.WeightedCategories(categories,
27                                                             [0.0] * len(categories)),
28                           search=search)
29
30
31 def build_poi_search(category: List[Tuple[str, str]],
32                      countries: Optional[List[str]]) -> dbs.PoiSearch:
33     """ Create a new search for places by the given category, possibly
34         constraint to the given countries.
35     """
36     if countries:
37         ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
38     else:
39         ccs = dbf.WeightedStrings([], [])
40
41     class _PoiData(dbf.SearchData):
42         penalty = 0.0
43         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
44         countries=ccs
45
46     return dbs.PoiSearch(_PoiData())
47
48
49 class SearchBuilder:
50     """ Build the abstract search queries from token assignments.
51     """
52
53     def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
54         self.query = query
55         self.details = details
56
57
58     @property
59     def configured_for_country(self) -> bool:
60         """ Return true if the search details are configured to
61             allow countries in the result.
62         """
63         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
64                and self.details.layer_enabled(DataLayer.ADDRESS)
65
66
67     @property
68     def configured_for_postcode(self) -> bool:
69         """ Return true if the search details are configured to
70             allow postcodes in the result.
71         """
72         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
73                and self.details.layer_enabled(DataLayer.ADDRESS)
74
75
76     @property
77     def configured_for_housenumbers(self) -> bool:
78         """ Return true if the search details are configured to
79             allow addresses in the result.
80         """
81         return self.details.max_rank >= 30 \
82                and self.details.layer_enabled(DataLayer.ADDRESS)
83
84
85     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
86         """ Yield all possible abstract searches for the given token assignment.
87         """
88         sdata = self.get_search_data(assignment)
89         if sdata is None:
90             return
91
92         near_items = self.get_near_items(assignment)
93         if near_items is not None and not near_items:
94             return # impossible compbination of near items and category parameter
95
96         if assignment.name is None:
97             if near_items and not sdata.postcodes:
98                 sdata.qualifiers = near_items
99                 near_items = None
100                 builder = self.build_poi_search(sdata)
101             elif assignment.housenumber:
102                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
103                                                    TokenType.HOUSENUMBER)
104                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
105             else:
106                 builder = self.build_special_search(sdata, assignment.address,
107                                                     bool(near_items))
108         else:
109             builder = self.build_name_search(sdata, assignment.name, assignment.address,
110                                              bool(near_items))
111
112         if near_items:
113             penalty = min(near_items.penalties)
114             near_items.penalties = [p - penalty for p in near_items.penalties]
115             for search in builder:
116                 search_penalty = search.penalty
117                 search.penalty = 0.0
118                 yield dbs.NearSearch(penalty + assignment.penalty + search_penalty,
119                                      near_items, search)
120         else:
121             for search in builder:
122                 search.penalty += assignment.penalty
123                 yield search
124
125
126     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
127         """ Build abstract search query for a simple category search.
128             This kind of search requires an additional geographic constraint.
129         """
130         if not sdata.housenumbers \
131            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
132             yield dbs.PoiSearch(sdata)
133
134
135     def build_special_search(self, sdata: dbf.SearchData,
136                              address: List[TokenRange],
137                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
138         """ Build abstract search queries for searches that do not involve
139             a named place.
140         """
141         if sdata.qualifiers:
142             # No special searches over qualifiers supported.
143             return
144
145         if sdata.countries and not address and not sdata.postcodes \
146            and self.configured_for_country:
147             yield dbs.CountrySearch(sdata)
148
149         if sdata.postcodes and (is_category or self.configured_for_postcode):
150             penalty = 0.0 if sdata.countries else 0.1
151             if address:
152                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
153                                                  [t.token for r in address
154                                                   for t in self.query.get_partials_list(r)],
155                                                  'restrict')]
156                 penalty += 0.2
157             yield dbs.PostcodeSearch(penalty, sdata)
158
159
160     def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
161                                  address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
162         """ Build a simple address search for special entries where the
163             housenumber is the main name token.
164         """
165         sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], 'lookup_any')]
166         expected_count = sum(t.count for t in hnrs)
167
168         partials = [t for trange in address
169                        for t in self.query.get_partials_list(trange)]
170
171         if expected_count < 8000:
172             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
173                                                  [t.token for t in partials], 'restrict'))
174         elif len(partials) != 1 or partials[0].count < 10000:
175             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
176                                                  [t.token for t in partials], 'lookup_all'))
177         else:
178             sdata.lookups.append(
179                 dbf.FieldLookup('nameaddress_vector',
180                                 [t.token for t
181                                  in self.query.get_tokens(address[0], TokenType.WORD)],
182                                 'lookup_any'))
183
184         sdata.housenumbers = dbf.WeightedStrings([], [])
185         yield dbs.PlaceSearch(0.05, sdata, expected_count)
186
187
188     def build_name_search(self, sdata: dbf.SearchData,
189                           name: TokenRange, address: List[TokenRange],
190                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
191         """ Build abstract search queries for simple name or address searches.
192         """
193         if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
194             ranking = self.get_name_ranking(name)
195             name_penalty = ranking.normalize_penalty()
196             if ranking.rankings:
197                 sdata.rankings.append(ranking)
198             for penalty, count, lookup in self.yield_lookups(name, address):
199                 sdata.lookups = lookup
200                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
201
202
203     def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
204                           -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
205         """ Yield all variants how the given name and address should best
206             be searched for. This takes into account how frequent the terms
207             are and tries to find a lookup that optimizes index use.
208         """
209         penalty = 0.0 # extra penalty
210         name_partials = self.query.get_partials_list(name)
211         name_tokens = [t.token for t in name_partials]
212
213         addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
214         addr_tokens = [t.token for t in addr_partials]
215
216         partials_indexed = all(t.is_indexed for t in name_partials) \
217                            and all(t.is_indexed for t in addr_partials)
218         exp_count = min(t.count for t in name_partials) / (2**(len(name_partials) - 1))
219
220         if (len(name_partials) > 3 or exp_count < 8000) and partials_indexed:
221             yield penalty, exp_count, dbf.lookup_by_names(name_tokens, addr_tokens)
222             return
223
224         # Partial term to frequent. Try looking up by rare full names first.
225         name_fulls = self.query.get_tokens(name, TokenType.WORD)
226         fulls_count = sum(t.count for t in name_fulls)
227         # At this point drop unindexed partials from the address.
228         # This might yield wrong results, nothing we can do about that.
229         if not partials_indexed:
230             addr_tokens = [t.token for t in addr_partials if t.is_indexed]
231             penalty += 1.2 * sum(t.penalty for t in addr_partials if not t.is_indexed)
232         # Any of the full names applies with all of the partials from the address
233         yield penalty, fulls_count / (2**len(addr_partials)),\
234               dbf.lookup_by_any_name([t.token for t in name_fulls], addr_tokens,
235                                      'restrict' if fulls_count < 10000 else 'lookup_all')
236
237         # To catch remaining results, lookup by name and address
238         # We only do this if there is a reasonable number of results expected.
239         exp_count = exp_count / (2**len(addr_partials)) if addr_partials else exp_count
240         if exp_count < 10000 and all(t.is_indexed for t in name_partials):
241             lookup = [dbf.FieldLookup('name_vector', name_tokens, 'lookup_all')]
242             if addr_tokens:
243                 lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all'))
244             penalty += 0.35 * max(0, 5 - len(name_partials) - len(addr_tokens))
245             yield penalty, exp_count, lookup
246
247
248     def get_name_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
249         """ Create a ranking expression for a name term in the given range.
250         """
251         name_fulls = self.query.get_tokens(trange, TokenType.WORD)
252         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
253         ranks.sort(key=lambda r: r.penalty)
254         # Fallback, sum of penalty for partials
255         name_partials = self.query.get_partials_list(trange)
256         default = sum(t.penalty for t in name_partials) + 0.2
257         return dbf.FieldRanking('name_vector', default, ranks)
258
259
260     def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
261         """ Create a list of ranking expressions for an address term
262             for the given ranges.
263         """
264         todo: List[Tuple[int, int, dbf.RankedTokens]] = []
265         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
266         ranks: List[dbf.RankedTokens] = []
267
268         while todo: # pylint: disable=too-many-nested-blocks
269             neglen, pos, rank = heapq.heappop(todo)
270             for tlist in self.query.nodes[pos].starting:
271                 if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
272                     if tlist.end < trange.end:
273                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
274                         if tlist.ttype == TokenType.PARTIAL:
275                             penalty = rank.penalty + chgpenalty \
276                                       + max(t.penalty for t in tlist.tokens)
277                             heapq.heappush(todo, (neglen - 1, tlist.end,
278                                                   dbf.RankedTokens(penalty, rank.tokens)))
279                         else:
280                             for t in tlist.tokens:
281                                 heapq.heappush(todo, (neglen - 1, tlist.end,
282                                                       rank.with_token(t, chgpenalty)))
283                     elif tlist.end == trange.end:
284                         if tlist.ttype == TokenType.PARTIAL:
285                             ranks.append(dbf.RankedTokens(rank.penalty
286                                                           + max(t.penalty for t in tlist.tokens),
287                                                           rank.tokens))
288                         else:
289                             ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
290                         if len(ranks) >= 10:
291                             # Too many variants, bail out and only add
292                             # Worst-case Fallback: sum of penalty of partials
293                             name_partials = self.query.get_partials_list(trange)
294                             default = sum(t.penalty for t in name_partials) + 0.2
295                             ranks.append(dbf.RankedTokens(rank.penalty + default, []))
296                             # Bail out of outer loop
297                             todo.clear()
298                             break
299
300         ranks.sort(key=lambda r: len(r.tokens))
301         default = ranks[0].penalty + 0.3
302         del ranks[0]
303         ranks.sort(key=lambda r: r.penalty)
304
305         return dbf.FieldRanking('nameaddress_vector', default, ranks)
306
307
308     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
309         """ Collect the tokens for the non-name search fields in the
310             assignment.
311         """
312         sdata = dbf.SearchData()
313         sdata.penalty = assignment.penalty
314         if assignment.country:
315             tokens = self.query.get_tokens(assignment.country, TokenType.COUNTRY)
316             if self.details.countries:
317                 tokens = [t for t in tokens if t.lookup_word in self.details.countries]
318                 if not tokens:
319                     return None
320             sdata.set_strings('countries', tokens)
321         elif self.details.countries:
322             sdata.countries = dbf.WeightedStrings(self.details.countries,
323                                                   [0.0] * len(self.details.countries))
324         if assignment.housenumber:
325             sdata.set_strings('housenumbers',
326                               self.query.get_tokens(assignment.housenumber,
327                                                     TokenType.HOUSENUMBER))
328         if assignment.postcode:
329             sdata.set_strings('postcodes',
330                               self.query.get_tokens(assignment.postcode,
331                                                     TokenType.POSTCODE))
332         if assignment.qualifier:
333             tokens = self.query.get_tokens(assignment.qualifier, TokenType.QUALIFIER)
334             if self.details.categories:
335                 tokens = [t for t in tokens if t.get_category() in self.details.categories]
336                 if not tokens:
337                     return None
338             sdata.set_qualifiers(tokens)
339         elif self.details.categories:
340             sdata.qualifiers = dbf.WeightedCategories(self.details.categories,
341                                                       [0.0] * len(self.details.categories))
342
343         if assignment.address:
344             sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
345         else:
346             sdata.rankings = []
347
348         return sdata
349
350
351     def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
352         """ Collect tokens for near items search or use the categories
353             requested per parameter.
354             Returns None if no category search is requested.
355         """
356         if assignment.near_item:
357             tokens: Dict[Tuple[str, str], float] = {}
358             for t in self.query.get_tokens(assignment.near_item, TokenType.NEAR_ITEM):
359                 cat = t.get_category()
360                 # The category of a near search will be that of near_item.
361                 # Thus, if search is restricted to a category parameter,
362                 # the two sets must intersect.
363                 if (not self.details.categories or cat in self.details.categories)\
364                    and t.penalty < tokens.get(cat, 1000.0):
365                     tokens[cat] = t.penalty
366             return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))
367
368         return None
369
370
371 PENALTY_WORDCHANGE = {
372     BreakType.START: 0.0,
373     BreakType.END: 0.0,
374     BreakType.PHRASE: 0.0,
375     BreakType.WORD: 0.1,
376     BreakType.PART: 0.2,
377     BreakType.TOKEN: 0.4
378 }