1 # SPDX-License-Identifier: GPL-2.0-only
 
   3 # This file is part of Nominatim. (https://nominatim.org)
 
   5 # Copyright (C) 2022 by the Nominatim developer community.
 
   6 # For a full list of authors see the git log.
 
   8 Specialised connection and cursor functions.
 
  15 import psycopg2.extensions
 
  16 import psycopg2.extras
 
  17 from psycopg2 import sql as pysql
 
  19 from nominatim.errors import UsageError
 
  21 LOG = logging.getLogger()
 
  23 class _Cursor(psycopg2.extras.DictCursor):
 
  24     """ A cursor returning dict-like objects and providing specialised
 
  28     def execute(self, query, args=None): # pylint: disable=W0221
 
  29         """ Query execution that logs the SQL query when debugging is enabled.
 
  31         LOG.debug(self.mogrify(query, args).decode('utf-8'))
 
  33         super().execute(query, args)
 
  36     def execute_values(self, sql, argslist, template=None):
 
  37         """ Wrapper for the psycopg2 convenience function to execute
 
  38             SQL for a list of values.
 
  40         LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
 
  42         psycopg2.extras.execute_values(self, sql, argslist, template=template)
 
  45     def scalar(self, sql, args=None):
 
  46         """ Execute query that returns a single value. The value is returned.
 
  47             If the query yields more than one row, a ValueError is raised.
 
  49         self.execute(sql, args)
 
  51         if self.rowcount != 1:
 
  52             raise RuntimeError("Query did not return a single row.")
 
  54         return self.fetchone()[0]
 
  57     def drop_table(self, name, if_exists=True, cascade=False):
 
  58         """ Drop the table with the given name.
 
  59             Set `if_exists` to False if a non-existant table should raise
 
  60             an exception instead of just being ignored. If 'cascade' is set
 
  61             to True then all dependent tables are deleted as well.
 
  70         self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
 
  73 class _Connection(psycopg2.extensions.connection):
 
  74     """ A connection that provides the specialised cursor by default and
 
  75         adds convenience functions for administrating the database.
 
  78     def cursor(self, cursor_factory=_Cursor, **kwargs):
 
  79         """ Return a new cursor. By default the specialised cursor is returned.
 
  81         return super().cursor(cursor_factory=cursor_factory, **kwargs)
 
  84     def table_exists(self, table):
 
  85         """ Check that a table with the given name exists in the database.
 
  87         with self.cursor() as cur:
 
  88             num = cur.scalar("""SELECT count(*) FROM pg_tables
 
  89                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
 
  93     def index_exists(self, index, table=None):
 
  94         """ Check that an index with the given name exists in the database.
 
  95             If table is not None then the index must relate to the given
 
  98         with self.cursor() as cur:
 
  99             cur.execute("""SELECT tablename FROM pg_indexes
 
 100                            WHERE indexname = %s and schemaname = 'public'""", (index, ))
 
 101             if cur.rowcount == 0:
 
 104             if table is not None:
 
 106                 return row[0] == table
 
 111     def drop_table(self, name, if_exists=True, cascade=False):
 
 112         """ Drop the table with the given name.
 
 113             Set `if_exists` to False if a non-existant table should raise
 
 114             an exception instead of just being ignored.
 
 116         with self.cursor() as cur:
 
 117             cur.drop_table(name, if_exists, cascade)
 
 121     def server_version_tuple(self):
 
 122         """ Return the server version as a tuple of (major, minor).
 
 123             Converts correctly for pre-10 and post-10 PostgreSQL versions.
 
 125         version = self.server_version
 
 127             return (int(version / 10000), (version % 10000) / 100)
 
 129         return (int(version / 10000), version % 10000)
 
 132     def postgis_version_tuple(self):
 
 133         """ Return the postgis version installed in the database as a
 
 134             tuple of (major, minor). Assumes that the PostGIS extension
 
 135             has been installed already.
 
 137         with self.cursor() as cur:
 
 138             version = cur.scalar('SELECT postgis_lib_version()')
 
 140         return tuple((int(x) for x in version.split('.')[:2]))
 
 144     """ Open a connection to the database using the specialised connection
 
 145         factory. The returned object may be used in conjunction with 'with'.
 
 146         When used outside a context manager, use the `connection` attribute
 
 147         to get the connection.
 
 150         conn = psycopg2.connect(dsn, connection_factory=_Connection)
 
 151         ctxmgr = contextlib.closing(conn)
 
 152         ctxmgr.connection = conn
 
 154     except psycopg2.OperationalError as err:
 
 155         raise UsageError("Cannot connect to database: {}".format(err)) from err
 
 158 # Translation from PG connection string parameters to PG environment variables.
 
 159 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
 
 160 _PG_CONNECTION_STRINGS = {
 
 162     'hostaddr': 'PGHOSTADDR',
 
 164     'dbname': 'PGDATABASE',
 
 166     'password': 'PGPASSWORD',
 
 167     'passfile': 'PGPASSFILE',
 
 168     'channel_binding': 'PGCHANNELBINDING',
 
 169     'service': 'PGSERVICE',
 
 170     'options': 'PGOPTIONS',
 
 171     'application_name': 'PGAPPNAME',
 
 172     'sslmode': 'PGSSLMODE',
 
 173     'requiressl': 'PGREQUIRESSL',
 
 174     'sslcompression': 'PGSSLCOMPRESSION',
 
 175     'sslcert': 'PGSSLCERT',
 
 176     'sslkey': 'PGSSLKEY',
 
 177     'sslrootcert': 'PGSSLROOTCERT',
 
 178     'sslcrl': 'PGSSLCRL',
 
 179     'requirepeer': 'PGREQUIREPEER',
 
 180     'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
 
 181     'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
 
 182     'gssencmode': 'PGGSSENCMODE',
 
 183     'krbsrvname': 'PGKRBSRVNAME',
 
 184     'gsslib': 'PGGSSLIB',
 
 185     'connect_timeout': 'PGCONNECT_TIMEOUT',
 
 186     'target_session_attrs': 'PGTARGETSESSIONATTRS',
 
 190 def get_pg_env(dsn, base_env=None):
 
 191     """ Return a copy of `base_env` with the environment variables for
 
 192         PostgresSQL set up from the given database connection string.
 
 193         If `base_env` is None, then the OS environment is used as a base
 
 196     env = dict(base_env if base_env is not None else os.environ)
 
 198     for param, value in psycopg2.extensions.parse_dsn(dsn).items():
 
 199         if param in _PG_CONNECTION_STRINGS:
 
 200             env[_PG_CONNECTION_STRINGS[param]] = value
 
 202             LOG.error("Unknown connection parameter '%s' ignored.", param)