]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/connection.py
Merge pull request #2655 from lonvia/migration-internal-country-name
[nominatim.git] / nominatim / db / connection.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2022 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Specialised connection and cursor functions.
9 """
10 import contextlib
11 import logging
12 import os
13
14 import psycopg2
15 import psycopg2.extensions
16 import psycopg2.extras
17 from psycopg2 import sql as pysql
18
19 from nominatim.errors import UsageError
20
21 LOG = logging.getLogger()
22
23 class _Cursor(psycopg2.extras.DictCursor):
24     """ A cursor returning dict-like objects and providing specialised
25         execution functions.
26     """
27
28     def execute(self, query, args=None): # pylint: disable=W0221
29         """ Query execution that logs the SQL query when debugging is enabled.
30         """
31         LOG.debug(self.mogrify(query, args).decode('utf-8'))
32
33         super().execute(query, args)
34
35
36     def execute_values(self, sql, argslist, template=None):
37         """ Wrapper for the psycopg2 convenience function to execute
38             SQL for a list of values.
39         """
40         LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
41
42         psycopg2.extras.execute_values(self, sql, argslist, template=template)
43
44
45     def scalar(self, sql, args=None):
46         """ Execute query that returns a single value. The value is returned.
47             If the query yields more than one row, a ValueError is raised.
48         """
49         self.execute(sql, args)
50
51         if self.rowcount != 1:
52             raise RuntimeError("Query did not return a single row.")
53
54         return self.fetchone()[0]
55
56
57     def drop_table(self, name, if_exists=True, cascade=False):
58         """ Drop the table with the given name.
59             Set `if_exists` to False if a non-existant table should raise
60             an exception instead of just being ignored. If 'cascade' is set
61             to True then all dependent tables are deleted as well.
62         """
63         sql = 'DROP TABLE '
64         if if_exists:
65             sql += 'IF EXISTS '
66         sql += '{}'
67         if cascade:
68             sql += ' CASCADE'
69
70         self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
71
72
73 class _Connection(psycopg2.extensions.connection):
74     """ A connection that provides the specialised cursor by default and
75         adds convenience functions for administrating the database.
76     """
77
78     def cursor(self, cursor_factory=_Cursor, **kwargs):
79         """ Return a new cursor. By default the specialised cursor is returned.
80         """
81         return super().cursor(cursor_factory=cursor_factory, **kwargs)
82
83
84     def table_exists(self, table):
85         """ Check that a table with the given name exists in the database.
86         """
87         with self.cursor() as cur:
88             num = cur.scalar("""SELECT count(*) FROM pg_tables
89                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
90             return num == 1
91
92
93     def table_has_column(self, table, column):
94         """ Check if the table 'table' exists and has a column with name 'column'.
95         """
96         with self.cursor() as cur:
97             has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
98                                        WHERE table_name = %s
99                                              and column_name = %s""",
100                                     (table, column))
101             return has_column > 0
102
103
104     def index_exists(self, index, table=None):
105         """ Check that an index with the given name exists in the database.
106             If table is not None then the index must relate to the given
107             table.
108         """
109         with self.cursor() as cur:
110             cur.execute("""SELECT tablename FROM pg_indexes
111                            WHERE indexname = %s and schemaname = 'public'""", (index, ))
112             if cur.rowcount == 0:
113                 return False
114
115             if table is not None:
116                 row = cur.fetchone()
117                 return row[0] == table
118
119         return True
120
121
122     def drop_table(self, name, if_exists=True, cascade=False):
123         """ Drop the table with the given name.
124             Set `if_exists` to False if a non-existant table should raise
125             an exception instead of just being ignored.
126         """
127         with self.cursor() as cur:
128             cur.drop_table(name, if_exists, cascade)
129         self.commit()
130
131
132     def server_version_tuple(self):
133         """ Return the server version as a tuple of (major, minor).
134             Converts correctly for pre-10 and post-10 PostgreSQL versions.
135         """
136         version = self.server_version
137         if version < 100000:
138             return (int(version / 10000), (version % 10000) / 100)
139
140         return (int(version / 10000), version % 10000)
141
142
143     def postgis_version_tuple(self):
144         """ Return the postgis version installed in the database as a
145             tuple of (major, minor). Assumes that the PostGIS extension
146             has been installed already.
147         """
148         with self.cursor() as cur:
149             version = cur.scalar('SELECT postgis_lib_version()')
150
151         return tuple((int(x) for x in version.split('.')[:2]))
152
153
154 def connect(dsn):
155     """ Open a connection to the database using the specialised connection
156         factory. The returned object may be used in conjunction with 'with'.
157         When used outside a context manager, use the `connection` attribute
158         to get the connection.
159     """
160     try:
161         conn = psycopg2.connect(dsn, connection_factory=_Connection)
162         ctxmgr = contextlib.closing(conn)
163         ctxmgr.connection = conn
164         return ctxmgr
165     except psycopg2.OperationalError as err:
166         raise UsageError("Cannot connect to database: {}".format(err)) from err
167
168
169 # Translation from PG connection string parameters to PG environment variables.
170 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
171 _PG_CONNECTION_STRINGS = {
172     'host': 'PGHOST',
173     'hostaddr': 'PGHOSTADDR',
174     'port': 'PGPORT',
175     'dbname': 'PGDATABASE',
176     'user': 'PGUSER',
177     'password': 'PGPASSWORD',
178     'passfile': 'PGPASSFILE',
179     'channel_binding': 'PGCHANNELBINDING',
180     'service': 'PGSERVICE',
181     'options': 'PGOPTIONS',
182     'application_name': 'PGAPPNAME',
183     'sslmode': 'PGSSLMODE',
184     'requiressl': 'PGREQUIRESSL',
185     'sslcompression': 'PGSSLCOMPRESSION',
186     'sslcert': 'PGSSLCERT',
187     'sslkey': 'PGSSLKEY',
188     'sslrootcert': 'PGSSLROOTCERT',
189     'sslcrl': 'PGSSLCRL',
190     'requirepeer': 'PGREQUIREPEER',
191     'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
192     'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
193     'gssencmode': 'PGGSSENCMODE',
194     'krbsrvname': 'PGKRBSRVNAME',
195     'gsslib': 'PGGSSLIB',
196     'connect_timeout': 'PGCONNECT_TIMEOUT',
197     'target_session_attrs': 'PGTARGETSESSIONATTRS',
198 }
199
200
201 def get_pg_env(dsn, base_env=None):
202     """ Return a copy of `base_env` with the environment variables for
203         PostgresSQL set up from the given database connection string.
204         If `base_env` is None, then the OS environment is used as a base
205         environment.
206     """
207     env = dict(base_env if base_env is not None else os.environ)
208
209     for param, value in psycopg2.extensions.parse_dsn(dsn).items():
210         if param in _PG_CONNECTION_STRINGS:
211             env[_PG_CONNECTION_STRINGS[param]] = value
212         else:
213             LOG.error("Unknown connection parameter '%s' ignored.", param)
214
215     return env