]> git.openstreetmap.org Git - nominatim.git/blob - test/bdd/utils/checks.py
move database setup to generic conftest.py
[nominatim.git] / test / bdd / utils / checks.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2025 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Helper functions to compare expected values.
9 """
10 import json
11 import re
12 import math
13
14 from psycopg import sql as pysql
15 from psycopg.rows import dict_row, tuple_row
16 from .geometry_alias import ALIASES
17
18 COMPARATOR_TERMS = {
19     'exactly': lambda exp, act: exp == act,
20     'more than': lambda exp, act: act > exp,
21     'less than': lambda exp, act: act < exp,
22 }
23
24
25 def _pretty(obj):
26     return json.dumps(obj, sort_keys=True, indent=2)
27
28
29 def within_box(value, expect):
30     coord = [float(x) for x in expect.split(',')]
31
32     if isinstance(value, str):
33         value = value.split(',')
34     value = list(map(float, value))
35
36     if len(value) == 2:
37         return coord[0] <= value[0] <= coord[2] \
38                and coord[1] <= value[1] <= coord[3]
39
40     if len(value) == 4:
41         return value[0] >= coord[0] and value[1] <= coord[1] \
42                and value[2] >= coord[2] and value[3] <= coord[3]
43
44     raise ValueError("Not a coordinate or bbox.")
45
46
47 COMPARISON_FUNCS = {
48     None: lambda val, exp: str(val) == exp,
49     'i': lambda val, exp: str(val).lower() == exp.lower(),
50     'fm': lambda val, exp: re.fullmatch(exp, val) is not None,
51     'dict': lambda val, exp: val is None if exp == '-' else (val == eval('{' + exp + '}')),
52     'in_box': within_box
53 }
54
55 OSM_TYPE = {'node': 'n', 'way': 'w', 'relation': 'r',
56             'N': 'n', 'W': 'w', 'R': 'r'}
57
58
59 class ResultAttr:
60     """ Returns the given attribute as a string.
61
62         The key parameter determines how the value is formatted before
63         returning. To refer to sub attributes, use '+' to add more keys
64         (e.g. 'name+ref' will access obj['name']['ref']). A '!' introduces
65         a formatting suffix. If no suffix is given, the value will be
66         converted using the str() function.
67
68         Available formatters:
69
70         !:...   - use a formatting expression according to Python Mini Format Spec
71         !i      - make case-insensitive comparison
72         !fm     - consider comparison string a regular expression and match full value
73         !wkt    - convert the expected value to a WKT string before comparing
74         !in_box - the expected value is a comma-separated bbox description
75     """
76
77     def __init__(self, obj, key, grid=None):
78         self.grid = grid
79         self.obj = obj
80         if '!' in key:
81             self.key, self.fmt = key.rsplit('!', 1)
82         else:
83             self.key = key
84             self.fmt = None
85
86         if self.key == 'object':
87             assert 'osm_id' in obj
88             assert 'osm_type' in obj
89             self.subobj = OSM_TYPE[obj['osm_type']] + str(obj['osm_id'])
90             self.fmt = 'i'
91         else:
92             done = ''
93             self.subobj = self.obj
94             for sub in self.key.split('+'):
95                 done += f"[{sub}]"
96                 assert sub in self.subobj, \
97                     f"Missing attribute {done}. Full object:\n{_pretty(self.obj)}"
98                 self.subobj = self.subobj[sub]
99
100     def __eq__(self, other):
101         if not isinstance(other, str):
102             raise NotImplementedError()
103
104         # work around bad quoting by pytest-bdd
105         other = other.replace(r'\\', '\\')
106
107         if self.fmt in COMPARISON_FUNCS:
108             return COMPARISON_FUNCS[self.fmt](self.subobj, other)
109
110         if self.fmt.startswith(':'):
111             return other == f"{{{self.fmt}}}".format(self.subobj)
112
113         if self.fmt == 'wkt':
114             return self.compare_wkt(self.subobj, other)
115
116         raise RuntimeError(f"Unknown format string '{self.fmt}'.")
117
118     def __repr__(self):
119         k = self.key.replace('+', '][')
120         if self.fmt:
121             k += '!' + self.fmt
122         return f"result[{k}]({self.subobj})"
123
124     def compare_wkt(self, value, expected):
125         """ Compare a WKT value against a compact geometry format.
126             The function understands the following formats:
127
128               country:<country code>
129                  Point geometry guaranteed to be in the given country
130               <P>
131                  Point geometry
132               <P>,...,<P>
133                  Line geometry
134               (<P>,...,<P>)
135                  Polygon geometry
136
137            <P> may either be a coordinate of the form '<x> <y>' or a single
138            number. In the latter case it must refer to a point in
139            a previously defined grid.
140         """
141         m = re.fullmatch(r'(POINT)\(([0-9. -]*)\)', value) \
142             or re.fullmatch(r'(LINESTRING)\(([0-9,. -]*)\)', value) \
143             or re.fullmatch(r'(POLYGON)\(\(([0-9,. -]*)\)\)', value)
144         if not m:
145             return False
146
147         converted = [list(map(float, pt.split(' ', 1)))
148                      for pt in map(str.strip, m[2].split(','))]
149
150         if expected.startswith('country:'):
151             ccode = geom[8:].upper()
152             assert ccode in ALIASES, f"Geometry error: unknown country {ccode}"
153             return m[1] == 'POINT' and \
154                 all(math.isclose(p1, p2) for p1, p2 in zip(converted[0], ALIASES[ccode]))
155
156         if ',' not in expected:
157             return m[1] == 'POINT' and \
158                 all(math.isclose(p1, p2) for p1, p2 in zip(converted[0], self.get_point(expected)))
159
160         if '(' not in expected:
161             return m[1] == 'LINESTRING' and \
162                 all(math.isclose(p1[0], p2[0]) and math.isclose(p1[1], p2[1]) for p1, p2 in
163                     zip(converted, (self.get_point(p) for p in expected.split(','))))
164
165         if m[1] != 'POLYGON':
166             return False
167
168         # Polygon comparison is tricky because the polygons don't necessarily
169         # end at the same point or have the same winding order.
170         # Brute force all possible variants of the expected polygon
171         exp_coords = [self.get_point(p) for p in expected[1:-1].split(',')]
172         if exp_coords[0] != exp_coords[-1]:
173             raise RuntimeError(f"Invalid polygon {expected}. "
174                                "First and last point need to be the same")
175         for line in (exp_coords[:-1], exp_coords[-1:0:-1]):
176             for i in range(len(line)):
177                 if all(math.isclose(p1[0], p2[0]) and math.isclose(p1[1], p2[1]) for p1, p2 in
178                        zip(converted, line[i:] + line[:i])):
179                     return True
180
181         return False
182
183     def get_point(self, pt):
184         pt = pt.strip()
185         if ' ' in pt:
186             return list(map(float, pt.split(' ', 1)))
187
188         assert self.grid
189
190         return self.grid.get(pt)
191
192
193 def check_table_content(conn, tablename, data, grid=None, exact=False):
194     lines = set(range(1, len(data)))
195
196     cols = []
197     for col in data[0]:
198         if col == 'object':
199             cols.extend(('osm_id', 'osm_type'))
200         elif '!' in col:
201             name, fmt = col.rsplit('!', 1)
202             if fmt == 'wkt':
203                 cols.append(f"ST_AsText({name}) as {name}")
204             else:
205                 cols.append(name.split('+')[0])
206         else:
207             cols.append(col.split('+')[0])
208
209     with conn.cursor(row_factory=dict_row) as cur:
210         cur.execute(pysql.SQL(f"SELECT {','.join(cols)} FROM")
211                     + pysql.Identifier(tablename))
212
213         table_content = ''
214         for row in cur:
215             table_content += '\n' + str(row)
216             for i in lines:
217                 for col, value in zip(data[0], data[i]):
218                     if ResultAttr(row, col, grid=grid) != value:
219                         break
220                 else:
221                     lines.remove(i)
222                     break
223             else:
224                 assert not exact, f"Unexpected row in table {tablename}: {row}"
225
226         assert not lines, \
227                "Rows not found:\n" \
228                + '\n'.join(str(data[i]) for i in lines) \
229                + "\nTable content:\n" \
230                + table_content
231
232
233 def check_table_has_lines(conn, tablename, osm_type, osm_id, osm_class):
234     sql = pysql.SQL("""SELECT count(*) FROM {}
235                        WHERE osm_type = %s and osm_id = %s""").format(pysql.Identifier(tablename))
236     params = [osm_type, int(osm_id)]
237     if osm_class:
238         sql += pysql.SQL(' AND class = %s')
239         params.append(osm_class)
240
241     with conn.cursor(row_factory=tuple_row) as cur:
242         assert cur.execute(sql, params).fetchone()[0] == 0