]> git.openstreetmap.org Git - nominatim.git/blob - test/bdd/test_db.py
prepare release 5.3.2.post5
[nominatim.git] / test / bdd / test_db.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2026 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Collector for BDD import acceptance tests.
9
10 These tests check the Nominatim import chain after the osm2pgsql import.
11 """
12 import asyncio
13 import re
14 from collections import defaultdict
15
16 import psycopg
17 import psycopg.sql as pysql
18
19 import pytest
20 from pytest_bdd import when, then, given
21 from pytest_bdd.parsers import re as step_parse
22
23 from utils.place_inserter import PlaceColumn
24 from utils.checks import check_table_content
25 from utils.geometry_alias import ALIASES
26
27 from nominatim_db.config import Configuration
28 from nominatim_db import cli
29 from nominatim_db.tools.database_import import load_data, create_table_triggers
30 from nominatim_db.tools.postcodes import update_postcodes
31 from nominatim_db.tokenizer import factory as tokenizer_factory
32
33
34 def _rewrite_placeid_field(field, new_field, datatable, place_ids):
35     try:
36         oidx = datatable[0].index(field)
37         datatable[0][oidx] = new_field
38         for line in datatable[1:]:
39             line[oidx] = None if line[oidx] == '-' else place_ids[line[oidx]]
40     except ValueError:
41         pass
42
43
44 def _collect_place_ids(conn):
45     pids = {}
46     with conn.cursor() as cur:
47         for row in cur.execute('SELECT place_id, osm_type, osm_id, class FROM placex'):
48             pids[f"{row[1]}{row[2]}"] = row[0]
49             pids[f"{row[1]}{row[2]}:{row[3]}"] = row[0]
50
51     return pids
52
53
54 @pytest.fixture
55 def row_factory(db_conn):
56     def _insert_row(table, **data):
57         columns = []
58         placeholders = []
59         values = []
60         for k, v in data.items():
61             columns.append(pysql.Identifier(k))
62             if isinstance(v, tuple):
63                 placeholders.append(pysql.SQL(v[0]))
64                 values.append(v[1])
65             elif isinstance(v, (pysql.Literal, pysql.SQL)):
66                 placeholders.append(v)
67             else:
68                 placeholders.append(pysql.Placeholder())
69                 values.append(v)
70
71         sql = pysql.SQL("INSERT INTO {table} ({columns}) VALUES({values})")\
72                    .format(table=pysql.Identifier(table),
73                            columns=pysql.SQL(',').join(columns),
74                            values=pysql.SQL(',').join(placeholders))
75
76         db_conn.execute(sql, values)
77         db_conn.commit()
78
79     return _insert_row
80
81
82 @pytest.fixture
83 def test_config_env(pytestconfig):
84     dbname = pytestconfig.getini('nominatim_test_db')
85
86     config = Configuration(None).get_os_env()
87     config['NOMINATIM_DATABASE_DSN'] = f"pgsql:dbname={dbname}"
88     config['NOMINATIM_LANGUAGES'] = 'en,de,fr,ja'
89     config['NOMINATIM_USE_US_TIGER_DATA'] = 'yes'
90     if pytestconfig.option.NOMINATIM_TOKENIZER is not None:
91         config['NOMINATIM_TOKENIZER'] = pytestconfig.option.NOMINATIM_TOKENIZER
92
93     return config
94
95
96 @pytest.fixture
97 def update_config(def_config):
98     """ Prepare the database for being updatable and return the config.
99     """
100     cli.nominatim(['refresh', '--functions'], def_config.environ)
101
102     return def_config
103
104
105 @given(step_parse('the (?P<named>named )?places'), target_fixture=None)
106 def import_places(db_conn, named, datatable, node_grid):
107     """ Insert todo rows into the place table.
108         When 'named' is given, then a random name will be generated for all
109         objects.
110     """
111     with db_conn.cursor() as cur:
112         for row in datatable[1:]:
113             PlaceColumn(node_grid).add_row(datatable[0], row, named is not None).db_insert(cur)
114
115
116 @given(step_parse('the entrances'), target_fixture=None)
117 def import_place_entrances(row_factory, datatable, node_grid):
118     """ Insert todo rows into the place_entrance table.
119     """
120     for row in datatable[1:]:
121         data = PlaceColumn(node_grid).add_row(datatable[0], row, False)
122         assert data.columns['osm_type'] == 'N'
123
124         params = {'osm_id': data.columns['osm_id'],
125                   'type': data.columns['type'],
126                   'extratags': data.columns.get('extratags'),
127                   'geometry': pysql.SQL(data.get_wkt())}
128
129         row_factory('place_entrance', **params)
130
131
132 @given(step_parse('the interpolations'), target_fixture=None)
133 def import_place_interpolations(row_factory, datatable, node_grid):
134     """ Insert todo rows into the place_entrance table.
135     """
136     for row in datatable[1:]:
137         data = PlaceColumn(node_grid).add_row(datatable[0], row, False)
138         assert data.columns['osm_type'] == 'W'
139
140         params = {'osm_id': data.columns['osm_id'],
141                   'type': data.columns['type'],
142                   'address': data.columns.get('address'),
143                   'nodes': [int(x) for x in data.columns['nodes'].split(',')],
144                   'geometry': pysql.SQL(data.get_wkt())}
145
146         row_factory('place_interpolation', **params)
147
148
149 @given(step_parse('the postcodes'), target_fixture=None)
150 def import_place_postcode(db_conn, datatable, node_grid):
151     """ Insert todo rows into the place_postcode table. If a row for the
152         requested object already exists it is overwritten.
153     """
154     with db_conn.cursor() as cur:
155         for row in datatable[1:]:
156             data = defaultdict(lambda: None)
157             data.update((k, v) for k, v in zip(datatable[0], row))
158
159             if data['centroid'].startswith('country:'):
160                 ccode = data['centroid'][8:].upper()
161                 data['centroid'] = 'srid=4326;POINT({} {})'.format(*ALIASES[ccode])
162             else:
163                 data['centroid'] = f"srid=4326;{node_grid.geometry_to_wkt(data['centroid'])}"
164
165             data['osm_type'] = data['osm'][0]
166             data['osm_id'] = data['osm'][1:]
167
168             if 'geometry' in data:
169                 geom = f"'srid=4326;{node_grid.geometry_to_wkt(data['geometry'])}'::geometry"
170             else:
171                 geom = 'null'
172
173             cur.execute(""" DELETE FROM place_postcode
174                             WHERE osm_type = %(osm_type)s and osm_id = %(osm_id)s""",
175                         data)
176             cur.execute(f"""INSERT INTO place_postcode
177                             (osm_type, osm_id, country_code, postcode, centroid, geometry)
178                             VALUES (%(osm_type)s, %(osm_id)s,
179                                     %(country)s, %(postcode)s,
180                                     %(centroid)s, {geom})""", data)
181     db_conn.commit()
182
183
184 @given('the ways', target_fixture=None)
185 def import_ways(row_factory, datatable):
186     """ Import raw ways into the osm2pgsql way middle table.
187     """
188     id_idx = datatable[0].index('id')
189     node_idx = datatable[0].index('nodes')
190     for line in datatable[1:]:
191         row_factory('planet_osm_ways',
192                     id=line[id_idx],
193                     nodes=[int(x) for x in line[node_idx].split(',')],
194                     tags=psycopg.types.json.Json(
195                         {k[5:]: v for k, v in zip(datatable[0], line)
196                          if k.startswith("tags+")}))
197
198
199 @given('the relations', target_fixture=None)
200 def import_rels(row_factory, datatable):
201     """ Import raw relations into the osm2pgsql relation middle table.
202         Also populates place_associated_street for associatedStreet relations.
203     """
204     id_idx = datatable[0].index('id')
205     memb_idx = datatable[0].index('members')
206     for line in datatable[1:]:
207         raw_tags = {k[5:]: v for k, v in zip(datatable[0], line)
208                     if k.startswith("tags+")}
209         members = []
210         if line[memb_idx]:
211             for member in line[memb_idx].split(','):
212                 m = re.fullmatch(r'\s*([RWN])(\d+)(?::(\S+))?\s*', member)
213                 if not m:
214                     raise ValueError(f'Illegal member {member}.')
215                 members.append({'ref': int(m[2]), 'role': m[3] or '', 'type': m[1]})
216
217         row_factory('planet_osm_rels',
218                     id=int(line[id_idx]), tags=psycopg.types.json.Json(raw_tags),
219                     members=psycopg.types.json.Json(members))
220
221         # Mirror associatedStreet data into the dedicated table.
222         if raw_tags.get('type') == 'associatedStreet':
223             for mem in members:
224                 row_factory('place_associated_street',
225                             relation_id=int(line[id_idx]),
226                             member_type=mem['type'],
227                             member_id=mem['ref'],
228                             member_role=mem['role'])
229
230
231 @when('importing', target_fixture='place_ids')
232 def do_import(db_conn, def_config):
233     """ Run a reduced version of the Nominatim import.
234     """
235     create_table_triggers(db_conn, def_config)
236     asyncio.run(load_data(def_config.get_libpq_dsn(), 1))
237     tokenizer = tokenizer_factory.get_tokenizer_for_db(def_config)
238     update_postcodes(def_config.get_libpq_dsn(), None, tokenizer)
239     cli.nominatim(['index', '-q'], def_config.environ)
240
241     return _collect_place_ids(db_conn)
242
243
244 @when('updating places', target_fixture='place_ids')
245 def do_update(db_conn, update_config, node_grid, datatable):
246     """ Update the place table with the given data. Also runs all triggers
247         related to updates and reindexes the new data.
248     """
249     with db_conn.cursor() as cur:
250         for row in datatable[1:]:
251             PlaceColumn(node_grid).add_row(datatable[0], row, False).db_insert(cur)
252         cur.execute('SELECT flush_deleted_places()')
253     db_conn.commit()
254
255     cli.nominatim(['index', '-q'], update_config.environ)
256
257     return _collect_place_ids(db_conn)
258
259
260 @when('updating relations', target_fixture=None)
261 def do_update_relations(db_conn, update_config, datatable):
262     """ Update associatedStreet relations in place_associated_street.
263         Simulates what osm2pgsql flex does: DELETE all existing rows for the
264         relation then re-INSERT the new member list.
265     """
266     id_idx = datatable[0].index('id')
267     memb_idx = datatable[0].index('members')
268     with db_conn.cursor() as cur:
269         for line in datatable[1:]:
270             relation_id = int(line[id_idx])
271             raw_tags = {k[5:]: v for k, v in zip(datatable[0], line)
272                         if k.startswith("tags+")}
273             # DELETE all old rows – trigger invalidates any house members.
274             cur.execute('DELETE FROM place_associated_street WHERE relation_id = %s',
275                         (relation_id,))
276             # Re-INSERT the new member list when the relation is still live.
277             if raw_tags.get('type') == 'associatedStreet' and line[memb_idx]:
278                 for member in line[memb_idx].split(','):
279                     m = re.fullmatch(r'\s*([RWN])(\d+)(?::(\S+))?\s*', member)
280                     if not m:
281                         raise ValueError(f'Illegal member {member}.')
282                     cur.execute(
283                         """INSERT INTO place_associated_street
284                                (relation_id, member_type, member_id, member_role)
285                                VALUES (%s, %s, %s, %s)""",
286                         (relation_id, m[1], int(m[2]), m[3] or ''))
287     db_conn.commit()
288     cli.nominatim(['index', '-q'], update_config.environ)
289
290
291 @when('updating entrances', target_fixture=None)
292 def update_place_entrances(db_conn, datatable, node_grid):
293     """ Update rows in the place_entrance table.
294     """
295     with db_conn.cursor() as cur:
296         for row in datatable[1:]:
297             data = PlaceColumn(node_grid).add_row(datatable[0], row, False)
298             assert data.columns['osm_type'] == 'N'
299
300             cur.execute("DELETE FROM place_entrance WHERE osm_id = %s",
301                         (data.columns['osm_id'],))
302             cur.execute("""INSERT INTO place_entrance (osm_id, type, extratags, geometry)
303                            VALUES (%s, %s, %s, {})""".format(data.get_wkt()),
304                         (data.columns['osm_id'], data.columns['type'],
305                          data.columns.get('extratags')))
306     db_conn.commit()
307
308
309 @when('updating interpolations', target_fixture=None)
310 def update_place_interpolations(db_conn, row_factory, update_config, datatable, node_grid):
311     """ Update rows in the place_entrance table.
312     """
313     for row in datatable[1:]:
314         data = PlaceColumn(node_grid).add_row(datatable[0], row, False)
315         assert data.columns['osm_type'] == 'W'
316
317         params = {'osm_id': data.columns['osm_id'],
318                   'type': data.columns['type'],
319                   'address': data.columns.get('address'),
320                   'nodes': [int(x) for x in data.columns['nodes'].split(',')],
321                   'geometry': pysql.SQL(data.get_wkt())}
322
323         row_factory('place_interpolation', **params)
324
325     db_conn.execute('SELECT flush_deleted_places()')
326     db_conn.commit()
327
328     cli.nominatim(['index', '-q', '--minrank', '30'], update_config.environ)
329
330
331 @when('refreshing postcodes')
332 def do_postcode_update(update_config):
333     """ Recompute the postcode centroids.
334     """
335     cli.nominatim(['refresh', '--postcodes'], update_config.environ)
336
337
338 @when(step_parse(r'marking for delete (?P<otype>[NRW])(?P<oid>\d+)'),
339       converters={'oid': int})
340 def do_delete_place(db_conn, update_config, node_grid, otype, oid):
341     """ Remove the given place from the database.
342     """
343     with db_conn.cursor() as cur:
344         cur.execute('TRUNCATE place_to_be_deleted')
345         cur.execute('DELETE FROM place WHERE osm_type = %s and osm_id = %s',
346                     (otype, oid))
347         cur.execute('DELETE FROM place_interpolation WHERE osm_id = %s',
348                     (oid, ))
349         cur.execute('SELECT flush_deleted_places()')
350         if otype == 'N':
351             cur.execute('DELETE FROM place_entrance WHERE osm_id = %s',
352                         (oid, ))
353         cur.execute('DELETE FROM place_postcode WHERE osm_type = %s and osm_id = %s',
354                     (otype, oid))
355     db_conn.commit()
356
357     cli.nominatim(['index', '-q'], update_config.environ)
358
359
360 @then(step_parse(r'(?P<table>\w+) contains(?P<exact> exactly)?'))
361 def then_check_table_content(db_conn, place_ids, datatable, node_grid, table, exact):
362     _rewrite_placeid_field('object', 'place_id', datatable, place_ids)
363     _rewrite_placeid_field('parent_place_id', 'parent_place_id', datatable, place_ids)
364     _rewrite_placeid_field('linked_place_id', 'linked_place_id', datatable, place_ids)
365     if table == 'place_addressline':
366         _rewrite_placeid_field('address', 'address_place_id', datatable, place_ids)
367
368     for i, title in enumerate(datatable[0]):
369         if title.startswith('addr+'):
370             datatable[0][i] = f"address+{title[5:]}"
371
372     check_table_content(db_conn, table, datatable, grid=node_grid, exact=bool(exact))
373
374
375 @then(step_parse(r'(DISABLED?P<table>placex?) has no entry for (?P<oid>[NRW]\d+(?::\S+)?)'))
376 def then_check_place_missing_lines(db_conn, place_ids, table, oid):
377     assert oid in place_ids
378
379     sql = pysql.SQL("""SELECT count(*) FROM {}
380                        WHERE place_id = %s""").format(pysql.Identifier(tablename))
381
382     with conn.cursor(row_factory=tuple_row) as cur:
383         assert cur.execute(sql, [place_ids[oid]]).fetchone()[0] == 0
384
385
386 @then(step_parse(r'W(?P<oid>\d+) expands to interpolation'),
387       converters={'oid': int})
388 def then_check_interpolation_table(db_conn, node_grid, place_ids, oid, datatable):
389     with db_conn.cursor() as cur:
390         cur.execute('SELECT count(*) FROM location_property_osmline WHERE osm_id = %s',
391                     [oid])
392         assert cur.fetchone()[0] == len(datatable) - 1
393
394     converted = [['osm_id', 'startnumber', 'endnumber', 'linegeo!wkt']]
395     start_idx = datatable[0].index('start') if 'start' in datatable[0] else None
396     end_idx = datatable[0].index('end') if 'end' in datatable[0] else None
397     geom_idx = datatable[0].index('geometry') if 'geometry' in datatable[0] else None
398     converted = [['osm_id']]
399     for val, col in zip((start_idx, end_idx, geom_idx),
400                         ('startnumber', 'endnumber', 'linegeo!wkt')):
401         if val is not None:
402             converted[0].append(col)
403
404     for line in datatable[1:]:
405         convline = [oid]
406         for val in (start_idx, end_idx):
407             if val is not None:
408                 convline.append(line[val])
409         if geom_idx is not None:
410             convline.append(line[geom_idx])
411         converted.append(convline)
412
413     _rewrite_placeid_field('parent_place_id', 'parent_place_id', converted, place_ids)
414
415     check_table_content(db_conn, 'location_property_osmline', converted, grid=node_grid)
416
417
418 @then(step_parse(r'W(?P<oid>\d+) expands to no interpolation'),
419       converters={'oid': int})
420 def then_check_interpolation_table_negative(db_conn, oid):
421     with db_conn.cursor() as cur:
422         cur.execute("""SELECT count(*) FROM location_property_osmline
423                        WHERE osm_id = %s and startnumber is not null""",
424                     [oid])
425         assert cur.fetchone()[0] == 0
426
427
428 if pytest.version_tuple >= (8, 0, 0):
429     PYTEST_BDD_SCENARIOS = ['features/db']
430 else:
431     from pytest_bdd import scenarios
432     scenarios('features/db')