]> git.openstreetmap.org Git - nominatim.git/blob - test/bdd/test_db.py
release 5.2.0.post9
[nominatim.git] / test / bdd / test_db.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2025 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Collector for BDD import acceptance tests.
9
10 These tests check the Nominatim import chain after the osm2pgsql import.
11 """
12 import asyncio
13 import re
14 from collections import defaultdict
15
16 import psycopg
17
18 import pytest
19 from pytest_bdd import when, then, given
20 from pytest_bdd.parsers import re as step_parse
21
22 from utils.place_inserter import PlaceColumn
23 from utils.checks import check_table_content
24 from utils.geometry_alias import ALIASES
25
26 from nominatim_db.config import Configuration
27 from nominatim_db import cli
28 from nominatim_db.tools.database_import import load_data, create_table_triggers
29 from nominatim_db.tools.postcodes import update_postcodes
30 from nominatim_db.tokenizer import factory as tokenizer_factory
31
32
33 def _rewrite_placeid_field(field, new_field, datatable, place_ids):
34     try:
35         oidx = datatable[0].index(field)
36         datatable[0][oidx] = new_field
37         for line in datatable[1:]:
38             line[oidx] = None if line[oidx] == '-' else place_ids[line[oidx]]
39     except ValueError:
40         pass
41
42
43 def _collect_place_ids(conn):
44     pids = {}
45     with conn.cursor() as cur:
46         for row in cur.execute('SELECT place_id, osm_type, osm_id, class FROM placex'):
47             pids[f"{row[1]}{row[2]}"] = row[0]
48             pids[f"{row[1]}{row[2]}:{row[3]}"] = row[0]
49
50     return pids
51
52
53 @pytest.fixture
54 def test_config_env(pytestconfig):
55     dbname = pytestconfig.getini('nominatim_test_db')
56
57     config = Configuration(None).get_os_env()
58     config['NOMINATIM_DATABASE_DSN'] = f"pgsql:dbname={dbname}"
59     config['NOMINATIM_LANGUAGES'] = 'en,de,fr,ja'
60     config['NOMINATIM_USE_US_TIGER_DATA'] = 'yes'
61     if pytestconfig.option.NOMINATIM_TOKENIZER is not None:
62         config['NOMINATIM_TOKENIZER'] = pytestconfig.option.NOMINATIM_TOKENIZER
63
64     return config
65
66
67 @pytest.fixture
68 def update_config(def_config):
69     """ Prepare the database for being updatable and return the config.
70     """
71     cli.nominatim(['refresh', '--functions'], def_config.environ)
72
73     return def_config
74
75
76 @given(step_parse('the (?P<named>named )?places'), target_fixture=None)
77 def import_places(db_conn, named, datatable, node_grid):
78     """ Insert todo rows into the place table.
79         When 'named' is given, then a random name will be generated for all
80         objects.
81     """
82     with db_conn.cursor() as cur:
83         for row in datatable[1:]:
84             PlaceColumn(node_grid).add_row(datatable[0], row, named is not None).db_insert(cur)
85
86
87 @given(step_parse('the entrances'), target_fixture=None)
88 def import_place_entrances(db_conn, datatable, node_grid):
89     """ Insert todo rows into the place_entrance table.
90     """
91     with db_conn.cursor() as cur:
92         for row in datatable[1:]:
93             data = PlaceColumn(node_grid).add_row(datatable[0], row, False)
94             assert data.columns['osm_type'] == 'N'
95
96             cur.execute("""INSERT INTO place_entrance (osm_id, type, extratags, geometry)
97                            VALUES (%s, %s, %s, {})""".format(data.get_wkt()),
98                         (data.columns['osm_id'], data.columns['type'],
99                          data.columns.get('extratags')))
100
101
102 @given(step_parse('the postcodes'), target_fixture=None)
103 def import_place_postcode(db_conn, datatable, node_grid):
104     """ Insert todo rows into the place_postcode table. If a row for the
105         requested object already exists it is overwritten.
106     """
107     with db_conn.cursor() as cur:
108         for row in datatable[1:]:
109             data = defaultdict(lambda: None)
110             data.update((k, v) for k, v in zip(datatable[0], row))
111
112             if data['centroid'].startswith('country:'):
113                 ccode = data['centroid'][8:].upper()
114                 data['centroid'] = 'srid=4326;POINT({} {})'.format(*ALIASES[ccode])
115             else:
116                 data['centroid'] = f"srid=4326;{node_grid.geometry_to_wkt(data['centroid'])}"
117
118             data['osm_type'] = data['osm'][0]
119             data['osm_id'] = data['osm'][1:]
120
121             if 'geometry' in data:
122                 geom = f"'srid=4326;{node_grid.geometry_to_wkt(data['geometry'])}'::geometry"
123             else:
124                 geom = 'null'
125
126             cur.execute(""" DELETE FROM place_postcode
127                             WHERE osm_type = %(osm_type)s and osm_id = %(osm_id)s""",
128                         data)
129             cur.execute(f"""INSERT INTO place_postcode
130                             (osm_type, osm_id, country_code, postcode, centroid, geometry)
131                             VALUES (%(osm_type)s, %(osm_id)s,
132                                     %(country)s, %(postcode)s,
133                                     %(centroid)s, {geom})""", data)
134     db_conn.commit()
135
136
137 @given('the ways', target_fixture=None)
138 def import_ways(db_conn, datatable):
139     """ Import raw ways into the osm2pgsql way middle table.
140     """
141     with db_conn.cursor() as cur:
142         id_idx = datatable[0].index('id')
143         node_idx = datatable[0].index('nodes')
144         for line in datatable[1:]:
145             tags = psycopg.types.json.Json(
146                 {k[5:]: v for k, v in zip(datatable[0], line)
147                  if k.startswith("tags+")})
148             nodes = [int(x) for x in line[node_idx].split(',')]
149
150             cur.execute("INSERT INTO planet_osm_ways (id, nodes, tags) VALUES (%s, %s, %s)",
151                         (line[id_idx], nodes, tags))
152
153
154 @given('the relations', target_fixture=None)
155 def import_rels(db_conn, datatable):
156     """ Import raw relations into the osm2pgsql relation middle table.
157     """
158     with db_conn.cursor() as cur:
159         id_idx = datatable[0].index('id')
160         memb_idx = datatable[0].index('members')
161         for line in datatable[1:]:
162             tags = psycopg.types.json.Json(
163                 {k[5:]: v for k, v in zip(datatable[0], line)
164                  if k.startswith("tags+")})
165             members = []
166             if line[memb_idx]:
167                 for member in line[memb_idx].split(','):
168                     m = re.fullmatch(r'\s*([RWN])(\d+)(?::(\S+))?\s*', member)
169                     if not m:
170                         raise ValueError(f'Illegal member {member}.')
171                     members.append({'ref': int(m[2]), 'role': m[3] or '', 'type': m[1]})
172
173             cur.execute('INSERT INTO planet_osm_rels (id, tags, members) VALUES (%s, %s, %s)',
174                         (int(line[id_idx]), tags, psycopg.types.json.Json(members)))
175
176
177 @when('importing', target_fixture='place_ids')
178 def do_import(db_conn, def_config):
179     """ Run a reduced version of the Nominatim import.
180     """
181     create_table_triggers(db_conn, def_config)
182     asyncio.run(load_data(def_config.get_libpq_dsn(), 1))
183     tokenizer = tokenizer_factory.get_tokenizer_for_db(def_config)
184     update_postcodes(def_config.get_libpq_dsn(), None, tokenizer)
185     cli.nominatim(['index', '-q'], def_config.environ)
186
187     return _collect_place_ids(db_conn)
188
189
190 @when('updating places', target_fixture='place_ids')
191 def do_update(db_conn, update_config, node_grid, datatable):
192     """ Update the place table with the given data. Also runs all triggers
193         related to updates and reindexes the new data.
194     """
195     with db_conn.cursor() as cur:
196         for row in datatable[1:]:
197             PlaceColumn(node_grid).add_row(datatable[0], row, False).db_insert(cur)
198         cur.execute('SELECT flush_deleted_places()')
199     db_conn.commit()
200
201     cli.nominatim(['index', '-q'], update_config.environ)
202
203     return _collect_place_ids(db_conn)
204
205
206 @when('updating entrances', target_fixture=None)
207 def update_place_entrances(db_conn, datatable, node_grid):
208     """ Update rows in the place_entrance table.
209     """
210     with db_conn.cursor() as cur:
211         for row in datatable[1:]:
212             data = PlaceColumn(node_grid).add_row(datatable[0], row, False)
213             assert data.columns['osm_type'] == 'N'
214
215             cur.execute("DELETE FROM place_entrance WHERE osm_id = %s",
216                         (data.columns['osm_id'],))
217             cur.execute("""INSERT INTO place_entrance (osm_id, type, extratags, geometry)
218                            VALUES (%s, %s, %s, {})""".format(data.get_wkt()),
219                         (data.columns['osm_id'], data.columns['type'],
220                          data.columns.get('extratags')))
221     db_conn.commit()
222
223
224 @when('refreshing postcodes')
225 def do_postcode_update(update_config):
226     """ Recompute the postcode centroids.
227     """
228     cli.nominatim(['refresh', '--postcodes'], update_config.environ)
229
230
231 @when(step_parse(r'marking for delete (?P<otype>[NRW])(?P<oid>\d+)'),
232       converters={'oid': int})
233 def do_delete_place(db_conn, update_config, node_grid, otype, oid):
234     """ Remove the given place from the database.
235     """
236     with db_conn.cursor() as cur:
237         cur.execute('TRUNCATE place_to_be_deleted')
238         cur.execute('DELETE FROM place WHERE osm_type = %s and osm_id = %s',
239                     (otype, oid))
240         cur.execute('SELECT flush_deleted_places()')
241         if otype == 'N':
242             cur.execute('DELETE FROM place_entrance WHERE osm_id = %s',
243                         (oid, ))
244         cur.execute('DELETE FROM place_postcode WHERE osm_type = %s and osm_id = %s',
245                     (otype, oid))
246     db_conn.commit()
247
248     cli.nominatim(['index', '-q'], update_config.environ)
249
250
251 @then(step_parse(r'(?P<table>\w+) contains(?P<exact> exactly)?'))
252 def then_check_table_content(db_conn, place_ids, datatable, node_grid, table, exact):
253     _rewrite_placeid_field('object', 'place_id', datatable, place_ids)
254     _rewrite_placeid_field('parent_place_id', 'parent_place_id', datatable, place_ids)
255     _rewrite_placeid_field('linked_place_id', 'linked_place_id', datatable, place_ids)
256     if table == 'place_addressline':
257         _rewrite_placeid_field('address', 'address_place_id', datatable, place_ids)
258
259     for i, title in enumerate(datatable[0]):
260         if title.startswith('addr+'):
261             datatable[0][i] = f"address+{title[5:]}"
262
263     check_table_content(db_conn, table, datatable, grid=node_grid, exact=bool(exact))
264
265
266 @then(step_parse(r'(DISABLED?P<table>placex?) has no entry for (?P<oid>[NRW]\d+(?::\S+)?)'))
267 def then_check_place_missing_lines(db_conn, place_ids, table, oid):
268     assert oid in place_ids
269
270     sql = pysql.SQL("""SELECT count(*) FROM {}
271                        WHERE place_id = %s""").format(pysql.Identifier(tablename))
272
273     with conn.cursor(row_factory=tuple_row) as cur:
274         assert cur.execute(sql, [place_ids[oid]]).fetchone()[0] == 0
275
276
277 @then(step_parse(r'W(?P<oid>\d+) expands to interpolation'),
278       converters={'oid': int})
279 def then_check_interpolation_table(db_conn, node_grid, place_ids, oid, datatable):
280     with db_conn.cursor() as cur:
281         cur.execute('SELECT count(*) FROM location_property_osmline WHERE osm_id = %s',
282                     [oid])
283         assert cur.fetchone()[0] == len(datatable) - 1
284
285     converted = [['osm_id', 'startnumber', 'endnumber', 'linegeo!wkt']]
286     start_idx = datatable[0].index('start') if 'start' in datatable[0] else None
287     end_idx = datatable[0].index('end') if 'end' in datatable[0] else None
288     geom_idx = datatable[0].index('geometry') if 'geometry' in datatable[0] else None
289     converted = [['osm_id']]
290     for val, col in zip((start_idx, end_idx, geom_idx),
291                         ('startnumber', 'endnumber', 'linegeo!wkt')):
292         if val is not None:
293             converted[0].append(col)
294
295     for line in datatable[1:]:
296         convline = [oid]
297         for val in (start_idx, end_idx):
298             if val is not None:
299                 convline.append(line[val])
300         if geom_idx is not None:
301             convline.append(line[geom_idx])
302         converted.append(convline)
303
304     _rewrite_placeid_field('parent_place_id', 'parent_place_id', converted, place_ids)
305
306     check_table_content(db_conn, 'location_property_osmline', converted, grid=node_grid)
307
308
309 @then(step_parse(r'W(?P<oid>\d+) expands to no interpolation'),
310       converters={'oid': int})
311 def then_check_interpolation_table_negative(db_conn, oid):
312     with db_conn.cursor() as cur:
313         cur.execute("""SELECT count(*) FROM location_property_osmline
314                        WHERE osm_id = %s and startnumber is not null""",
315                     [oid])
316         assert cur.fetchone()[0] == 0
317
318
319 if pytest.version_tuple >= (8, 0, 0):
320     PYTEST_BDD_SCENARIOS = ['features/db']
321 else:
322     from pytest_bdd import scenarios
323     scenarios('features/db')