]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/db_search_fields.py
324a7acc2cafe5a553dc60fdb6f5ca1b948568ae
[nominatim.git] / nominatim / api / search / db_search_fields.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Data structures for more complex fields in abstract search descriptions.
9 """
10 from typing import List, Tuple, Iterator, cast, Dict
11 import dataclasses
12
13 import sqlalchemy as sa
14
15 from nominatim.typing import SaFromClause, SaColumn, SaExpression
16 from nominatim.api.search.query import Token
17 from nominatim.utils.json_writer import JsonWriter
18
19 @dataclasses.dataclass
20 class WeightedStrings:
21     """ A list of strings together with a penalty.
22     """
23     values: List[str]
24     penalties: List[float]
25
26     def __bool__(self) -> bool:
27         return bool(self.values)
28
29
30     def __iter__(self) -> Iterator[Tuple[str, float]]:
31         return iter(zip(self.values, self.penalties))
32
33
34     def get_penalty(self, value: str, default: float = 1000.0) -> float:
35         """ Get the penalty for the given value. Returns the given default
36             if the value does not exist.
37         """
38         try:
39             return self.penalties[self.values.index(value)]
40         except ValueError:
41             pass
42         return default
43
44
45 @dataclasses.dataclass
46 class WeightedCategories:
47     """ A list of class/type tuples together with a penalty.
48     """
49     values: List[Tuple[str, str]]
50     penalties: List[float]
51
52     def __bool__(self) -> bool:
53         return bool(self.values)
54
55
56     def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
57         return iter(zip(self.values, self.penalties))
58
59
60     def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
61         """ Get the penalty for the given value. Returns the given default
62             if the value does not exist.
63         """
64         try:
65             return self.penalties[self.values.index(value)]
66         except ValueError:
67             pass
68         return default
69
70
71     def sql_restrict(self, table: SaFromClause) -> SaExpression:
72         """ Return an SQLAlcheny expression that restricts the
73             class and type columns of the given table to the values
74             in the list.
75             Must not be used with an empty list.
76         """
77         assert self.values
78         if len(self.values) == 1:
79             return sa.and_(table.c.class_ == self.values[0][0],
80                            table.c.type == self.values[0][1])
81
82         return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
83                         for c, t in self.values))
84
85
86 @dataclasses.dataclass(order=True)
87 class RankedTokens:
88     """ List of tokens together with the penalty of using it.
89     """
90     penalty: float
91     tokens: List[int]
92
93     def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
94         """ Create a new RankedTokens list with the given token appended.
95             The tokens penalty as well as the given transision penalty
96             are added to the overall penalty.
97         """
98         return RankedTokens(self.penalty + t.penalty + transition_penalty,
99                             self.tokens + [t.token])
100
101
102 @dataclasses.dataclass
103 class FieldRanking:
104     """ A list of rankings to be applied sequentially until one matches.
105         The matched ranking determines the penalty. If none matches a
106         default penalty is applied.
107     """
108     column: str
109     default: float
110     rankings: List[RankedTokens]
111
112     def normalize_penalty(self) -> float:
113         """ Reduce the default and ranking penalties, such that the minimum
114             penalty is 0. Return the penalty that was subtracted.
115         """
116         if self.rankings:
117             min_penalty = min(self.default, min(r.penalty for r in self.rankings))
118         else:
119             min_penalty = self.default
120         if min_penalty > 0.0:
121             self.default -= min_penalty
122             for ranking in self.rankings:
123                 ranking.penalty -= min_penalty
124         return min_penalty
125
126
127     def sql_penalty(self, table: SaFromClause) -> SaColumn:
128         """ Create an SQL expression for the rankings.
129         """
130         assert self.rankings
131
132         rout = JsonWriter().start_array()
133         for rank in self.rankings:
134             rout.start_array().value(rank.penalty).next()
135             rout.start_array()
136             for token in rank.tokens:
137                 rout.value(token).next()
138             rout.end_array()
139             rout.end_array().next()
140         rout.end_array()
141
142         return sa.func.weigh_search(table.c[self.column], rout(), self.default)
143
144
145 @dataclasses.dataclass
146 class FieldLookup:
147     """ A list of tokens to be searched for. The column names the database
148         column to search in and the lookup_type the operator that is applied.
149         'lookup_all' requires all tokens to match. 'lookup_any' requires
150         one of the tokens to match. 'restrict' requires to match all tokens
151         but avoids the use of indexes.
152     """
153     column: str
154     tokens: List[int]
155     lookup_type: str
156
157     def sql_condition(self, table: SaFromClause) -> SaColumn:
158         """ Create an SQL expression for the given match condition.
159         """
160         col = table.c[self.column]
161         if self.lookup_type == 'lookup_all':
162             return col.contains(self.tokens)
163         if self.lookup_type == 'lookup_any':
164             return cast(SaColumn, col.overlaps(self.tokens))
165
166         return sa.func.coalesce(sa.null(), col).contains(self.tokens) # pylint: disable=not-callable
167
168
169 class SearchData:
170     """ Search fields derived from query and token assignment
171         to be used with the SQL queries.
172     """
173     penalty: float
174
175     lookups: List[FieldLookup] = []
176     rankings: List[FieldRanking]
177
178     housenumbers: WeightedStrings = WeightedStrings([], [])
179     postcodes: WeightedStrings = WeightedStrings([], [])
180     countries: WeightedStrings = WeightedStrings([], [])
181
182     qualifiers: WeightedCategories = WeightedCategories([], [])
183
184
185     def set_strings(self, field: str, tokens: List[Token]) -> None:
186         """ Set on of the WeightedStrings properties from the given
187             token list. Adapt the global penalty, so that the
188             minimum penalty is 0.
189         """
190         if tokens:
191             min_penalty = min(t.penalty for t in tokens)
192             self.penalty += min_penalty
193             wstrs = WeightedStrings([t.lookup_word for t in tokens],
194                                     [t.penalty - min_penalty for t in tokens])
195
196             setattr(self, field, wstrs)
197
198
199     def set_qualifiers(self, tokens: List[Token]) -> None:
200         """ Set the qulaifier field from the given tokens.
201         """
202         if tokens:
203             categories: Dict[Tuple[str, str], float] = {}
204             min_penalty = 1000.0
205             for t in tokens:
206                 if t.penalty < min_penalty:
207                     min_penalty = t.penalty
208                 cat = t.get_category()
209                 if t.penalty < categories.get(cat, 1000.0):
210                     categories[cat] = t.penalty
211             self.penalty += min_penalty
212             self.qualifiers = WeightedCategories(list(categories.keys()),
213                                                  list(categories.values()))
214
215
216     def set_ranking(self, rankings: List[FieldRanking]) -> None:
217         """ Set the list of rankings and normalize the ranking.
218         """
219         self.rankings = []
220         for ranking in rankings:
221             if ranking.rankings:
222                 self.penalty += ranking.normalize_penalty()
223                 self.rankings.append(ranking)
224             else:
225                 self.penalty += ranking.default
226
227
228 def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
229     """ Create a lookup list where name tokens are looked up via index
230         and potential address tokens are used to restrict the search further.
231     """
232     lookup = [FieldLookup('name_vector', name_tokens, 'lookup_all')]
233     if addr_tokens:
234         lookup.append(FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
235
236     return lookup
237
238
239 def lookup_by_any_name(name_tokens: List[int], addr_tokens: List[int],
240                        lookup_type: str) -> List[FieldLookup]:
241     """ Create a lookup list where name tokens are looked up via index
242         and only one of the name tokens must be present.
243         Potential address tokens are used to restrict the search further.
244     """
245     lookup = [FieldLookup('name_vector', name_tokens, 'lookup_any')]
246     if addr_tokens:
247         lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookup_type))
248
249     return lookup
250
251
252 def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
253     """ Create a lookup list where address tokens are looked up via index
254         and the name tokens are only used to restrict the search further.
255     """
256     return [FieldLookup('name_vector', name_tokens, 'restrict'),
257             FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')]