]> git.openstreetmap.org Git - nominatim.git/blob - test/python/conftest.py
0d1cd2f37168f8446182d030d2e26e4427922ead
[nominatim.git] / test / python / conftest.py
1 import itertools
2 import sys
3 from pathlib import Path
4
5 import psycopg2
6 import psycopg2.extras
7 import pytest
8 import tempfile
9
10 SRC_DIR = Path(__file__) / '..' / '..' / '..'
11
12 # always test against the source
13 sys.path.insert(0, str(SRC_DIR.resolve()))
14
15 from nominatim.config import Configuration
16 from nominatim.db import connection
17 from nominatim.db.sql_preprocessor import SQLPreprocessor
18
19 class _TestingCursor(psycopg2.extras.DictCursor):
20     """ Extension to the DictCursor class that provides execution
21         short-cuts that simplify writing assertions.
22     """
23
24     def scalar(self, sql, params=None):
25         """ Execute a query with a single return value and return this value.
26             Raises an assertion when not exactly one row is returned.
27         """
28         self.execute(sql, params)
29         assert self.rowcount == 1
30         return self.fetchone()[0]
31
32     def row_set(self, sql, params=None):
33         """ Execute a query and return the result as a set of tuples.
34         """
35         self.execute(sql, params)
36
37         return set((tuple(row) for row in self))
38
39     def table_exists(self, table):
40         """ Check that a table with the given name exists in the database.
41         """
42         num = self.scalar("""SELECT count(*) FROM pg_tables
43                              WHERE tablename = %s""", (table, ))
44         return num == 1
45
46     def table_rows(self, table):
47         """ Return the number of rows in the given table.
48         """
49         return self.scalar('SELECT count(*) FROM ' + table)
50
51
52 @pytest.fixture
53 def temp_db(monkeypatch):
54     """ Create an empty database for the test. The database name is also
55         exported into NOMINATIM_DATABASE_DSN.
56     """
57     name = 'test_nominatim_python_unittest'
58     conn = psycopg2.connect(database='postgres')
59
60     conn.set_isolation_level(0)
61     with conn.cursor() as cur:
62         cur.execute('DROP DATABASE IF EXISTS {}'.format(name))
63         cur.execute('CREATE DATABASE {}'.format(name))
64
65     conn.close()
66
67     monkeypatch.setenv('NOMINATIM_DATABASE_DSN' , 'dbname=' + name)
68
69     yield name
70
71     conn = psycopg2.connect(database='postgres')
72
73     conn.set_isolation_level(0)
74     with conn.cursor() as cur:
75         cur.execute('DROP DATABASE IF EXISTS {}'.format(name))
76
77     conn.close()
78
79
80 @pytest.fixture
81 def dsn(temp_db):
82     return 'dbname=' + temp_db
83
84
85 @pytest.fixture
86 def temp_db_with_extensions(temp_db):
87     conn = psycopg2.connect(database=temp_db)
88     with conn.cursor() as cur:
89         cur.execute('CREATE EXTENSION hstore; CREATE EXTENSION postgis;')
90     conn.commit()
91     conn.close()
92
93     return temp_db
94
95 @pytest.fixture
96 def temp_db_conn(temp_db):
97     """ Connection to the test database.
98     """
99     with connection.connect('dbname=' + temp_db) as conn:
100         yield conn
101
102
103 @pytest.fixture
104 def temp_db_cursor(temp_db):
105     """ Connection and cursor towards the test database. The connection will
106         be in auto-commit mode.
107     """
108     conn = psycopg2.connect('dbname=' + temp_db)
109     conn.set_isolation_level(0)
110     with conn.cursor(cursor_factory=_TestingCursor) as cur:
111         yield cur
112     conn.close()
113
114
115 @pytest.fixture
116 def table_factory(temp_db_cursor):
117     def mk_table(name, definition='id INT', content=None):
118         temp_db_cursor.execute('CREATE TABLE {} ({})'.format(name, definition))
119         if content is not None:
120             if not isinstance(content, str):
121                 content = '),('.join([str(x) for x in content])
122             temp_db_cursor.execute("INSERT INTO {} VALUES ({})".format(name, content))
123
124     return mk_table
125
126
127 @pytest.fixture
128 def def_config():
129     return Configuration(None, SRC_DIR.resolve() / 'settings')
130
131 @pytest.fixture
132 def src_dir():
133     return SRC_DIR.resolve()
134
135 @pytest.fixture
136 def tmp_phplib_dir():
137     with tempfile.TemporaryDirectory() as phpdir:
138         (Path(phpdir) / 'admin').mkdir()
139
140         yield Path(phpdir)
141
142 @pytest.fixture
143 def status_table(temp_db_conn):
144     """ Create an empty version of the status table and
145         the status logging table.
146     """
147     with temp_db_conn.cursor() as cur:
148         cur.execute("""CREATE TABLE import_status (
149                            lastimportdate timestamp with time zone NOT NULL,
150                            sequence_id integer,
151                            indexed boolean
152                        )""")
153         cur.execute("""CREATE TABLE import_osmosis_log (
154                            batchend timestamp,
155                            batchseq integer,
156                            batchsize bigint,
157                            starttime timestamp,
158                            endtime timestamp,
159                            event text
160                            )""")
161     temp_db_conn.commit()
162
163
164 @pytest.fixture
165 def place_table(temp_db_with_extensions, temp_db_conn):
166     """ Create an empty version of the place table.
167     """
168     with temp_db_conn.cursor() as cur:
169         cur.execute("""CREATE TABLE place (
170                            osm_id int8 NOT NULL,
171                            osm_type char(1) NOT NULL,
172                            class text NOT NULL,
173                            type text NOT NULL,
174                            name hstore,
175                            admin_level smallint,
176                            address hstore,
177                            extratags hstore,
178                            geometry Geometry(Geometry,4326) NOT NULL)""")
179     temp_db_conn.commit()
180
181
182 @pytest.fixture
183 def place_row(place_table, temp_db_cursor):
184     """ A factory for rows in the place table. The table is created as a
185         prerequisite to the fixture.
186     """
187     idseq = itertools.count(1001)
188     def _insert(osm_type='N', osm_id=None, cls='amenity', typ='cafe', names=None,
189                 admin_level=None, address=None, extratags=None, geom=None):
190         temp_db_cursor.execute("INSERT INTO place VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)",
191                                (osm_id or next(idseq), osm_type, cls, typ, names,
192                                 admin_level, address, extratags,
193                                 geom or 'SRID=4326;POINT(0 0)'))
194
195     return _insert
196
197 @pytest.fixture
198 def placex_table(temp_db_with_extensions, temp_db_conn):
199     """ Create an empty version of the place table.
200     """
201     with temp_db_conn.cursor() as cur:
202         cur.execute("""CREATE TABLE placex (
203                            place_id BIGINT,
204                            parent_place_id BIGINT,
205                            linked_place_id BIGINT,
206                            importance FLOAT,
207                            indexed_date TIMESTAMP,
208                            geometry_sector INTEGER,
209                            rank_address SMALLINT,
210                            rank_search SMALLINT,
211                            partition SMALLINT,
212                            indexed_status SMALLINT,
213                            osm_id int8,
214                            osm_type char(1),
215                            class text,
216                            type text,
217                            name hstore,
218                            admin_level smallint,
219                            address hstore,
220                            extratags hstore,
221                            geometry Geometry(Geometry,4326),
222                            wikipedia TEXT,
223                            country_code varchar(2),
224                            housenumber TEXT,
225                            postcode TEXT,
226                            centroid GEOMETRY(Geometry, 4326))""")
227     temp_db_conn.commit()
228
229
230 @pytest.fixture
231 def osmline_table(temp_db_with_extensions, temp_db_conn):
232     with temp_db_conn.cursor() as cur:
233         cur.execute("""CREATE TABLE location_property_osmline (
234                            place_id BIGINT,
235                            osm_id BIGINT,
236                            parent_place_id BIGINT,
237                            geometry_sector INTEGER,
238                            indexed_date TIMESTAMP,
239                            startnumber INTEGER,
240                            endnumber INTEGER,
241                            partition SMALLINT,
242                            indexed_status SMALLINT,
243                            linegeo GEOMETRY,
244                            interpolationtype TEXT,
245                            address HSTORE,
246                            postcode TEXT,
247                            country_code VARCHAR(2))""")
248     temp_db_conn.commit()
249
250
251 @pytest.fixture
252 def word_table(temp_db, temp_db_conn):
253     with temp_db_conn.cursor() as cur:
254         cur.execute("""CREATE TABLE word (
255                            word_id INTEGER,
256                            word_token text,
257                            word text,
258                            class text,
259                            type text,
260                            country_code varchar(2),
261                            search_name_count INTEGER,
262                            operator TEXT)""")
263     temp_db_conn.commit()
264
265
266 @pytest.fixture
267 def osm2pgsql_options(temp_db):
268     return dict(osm2pgsql='echo',
269                 osm2pgsql_cache=10,
270                 osm2pgsql_style='style.file',
271                 threads=1,
272                 dsn='dbname=' + temp_db,
273                 flatnode_file='',
274                 tablespaces=dict(slim_data='', slim_index='',
275                                  main_data='', main_index=''))
276
277 @pytest.fixture
278 def sql_preprocessor(temp_db_conn, tmp_path, def_config, monkeypatch, table_factory):
279     monkeypatch.setenv('NOMINATIM_DATABASE_MODULE_PATH', '.')
280     table_factory('country_name', 'partition INT', (0, 1, 2))
281     return SQLPreprocessor(temp_db_conn, def_config, tmp_path)