1 # SPDX-License-Identifier: GPL-2.0-only
 
   3 # This file is part of Nominatim. (https://nominatim.org)
 
   5 # Copyright (C) 2022 by the Nominatim developer community.
 
   6 # For a full list of authors see the git log.
 
   8 Specialised connection and cursor functions.
 
  15 import psycopg2.extensions
 
  16 import psycopg2.extras
 
  17 from psycopg2 import sql as pysql
 
  19 from nominatim.errors import UsageError
 
  21 LOG = logging.getLogger()
 
  23 class _Cursor(psycopg2.extras.DictCursor):
 
  24     """ A cursor returning dict-like objects and providing specialised
 
  28     # pylint: disable=arguments-renamed,arguments-differ
 
  29     def execute(self, query, args=None):
 
  30         """ Query execution that logs the SQL query when debugging is enabled.
 
  32         LOG.debug(self.mogrify(query, args).decode('utf-8'))
 
  34         super().execute(query, args)
 
  37     def execute_values(self, sql, argslist, template=None):
 
  38         """ Wrapper for the psycopg2 convenience function to execute
 
  39             SQL for a list of values.
 
  41         LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
 
  43         psycopg2.extras.execute_values(self, sql, argslist, template=template)
 
  46     def scalar(self, sql, args=None):
 
  47         """ Execute query that returns a single value. The value is returned.
 
  48             If the query yields more than one row, a ValueError is raised.
 
  50         self.execute(sql, args)
 
  52         if self.rowcount != 1:
 
  53             raise RuntimeError("Query did not return a single row.")
 
  55         return self.fetchone()[0]
 
  58     def drop_table(self, name, if_exists=True, cascade=False):
 
  59         """ Drop the table with the given name.
 
  60             Set `if_exists` to False if a non-existant table should raise
 
  61             an exception instead of just being ignored. If 'cascade' is set
 
  62             to True then all dependent tables are deleted as well.
 
  71         self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
 
  74 class _Connection(psycopg2.extensions.connection):
 
  75     """ A connection that provides the specialised cursor by default and
 
  76         adds convenience functions for administrating the database.
 
  79     def cursor(self, cursor_factory=_Cursor, **kwargs):
 
  80         """ Return a new cursor. By default the specialised cursor is returned.
 
  82         return super().cursor(cursor_factory=cursor_factory, **kwargs)
 
  85     def table_exists(self, table):
 
  86         """ Check that a table with the given name exists in the database.
 
  88         with self.cursor() as cur:
 
  89             num = cur.scalar("""SELECT count(*) FROM pg_tables
 
  90                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
 
  94     def table_has_column(self, table, column):
 
  95         """ Check if the table 'table' exists and has a column with name 'column'.
 
  97         with self.cursor() as cur:
 
  98             has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
 
 100                                              and column_name = %s""",
 
 102             return has_column > 0
 
 105     def index_exists(self, index, table=None):
 
 106         """ Check that an index with the given name exists in the database.
 
 107             If table is not None then the index must relate to the given
 
 110         with self.cursor() as cur:
 
 111             cur.execute("""SELECT tablename FROM pg_indexes
 
 112                            WHERE indexname = %s and schemaname = 'public'""", (index, ))
 
 113             if cur.rowcount == 0:
 
 116             if table is not None:
 
 118                 return row[0] == table
 
 123     def drop_table(self, name, if_exists=True, cascade=False):
 
 124         """ Drop the table with the given name.
 
 125             Set `if_exists` to False if a non-existant table should raise
 
 126             an exception instead of just being ignored.
 
 128         with self.cursor() as cur:
 
 129             cur.drop_table(name, if_exists, cascade)
 
 133     def server_version_tuple(self):
 
 134         """ Return the server version as a tuple of (major, minor).
 
 135             Converts correctly for pre-10 and post-10 PostgreSQL versions.
 
 137         version = self.server_version
 
 139             return (int(version / 10000), (version % 10000) / 100)
 
 141         return (int(version / 10000), version % 10000)
 
 144     def postgis_version_tuple(self):
 
 145         """ Return the postgis version installed in the database as a
 
 146             tuple of (major, minor). Assumes that the PostGIS extension
 
 147             has been installed already.
 
 149         with self.cursor() as cur:
 
 150             version = cur.scalar('SELECT postgis_lib_version()')
 
 152         return tuple((int(x) for x in version.split('.')[:2]))
 
 156     """ Open a connection to the database using the specialised connection
 
 157         factory. The returned object may be used in conjunction with 'with'.
 
 158         When used outside a context manager, use the `connection` attribute
 
 159         to get the connection.
 
 162         conn = psycopg2.connect(dsn, connection_factory=_Connection)
 
 163         ctxmgr = contextlib.closing(conn)
 
 164         ctxmgr.connection = conn
 
 166     except psycopg2.OperationalError as err:
 
 167         raise UsageError(f"Cannot connect to database: {err}") from err
 
 170 # Translation from PG connection string parameters to PG environment variables.
 
 171 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
 
 172 _PG_CONNECTION_STRINGS = {
 
 174     'hostaddr': 'PGHOSTADDR',
 
 176     'dbname': 'PGDATABASE',
 
 178     'password': 'PGPASSWORD',
 
 179     'passfile': 'PGPASSFILE',
 
 180     'channel_binding': 'PGCHANNELBINDING',
 
 181     'service': 'PGSERVICE',
 
 182     'options': 'PGOPTIONS',
 
 183     'application_name': 'PGAPPNAME',
 
 184     'sslmode': 'PGSSLMODE',
 
 185     'requiressl': 'PGREQUIRESSL',
 
 186     'sslcompression': 'PGSSLCOMPRESSION',
 
 187     'sslcert': 'PGSSLCERT',
 
 188     'sslkey': 'PGSSLKEY',
 
 189     'sslrootcert': 'PGSSLROOTCERT',
 
 190     'sslcrl': 'PGSSLCRL',
 
 191     'requirepeer': 'PGREQUIREPEER',
 
 192     'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
 
 193     'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
 
 194     'gssencmode': 'PGGSSENCMODE',
 
 195     'krbsrvname': 'PGKRBSRVNAME',
 
 196     'gsslib': 'PGGSSLIB',
 
 197     'connect_timeout': 'PGCONNECT_TIMEOUT',
 
 198     'target_session_attrs': 'PGTARGETSESSIONATTRS',
 
 202 def get_pg_env(dsn, base_env=None):
 
 203     """ Return a copy of `base_env` with the environment variables for
 
 204         PostgresSQL set up from the given database connection string.
 
 205         If `base_env` is None, then the OS environment is used as a base
 
 208     env = dict(base_env if base_env is not None else os.environ)
 
 210     for param, value in psycopg2.extensions.parse_dsn(dsn).items():
 
 211         if param in _PG_CONNECTION_STRINGS:
 
 212             env[_PG_CONNECTION_STRINGS[param]] = value
 
 214             LOG.error("Unknown connection parameter '%s' ignored.", param)