]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/api/search/token_assignment.py
disallow category tokens in the middle of a query string
[nominatim.git] / nominatim / api / search / token_assignment.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2023 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Create query interpretations where each vertice in the query is assigned
9 a specific function (expressed as a token type).
10 """
11 from typing import Optional, List, Iterator
12 import dataclasses
13
14 import nominatim.api.search.query as qmod
15 from nominatim.api.logging import log
16
17 # pylint: disable=too-many-return-statements,too-many-branches
18
19 @dataclasses.dataclass
20 class TypedRange:
21     """ A token range for a specific type of tokens.
22     """
23     ttype: qmod.TokenType
24     trange: qmod.TokenRange
25
26
27 PENALTY_TOKENCHANGE = {
28     qmod.BreakType.START: 0.0,
29     qmod.BreakType.END: 0.0,
30     qmod.BreakType.PHRASE: 0.0,
31     qmod.BreakType.WORD: 0.1,
32     qmod.BreakType.PART: 0.2,
33     qmod.BreakType.TOKEN: 0.4
34 }
35
36 TypedRangeSeq = List[TypedRange]
37
38 @dataclasses.dataclass
39 class TokenAssignment: # pylint: disable=too-many-instance-attributes
40     """ Representation of a possible assignment of token types
41         to the tokens in a tokenized query.
42     """
43     penalty: float = 0.0
44     name: Optional[qmod.TokenRange] = None
45     address: List[qmod.TokenRange] = dataclasses.field(default_factory=list)
46     housenumber: Optional[qmod.TokenRange] = None
47     postcode: Optional[qmod.TokenRange] = None
48     country: Optional[qmod.TokenRange] = None
49     near_item: Optional[qmod.TokenRange] = None
50     qualifier: Optional[qmod.TokenRange] = None
51
52
53     @staticmethod
54     def from_ranges(ranges: TypedRangeSeq) -> 'TokenAssignment':
55         """ Create a new token assignment from a sequence of typed spans.
56         """
57         out = TokenAssignment()
58         for token in ranges:
59             if token.ttype == qmod.TokenType.PARTIAL:
60                 out.address.append(token.trange)
61             elif token.ttype == qmod.TokenType.HOUSENUMBER:
62                 out.housenumber = token.trange
63             elif token.ttype == qmod.TokenType.POSTCODE:
64                 out.postcode = token.trange
65             elif token.ttype == qmod.TokenType.COUNTRY:
66                 out.country = token.trange
67             elif token.ttype == qmod.TokenType.NEAR_ITEM:
68                 out.near_item = token.trange
69             elif token.ttype == qmod.TokenType.QUALIFIER:
70                 out.qualifier = token.trange
71         return out
72
73
74 class _TokenSequence:
75     """ Working state used to put together the token assignements.
76
77         Represents an intermediate state while traversing the tokenized
78         query.
79     """
80     def __init__(self, seq: TypedRangeSeq,
81                  direction: int = 0, penalty: float = 0.0) -> None:
82         self.seq = seq
83         self.direction = direction
84         self.penalty = penalty
85
86
87     def __str__(self) -> str:
88         seq = ''.join(f'[{r.trange.start} - {r.trange.end}: {r.ttype.name}]' for r in self.seq)
89         return f'{seq} (dir: {self.direction}, penalty: {self.penalty})'
90
91
92     @property
93     def end_pos(self) -> int:
94         """ Return the index of the global end of the current sequence.
95         """
96         return self.seq[-1].trange.end if self.seq else 0
97
98
99     def has_types(self, *ttypes: qmod.TokenType) -> bool:
100         """ Check if the current sequence contains any typed ranges of
101             the given types.
102         """
103         return any(s.ttype in ttypes for s in self.seq)
104
105
106     def is_final(self) -> bool:
107         """ Return true when the sequence cannot be extended by any
108             form of token anymore.
109         """
110         # Country and category must be the final term for left-to-right
111         return len(self.seq) > 1 and \
112                self.seq[-1].ttype in (qmod.TokenType.COUNTRY, qmod.TokenType.NEAR_ITEM)
113
114
115     def appendable(self, ttype: qmod.TokenType) -> Optional[int]:
116         """ Check if the give token type is appendable to the existing sequence.
117
118             Returns None if the token type is not appendable, otherwise the
119             new direction of the sequence after adding such a type. The
120             token is not added.
121         """
122         if ttype == qmod.TokenType.WORD:
123             return None
124
125         if not self.seq:
126             # Append unconditionally to the empty list
127             if ttype == qmod.TokenType.COUNTRY:
128                 return -1
129             if ttype in (qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
130                 return 1
131             return self.direction
132
133         # Name tokens are always acceptable and don't change direction
134         if ttype == qmod.TokenType.PARTIAL:
135             # qualifiers cannot appear in the middle of the qeury. They need
136             # to be near the next phrase.
137             if self.direction == -1 \
138                and any(t.ttype == qmod.TokenType.QUALIFIER for t in self.seq[:-1]):
139                 return None
140             return self.direction
141
142         # Other tokens may only appear once
143         if self.has_types(ttype):
144             return None
145
146         if ttype == qmod.TokenType.HOUSENUMBER:
147             if self.direction == 1:
148                 if len(self.seq) == 1 and self.seq[0].ttype == qmod.TokenType.QUALIFIER:
149                     return None
150                 if len(self.seq) > 2 \
151                    or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
152                     return None # direction left-to-right: housenumber must come before anything
153             elif self.direction == -1 \
154                  or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
155                 return -1 # force direction right-to-left if after other terms
156
157             return self.direction
158
159         if ttype == qmod.TokenType.POSTCODE:
160             if self.direction == -1:
161                 if self.has_types(qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
162                     return None
163                 return -1
164             if self.direction == 1:
165                 return None if self.has_types(qmod.TokenType.COUNTRY) else 1
166             if self.has_types(qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
167                 return 1
168             return self.direction
169
170         if ttype == qmod.TokenType.COUNTRY:
171             return None if self.direction == -1 else 1
172
173         if ttype == qmod.TokenType.NEAR_ITEM:
174             return self.direction
175
176         if ttype == qmod.TokenType.QUALIFIER:
177             if self.direction == 1:
178                 if (len(self.seq) == 1
179                     and self.seq[0].ttype in (qmod.TokenType.PARTIAL, qmod.TokenType.NEAR_ITEM)) \
180                    or (len(self.seq) == 2
181                        and self.seq[0].ttype == qmod.TokenType.NEAR_ITEM
182                        and self.seq[1].ttype == qmod.TokenType.PARTIAL):
183                     return 1
184                 return None
185             if self.direction == -1:
186                 return -1
187
188             tempseq = self.seq[1:] if self.seq[0].ttype == qmod.TokenType.NEAR_ITEM else self.seq
189             if len(tempseq) == 0:
190                 return 1
191             if len(tempseq) == 1 and self.seq[0].ttype == qmod.TokenType.HOUSENUMBER:
192                 return None
193             if len(tempseq) > 1 or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
194                 return -1
195             return 0
196
197         return None
198
199
200     def advance(self, ttype: qmod.TokenType, end_pos: int,
201                 btype: qmod.BreakType) -> Optional['_TokenSequence']:
202         """ Return a new token sequence state with the given token type
203             extended.
204         """
205         newdir = self.appendable(ttype)
206         if newdir is None:
207             return None
208
209         if not self.seq:
210             newseq = [TypedRange(ttype, qmod.TokenRange(0, end_pos))]
211             new_penalty = 0.0
212         else:
213             last = self.seq[-1]
214             if btype != qmod.BreakType.PHRASE and last.ttype == ttype:
215                 # extend the existing range
216                 newseq = self.seq[:-1] + [TypedRange(ttype, last.trange.replace_end(end_pos))]
217                 new_penalty = 0.0
218             else:
219                 # start a new range
220                 newseq = list(self.seq) + [TypedRange(ttype,
221                                                       qmod.TokenRange(last.trange.end, end_pos))]
222                 new_penalty = PENALTY_TOKENCHANGE[btype]
223
224         return _TokenSequence(newseq, newdir, self.penalty + new_penalty)
225
226
227     def _adapt_penalty_from_priors(self, priors: int, new_dir: int) -> bool:
228         if priors == 2:
229             self.penalty += 1.0
230         elif priors > 2:
231             if self.direction == 0:
232                 self.direction = new_dir
233             else:
234                 return False
235
236         return True
237
238
239     def recheck_sequence(self) -> bool:
240         """ Check that the sequence is a fully valid token assignment
241             and addapt direction and penalties further if necessary.
242
243             This function catches some impossible assignments that need
244             forward context and can therefore not be exluded when building
245             the assignment.
246         """
247         # housenumbers may not be further than 2 words from the beginning.
248         # If there are two words in front, give it a penalty.
249         hnrpos = next((i for i, tr in enumerate(self.seq)
250                        if tr.ttype == qmod.TokenType.HOUSENUMBER),
251                       None)
252         if hnrpos is not None:
253             if self.direction != -1:
254                 priors = sum(1 for t in self.seq[:hnrpos] if t.ttype == qmod.TokenType.PARTIAL)
255                 if not self._adapt_penalty_from_priors(priors, -1):
256                     return False
257             if self.direction != 1:
258                 priors = sum(1 for t in self.seq[hnrpos+1:] if t.ttype == qmod.TokenType.PARTIAL)
259                 if not self._adapt_penalty_from_priors(priors, 1):
260                     return False
261             if any(t.ttype == qmod.TokenType.NEAR_ITEM for t in self.seq):
262                 self.penalty += 1.0
263
264         return True
265
266
267     def _get_assignments_postcode(self, base: TokenAssignment,
268                                   query_len: int)  -> Iterator[TokenAssignment]:
269         """ Yield possible assignments of Postcode searches with an
270             address component.
271         """
272         assert base.postcode is not None
273
274         if (base.postcode.start == 0 and self.direction != -1)\
275            or (base.postcode.end == query_len and self.direction != 1):
276             log().comment('postcode search')
277             # <address>,<postcode> should give preference to address search
278             if base.postcode.start == 0:
279                 penalty = self.penalty
280                 self.direction = -1 # name searches are only possbile backwards
281             else:
282                 penalty = self.penalty + 0.1
283                 self.direction = 1 # name searches are only possbile forwards
284             yield dataclasses.replace(base, penalty=penalty)
285
286
287     def _get_assignments_address_forward(self, base: TokenAssignment,
288                                          query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
289         """ Yield possible assignments of address searches with
290             left-to-right reading.
291         """
292         first = base.address[0]
293
294         log().comment('first word = name')
295         yield dataclasses.replace(base, penalty=self.penalty,
296                                   name=first, address=base.address[1:])
297
298         # To paraphrase:
299         #  * if another name term comes after the first one and before the
300         #    housenumber
301         #  * a qualifier comes after the name
302         #  * the containing phrase is strictly typed
303         if (base.housenumber and first.end < base.housenumber.start)\
304            or (base.qualifier and base.qualifier > first)\
305            or (query.nodes[first.start].ptype != qmod.PhraseType.NONE):
306             return
307
308         penalty = self.penalty
309
310         # Penalty for:
311         #  * <name>, <street>, <housenumber> , ...
312         #  * queries that are comma-separated
313         if (base.housenumber and base.housenumber > first) or len(query.source) > 1:
314             penalty += 0.25
315
316         for i in range(first.start + 1, first.end):
317             name, addr = first.split(i)
318             log().comment(f'split first word = name ({i - first.start})')
319             yield dataclasses.replace(base, name=name, address=[addr] + base.address[1:],
320                                       penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
321
322
323     def _get_assignments_address_backward(self, base: TokenAssignment,
324                                           query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
325         """ Yield possible assignments of address searches with
326             right-to-left reading.
327         """
328         last = base.address[-1]
329
330         if self.direction == -1 or len(base.address) > 1:
331             log().comment('last word = name')
332             yield dataclasses.replace(base, penalty=self.penalty,
333                                       name=last, address=base.address[:-1])
334
335         # To paraphrase:
336         #  * if another name term comes before the last one and after the
337         #    housenumber
338         #  * a qualifier comes before the name
339         #  * the containing phrase is strictly typed
340         if (base.housenumber and last.start > base.housenumber.end)\
341            or (base.qualifier and base.qualifier < last)\
342            or (query.nodes[last.start].ptype != qmod.PhraseType.NONE):
343             return
344
345         penalty = self.penalty
346         if base.housenumber and base.housenumber < last:
347             penalty += 0.4
348         if len(query.source) > 1:
349             penalty += 0.25
350
351         for i in range(last.start + 1, last.end):
352             addr, name = last.split(i)
353             log().comment(f'split last word = name ({i - last.start})')
354             yield dataclasses.replace(base, name=name, address=base.address[:-1] + [addr],
355                                       penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
356
357
358     def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
359         """ Yield possible assignments for the current sequence.
360
361             This function splits up general name assignments into name
362             and address and yields all possible variants of that.
363         """
364         base = TokenAssignment.from_ranges(self.seq)
365
366         num_addr_tokens = sum(t.end - t.start for t in base.address)
367         if num_addr_tokens > 50:
368             return
369
370         # Postcode search (postcode-only search is covered in next case)
371         if base.postcode is not None and base.address:
372             yield from self._get_assignments_postcode(base, query.num_token_slots())
373
374         # Postcode or country-only search
375         if not base.address:
376             if not base.housenumber and (base.postcode or base.country or base.near_item):
377                 log().comment('postcode/country search')
378                 yield dataclasses.replace(base, penalty=self.penalty)
379         else:
380             # <postcode>,<address> should give preference to postcode search
381             if base.postcode and base.postcode.start == 0:
382                 self.penalty += 0.1
383
384             # Right-to-left reading of the address
385             if self.direction != -1:
386                 yield from self._get_assignments_address_forward(base, query)
387
388             # Left-to-right reading of the address
389             if self.direction != 1:
390                 yield from self._get_assignments_address_backward(base, query)
391
392             # variant for special housenumber searches
393             if base.housenumber:
394                 yield dataclasses.replace(base, penalty=self.penalty)
395
396
397 def yield_token_assignments(query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
398     """ Return possible word type assignments to word positions.
399
400         The assignments are computed from the concrete tokens listed
401         in the tokenized query.
402
403         The result includes the penalty for transitions from one word type to
404         another. It does not include penalties for transitions within a
405         type.
406     """
407     todo = [_TokenSequence([], direction=0 if query.source[0].ptype == qmod.PhraseType.NONE else 1)]
408
409     while todo:
410         state = todo.pop()
411         node = query.nodes[state.end_pos]
412
413         for tlist in node.starting:
414             newstate = state.advance(tlist.ttype, tlist.end, node.btype)
415             if newstate is not None:
416                 if newstate.end_pos == query.num_token_slots():
417                     if newstate.recheck_sequence():
418                         log().var_dump('Assignment', newstate)
419                         yield from newstate.get_assignments(query)
420                 elif not newstate.is_final():
421                     todo.append(newstate)