1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Data structures for more complex fields in abstract search descriptions.
10 from typing import List, Tuple, Iterator, cast, Dict
13 import sqlalchemy as sa
15 from nominatim.typing import SaFromClause, SaColumn, SaExpression
16 from nominatim.api.search.query import Token
17 from nominatim.utils.json_writer import JsonWriter
19 @dataclasses.dataclass
20 class WeightedStrings:
21 """ A list of strings together with a penalty.
24 penalties: List[float]
26 def __bool__(self) -> bool:
27 return bool(self.values)
30 def __iter__(self) -> Iterator[Tuple[str, float]]:
31 return iter(zip(self.values, self.penalties))
34 def get_penalty(self, value: str, default: float = 1000.0) -> float:
35 """ Get the penalty for the given value. Returns the given default
36 if the value does not exist.
39 return self.penalties[self.values.index(value)]
45 @dataclasses.dataclass
46 class WeightedCategories:
47 """ A list of class/type tuples together with a penalty.
49 values: List[Tuple[str, str]]
50 penalties: List[float]
52 def __bool__(self) -> bool:
53 return bool(self.values)
56 def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
57 return iter(zip(self.values, self.penalties))
60 def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
61 """ Get the penalty for the given value. Returns the given default
62 if the value does not exist.
65 return self.penalties[self.values.index(value)]
71 def sql_restrict(self, table: SaFromClause) -> SaExpression:
72 """ Return an SQLAlcheny expression that restricts the
73 class and type columns of the given table to the values
75 Must not be used with an empty list.
78 if len(self.values) == 1:
79 return sa.and_(table.c.class_ == self.values[0][0],
80 table.c.type == self.values[0][1])
82 return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
83 for c, t in self.values))
86 @dataclasses.dataclass(order=True)
88 """ List of tokens together with the penalty of using it.
93 def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
94 """ Create a new RankedTokens list with the given token appended.
95 The tokens penalty as well as the given transision penalty
96 are added to the overall penalty.
98 return RankedTokens(self.penalty + t.penalty + transition_penalty,
99 self.tokens + [t.token])
102 @dataclasses.dataclass
104 """ A list of rankings to be applied sequentially until one matches.
105 The matched ranking determines the penalty. If none matches a
106 default penalty is applied.
110 rankings: List[RankedTokens]
112 def normalize_penalty(self) -> float:
113 """ Reduce the default and ranking penalties, such that the minimum
114 penalty is 0. Return the penalty that was subtracted.
117 min_penalty = min(self.default, min(r.penalty for r in self.rankings))
119 min_penalty = self.default
120 if min_penalty > 0.0:
121 self.default -= min_penalty
122 for ranking in self.rankings:
123 ranking.penalty -= min_penalty
127 def sql_penalty(self, table: SaFromClause) -> SaColumn:
128 """ Create an SQL expression for the rankings.
132 rout = JsonWriter().start_array()
133 for rank in self.rankings:
134 rout.start_array().value(rank.penalty).next()
136 for token in rank.tokens:
137 rout.value(token).next()
139 rout.end_array().next()
142 return sa.func.weigh_search(table.c[self.column], rout(), self.default)
145 @dataclasses.dataclass
147 """ A list of tokens to be searched for. The column names the database
148 column to search in and the lookup_type the operator that is applied.
149 'lookup_all' requires all tokens to match. 'lookup_any' requires
150 one of the tokens to match. 'restrict' requires to match all tokens
151 but avoids the use of indexes.
157 def sql_condition(self, table: SaFromClause) -> SaColumn:
158 """ Create an SQL expression for the given match condition.
160 col = table.c[self.column]
161 if self.lookup_type == 'lookup_all':
162 return col.contains(self.tokens)
163 if self.lookup_type == 'lookup_any':
164 return cast(SaColumn, col.overlaps(self.tokens))
166 return sa.func.coalesce(sa.null(), col).contains(self.tokens) # pylint: disable=not-callable
170 """ Search fields derived from query and token assignment
171 to be used with the SQL queries.
175 lookups: List[FieldLookup] = []
176 rankings: List[FieldRanking]
178 housenumbers: WeightedStrings = WeightedStrings([], [])
179 postcodes: WeightedStrings = WeightedStrings([], [])
180 countries: WeightedStrings = WeightedStrings([], [])
182 qualifiers: WeightedCategories = WeightedCategories([], [])
185 def set_strings(self, field: str, tokens: List[Token]) -> None:
186 """ Set on of the WeightedStrings properties from the given
187 token list. Adapt the global penalty, so that the
188 minimum penalty is 0.
191 min_penalty = min(t.penalty for t in tokens)
192 self.penalty += min_penalty
193 wstrs = WeightedStrings([t.lookup_word for t in tokens],
194 [t.penalty - min_penalty for t in tokens])
196 setattr(self, field, wstrs)
199 def set_qualifiers(self, tokens: List[Token]) -> None:
200 """ Set the qulaifier field from the given tokens.
203 categories: Dict[Tuple[str, str], float] = {}
206 if t.penalty < min_penalty:
207 min_penalty = t.penalty
208 cat = t.get_category()
209 if t.penalty < categories.get(cat, 1000.0):
210 categories[cat] = t.penalty
211 self.penalty += min_penalty
212 self.qualifiers = WeightedCategories(list(categories.keys()),
213 list(categories.values()))
216 def set_ranking(self, rankings: List[FieldRanking]) -> None:
217 """ Set the list of rankings and normalize the ranking.
220 for ranking in rankings:
222 self.penalty += ranking.normalize_penalty()
223 self.rankings.append(ranking)
225 self.penalty += ranking.default
228 def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
229 """ Create a lookup list where name tokens are looked up via index
230 and potential address tokens are used to restrict the search further.
232 lookup = [FieldLookup('name_vector', name_tokens, 'lookup_all')]
234 lookup.append(FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
239 def lookup_by_any_name(name_tokens: List[int], addr_tokens: List[int],
240 lookup_type: str) -> List[FieldLookup]:
241 """ Create a lookup list where name tokens are looked up via index
242 and only one of the name tokens must be present.
243 Potential address tokens are used to restrict the search further.
245 lookup = [FieldLookup('name_vector', name_tokens, 'lookup_any')]
247 lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookup_type))
252 def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
253 """ Create a lookup list where address tokens are looked up via index
254 and the name tokens are only used to restrict the search further.
256 return [FieldLookup('name_vector', name_tokens, 'restrict'),
257 FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')]