]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/db_search_fields.py
9fcc2c4e521e9aa3ba55207edcf438af79a26949
[nominatim.git] / nominatim / api / search / db_search_fields.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Data structures for more complex fields in abstract search descriptions.
9 """
10 from typing import List, Tuple, cast
11 import dataclasses
12
13 import sqlalchemy as sa
14 from sqlalchemy.dialects.postgresql import ARRAY
15
16 from nominatim.typing import SaFromClause, SaColumn
17 from nominatim.api.search.query import Token
18
19 @dataclasses.dataclass
20 class WeightedStrings:
21     """ A list of strings together with a penalty.
22     """
23     values: List[str]
24     penalties: List[float]
25
26     def __bool__(self) -> bool:
27         return bool(self.values)
28
29
30 @dataclasses.dataclass
31 class WeightedCategories:
32     """ A list of class/type tuples together with a penalty.
33     """
34     values: List[Tuple[str, str]]
35     penalties: List[float]
36
37     def __bool__(self) -> bool:
38         return bool(self.values)
39
40
41 @dataclasses.dataclass(order=True)
42 class RankedTokens:
43     """ List of tokens together with the penalty of using it.
44     """
45     penalty: float
46     tokens: List[int]
47
48     def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
49         """ Create a new RankedTokens list with the given token appended.
50             The tokens penalty as well as the given transision penalty
51             are added to the overall penalty.
52         """
53         return RankedTokens(self.penalty + t.penalty + transition_penalty,
54                             self.tokens + [t.token])
55
56
57 @dataclasses.dataclass
58 class FieldRanking:
59     """ A list of rankings to be applied sequentially until one matches.
60         The matched ranking determines the penalty. If none matches a
61         default penalty is applied.
62     """
63     column: str
64     default: float
65     rankings: List[RankedTokens]
66
67     def normalize_penalty(self) -> float:
68         """ Reduce the default and ranking penalties, such that the minimum
69             penalty is 0. Return the penalty that was subtracted.
70         """
71         if self.rankings:
72             min_penalty = min(self.default, min(r.penalty for r in self.rankings))
73         else:
74             min_penalty = self.default
75         if min_penalty > 0.0:
76             self.default -= min_penalty
77             for ranking in self.rankings:
78                 ranking.penalty -= min_penalty
79         return min_penalty
80
81
82     def sql_penalty(self, table: SaFromClause) -> SaColumn:
83         """ Create an SQL expression for the rankings.
84         """
85         assert self.rankings
86
87         col = table.c[self.column]
88
89         return sa.case(*((col.contains(r.tokens),r.penalty) for r in self.rankings),
90                        else_=self.default)
91
92
93 @dataclasses.dataclass
94 class FieldLookup:
95     """ A list of tokens to be searched for. The column names the database
96         column to search in and the lookup_type the operator that is applied.
97         'lookup_all' requires all tokens to match. 'lookup_any' requires
98         one of the tokens to match. 'restrict' requires to match all tokens
99         but avoids the use of indexes.
100     """
101     column: str
102     tokens: List[int]
103     lookup_type: str
104
105     def sql_condition(self, table: SaFromClause) -> SaColumn:
106         """ Create an SQL expression for the given match condition.
107         """
108         col = table.c[self.column]
109         if self.lookup_type == 'lookup_all':
110             return col.contains(self.tokens)
111         if self.lookup_type == 'lookup_any':
112             return cast(SaColumn, col.overlap(self.tokens))
113
114         return sa.func.array_cat(col, sa.text('ARRAY[]::integer[]'),
115                                  type_=ARRAY(sa.Integer())).contains(self.tokens)
116
117
118 class SearchData:
119     """ Search fields derived from query and token assignment
120         to be used with the SQL queries.
121     """
122     penalty: float
123
124     lookups: List[FieldLookup] = []
125     rankings: List[FieldRanking]
126
127     housenumbers: WeightedStrings = WeightedStrings([], [])
128     postcodes: WeightedStrings = WeightedStrings([], [])
129     countries: WeightedStrings = WeightedStrings([], [])
130
131     qualifiers: WeightedCategories = WeightedCategories([], [])
132
133
134     def set_strings(self, field: str, tokens: List[Token]) -> None:
135         """ Set on of the WeightedStrings properties from the given
136             token list. Adapt the global penalty, so that the
137             minimum penalty is 0.
138         """
139         if tokens:
140             min_penalty = min(t.penalty for t in tokens)
141             self.penalty += min_penalty
142             wstrs = WeightedStrings([t.lookup_word for t in tokens],
143                                     [t.penalty - min_penalty for t in tokens])
144
145             setattr(self, field, wstrs)
146
147
148     def set_qualifiers(self, tokens: List[Token]) -> None:
149         """ Set the qulaifier field from the given tokens.
150         """
151         if tokens:
152             min_penalty = min(t.penalty for t in tokens)
153             self.penalty += min_penalty
154             self.qualifiers = WeightedCategories([t.get_category() for t in tokens],
155                                                  [t.penalty - min_penalty for t in tokens])
156
157
158     def set_ranking(self, rankings: List[FieldRanking]) -> None:
159         """ Set the list of rankings and normalize the ranking.
160         """
161         self.rankings = []
162         for ranking in rankings:
163             if ranking.rankings:
164                 self.penalty += ranking.normalize_penalty()
165                 self.rankings.append(ranking)
166             else:
167                 self.penalty += ranking.default