]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/db/connection.py
type annotations for DB connection
[nominatim.git] / nominatim / db / connection.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2022 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Specialised connection and cursor functions.
9 """
10 from typing import Union, List, Optional, Any, Callable, ContextManager, Mapping, cast, TypeVar, overload, Tuple, Sequence
11 import contextlib
12 import logging
13 import os
14
15 import psycopg2
16 import psycopg2.extensions
17 import psycopg2.extras
18 from psycopg2 import sql as pysql
19
20 from nominatim.errors import UsageError
21
22 Query = Union[str, bytes, pysql.Composable]
23 T = TypeVar('T', bound=psycopg2.extensions.cursor)
24
25 LOG = logging.getLogger()
26
27 class _Cursor(psycopg2.extras.DictCursor):
28     """ A cursor returning dict-like objects and providing specialised
29         execution functions.
30     """
31     # pylint: disable=arguments-renamed,arguments-differ
32     def execute(self, query: Query, args: Any = None) -> None:
33         """ Query execution that logs the SQL query when debugging is enabled.
34         """
35         if LOG.isEnabledFor(logging.DEBUG):
36             LOG.debug(self.mogrify(query, args).decode('utf-8')) # type: ignore
37
38         super().execute(query, args)
39
40
41     def execute_values(self, sql: Query, argslist: List[Any], template: Optional[str] = None) -> None:
42         """ Wrapper for the psycopg2 convenience function to execute
43             SQL for a list of values.
44         """
45         LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
46
47         psycopg2.extras.execute_values(self, sql, argslist, template=template)
48
49
50     def scalar(self, sql: Query, args: Any = None) -> Any:
51         """ Execute query that returns a single value. The value is returned.
52             If the query yields more than one row, a ValueError is raised.
53         """
54         self.execute(sql, args)
55
56         if self.rowcount != 1:
57             raise RuntimeError("Query did not return a single row.")
58
59         result = self.fetchone() # type: ignore
60         assert result is not None
61
62         return result[0]
63
64
65     def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
66         """ Drop the table with the given name.
67             Set `if_exists` to False if a non-existant table should raise
68             an exception instead of just being ignored. If 'cascade' is set
69             to True then all dependent tables are deleted as well.
70         """
71         sql = 'DROP TABLE '
72         if if_exists:
73             sql += 'IF EXISTS '
74         sql += '{}'
75         if cascade:
76             sql += ' CASCADE'
77
78         self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore
79
80
81 class _Connection(psycopg2.extensions.connection):
82     """ A connection that provides the specialised cursor by default and
83         adds convenience functions for administrating the database.
84     """
85     @overload # type: ignore[override]
86     def cursor(self) -> _Cursor:
87         ...
88
89     @overload
90     def cursor(self, name: str) -> _Cursor:
91         ...
92
93     @overload
94     def cursor(self, cursor_factory: Callable[..., T]) -> T:
95         ...
96
97     def cursor(self, cursor_factory  = _Cursor, **kwargs): # type: ignore
98         """ Return a new cursor. By default the specialised cursor is returned.
99         """
100         return super().cursor(cursor_factory=cursor_factory, **kwargs)
101
102
103     def table_exists(self, table: str) -> bool:
104         """ Check that a table with the given name exists in the database.
105         """
106         with self.cursor() as cur:
107             num = cur.scalar("""SELECT count(*) FROM pg_tables
108                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
109             return num == 1 if isinstance(num, int) else False
110
111
112     def table_has_column(self, table: str, column: str) -> bool:
113         """ Check if the table 'table' exists and has a column with name 'column'.
114         """
115         with self.cursor() as cur:
116             has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
117                                        WHERE table_name = %s
118                                              and column_name = %s""",
119                                     (table, column))
120             return has_column > 0 if isinstance(has_column, int) else False
121
122
123     def index_exists(self, index: str, table: Optional[str] = None) -> bool:
124         """ Check that an index with the given name exists in the database.
125             If table is not None then the index must relate to the given
126             table.
127         """
128         with self.cursor() as cur:
129             cur.execute("""SELECT tablename FROM pg_indexes
130                            WHERE indexname = %s and schemaname = 'public'""", (index, ))
131             if cur.rowcount == 0:
132                 return False
133
134             if table is not None:
135                 row = cur.fetchone() # type: ignore
136                 if row is None or not isinstance(row[0], str):
137                     return False
138                 return row[0] == table
139
140         return True
141
142
143     def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
144         """ Drop the table with the given name.
145             Set `if_exists` to False if a non-existant table should raise
146             an exception instead of just being ignored.
147         """
148         with self.cursor() as cur:
149             cur.drop_table(name, if_exists, cascade)
150         self.commit()
151
152
153     def server_version_tuple(self) -> Tuple[int, int]:
154         """ Return the server version as a tuple of (major, minor).
155             Converts correctly for pre-10 and post-10 PostgreSQL versions.
156         """
157         version = self.server_version
158         if version < 100000:
159             return (int(version / 10000), int((version % 10000) / 100))
160
161         return (int(version / 10000), version % 10000)
162
163
164     def postgis_version_tuple(self) -> Tuple[int, int]:
165         """ Return the postgis version installed in the database as a
166             tuple of (major, minor). Assumes that the PostGIS extension
167             has been installed already.
168         """
169         with self.cursor() as cur:
170             version = cur.scalar('SELECT postgis_lib_version()')
171
172         version_parts = version.split('.')
173         if len(version_parts) < 2:
174             raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
175
176         return (int(version_parts[0]), int(version_parts[1]))
177
178 class _ConnectionContext(ContextManager[_Connection]):
179     connection: _Connection
180
181 def connect(dsn: str) -> _ConnectionContext:
182     """ Open a connection to the database using the specialised connection
183         factory. The returned object may be used in conjunction with 'with'.
184         When used outside a context manager, use the `connection` attribute
185         to get the connection.
186     """
187     try:
188         conn = psycopg2.connect(dsn, connection_factory=_Connection)
189         ctxmgr = cast(_ConnectionContext, contextlib.closing(conn))
190         ctxmgr.connection = cast(_Connection, conn)
191         return ctxmgr
192     except psycopg2.OperationalError as err:
193         raise UsageError(f"Cannot connect to database: {err}") from err
194
195
196 # Translation from PG connection string parameters to PG environment variables.
197 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
198 _PG_CONNECTION_STRINGS = {
199     'host': 'PGHOST',
200     'hostaddr': 'PGHOSTADDR',
201     'port': 'PGPORT',
202     'dbname': 'PGDATABASE',
203     'user': 'PGUSER',
204     'password': 'PGPASSWORD',
205     'passfile': 'PGPASSFILE',
206     'channel_binding': 'PGCHANNELBINDING',
207     'service': 'PGSERVICE',
208     'options': 'PGOPTIONS',
209     'application_name': 'PGAPPNAME',
210     'sslmode': 'PGSSLMODE',
211     'requiressl': 'PGREQUIRESSL',
212     'sslcompression': 'PGSSLCOMPRESSION',
213     'sslcert': 'PGSSLCERT',
214     'sslkey': 'PGSSLKEY',
215     'sslrootcert': 'PGSSLROOTCERT',
216     'sslcrl': 'PGSSLCRL',
217     'requirepeer': 'PGREQUIREPEER',
218     'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
219     'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
220     'gssencmode': 'PGGSSENCMODE',
221     'krbsrvname': 'PGKRBSRVNAME',
222     'gsslib': 'PGGSSLIB',
223     'connect_timeout': 'PGCONNECT_TIMEOUT',
224     'target_session_attrs': 'PGTARGETSESSIONATTRS',
225 }
226
227
228 def get_pg_env(dsn: str,
229                base_env: Optional[Mapping[str, Optional[str]]] = None) -> Mapping[str, Optional[str]]:
230     """ Return a copy of `base_env` with the environment variables for
231         PostgresSQL set up from the given database connection string.
232         If `base_env` is None, then the OS environment is used as a base
233         environment.
234     """
235     env = dict(base_env if base_env is not None else os.environ)
236
237     for param, value in psycopg2.extensions.parse_dsn(dsn).items(): # type: ignore
238         if param in _PG_CONNECTION_STRINGS:
239             env[_PG_CONNECTION_STRINGS[param]] = value
240         else:
241             LOG.error("Unknown connection parameter '%s' ignored.", param)
242
243     return env