1 # SPDX-License-Identifier: GPL-2.0-only
 
   3 # This file is part of Nominatim. (https://nominatim.org)
 
   5 # Copyright (C) 2022 by the Nominatim developer community.
 
   6 # For a full list of authors see the git log.
 
   8 Specialised connection and cursor functions.
 
  10 from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
 
  16 import psycopg2.extensions
 
  17 import psycopg2.extras
 
  18 from psycopg2 import sql as pysql
 
  20 from nominatim.typing import SysEnv, Query, T_cursor
 
  21 from nominatim.errors import UsageError
 
  23 LOG = logging.getLogger()
 
  25 class Cursor(psycopg2.extras.DictCursor):
 
  26     """ A cursor returning dict-like objects and providing specialised
 
  29     # pylint: disable=arguments-renamed,arguments-differ
 
  30     def execute(self, query: Query, args: Any = None) -> None:
 
  31         """ Query execution that logs the SQL query when debugging is enabled.
 
  33         if LOG.isEnabledFor(logging.DEBUG):
 
  34             LOG.debug(self.mogrify(query, args).decode('utf-8'))
 
  36         super().execute(query, args)
 
  39     def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
 
  40                        template: Optional[Query] = None) -> None:
 
  41         """ Wrapper for the psycopg2 convenience function to execute
 
  42             SQL for a list of values.
 
  44         LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
 
  46         psycopg2.extras.execute_values(self, sql, argslist, template=template)
 
  49     def scalar(self, sql: Query, args: Any = None) -> Any:
 
  50         """ Execute query that returns a single value. The value is returned.
 
  51             If the query yields more than one row, a ValueError is raised.
 
  53         self.execute(sql, args)
 
  55         if self.rowcount != 1:
 
  56             raise RuntimeError("Query did not return a single row.")
 
  58         result = self.fetchone()
 
  59         assert result is not None
 
  64     def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
 
  65         """ Drop the table with the given name.
 
  66             Set `if_exists` to False if a non-existent table should raise
 
  67             an exception instead of just being ignored. If 'cascade' is set
 
  68             to True then all dependent tables are deleted as well.
 
  77         self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
 
  80 class Connection(psycopg2.extensions.connection):
 
  81     """ A connection that provides the specialised cursor by default and
 
  82         adds convenience functions for administrating the database.
 
  84     @overload # type: ignore[override]
 
  85     def cursor(self) -> Cursor:
 
  89     def cursor(self, name: str) -> Cursor:
 
  93     def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
 
  96     def cursor(self, cursor_factory  = Cursor, **kwargs): # type: ignore
 
  97         """ Return a new cursor. By default the specialised cursor is returned.
 
  99         return super().cursor(cursor_factory=cursor_factory, **kwargs)
 
 102     def table_exists(self, table: str) -> bool:
 
 103         """ Check that a table with the given name exists in the database.
 
 105         with self.cursor() as cur:
 
 106             num = cur.scalar("""SELECT count(*) FROM pg_tables
 
 107                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
 
 108             return num == 1 if isinstance(num, int) else False
 
 111     def table_has_column(self, table: str, column: str) -> bool:
 
 112         """ Check if the table 'table' exists and has a column with name 'column'.
 
 114         with self.cursor() as cur:
 
 115             has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
 
 116                                        WHERE table_name = %s
 
 117                                              and column_name = %s""",
 
 119             return has_column > 0 if isinstance(has_column, int) else False
 
 122     def index_exists(self, index: str, table: Optional[str] = None) -> bool:
 
 123         """ Check that an index with the given name exists in the database.
 
 124             If table is not None then the index must relate to the given
 
 127         with self.cursor() as cur:
 
 128             cur.execute("""SELECT tablename FROM pg_indexes
 
 129                            WHERE indexname = %s and schemaname = 'public'""", (index, ))
 
 130             if cur.rowcount == 0:
 
 133             if table is not None:
 
 135                 if row is None or not isinstance(row[0], str):
 
 137                 return row[0] == table
 
 142     def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
 
 143         """ Drop the table with the given name.
 
 144             Set `if_exists` to False if a non-existent table should raise
 
 145             an exception instead of just being ignored.
 
 147         with self.cursor() as cur:
 
 148             cur.drop_table(name, if_exists, cascade)
 
 152     def server_version_tuple(self) -> Tuple[int, int]:
 
 153         """ Return the server version as a tuple of (major, minor).
 
 154             Converts correctly for pre-10 and post-10 PostgreSQL versions.
 
 156         version = self.server_version
 
 158             return (int(version / 10000), int((version % 10000) / 100))
 
 160         return (int(version / 10000), version % 10000)
 
 163     def postgis_version_tuple(self) -> Tuple[int, int]:
 
 164         """ Return the postgis version installed in the database as a
 
 165             tuple of (major, minor). Assumes that the PostGIS extension
 
 166             has been installed already.
 
 168         with self.cursor() as cur:
 
 169             version = cur.scalar('SELECT postgis_lib_version()')
 
 171         version_parts = version.split('.')
 
 172         if len(version_parts) < 2:
 
 173             raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
 
 175         return (int(version_parts[0]), int(version_parts[1]))
 
 178     def extension_loaded(self, extension_name: str) -> bool:
 
 179         """ Return True if the hstore extension is loaded in the database.
 
 181         with self.cursor() as cur:
 
 182             cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, ))
 
 183             return cur.rowcount > 0
 
 186 class ConnectionContext(ContextManager[Connection]):
 
 187     """ Context manager of the connection that also provides direct access
 
 188         to the underlying connection.
 
 190     connection: Connection
 
 192 def connect(dsn: str) -> ConnectionContext:
 
 193     """ Open a connection to the database using the specialised connection
 
 194         factory. The returned object may be used in conjunction with 'with'.
 
 195         When used outside a context manager, use the `connection` attribute
 
 196         to get the connection.
 
 199         conn = psycopg2.connect(dsn, connection_factory=Connection)
 
 200         ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
 
 201         ctxmgr.connection = conn
 
 203     except psycopg2.OperationalError as err:
 
 204         raise UsageError(f"Cannot connect to database: {err}") from err
 
 207 # Translation from PG connection string parameters to PG environment variables.
 
 208 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
 
 209 _PG_CONNECTION_STRINGS = {
 
 211     'hostaddr': 'PGHOSTADDR',
 
 213     'dbname': 'PGDATABASE',
 
 215     'password': 'PGPASSWORD',
 
 216     'passfile': 'PGPASSFILE',
 
 217     'channel_binding': 'PGCHANNELBINDING',
 
 218     'service': 'PGSERVICE',
 
 219     'options': 'PGOPTIONS',
 
 220     'application_name': 'PGAPPNAME',
 
 221     'sslmode': 'PGSSLMODE',
 
 222     'requiressl': 'PGREQUIRESSL',
 
 223     'sslcompression': 'PGSSLCOMPRESSION',
 
 224     'sslcert': 'PGSSLCERT',
 
 225     'sslkey': 'PGSSLKEY',
 
 226     'sslrootcert': 'PGSSLROOTCERT',
 
 227     'sslcrl': 'PGSSLCRL',
 
 228     'requirepeer': 'PGREQUIREPEER',
 
 229     'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
 
 230     'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
 
 231     'gssencmode': 'PGGSSENCMODE',
 
 232     'krbsrvname': 'PGKRBSRVNAME',
 
 233     'gsslib': 'PGGSSLIB',
 
 234     'connect_timeout': 'PGCONNECT_TIMEOUT',
 
 235     'target_session_attrs': 'PGTARGETSESSIONATTRS',
 
 239 def get_pg_env(dsn: str,
 
 240                base_env: Optional[SysEnv] = None) -> Dict[str, str]:
 
 241     """ Return a copy of `base_env` with the environment variables for
 
 242         PostgresSQL set up from the given database connection string.
 
 243         If `base_env` is None, then the OS environment is used as a base
 
 246     env = dict(base_env if base_env is not None else os.environ)
 
 248     for param, value in psycopg2.extensions.parse_dsn(dsn).items():
 
 249         if param in _PG_CONNECTION_STRINGS:
 
 250             env[_PG_CONNECTION_STRINGS[param]] = value
 
 252             LOG.error("Unknown connection parameter '%s' ignored.", param)