2 Specialised connection and cursor functions.
 
   9 import psycopg2.extensions
 
  10 import psycopg2.extras
 
  12 from nominatim.errors import UsageError
 
  14 LOG = logging.getLogger()
 
  16 class _Cursor(psycopg2.extras.DictCursor):
 
  17     """ A cursor returning dict-like objects and providing specialised
 
  21     def execute(self, query, args=None): # pylint: disable=W0221
 
  22         """ Query execution that logs the SQL query when debugging is enabled.
 
  24         LOG.debug(self.mogrify(query, args).decode('utf-8'))
 
  26         super().execute(query, args)
 
  28     def scalar(self, sql, args=None):
 
  29         """ Execute query that returns a single value. The value is returned.
 
  30             If the query yields more than one row, a ValueError is raised.
 
  32         self.execute(sql, args)
 
  34         if self.rowcount != 1:
 
  35             raise RuntimeError("Query did not return a single row.")
 
  37         return self.fetchone()[0]
 
  40 class _Connection(psycopg2.extensions.connection):
 
  41     """ A connection that provides the specialised cursor by default and
 
  42         adds convenience functions for administrating the database.
 
  45     def cursor(self, cursor_factory=_Cursor, **kwargs):
 
  46         """ Return a new cursor. By default the specialised cursor is returned.
 
  48         return super().cursor(cursor_factory=cursor_factory, **kwargs)
 
  51     def table_exists(self, table):
 
  52         """ Check that a table with the given name exists in the database.
 
  54         with self.cursor() as cur:
 
  55             num = cur.scalar("""SELECT count(*) FROM pg_tables
 
  56                                 WHERE tablename = %s and schemaname = 'public'""", (table, ))
 
  60     def index_exists(self, index, table=None):
 
  61         """ Check that an index with the given name exists in the database.
 
  62             If table is not None then the index must relate to the given
 
  65         with self.cursor() as cur:
 
  66             cur.execute("""SELECT tablename FROM pg_indexes
 
  67                            WHERE indexname = %s and schemaname = 'public'""", (index, ))
 
  73                 return row[0] == table
 
  78     def drop_table(self, name, if_exists=True):
 
  79         """ Drop the table with the given name.
 
  80             Set `if_exists` to False if a non-existant table should raise
 
  81             an exception instead of just being ignored.
 
  83         with self.cursor() as cur:
 
  84             cur.execute("""DROP TABLE {} "{}"
 
  85                         """.format('IF EXISTS' if if_exists else '', name))
 
  89     def server_version_tuple(self):
 
  90         """ Return the server version as a tuple of (major, minor).
 
  91             Converts correctly for pre-10 and post-10 PostgreSQL versions.
 
  93         version = self.server_version
 
  95             return (int(version / 10000), (version % 10000) / 100)
 
  97         return (int(version / 10000), version % 10000)
 
 100     def postgis_version_tuple(self):
 
 101         """ Return the postgis version installed in the database as a
 
 102             tuple of (major, minor). Assumes that the PostGIS extension
 
 103             has been installed already.
 
 105         with self.cursor() as cur:
 
 106             version = cur.scalar('SELECT postgis_lib_version()')
 
 108         return tuple((int(x) for x in version.split('.')[:2]))
 
 112     """ Open a connection to the database using the specialised connection
 
 113         factory. The returned object may be used in conjunction with 'with'.
 
 114         When used outside a context manager, use the `connection` attribute
 
 115         to get the connection.
 
 118         conn = psycopg2.connect(dsn, connection_factory=_Connection)
 
 119         ctxmgr = contextlib.closing(conn)
 
 120         ctxmgr.connection = conn
 
 122     except psycopg2.OperationalError as err:
 
 123         raise UsageError("Cannot connect to database: {}".format(err)) from err
 
 126 # Translation from PG connection string parameters to PG environment variables.
 
 127 # Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
 
 128 _PG_CONNECTION_STRINGS = {
 
 130     'hostaddr': 'PGHOSTADDR',
 
 132     'dbname': 'PGDATABASE',
 
 134     'password': 'PGPASSWORD',
 
 135     'passfile': 'PGPASSFILE',
 
 136     'channel_binding': 'PGCHANNELBINDING',
 
 137     'service': 'PGSERVICE',
 
 138     'options': 'PGOPTIONS',
 
 139     'application_name': 'PGAPPNAME',
 
 140     'sslmode': 'PGSSLMODE',
 
 141     'requiressl': 'PGREQUIRESSL',
 
 142     'sslcompression': 'PGSSLCOMPRESSION',
 
 143     'sslcert': 'PGSSLCERT',
 
 144     'sslkey': 'PGSSLKEY',
 
 145     'sslrootcert': 'PGSSLROOTCERT',
 
 146     'sslcrl': 'PGSSLCRL',
 
 147     'requirepeer': 'PGREQUIREPEER',
 
 148     'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
 
 149     'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
 
 150     'gssencmode': 'PGGSSENCMODE',
 
 151     'krbsrvname': 'PGKRBSRVNAME',
 
 152     'gsslib': 'PGGSSLIB',
 
 153     'connect_timeout': 'PGCONNECT_TIMEOUT',
 
 154     'target_session_attrs': 'PGTARGETSESSIONATTRS',
 
 158 def get_pg_env(dsn, base_env=None):
 
 159     """ Return a copy of `base_env` with the environment variables for
 
 160         PostgresSQL set up from the given database connection string.
 
 161         If `base_env` is None, then the OS environment is used as a base
 
 164     env = dict(base_env if base_env is not None else os.environ)
 
 166     for param, value in psycopg2.extensions.parse_dsn(dsn).items():
 
 167         if param in _PG_CONNECTION_STRINGS:
 
 168             env[_PG_CONNECTION_STRINGS[param]] = value
 
 170             LOG.error("Unknown connection parameter '%s' ignored.", param)