]> git.openstreetmap.org Git - nominatim.git/blob - test/bdd/utils/checks.py
Merge pull request #3991 from lonvia/interpolation-on-addresses
[nominatim.git] / test / bdd / utils / checks.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2025 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Helper functions to compare expected values.
9 """
10 import ast
11 import collections.abc
12 import json
13 import re
14 import math
15
16 from psycopg import sql as pysql
17 from psycopg.rows import dict_row
18 from .geometry_alias import ALIASES
19
20
21 COMPARATOR_TERMS = {
22     'exactly': lambda exp, act: exp == act,
23     'more than': lambda exp, act: act > exp,
24     'less than': lambda exp, act: act < exp,
25 }
26
27
28 def _pretty(obj):
29     return json.dumps(obj, sort_keys=True, indent=2)
30
31
32 def _pt_close(p1, p2):
33     return math.isclose(p1[0], p2[0], abs_tol=1e-07) \
34            and math.isclose(p1[1], p2[1], abs_tol=1e-07)
35
36
37 def within_box(value, expect):
38     coord = [float(x) for x in expect.split(',')]
39
40     if isinstance(value, str):
41         if value.startswith('POINT'):
42             value = value[6:-1].split(' ')
43         else:
44             value = value.split(',')
45     value = list(map(float, value))
46
47     if len(value) == 2:
48         return coord[0] <= value[0] <= coord[2] \
49                and coord[1] <= value[1] <= coord[3]
50
51     if len(value) == 4:
52         return value[0] >= coord[0] and value[1] <= coord[1] \
53                and value[2] >= coord[2] and value[3] <= coord[3]
54
55     raise ValueError("Not a coordinate or bbox.")
56
57
58 COMPARISON_FUNCS = {
59     None: lambda val, exp: str(val) == exp,
60     'i': lambda val, exp: str(val).lower() == exp.lower(),
61     'fm': lambda val, exp: re.fullmatch(exp, val) is not None,
62     'dict': lambda val, exp: (val is None if exp == '-'
63                               else (val == ast.literal_eval('{' + exp + '}'))),
64     'in_box': within_box
65 }
66
67 OSM_TYPE = {'node': 'n', 'way': 'w', 'relation': 'r',
68             'N': 'n', 'W': 'w', 'R': 'r'}
69
70
71 class ResultAttr:
72     """ Returns the given attribute as a string.
73
74         The key parameter determines how the value is formatted before
75         returning. To refer to sub attributes, use '+' to add more keys
76         (e.g. 'name+ref' will access obj['name']['ref']). A '!' introduces
77         a formatting suffix. If no suffix is given, the value will be
78         converted using the str() function.
79
80         Available formatters:
81
82         !:...   - use a formatting expression according to Python Mini Format Spec
83         !i      - make case-insensitive comparison
84         !fm     - consider comparison string a regular expression and match full value
85         !wkt    - convert the expected value to a WKT string before comparing
86         !in_box - the expected value is a comma-separated bbox description
87     """
88
89     def __init__(self, obj, key, grid=None):
90         self.grid = grid
91         self.obj = obj
92         if '!' in key:
93             self.key, self.fmt = key.rsplit('!', 1)
94         else:
95             self.key = key
96             self.fmt = None
97
98         if self.key == 'object':
99             assert 'osm_id' in obj
100             assert 'osm_type' in obj
101             self.subobj = OSM_TYPE[obj['osm_type']] + str(obj['osm_id'])
102             self.fmt = 'i'
103         else:
104             done = ''
105             self.subobj = self.obj
106             for sub in self.key.split('+'):
107                 done += f"[{sub}]"
108                 if isinstance(self.subobj, collections.abc.Sequence) and sub.isdigit():
109                     sub = int(sub)
110                     assert sub < len(self.subobj), \
111                         f"Out of bound index {done}. Full object:\n{_pretty(self.obj)}"
112                 else:
113                     assert sub in self.subobj, \
114                         f"Missing attribute {done}. Full object:\n{_pretty(self.obj)}"
115                 self.subobj = self.subobj[sub]
116
117     def __eq__(self, other):
118         # work around bad quoting by pytest-bdd
119         if not isinstance(other, str):
120             return self.subobj == other
121
122         other = other.replace(r'\\', '\\')
123
124         if self.fmt in COMPARISON_FUNCS:
125             return COMPARISON_FUNCS[self.fmt](self.subobj, other)
126
127         if self.fmt.startswith(':'):
128             return other == f"{{{self.fmt}}}".format(self.subobj)
129
130         if self.fmt == 'wkt':
131             return self.compare_wkt(self.subobj, other)
132
133         raise RuntimeError(f"Unknown format string '{self.fmt}'.")
134
135     def __repr__(self):
136         k = self.key.replace('+', '][')
137         if self.fmt:
138             k += '!' + self.fmt
139         return f"result[{k}]({self.subobj})"
140
141     def compare_wkt(self, value, expected):
142         """ Compare a WKT value against a compact geometry format.
143             The function understands the following formats:
144
145               country:<country code>
146                  Point geometry guaranteed to be in the given country
147               <P>
148                  Point geometry
149               <P>,...,<P>
150                  Line geometry
151               (<P>,...,<P>)
152                  Polygon geometry
153
154            <P> may either be a coordinate of the form '<x> <y>' or a single
155            number. In the latter case it must refer to a point in
156            a previously defined grid.
157         """
158         m = re.fullmatch(r'(POINT)\(([0-9. -]*)\)', value) \
159             or re.fullmatch(r'(LINESTRING)\(([0-9,. -]*)\)', value) \
160             or re.fullmatch(r'(POLYGON)\(\(([0-9,. -]*)\)\)', value)
161         if not m:
162             return False
163
164         converted = [list(map(float, pt.split(' ', 1)))
165                      for pt in map(str.strip, m[2].split(','))]
166
167         if expected.startswith('country:'):
168             ccode = expected[8:].upper()
169             assert ccode in ALIASES, f"Geometry error: unknown country {ccode}"
170             return m[1] == 'POINT' and _pt_close(converted[0], ALIASES[ccode])
171
172         if ',' not in expected:
173             return m[1] == 'POINT' and _pt_close(converted[0], self.get_point(expected))
174
175         if '(' not in expected:
176             return m[1] == 'LINESTRING' and \
177                 all(_pt_close(p1, p2) for p1, p2 in
178                     zip(converted, (self.get_point(p) for p in expected.split(','))))
179
180         if m[1] != 'POLYGON':
181             return False
182
183         # Polygon comparison is tricky because the polygons don't necessarily
184         # end at the same point or have the same winding order.
185         # Brute force all possible variants of the expected polygon
186         exp_coords = [self.get_point(p) for p in expected[1:-1].split(',')]
187         if exp_coords[0] != exp_coords[-1]:
188             raise RuntimeError(f"Invalid polygon {expected}. "
189                                "First and last point need to be the same")
190         for line in (exp_coords[:-1], exp_coords[-1:0:-1]):
191             for i in range(len(line)):
192                 if all(_pt_close(p1, p2) for p1, p2 in
193                        zip(converted, line[i:] + line[:i])):
194                     return True
195
196         return False
197
198     def get_point(self, pt):
199         pt = pt.strip()
200         if ' ' in pt:
201             return list(map(float, pt.split(' ', 1)))
202
203         assert self.grid
204
205         return self.grid.get(pt)
206
207
208 def check_table_content(conn, tablename, data, grid=None, exact=False):
209     lines = set(range(1, len(data)))
210
211     cols = []
212     for col in data[0]:
213         if col == 'object':
214             cols.extend(('osm_id', 'osm_type'))
215         elif '!' in col:
216             name, fmt = col.rsplit('!', 1)
217             if fmt in ('wkt', 'in_box'):
218                 cols.append(f"ST_AsText({name}) as {name}")
219             else:
220                 cols.append(name.split('+')[0])
221         else:
222             cols.append(col.split('+')[0])
223
224     with conn.cursor(row_factory=dict_row) as cur:
225         cur.execute(pysql.SQL(f"SELECT {','.join(cols)} FROM")
226                     + pysql.Identifier(tablename))
227
228         table_content = ''
229         for row in cur:
230             table_content += '\n' + str(row)
231             for i in lines:
232                 for col, value in zip(data[0], data[i]):
233                     if ResultAttr(row, col, grid=grid) != (None if value == '-' else value):
234                         break
235                 else:
236                     lines.remove(i)
237                     break
238             else:
239                 assert not exact, f"Unexpected row in table {tablename}: {row}"
240
241         assert not lines, \
242                "Rows not found:\n" \
243                + '\n'.join(str(data[i]) for i in lines) \
244                + "\nTable content:\n" \
245                + table_content