]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/db_search_builder.py
NearSearch needs to inherit penalty from inner search
[nominatim.git] / nominatim / api / search / db_search_builder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Convertion from token assignment to an abstract DB search.
9 """
10 from typing import Optional, List, Tuple, Iterator, Dict
11 import heapq
12
13 from nominatim.api.types import SearchDetails, DataLayer
14 from nominatim.api.search.query import QueryStruct, Token, TokenType, TokenRange, BreakType
15 from nominatim.api.search.token_assignment import TokenAssignment
16 import nominatim.api.search.db_search_fields as dbf
17 import nominatim.api.search.db_searches as dbs
18
19
20 def wrap_near_search(categories: List[Tuple[str, str]],
21                      search: dbs.AbstractSearch) -> dbs.NearSearch:
22     """ Create a new search that wraps the given search in a search
23         for near places of the given category.
24     """
25     return dbs.NearSearch(penalty=search.penalty,
26                           categories=dbf.WeightedCategories(categories,
27                                                             [0.0] * len(categories)),
28                           search=search)
29
30
31 def build_poi_search(category: List[Tuple[str, str]],
32                      countries: Optional[List[str]]) -> dbs.PoiSearch:
33     """ Create a new search for places by the given category, possibly
34         constraint to the given countries.
35     """
36     if countries:
37         ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
38     else:
39         ccs = dbf.WeightedStrings([], [])
40
41     class _PoiData(dbf.SearchData):
42         penalty = 0.0
43         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
44         countries=ccs
45
46     return dbs.PoiSearch(_PoiData())
47
48
49 class SearchBuilder:
50     """ Build the abstract search queries from token assignments.
51     """
52
53     def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
54         self.query = query
55         self.details = details
56
57
58     @property
59     def configured_for_country(self) -> bool:
60         """ Return true if the search details are configured to
61             allow countries in the result.
62         """
63         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
64                and self.details.layer_enabled(DataLayer.ADDRESS)
65
66
67     @property
68     def configured_for_postcode(self) -> bool:
69         """ Return true if the search details are configured to
70             allow postcodes in the result.
71         """
72         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
73                and self.details.layer_enabled(DataLayer.ADDRESS)
74
75
76     @property
77     def configured_for_housenumbers(self) -> bool:
78         """ Return true if the search details are configured to
79             allow addresses in the result.
80         """
81         return self.details.max_rank >= 30 \
82                and self.details.layer_enabled(DataLayer.ADDRESS)
83
84
85     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
86         """ Yield all possible abstract searches for the given token assignment.
87         """
88         sdata = self.get_search_data(assignment)
89         if sdata is None:
90             return
91
92         near_items = self.get_near_items(assignment)
93         if near_items is not None and not near_items:
94             return # impossible compbination of near items and category parameter
95
96         if assignment.name is None:
97             if near_items and not sdata.postcodes:
98                 sdata.qualifiers = near_items
99                 near_items = None
100                 builder = self.build_poi_search(sdata)
101             elif assignment.housenumber:
102                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
103                                                    TokenType.HOUSENUMBER)
104                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
105             else:
106                 builder = self.build_special_search(sdata, assignment.address,
107                                                     bool(near_items))
108         else:
109             builder = self.build_name_search(sdata, assignment.name, assignment.address,
110                                              bool(near_items))
111
112         if near_items:
113             penalty = min(near_items.penalties)
114             near_items.penalties = [p - penalty for p in near_items.penalties]
115             for search in builder:
116                 search_penalty = search.penalty
117                 search.penalty = 0.0
118                 yield dbs.NearSearch(penalty + assignment.penalty + search_penalty,
119                                      near_items, search)
120         else:
121             for search in builder:
122                 search.penalty += assignment.penalty
123                 yield search
124
125
126     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
127         """ Build abstract search query for a simple category search.
128             This kind of search requires an additional geographic constraint.
129         """
130         if not sdata.housenumbers \
131            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
132             yield dbs.PoiSearch(sdata)
133
134
135     def build_special_search(self, sdata: dbf.SearchData,
136                              address: List[TokenRange],
137                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
138         """ Build abstract search queries for searches that do not involve
139             a named place.
140         """
141         if sdata.qualifiers:
142             # No special searches over qualifiers supported.
143             return
144
145         if sdata.countries and not address and not sdata.postcodes \
146            and self.configured_for_country:
147             yield dbs.CountrySearch(sdata)
148
149         if sdata.postcodes and (is_category or self.configured_for_postcode):
150             penalty = 0.0 if sdata.countries else 0.1
151             if address:
152                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
153                                                  [t.token for r in address
154                                                   for t in self.query.get_partials_list(r)],
155                                                  'restrict')]
156                 penalty += 0.2
157             yield dbs.PostcodeSearch(penalty, sdata)
158
159
160     def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
161                                  address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
162         """ Build a simple address search for special entries where the
163             housenumber is the main name token.
164         """
165         sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], 'lookup_any')]
166
167         partials = [t for trange in address
168                        for t in self.query.get_partials_list(trange)]
169
170         if len(partials) != 1 or partials[0].count < 10000:
171             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
172                                                  [t.token for t in partials], 'lookup_all'))
173         else:
174             sdata.lookups.append(
175                 dbf.FieldLookup('nameaddress_vector',
176                                 [t.token for t
177                                  in self.query.get_tokens(address[0], TokenType.WORD)],
178                                 'lookup_any'))
179
180         sdata.housenumbers = dbf.WeightedStrings([], [])
181         yield dbs.PlaceSearch(0.05, sdata, sum(t.count for t in hnrs))
182
183
184     def build_name_search(self, sdata: dbf.SearchData,
185                           name: TokenRange, address: List[TokenRange],
186                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
187         """ Build abstract search queries for simple name or address searches.
188         """
189         if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
190             ranking = self.get_name_ranking(name)
191             name_penalty = ranking.normalize_penalty()
192             if ranking.rankings:
193                 sdata.rankings.append(ranking)
194             for penalty, count, lookup in self.yield_lookups(name, address):
195                 sdata.lookups = lookup
196                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
197
198
199     def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
200                           -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
201         """ Yield all variants how the given name and address should best
202             be searched for. This takes into account how frequent the terms
203             are and tries to find a lookup that optimizes index use.
204         """
205         penalty = 0.0 # extra penalty
206         name_partials = self.query.get_partials_list(name)
207         name_tokens = [t.token for t in name_partials]
208
209         addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
210         addr_tokens = [t.token for t in addr_partials]
211
212         partials_indexed = all(t.is_indexed for t in name_partials) \
213                            and all(t.is_indexed for t in addr_partials)
214         exp_count = min(t.count for t in name_partials) / (2**(len(name_partials) - 1))
215
216         if (len(name_partials) > 3 or exp_count < 8000) and partials_indexed:
217             yield penalty, exp_count, dbf.lookup_by_names(name_tokens, addr_tokens)
218             return
219
220         # Partial term to frequent. Try looking up by rare full names first.
221         name_fulls = self.query.get_tokens(name, TokenType.WORD)
222         fulls_count = sum(t.count for t in name_fulls)
223         # At this point drop unindexed partials from the address.
224         # This might yield wrong results, nothing we can do about that.
225         if not partials_indexed:
226             addr_tokens = [t.token for t in addr_partials if t.is_indexed]
227             penalty += 1.2 * sum(t.penalty for t in addr_partials if not t.is_indexed)
228         # Any of the full names applies with all of the partials from the address
229         yield penalty, fulls_count / (2**len(addr_partials)),\
230               dbf.lookup_by_any_name([t.token for t in name_fulls], addr_tokens,
231                                      'restrict' if fulls_count < 10000 else 'lookup_all')
232
233         # To catch remaining results, lookup by name and address
234         # We only do this if there is a reasonable number of results expected.
235         exp_count = exp_count / (2**len(addr_partials)) if addr_partials else exp_count
236         if exp_count < 10000 and all(t.is_indexed for t in name_partials):
237             lookup = [dbf.FieldLookup('name_vector', name_tokens, 'lookup_all')]
238             if addr_tokens:
239                 lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all'))
240             penalty += 0.35 * max(0, 5 - len(name_partials) - len(addr_tokens))
241             yield penalty, exp_count, lookup
242
243
244     def get_name_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
245         """ Create a ranking expression for a name term in the given range.
246         """
247         name_fulls = self.query.get_tokens(trange, TokenType.WORD)
248         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
249         ranks.sort(key=lambda r: r.penalty)
250         # Fallback, sum of penalty for partials
251         name_partials = self.query.get_partials_list(trange)
252         default = sum(t.penalty for t in name_partials) + 0.2
253         return dbf.FieldRanking('name_vector', default, ranks)
254
255
256     def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
257         """ Create a list of ranking expressions for an address term
258             for the given ranges.
259         """
260         todo: List[Tuple[int, int, dbf.RankedTokens]] = []
261         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
262         ranks: List[dbf.RankedTokens] = []
263
264         while todo: # pylint: disable=too-many-nested-blocks
265             neglen, pos, rank = heapq.heappop(todo)
266             for tlist in self.query.nodes[pos].starting:
267                 if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
268                     if tlist.end < trange.end:
269                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
270                         if tlist.ttype == TokenType.PARTIAL:
271                             penalty = rank.penalty + chgpenalty \
272                                       + max(t.penalty for t in tlist.tokens)
273                             heapq.heappush(todo, (neglen - 1, tlist.end,
274                                                   dbf.RankedTokens(penalty, rank.tokens)))
275                         else:
276                             for t in tlist.tokens:
277                                 heapq.heappush(todo, (neglen - 1, tlist.end,
278                                                       rank.with_token(t, chgpenalty)))
279                     elif tlist.end == trange.end:
280                         if tlist.ttype == TokenType.PARTIAL:
281                             ranks.append(dbf.RankedTokens(rank.penalty
282                                                           + max(t.penalty for t in tlist.tokens),
283                                                           rank.tokens))
284                         else:
285                             ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
286                         if len(ranks) >= 10:
287                             # Too many variants, bail out and only add
288                             # Worst-case Fallback: sum of penalty of partials
289                             name_partials = self.query.get_partials_list(trange)
290                             default = sum(t.penalty for t in name_partials) + 0.2
291                             ranks.append(dbf.RankedTokens(rank.penalty + default, []))
292                             # Bail out of outer loop
293                             todo.clear()
294                             break
295
296         ranks.sort(key=lambda r: len(r.tokens))
297         default = ranks[0].penalty + 0.3
298         del ranks[0]
299         ranks.sort(key=lambda r: r.penalty)
300
301         return dbf.FieldRanking('nameaddress_vector', default, ranks)
302
303
304     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
305         """ Collect the tokens for the non-name search fields in the
306             assignment.
307         """
308         sdata = dbf.SearchData()
309         sdata.penalty = assignment.penalty
310         if assignment.country:
311             tokens = self.query.get_tokens(assignment.country, TokenType.COUNTRY)
312             if self.details.countries:
313                 tokens = [t for t in tokens if t.lookup_word in self.details.countries]
314                 if not tokens:
315                     return None
316             sdata.set_strings('countries', tokens)
317         elif self.details.countries:
318             sdata.countries = dbf.WeightedStrings(self.details.countries,
319                                                   [0.0] * len(self.details.countries))
320         if assignment.housenumber:
321             sdata.set_strings('housenumbers',
322                               self.query.get_tokens(assignment.housenumber,
323                                                     TokenType.HOUSENUMBER))
324         if assignment.postcode:
325             sdata.set_strings('postcodes',
326                               self.query.get_tokens(assignment.postcode,
327                                                     TokenType.POSTCODE))
328         if assignment.qualifier:
329             tokens = self.query.get_tokens(assignment.qualifier, TokenType.QUALIFIER)
330             if self.details.categories:
331                 tokens = [t for t in tokens if t.get_category() in self.details.categories]
332                 if not tokens:
333                     return None
334             sdata.set_qualifiers(tokens)
335         elif self.details.categories:
336             sdata.qualifiers = dbf.WeightedCategories(self.details.categories,
337                                                       [0.0] * len(self.details.categories))
338
339         if assignment.address:
340             sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
341         else:
342             sdata.rankings = []
343
344         return sdata
345
346
347     def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
348         """ Collect tokens for near items search or use the categories
349             requested per parameter.
350             Returns None if no category search is requested.
351         """
352         if assignment.near_item:
353             tokens: Dict[Tuple[str, str], float] = {}
354             for t in self.query.get_tokens(assignment.near_item, TokenType.NEAR_ITEM):
355                 cat = t.get_category()
356                 # The category of a near search will be that of near_item.
357                 # Thus, if search is restricted to a category parameter,
358                 # the two sets must intersect.
359                 if (not self.details.categories or cat in self.details.categories)\
360                    and t.penalty < tokens.get(cat, 1000.0):
361                     tokens[cat] = t.penalty
362             return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))
363
364         return None
365
366
367 PENALTY_WORDCHANGE = {
368     BreakType.START: 0.0,
369     BreakType.END: 0.0,
370     BreakType.PHRASE: 0.0,
371     BreakType.WORD: 0.1,
372     BreakType.PART: 0.2,
373     BreakType.TOKEN: 0.4
374 }