]> git.openstreetmap.org Git - nominatim.git/blob - test/bdd/utils/checks.py
reorganise token reranking
[nominatim.git] / test / bdd / utils / checks.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2025 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Helper functions to compare expected values.
9 """
10 import json
11 import re
12 import math
13
14 from psycopg import sql as pysql
15 from psycopg.rows import dict_row
16 from .geometry_alias import ALIASES
17
18
19 COMPARATOR_TERMS = {
20     'exactly': lambda exp, act: exp == act,
21     'more than': lambda exp, act: act > exp,
22     'less than': lambda exp, act: act < exp,
23 }
24
25
26 def _pretty(obj):
27     return json.dumps(obj, sort_keys=True, indent=2)
28
29
30 def _pt_close(p1, p2):
31     return math.isclose(p1[0], p2[0], abs_tol=1e-07) \
32            and math.isclose(p1[1], p2[1], abs_tol=1e-07)
33
34
35 def within_box(value, expect):
36     coord = [float(x) for x in expect.split(',')]
37
38     if isinstance(value, str):
39         if value.startswith('POINT'):
40             value = value[6:-1].split(' ')
41         else:
42             value = value.split(',')
43     value = list(map(float, value))
44
45     if len(value) == 2:
46         return coord[0] <= value[0] <= coord[2] \
47                and coord[1] <= value[1] <= coord[3]
48
49     if len(value) == 4:
50         return value[0] >= coord[0] and value[1] <= coord[1] \
51                and value[2] >= coord[2] and value[3] <= coord[3]
52
53     raise ValueError("Not a coordinate or bbox.")
54
55
56 COMPARISON_FUNCS = {
57     None: lambda val, exp: str(val) == exp,
58     'i': lambda val, exp: str(val).lower() == exp.lower(),
59     'fm': lambda val, exp: re.fullmatch(exp, val) is not None,
60     'dict': lambda val, exp: val is None if exp == '-' else (val == eval('{' + exp + '}')),
61     'in_box': within_box
62 }
63
64 OSM_TYPE = {'node': 'n', 'way': 'w', 'relation': 'r',
65             'N': 'n', 'W': 'w', 'R': 'r'}
66
67
68 class ResultAttr:
69     """ Returns the given attribute as a string.
70
71         The key parameter determines how the value is formatted before
72         returning. To refer to sub attributes, use '+' to add more keys
73         (e.g. 'name+ref' will access obj['name']['ref']). A '!' introduces
74         a formatting suffix. If no suffix is given, the value will be
75         converted using the str() function.
76
77         Available formatters:
78
79         !:...   - use a formatting expression according to Python Mini Format Spec
80         !i      - make case-insensitive comparison
81         !fm     - consider comparison string a regular expression and match full value
82         !wkt    - convert the expected value to a WKT string before comparing
83         !in_box - the expected value is a comma-separated bbox description
84     """
85
86     def __init__(self, obj, key, grid=None):
87         self.grid = grid
88         self.obj = obj
89         if '!' in key:
90             self.key, self.fmt = key.rsplit('!', 1)
91         else:
92             self.key = key
93             self.fmt = None
94
95         if self.key == 'object':
96             assert 'osm_id' in obj
97             assert 'osm_type' in obj
98             self.subobj = OSM_TYPE[obj['osm_type']] + str(obj['osm_id'])
99             self.fmt = 'i'
100         else:
101             done = ''
102             self.subobj = self.obj
103             for sub in self.key.split('+'):
104                 done += f"[{sub}]"
105                 assert sub in self.subobj, \
106                     f"Missing attribute {done}. Full object:\n{_pretty(self.obj)}"
107                 self.subobj = self.subobj[sub]
108
109     def __eq__(self, other):
110         # work around bad quoting by pytest-bdd
111         if not isinstance(other, str):
112             return self.subobj == other
113
114         other = other.replace(r'\\', '\\')
115
116         if self.fmt in COMPARISON_FUNCS:
117             return COMPARISON_FUNCS[self.fmt](self.subobj, other)
118
119         if self.fmt.startswith(':'):
120             return other == f"{{{self.fmt}}}".format(self.subobj)
121
122         if self.fmt == 'wkt':
123             return self.compare_wkt(self.subobj, other)
124
125         raise RuntimeError(f"Unknown format string '{self.fmt}'.")
126
127     def __repr__(self):
128         k = self.key.replace('+', '][')
129         if self.fmt:
130             k += '!' + self.fmt
131         return f"result[{k}]({self.subobj})"
132
133     def compare_wkt(self, value, expected):
134         """ Compare a WKT value against a compact geometry format.
135             The function understands the following formats:
136
137               country:<country code>
138                  Point geometry guaranteed to be in the given country
139               <P>
140                  Point geometry
141               <P>,...,<P>
142                  Line geometry
143               (<P>,...,<P>)
144                  Polygon geometry
145
146            <P> may either be a coordinate of the form '<x> <y>' or a single
147            number. In the latter case it must refer to a point in
148            a previously defined grid.
149         """
150         m = re.fullmatch(r'(POINT)\(([0-9. -]*)\)', value) \
151             or re.fullmatch(r'(LINESTRING)\(([0-9,. -]*)\)', value) \
152             or re.fullmatch(r'(POLYGON)\(\(([0-9,. -]*)\)\)', value)
153         if not m:
154             return False
155
156         converted = [list(map(float, pt.split(' ', 1)))
157                      for pt in map(str.strip, m[2].split(','))]
158
159         if expected.startswith('country:'):
160             ccode = expected[8:].upper()
161             assert ccode in ALIASES, f"Geometry error: unknown country {ccode}"
162             return m[1] == 'POINT' and _pt_close(converted[0], ALIASES[ccode])
163
164         if ',' not in expected:
165             return m[1] == 'POINT' and _pt_close(converted[0], self.get_point(expected))
166
167         if '(' not in expected:
168             return m[1] == 'LINESTRING' and \
169                 all(_pt_close(p1, p2) for p1, p2 in
170                     zip(converted, (self.get_point(p) for p in expected.split(','))))
171
172         if m[1] != 'POLYGON':
173             return False
174
175         # Polygon comparison is tricky because the polygons don't necessarily
176         # end at the same point or have the same winding order.
177         # Brute force all possible variants of the expected polygon
178         exp_coords = [self.get_point(p) for p in expected[1:-1].split(',')]
179         if exp_coords[0] != exp_coords[-1]:
180             raise RuntimeError(f"Invalid polygon {expected}. "
181                                "First and last point need to be the same")
182         for line in (exp_coords[:-1], exp_coords[-1:0:-1]):
183             for i in range(len(line)):
184                 if all(_pt_close(p1, p2) for p1, p2 in
185                        zip(converted, line[i:] + line[:i])):
186                     return True
187
188         return False
189
190     def get_point(self, pt):
191         pt = pt.strip()
192         if ' ' in pt:
193             return list(map(float, pt.split(' ', 1)))
194
195         assert self.grid
196
197         return self.grid.get(pt)
198
199
200 def check_table_content(conn, tablename, data, grid=None, exact=False):
201     lines = set(range(1, len(data)))
202
203     cols = []
204     for col in data[0]:
205         if col == 'object':
206             cols.extend(('osm_id', 'osm_type'))
207         elif '!' in col:
208             name, fmt = col.rsplit('!', 1)
209             if fmt in ('wkt', 'in_box'):
210                 cols.append(f"ST_AsText({name}) as {name}")
211             else:
212                 cols.append(name.split('+')[0])
213         else:
214             cols.append(col.split('+')[0])
215
216     with conn.cursor(row_factory=dict_row) as cur:
217         cur.execute(pysql.SQL(f"SELECT {','.join(cols)} FROM")
218                     + pysql.Identifier(tablename))
219
220         table_content = ''
221         for row in cur:
222             table_content += '\n' + str(row)
223             for i in lines:
224                 for col, value in zip(data[0], data[i]):
225                     if ResultAttr(row, col, grid=grid) != (None if value == '-' else value):
226                         break
227                 else:
228                     lines.remove(i)
229                     break
230             else:
231                 assert not exact, f"Unexpected row in table {tablename}: {row}"
232
233         assert not lines, \
234                "Rows not found:\n" \
235                + '\n'.join(str(data[i]) for i in lines) \
236                + "\nTable content:\n" \
237                + table_content