]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/db_search_builder.py
Merge pull request #3262 from lonvia/fix-category-search
[nominatim.git] / nominatim / api / search / db_search_builder.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Convertion from token assignment to an abstract DB search.
9 """
10 from typing import Optional, List, Tuple, Iterator, Dict
11 import heapq
12
13 from nominatim.api.types import SearchDetails, DataLayer
14 from nominatim.api.search.query import QueryStruct, Token, TokenType, TokenRange, BreakType
15 from nominatim.api.search.token_assignment import TokenAssignment
16 import nominatim.api.search.db_search_fields as dbf
17 import nominatim.api.search.db_searches as dbs
18
19
20 def wrap_near_search(categories: List[Tuple[str, str]],
21                      search: dbs.AbstractSearch) -> dbs.NearSearch:
22     """ Create a new search that wraps the given search in a search
23         for near places of the given category.
24     """
25     return dbs.NearSearch(penalty=search.penalty,
26                           categories=dbf.WeightedCategories(categories,
27                                                             [0.0] * len(categories)),
28                           search=search)
29
30
31 def build_poi_search(category: List[Tuple[str, str]],
32                      countries: Optional[List[str]]) -> dbs.PoiSearch:
33     """ Create a new search for places by the given category, possibly
34         constraint to the given countries.
35     """
36     if countries:
37         ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
38     else:
39         ccs = dbf.WeightedStrings([], [])
40
41     class _PoiData(dbf.SearchData):
42         penalty = 0.0
43         qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
44         countries=ccs
45
46     return dbs.PoiSearch(_PoiData())
47
48
49 class SearchBuilder:
50     """ Build the abstract search queries from token assignments.
51     """
52
53     def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
54         self.query = query
55         self.details = details
56
57
58     @property
59     def configured_for_country(self) -> bool:
60         """ Return true if the search details are configured to
61             allow countries in the result.
62         """
63         return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
64                and self.details.layer_enabled(DataLayer.ADDRESS)
65
66
67     @property
68     def configured_for_postcode(self) -> bool:
69         """ Return true if the search details are configured to
70             allow postcodes in the result.
71         """
72         return self.details.min_rank <= 5 and self.details.max_rank >= 11\
73                and self.details.layer_enabled(DataLayer.ADDRESS)
74
75
76     @property
77     def configured_for_housenumbers(self) -> bool:
78         """ Return true if the search details are configured to
79             allow addresses in the result.
80         """
81         return self.details.max_rank >= 30 \
82                and self.details.layer_enabled(DataLayer.ADDRESS)
83
84
85     def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
86         """ Yield all possible abstract searches for the given token assignment.
87         """
88         sdata = self.get_search_data(assignment)
89         if sdata is None:
90             return
91
92         near_items = self.get_near_items(assignment)
93         if near_items is not None and not near_items:
94             return # impossible compbination of near items and category parameter
95
96         if assignment.name is None:
97             if near_items and not sdata.postcodes:
98                 sdata.qualifiers = near_items
99                 near_items = None
100                 builder = self.build_poi_search(sdata)
101             elif assignment.housenumber:
102                 hnr_tokens = self.query.get_tokens(assignment.housenumber,
103                                                    TokenType.HOUSENUMBER)
104                 builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
105             else:
106                 builder = self.build_special_search(sdata, assignment.address,
107                                                     bool(near_items))
108         else:
109             builder = self.build_name_search(sdata, assignment.name, assignment.address,
110                                              bool(near_items))
111
112         if near_items:
113             penalty = min(near_items.penalties)
114             near_items.penalties = [p - penalty for p in near_items.penalties]
115             for search in builder:
116                 yield dbs.NearSearch(penalty + assignment.penalty, near_items, search)
117         else:
118             for search in builder:
119                 search.penalty += assignment.penalty
120                 yield search
121
122
123     def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
124         """ Build abstract search query for a simple category search.
125             This kind of search requires an additional geographic constraint.
126         """
127         if not sdata.housenumbers \
128            and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
129             yield dbs.PoiSearch(sdata)
130
131
132     def build_special_search(self, sdata: dbf.SearchData,
133                              address: List[TokenRange],
134                              is_category: bool) -> Iterator[dbs.AbstractSearch]:
135         """ Build abstract search queries for searches that do not involve
136             a named place.
137         """
138         if sdata.qualifiers:
139             # No special searches over qualifiers supported.
140             return
141
142         if sdata.countries and not address and not sdata.postcodes \
143            and self.configured_for_country:
144             yield dbs.CountrySearch(sdata)
145
146         if sdata.postcodes and (is_category or self.configured_for_postcode):
147             penalty = 0.0 if sdata.countries else 0.1
148             if address:
149                 sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
150                                                  [t.token for r in address
151                                                   for t in self.query.get_partials_list(r)],
152                                                  'restrict')]
153                 penalty += 0.2
154             yield dbs.PostcodeSearch(penalty, sdata)
155
156
157     def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
158                                  address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
159         """ Build a simple address search for special entries where the
160             housenumber is the main name token.
161         """
162         sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], 'lookup_any')]
163
164         partials = [t for trange in address
165                        for t in self.query.get_partials_list(trange)]
166
167         if len(partials) != 1 or partials[0].count < 10000:
168             sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
169                                                  [t.token for t in partials], 'lookup_all'))
170         else:
171             sdata.lookups.append(
172                 dbf.FieldLookup('nameaddress_vector',
173                                 [t.token for t
174                                  in self.query.get_tokens(address[0], TokenType.WORD)],
175                                 'lookup_any'))
176
177         sdata.housenumbers = dbf.WeightedStrings([], [])
178         yield dbs.PlaceSearch(0.05, sdata, sum(t.count for t in hnrs))
179
180
181     def build_name_search(self, sdata: dbf.SearchData,
182                           name: TokenRange, address: List[TokenRange],
183                           is_category: bool) -> Iterator[dbs.AbstractSearch]:
184         """ Build abstract search queries for simple name or address searches.
185         """
186         if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
187             ranking = self.get_name_ranking(name)
188             name_penalty = ranking.normalize_penalty()
189             if ranking.rankings:
190                 sdata.rankings.append(ranking)
191             for penalty, count, lookup in self.yield_lookups(name, address):
192                 sdata.lookups = lookup
193                 yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
194
195
196     def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
197                           -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
198         """ Yield all variants how the given name and address should best
199             be searched for. This takes into account how frequent the terms
200             are and tries to find a lookup that optimizes index use.
201         """
202         penalty = 0.0 # extra penalty
203         name_partials = self.query.get_partials_list(name)
204         name_tokens = [t.token for t in name_partials]
205
206         addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
207         addr_tokens = [t.token for t in addr_partials]
208
209         partials_indexed = all(t.is_indexed for t in name_partials) \
210                            and all(t.is_indexed for t in addr_partials)
211         exp_count = min(t.count for t in name_partials) / (2**(len(name_partials) - 1))
212
213         if (len(name_partials) > 3 or exp_count < 8000) and partials_indexed:
214             yield penalty, exp_count, dbf.lookup_by_names(name_tokens, addr_tokens)
215             return
216
217         # Partial term to frequent. Try looking up by rare full names first.
218         name_fulls = self.query.get_tokens(name, TokenType.WORD)
219         fulls_count = sum(t.count for t in name_fulls)
220         # At this point drop unindexed partials from the address.
221         # This might yield wrong results, nothing we can do about that.
222         if not partials_indexed:
223             addr_tokens = [t.token for t in addr_partials if t.is_indexed]
224             penalty += 1.2 * sum(t.penalty for t in addr_partials if not t.is_indexed)
225         # Any of the full names applies with all of the partials from the address
226         yield penalty, fulls_count / (2**len(addr_partials)),\
227               dbf.lookup_by_any_name([t.token for t in name_fulls], addr_tokens,
228                                      'restrict' if fulls_count < 10000 else 'lookup_all')
229
230         # To catch remaining results, lookup by name and address
231         # We only do this if there is a reasonable number of results expected.
232         exp_count = exp_count / (2**len(addr_partials)) if addr_partials else exp_count
233         if exp_count < 10000 and all(t.is_indexed for t in name_partials):
234             lookup = [dbf.FieldLookup('name_vector', name_tokens, 'lookup_all')]
235             if addr_tokens:
236                 lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all'))
237             penalty += 0.35 * max(0, 5 - len(name_partials) - len(addr_tokens))
238             yield penalty, exp_count, lookup
239
240
241     def get_name_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
242         """ Create a ranking expression for a name term in the given range.
243         """
244         name_fulls = self.query.get_tokens(trange, TokenType.WORD)
245         ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
246         ranks.sort(key=lambda r: r.penalty)
247         # Fallback, sum of penalty for partials
248         name_partials = self.query.get_partials_list(trange)
249         default = sum(t.penalty for t in name_partials) + 0.2
250         return dbf.FieldRanking('name_vector', default, ranks)
251
252
253     def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
254         """ Create a list of ranking expressions for an address term
255             for the given ranges.
256         """
257         todo: List[Tuple[int, int, dbf.RankedTokens]] = []
258         heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
259         ranks: List[dbf.RankedTokens] = []
260
261         while todo: # pylint: disable=too-many-nested-blocks
262             neglen, pos, rank = heapq.heappop(todo)
263             for tlist in self.query.nodes[pos].starting:
264                 if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
265                     if tlist.end < trange.end:
266                         chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
267                         if tlist.ttype == TokenType.PARTIAL:
268                             penalty = rank.penalty + chgpenalty \
269                                       + max(t.penalty for t in tlist.tokens)
270                             heapq.heappush(todo, (neglen - 1, tlist.end,
271                                                   dbf.RankedTokens(penalty, rank.tokens)))
272                         else:
273                             for t in tlist.tokens:
274                                 heapq.heappush(todo, (neglen - 1, tlist.end,
275                                                       rank.with_token(t, chgpenalty)))
276                     elif tlist.end == trange.end:
277                         if tlist.ttype == TokenType.PARTIAL:
278                             ranks.append(dbf.RankedTokens(rank.penalty
279                                                           + max(t.penalty for t in tlist.tokens),
280                                                           rank.tokens))
281                         else:
282                             ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
283                         if len(ranks) >= 10:
284                             # Too many variants, bail out and only add
285                             # Worst-case Fallback: sum of penalty of partials
286                             name_partials = self.query.get_partials_list(trange)
287                             default = sum(t.penalty for t in name_partials) + 0.2
288                             ranks.append(dbf.RankedTokens(rank.penalty + default, []))
289                             # Bail out of outer loop
290                             todo.clear()
291                             break
292
293         ranks.sort(key=lambda r: len(r.tokens))
294         default = ranks[0].penalty + 0.3
295         del ranks[0]
296         ranks.sort(key=lambda r: r.penalty)
297
298         return dbf.FieldRanking('nameaddress_vector', default, ranks)
299
300
301     def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
302         """ Collect the tokens for the non-name search fields in the
303             assignment.
304         """
305         sdata = dbf.SearchData()
306         sdata.penalty = assignment.penalty
307         if assignment.country:
308             tokens = self.query.get_tokens(assignment.country, TokenType.COUNTRY)
309             if self.details.countries:
310                 tokens = [t for t in tokens if t.lookup_word in self.details.countries]
311                 if not tokens:
312                     return None
313             sdata.set_strings('countries', tokens)
314         elif self.details.countries:
315             sdata.countries = dbf.WeightedStrings(self.details.countries,
316                                                   [0.0] * len(self.details.countries))
317         if assignment.housenumber:
318             sdata.set_strings('housenumbers',
319                               self.query.get_tokens(assignment.housenumber,
320                                                     TokenType.HOUSENUMBER))
321         if assignment.postcode:
322             sdata.set_strings('postcodes',
323                               self.query.get_tokens(assignment.postcode,
324                                                     TokenType.POSTCODE))
325         if assignment.qualifier:
326             tokens = self.query.get_tokens(assignment.qualifier, TokenType.QUALIFIER)
327             if self.details.categories:
328                 tokens = [t for t in tokens if t.get_category() in self.details.categories]
329                 if not tokens:
330                     return None
331             sdata.set_qualifiers(tokens)
332         elif self.details.categories:
333             sdata.qualifiers = dbf.WeightedCategories(self.details.categories,
334                                                       [0.0] * len(self.details.categories))
335
336         if assignment.address:
337             sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
338         else:
339             sdata.rankings = []
340
341         return sdata
342
343
344     def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
345         """ Collect tokens for near items search or use the categories
346             requested per parameter.
347             Returns None if no category search is requested.
348         """
349         if assignment.near_item:
350             tokens: Dict[Tuple[str, str], float] = {}
351             for t in self.query.get_tokens(assignment.near_item, TokenType.NEAR_ITEM):
352                 cat = t.get_category()
353                 # The category of a near search will be that of near_item.
354                 # Thus, if search is restricted to a category parameter,
355                 # the two sets must intersect.
356                 if (not self.details.categories or cat in self.details.categories)\
357                    and t.penalty < tokens.get(cat, 1000.0):
358                     tokens[cat] = t.penalty
359             return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))
360
361         return None
362
363
364 PENALTY_WORDCHANGE = {
365     BreakType.START: 0.0,
366     BreakType.END: 0.0,
367     BreakType.PHRASE: 0.0,
368     BreakType.WORD: 0.1,
369     BreakType.PART: 0.2,
370     BreakType.TOKEN: 0.4
371 }