]> git.openstreetmap.org Git - nominatim.git/blob - test/bdd/conftest.py
Merge pull request #3863 from lonvia/improve-bdd-test-names
[nominatim.git] / test / bdd / conftest.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2025 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Fixtures for BDD test steps
9 """
10 import sys
11 import json
12 import re
13 from pathlib import Path
14
15 import psycopg
16 from psycopg import sql as pysql
17
18 # always test against the source
19 SRC_DIR = (Path(__file__) / '..' / '..' / '..').resolve()
20 sys.path.insert(0, str(SRC_DIR / 'src'))
21
22 import pytest
23 from pytest_bdd.parsers import re as step_parse
24 from pytest_bdd import given, when, then, scenario
25 from pytest_bdd.feature import get_features
26
27 pytest.register_assert_rewrite('utils')
28
29 from utils.api_runner import APIRunner
30 from utils.api_result import APIResult
31 from utils.checks import ResultAttr, COMPARATOR_TERMS
32 from utils.geometry_alias import ALIASES
33 from utils.grid import Grid
34 from utils.db import DBManager
35
36 from nominatim_db.config import Configuration
37 from nominatim_db.data.country_info import setup_country_config
38
39
40 def _strlist(inp):
41     return [s.strip() for s in inp.split(',')]
42
43
44 def _pretty_json(inp):
45     return json.dumps(inp, indent=2)
46
47
48 def pytest_addoption(parser, pluginmanager):
49     parser.addoption('--nominatim-purge', dest='NOMINATIM_PURGE', action='store_true',
50                      help='Force recreation of test databases from scratch.')
51     parser.addoption('--nominatim-keep-db', dest='NOMINATIM_KEEP_DB', action='store_true',
52                      help='Do not drop the database after tests are finished.')
53     parser.addoption('--nominatim-api-engine', dest='NOMINATIM_API_ENGINE',
54                      default='falcon',
55                      help='Chose the API engine to use when sending requests.')
56     parser.addoption('--nominatim-tokenizer', dest='NOMINATIM_TOKENIZER',
57                      metavar='TOKENIZER',
58                      help='Use the specified tokenizer for importing data into '
59                           'a Nominatim database.')
60
61     parser.addini('nominatim_test_db', default='test_nominatim',
62                   help='Name of the database used for running a single test.')
63     parser.addini('nominatim_api_test_db', default='test_api_nominatim',
64                   help='Name of the database for storing API test data.')
65     parser.addini('nominatim_template_db', default='test_template_nominatim',
66                   help='Name of database used as a template for test databases.')
67
68
69 @pytest.fixture
70 def datatable():
71     """ Default fixture for datatables, so that their presence can be optional.
72     """
73     return None
74
75
76 @pytest.fixture
77 def node_grid():
78     """ Default fixture for node grids. Nothing set.
79     """
80     return Grid([[]], None, None)
81
82
83 @pytest.fixture(scope='session', autouse=True)
84 def setup_country_info():
85     setup_country_config(Configuration(None))
86
87
88 @pytest.fixture(scope='session')
89 def template_db(pytestconfig):
90     """ Create a template database containing the extensions and base data
91         needed by Nominatim. Using the template instead of doing the full
92         setup can speed up the tests.
93
94         The template database will only be created if it does not exist yet
95         or a purge has been explicitly requested.
96     """
97     dbm = DBManager(purge=pytestconfig.option.NOMINATIM_PURGE)
98
99     template_db = pytestconfig.getini('nominatim_template_db')
100
101     template_config = Configuration(
102         None, environ={'NOMINATIM_DATABASE_DSN': f"pgsql:dbname={template_db}"})
103
104     dbm.setup_template_db(template_config)
105
106     return template_db
107
108
109 @pytest.fixture
110 def def_config(pytestconfig):
111     dbname = pytestconfig.getini('nominatim_test_db')
112
113     return Configuration(None,
114                          environ={'NOMINATIM_DATABASE_DSN': f"pgsql:dbname={dbname}"})
115
116
117 @pytest.fixture
118 def db(template_db, pytestconfig):
119     """ Set up an empty database for use with osm2pgsql.
120     """
121     dbm = DBManager(purge=pytestconfig.option.NOMINATIM_PURGE)
122
123     dbname = pytestconfig.getini('nominatim_test_db')
124
125     dbm.create_db_from_template(dbname, template_db)
126
127     yield dbname
128
129     if not pytestconfig.option.NOMINATIM_KEEP_DB:
130         dbm.drop_db(dbname)
131
132
133 @pytest.fixture
134 def db_conn(db, def_config):
135     with psycopg.connect(def_config.get_libpq_dsn()) as conn:
136         info = psycopg.types.TypeInfo.fetch(conn, "hstore")
137         psycopg.types.hstore.register_hstore(info, conn)
138         yield conn
139
140
141 @when(step_parse(r'reverse geocoding (?P<lat>[\d.-]*),(?P<lon>[\d.-]*)'),
142       target_fixture='nominatim_result')
143 def reverse_geocode_via_api(test_config_env, pytestconfig, datatable, lat, lon):
144     runner = APIRunner(test_config_env, pytestconfig.option.NOMINATIM_API_ENGINE)
145     api_response = runner.run_step('reverse',
146                                    {'lat': float(lat), 'lon': float(lon)},
147                                    datatable, 'jsonv2', {})
148
149     assert api_response.status == 200
150     assert api_response.headers['content-type'] == 'application/json; charset=utf-8'
151
152     result = APIResult('json', 'reverse', api_response.body)
153     assert result.is_simple()
154
155     assert isinstance(result.result['lat'], str)
156     assert isinstance(result.result['lon'], str)
157     result.result['centroid'] = f"POINT({result.result['lon']} {result.result['lat']})"
158
159     return result
160
161
162 @when(step_parse(r'reverse geocoding at node (?P<node>[\d]+)'),
163       target_fixture='nominatim_result')
164 def reverse_geocode_via_api_and_grid(test_config_env, pytestconfig, node_grid, datatable, node):
165     coords = node_grid.get(node)
166     if coords is None:
167         raise ValueError('Unknown node id')
168
169     return reverse_geocode_via_api(test_config_env, pytestconfig, datatable, coords[1], coords[0])
170
171
172 @when(step_parse(r'geocoding(?: "(?P<query>.*)")?'),
173       target_fixture='nominatim_result')
174 def forward_geocode_via_api(test_config_env, pytestconfig, datatable, query):
175     runner = APIRunner(test_config_env, pytestconfig.option.NOMINATIM_API_ENGINE)
176
177     params = {'addressdetails': '1'}
178     if query:
179         params['q'] = query
180
181     api_response = runner.run_step('search', params, datatable, 'jsonv2', {})
182
183     assert api_response.status == 200
184     assert api_response.headers['content-type'] == 'application/json; charset=utf-8'
185
186     result = APIResult('json', 'search', api_response.body)
187     assert not result.is_simple()
188
189     for res in result.result:
190         assert isinstance(res['lat'], str)
191         assert isinstance(res['lon'], str)
192         res['centroid'] = f"POINT({res['lon']} {res['lat']})"
193
194     return result
195
196
197 @then(step_parse(r'(?P<op>[a-z ]+) (?P<num>\d+) results? (?:are|is) returned'),
198       converters={'num': int})
199 def check_number_of_results(nominatim_result, op, num):
200     assert not nominatim_result.is_simple()
201     assert COMPARATOR_TERMS[op](num, len(nominatim_result))
202
203
204 @then(step_parse('the result metadata contains'))
205 def check_metadata_for_fields(nominatim_result, datatable):
206     if datatable[0] == ['param', 'value']:
207         pairs = datatable[1:]
208     else:
209         pairs = zip(datatable[0], datatable[1])
210
211     for k, v in pairs:
212         assert ResultAttr(nominatim_result.meta, k) == v
213
214
215 @then(step_parse('the result metadata has no attributes (?P<attributes>.*)'),
216       converters={'attributes': _strlist})
217 def check_metadata_for_field_presence(nominatim_result, attributes):
218     assert all(a not in nominatim_result.meta for a in attributes), \
219         f"Unexpectedly have one of the attributes '{attributes}' in\n" \
220         f"{_pretty_json(nominatim_result.meta)}"
221
222
223 @then(step_parse(r'the result contains(?: in field (?P<field>\S+))?'))
224 def check_result_for_fields(nominatim_result, datatable, node_grid, field):
225     assert nominatim_result.is_simple()
226
227     if datatable[0] == ['param', 'value']:
228         pairs = datatable[1:]
229     else:
230         pairs = zip(datatable[0], datatable[1])
231
232     prefix = field + '+' if field else ''
233
234     for k, v in pairs:
235         assert ResultAttr(nominatim_result.result, prefix + k, grid=node_grid) == v
236
237
238 @then(step_parse('the result has attributes (?P<attributes>.*)'),
239       converters={'attributes': _strlist})
240 def check_result_for_field_presence(nominatim_result, attributes):
241     assert nominatim_result.is_simple()
242     assert all(a in nominatim_result.result for a in attributes)
243
244
245 @then(step_parse('the result has no attributes (?P<attributes>.*)'),
246       converters={'attributes': _strlist})
247 def check_result_for_field_absence(nominatim_result, attributes):
248     assert nominatim_result.is_simple()
249     assert all(a not in nominatim_result.result for a in attributes)
250
251
252 @then(step_parse(
253     r'the result contains array field (?P<field>\S+) where element (?P<num>\d+) contains'),
254       converters={'num': int})
255 def check_result_array_field_for_attributes(nominatim_result, datatable, field, num):
256     assert nominatim_result.is_simple()
257
258     if datatable[0] == ['param', 'value']:
259         pairs = datatable[1:]
260     else:
261         pairs = zip(datatable[0], datatable[1])
262
263     prefix = f"{field}+{num}+"
264
265     for k, v in pairs:
266         assert ResultAttr(nominatim_result.result, prefix + k) == v
267
268
269 @then(step_parse('the result set contains(?P<exact> exactly)?'))
270 def check_result_list_match(nominatim_result, datatable, exact):
271     assert not nominatim_result.is_simple()
272
273     result_set = set(range(len(nominatim_result.result)))
274
275     for row in datatable[1:]:
276         for idx in result_set:
277             for key, value in zip(datatable[0], row):
278                 if ResultAttr(nominatim_result.result[idx], key) != value:
279                     break
280             else:
281                 # found a match
282                 result_set.remove(idx)
283                 break
284         else:
285             assert False, f"Missing data row {row}. Full response:\n{nominatim_result}"
286
287     if exact:
288         assert not [nominatim_result.result[i] for i in result_set]
289
290
291 @then(step_parse('all results have attributes (?P<attributes>.*)'),
292       converters={'attributes': _strlist})
293 def check_all_results_for_field_presence(nominatim_result, attributes):
294     assert not nominatim_result.is_simple()
295     assert len(nominatim_result) > 0
296     for res in nominatim_result.result:
297         assert all(a in res for a in attributes), \
298             f"Missing one of the attributes '{attributes}' in\n{_pretty_json(res)}"
299
300
301 @then(step_parse('all results have no attributes (?P<attributes>.*)'),
302       converters={'attributes': _strlist})
303 def check_all_result_for_field_absence(nominatim_result, attributes):
304     assert not nominatim_result.is_simple()
305     assert len(nominatim_result) > 0
306     for res in nominatim_result.result:
307         assert all(a not in res for a in attributes), \
308             f"Unexpectedly have one of the attributes '{attributes}' in\n{_pretty_json(res)}"
309
310
311 @then(step_parse(r'all results contain(?: in field (?P<field>\S+))?'))
312 def check_all_results_contain(nominatim_result, datatable, node_grid, field):
313     assert not nominatim_result.is_simple()
314     assert len(nominatim_result) > 0
315
316     if datatable[0] == ['param', 'value']:
317         pairs = datatable[1:]
318     else:
319         pairs = zip(datatable[0], datatable[1])
320
321     prefix = field + '+' if field else ''
322
323     for k, v in pairs:
324         for r in nominatim_result.result:
325             assert ResultAttr(r, prefix + k, grid=node_grid) == v
326
327
328 @then(step_parse(r'result (?P<num>\d+) contains(?: in field (?P<field>\S+))?'),
329       converters={'num': int})
330 def check_specific_result_for_fields(nominatim_result, datatable, num, field):
331     assert not nominatim_result.is_simple()
332     assert len(nominatim_result) > num
333
334     if datatable[0] == ['param', 'value']:
335         pairs = datatable[1:]
336     else:
337         pairs = zip(datatable[0], datatable[1])
338
339     prefix = field + '+' if field else ''
340
341     for k, v in pairs:
342         assert ResultAttr(nominatim_result.result[num], prefix + k) == v
343
344
345 @given(step_parse(r'the (?P<step>[0-9.]+ )?grid(?: with origin (?P<origin>.*))?'),
346        target_fixture='node_grid')
347 def set_node_grid(datatable, step, origin):
348     if step is not None:
349         step = float(step)
350
351     if origin:
352         if ',' in origin:
353             coords = origin.split(',')
354             if len(coords) != 2:
355                 raise RuntimeError('Grid origin expects origin with x,y coordinates.')
356             origin = list(map(float, coords))
357         elif origin in ALIASES:
358             origin = ALIASES[origin]
359         else:
360             raise RuntimeError('Grid origin must be either coordinate or alias.')
361
362     return Grid(datatable, step, origin)
363
364
365 @then(step_parse('(?P<table>placex?) has no entry for '
366                  r'(?P<osm_type>[NRW])(?P<osm_id>\d+)(?::(?P<osm_class>\S+))?'),
367       converters={'osm_id': int})
368 def check_place_missing_lines(db_conn, table, osm_type, osm_id, osm_class):
369     sql = pysql.SQL("""SELECT count(*) FROM {}
370                        WHERE osm_type = %s and osm_id = %s""").format(pysql.Identifier(table))
371     params = [osm_type, int(osm_id)]
372     if osm_class:
373         sql += pysql.SQL(' AND class = %s')
374         params.append(osm_class)
375
376     with db_conn.cursor() as cur:
377         assert cur.execute(sql, params).fetchone()[0] == 0
378
379
380 if pytest.version_tuple >= (8, 0, 0):
381     def pytest_pycollect_makemodule(module_path, parent):
382         return BddTestCollector.from_parent(parent, path=module_path)
383
384
385 class BddTestCollector(pytest.Module):
386
387     def __init__(self, **kwargs):
388         super().__init__(**kwargs)
389
390     def collect(self):
391         for item in super().collect():
392             yield item
393
394         if hasattr(self.obj, 'PYTEST_BDD_SCENARIOS'):
395             for path in self.obj.PYTEST_BDD_SCENARIOS:
396                 for feature in get_features([str(Path(self.path.parent, path).resolve())]):
397                     yield FeatureFile.from_parent(self,
398                                                   name=str(Path(path, feature.rel_filename)),
399                                                   path=Path(feature.filename),
400                                                   feature=feature)
401
402
403 # borrowed from pytest-bdd: src/pytest_bdd/scenario.py
404 def make_python_name(string: str) -> str:
405     """Make python attribute name out of a given string."""
406     string = re.sub(r"\W", "", string.replace(" ", "_"))
407     return re.sub(r"^\d+_*", "", string).lower()
408
409
410 class FeatureFile(pytest.File):
411     class obj:
412         pass
413
414     def __init__(self, feature, **kwargs):
415         self.feature = feature
416         super().__init__(**kwargs)
417
418     def collect(self):
419         for sname, sobject in self.feature.scenarios.items():
420             class_name = f"L{sobject.line_number}"
421             test_name = "test_" + make_python_name(sname)
422
423             @scenario(self.feature.filename, sname)
424             def _test():
425                 pass
426
427             tclass = type(class_name, (),
428                           {test_name: staticmethod(_test)})
429             setattr(self.obj, class_name, tclass)
430
431             yield pytest.Class.from_parent(self, name=class_name, obj=tclass)