]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/connection.py
1eb3599bee64a8cdfbb8c08972138fc6c47e282a
[nominatim.git] / nominatim / db / connection.py
1 """
2 Specialised connection and cursor functions.
3 """
4 import contextlib
5 import logging
6 import os
7
8 import psycopg2
9 import psycopg2.extensions
10 import psycopg2.extras
11 from psycopg2 import sql as pysql
12
13 from nominatim.errors import UsageError
14
15 LOG = logging.getLogger()
16
17 class _Cursor(psycopg2.extras.DictCursor):
18     """ A cursor returning dict-like objects and providing specialised
19         execution functions.
20     """
21
22     def execute(self, query, args=None): # pylint: disable=W0221
23         """ Query execution that logs the SQL query when debugging is enabled.
24         """
25         LOG.debug(self.mogrify(query, args).decode('utf-8'))
26
27         super().execute(query, args)
28
29     def scalar(self, sql, args=None):
30         """ Execute query that returns a single value. The value is returned.
31             If the query yields more than one row, a ValueError is raised.
32         """
33         self.execute(sql, args)
34
35         if self.rowcount != 1:
36             raise RuntimeError("Query did not return a single row.")
37
38         return self.fetchone()[0]
39
40
41     def drop_table(self, name, if_exists=True, cascade=False):
42         """ Drop the table with the given name.
43             Set `if_exists` to False if a non-existant table should raise
44             an exception instead of just being ignored. If 'cascade' is set
45             to True then all dependent tables are deleted as well.
46         """
47         sql = 'DROP TABLE '
48         if if_exists:
49             sql += 'IF EXISTS '
50         sql += '{}'
51         if cascade:
52             sql += ' CASCADE'
53
54         self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
55
56
57 class _Connection(psycopg2.extensions.connection):
58     """ A connection that provides the specialised cursor by default and
59         adds convenience functions for administrating the database.
60     """
61
62     def cursor(self, cursor_factory=_Cursor, **kwargs):
63         """ Return a new cursor. By default the specialised cursor is returned.
64         """
65         return super().cursor(cursor_factory=cursor_factory, **kwargs)
66
67
68     def table_exists(self, table):
69         """ Check that a table with the given name exists in the database.
70         """
71         with self.cursor() as cur:
72             num = cur.scalar("""SELECT count(*) FROM pg_tables
73                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
74             return num == 1
75
76
77     def index_exists(self, index, table=None):
78         """ Check that an index with the given name exists in the database.
79             If table is not None then the index must relate to the given
80             table.
81         """
82         with self.cursor() as cur:
83             cur.execute("""SELECT tablename FROM pg_indexes
84                            WHERE indexname = %s and schemaname = 'public'""", (index, ))
85             if cur.rowcount == 0:
86                 return False
87
88             if table is not None:
89                 row = cur.fetchone()
90                 return row[0] == table
91
92         return True
93
94
95     def drop_table(self, name, if_exists=True, cascade=False):
96         """ Drop the table with the given name.
97             Set `if_exists` to False if a non-existant table should raise
98             an exception instead of just being ignored.
99         """
100         with self.cursor() as cur:
101             cur.drop_table(name, if_exists, cascade)
102         self.commit()
103
104
105     def server_version_tuple(self):
106         """ Return the server version as a tuple of (major, minor).
107             Converts correctly for pre-10 and post-10 PostgreSQL versions.
108         """
109         version = self.server_version
110         if version < 100000:
111             return (int(version / 10000), (version % 10000) / 100)
112
113         return (int(version / 10000), version % 10000)
114
115
116     def postgis_version_tuple(self):
117         """ Return the postgis version installed in the database as a
118             tuple of (major, minor). Assumes that the PostGIS extension
119             has been installed already.
120         """
121         with self.cursor() as cur:
122             version = cur.scalar('SELECT postgis_lib_version()')
123
124         return tuple((int(x) for x in version.split('.')[:2]))
125
126
127 def connect(dsn):
128     """ Open a connection to the database using the specialised connection
129         factory. The returned object may be used in conjunction with 'with'.
130         When used outside a context manager, use the `connection` attribute
131         to get the connection.
132     """
133     try:
134         conn = psycopg2.connect(dsn, connection_factory=_Connection)
135         ctxmgr = contextlib.closing(conn)
136         ctxmgr.connection = conn
137         return ctxmgr
138     except psycopg2.OperationalError as err:
139         raise UsageError("Cannot connect to database: {}".format(err)) from err
140
141
142 # Translation from PG connection string parameters to PG environment variables.
143 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
144 _PG_CONNECTION_STRINGS = {
145     'host': 'PGHOST',
146     'hostaddr': 'PGHOSTADDR',
147     'port': 'PGPORT',
148     'dbname': 'PGDATABASE',
149     'user': 'PGUSER',
150     'password': 'PGPASSWORD',
151     'passfile': 'PGPASSFILE',
152     'channel_binding': 'PGCHANNELBINDING',
153     'service': 'PGSERVICE',
154     'options': 'PGOPTIONS',
155     'application_name': 'PGAPPNAME',
156     'sslmode': 'PGSSLMODE',
157     'requiressl': 'PGREQUIRESSL',
158     'sslcompression': 'PGSSLCOMPRESSION',
159     'sslcert': 'PGSSLCERT',
160     'sslkey': 'PGSSLKEY',
161     'sslrootcert': 'PGSSLROOTCERT',
162     'sslcrl': 'PGSSLCRL',
163     'requirepeer': 'PGREQUIREPEER',
164     'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
165     'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
166     'gssencmode': 'PGGSSENCMODE',
167     'krbsrvname': 'PGKRBSRVNAME',
168     'gsslib': 'PGGSSLIB',
169     'connect_timeout': 'PGCONNECT_TIMEOUT',
170     'target_session_attrs': 'PGTARGETSESSIONATTRS',
171 }
172
173
174 def get_pg_env(dsn, base_env=None):
175     """ Return a copy of `base_env` with the environment variables for
176         PostgresSQL set up from the given database connection string.
177         If `base_env` is None, then the OS environment is used as a base
178         environment.
179     """
180     env = dict(base_env if base_env is not None else os.environ)
181
182     for param, value in psycopg2.extensions.parse_dsn(dsn).items():
183         if param in _PG_CONNECTION_STRINGS:
184             env[_PG_CONNECTION_STRINGS[param]] = value
185         else:
186             LOG.error("Unknown connection parameter '%s' ignored.", param)
187
188     return env