]> git.openstreetmap.org Git - nominatim.git/commitdiff
add function to set up libpq environment
authorSarah Hoffmann <lonvia@denofr.de>
Tue, 23 Feb 2021 13:11:11 +0000 (14:11 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Thu, 25 Feb 2021 17:42:54 +0000 (18:42 +0100)
Instead of parsing the DSN for each external libpq program we
are going to execute, provide a function that feeds them all
necessary parameters through the environment.

osm2pgsql is the first user.

nominatim/db/connection.py
nominatim/tools/exec_utils.py
test/python/test_db_connection.py

index 6bd81a2ff53d2025e164169a12ddd9bbaa88edab..68e988f6a4c57023e23549c86946f1df28a2ee9a 100644 (file)
@@ -3,6 +3,7 @@ Specialised connection and cursor functions.
 """
 import contextlib
 import logging
+import os
 
 import psycopg2
 import psycopg2.extensions
@@ -10,6 +11,8 @@ import psycopg2.extras
 
 from ..errors import UsageError
 
+LOG = logging.getLogger()
+
 class _Cursor(psycopg2.extras.DictCursor):
     """ A cursor returning dict-like objects and providing specialised
         execution functions.
@@ -18,8 +21,7 @@ class _Cursor(psycopg2.extras.DictCursor):
     def execute(self, query, args=None): # pylint: disable=W0221
         """ Query execution that logs the SQL query when debugging is enabled.
         """
-        logger = logging.getLogger()
-        logger.debug(self.mogrify(query, args).decode('utf-8'))
+        LOG.debug(self.mogrify(query, args).decode('utf-8'))
 
         super().execute(query, args)
 
@@ -96,3 +98,52 @@ def connect(dsn):
         return ctxmgr
     except psycopg2.OperationalError as err:
         raise UsageError("Cannot connect to database: {}".format(err)) from err
+
+
+# Translation from PG connection string parameters to PG environment variables.
+# Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
+_PG_CONNECTION_STRINGS = {
+    'host': 'PGHOST',
+    'hostaddr': 'PGHOSTADDR',
+    'port': 'PGPORT',
+    'dbname': 'PGDATABASE',
+    'user': 'PGUSER',
+    'password': 'PGPASSWORD',
+    'passfile': 'PGPASSFILE',
+    'channel_binding': 'PGCHANNELBINDING',
+    'service': 'PGSERVICE',
+    'options': 'PGOPTIONS',
+    'application_name': 'PGAPPNAME',
+    'sslmode': 'PGSSLMODE',
+    'requiressl': 'PGREQUIRESSL',
+    'sslcompression': 'PGSSLCOMPRESSION',
+    'sslcert': 'PGSSLCERT',
+    'sslkey': 'PGSSLKEY',
+    'sslrootcert': 'PGSSLROOTCERT',
+    'sslcrl': 'PGSSLCRL',
+    'requirepeer': 'PGREQUIREPEER',
+    'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
+    'ssl_min_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
+    'gssencmode': 'PGGSSENCMODE',
+    'krbsrvname': 'PGKRBSRVNAME',
+    'gsslib': 'PGGSSLIB',
+    'connect_timeout': 'PGCONNECT_TIMEOUT',
+    'target_session_attrs': 'PGTARGETSESSIONATTRS',
+}
+
+
+def get_pg_env(dsn, base_env=None):
+    """ Return a copy of `base_env` with the environment variables for
+        PostgresSQL set up from the given database connection string.
+        If `base_env` is None, then the OS environment is used as a base
+        environment.
+    """
+    env = base_env if base_env is not None else os.environ
+
+    for param, value in psycopg2.extensions.parse_dsn(dsn).items():
+        if param in _PG_CONNECTION_STRINGS:
+            env[_PG_CONNECTION_STRINGS[param]] = value
+        else:
+            LOG.error("Unknown connection parameter '%s' ignored.", param)
+
+    return env
index f373f347dd23936fd155edda465373fcb09e42d4..004a821c5f14d789af393cc93e298f1a5519cf16 100644 (file)
@@ -10,6 +10,7 @@ from urllib.parse import urlencode
 from psycopg2.extensions import parse_dsn
 
 from ..version import NOMINATIM_VERSION
+from ..db.connection import get_pg_env
 
 LOG = logging.getLogger()
 
@@ -100,7 +101,7 @@ def run_php_server(server_address, base_dir):
 def run_osm2pgsql(options):
     """ Run osm2pgsql with the given options.
     """
-    env = os.environ
+    env = get_pg_env(options['dsn'])
     cmd = [options['osm2pgsql'],
            '--hstore', '--latlon', '--slim',
            '--with-forward-dependencies', 'false',
@@ -116,17 +117,6 @@ def run_osm2pgsql(options):
     if options['flatnode_file']:
         cmd.extend(('--flat-nodes', options['flatnode_file']))
 
-    dsn = parse_dsn(options['dsn'])
-    if 'password' in dsn:
-        env['PGPASSWORD'] = dsn['password']
-    if 'dbname' in dsn:
-        cmd.extend(('-d', dsn['dbname']))
-    if 'user' in dsn:
-        cmd.extend(('--username', dsn['user']))
-    for param in ('host', 'port'):
-        if param in dsn:
-            cmd.extend(('--' + param, dsn[param]))
-
     if options.get('disable_jit', False):
         env['PGOPTIONS'] = '-c jit=off -c max_parallel_workers_per_gather=0'
 
index 846ef864db853f102d3c02667ba28272eb5e2dc7..fd5da754284be7408c3f1711d1589601adc9b5b8 100644 (file)
@@ -3,7 +3,7 @@ Tests for specialised conenction and cursor classes.
 """
 import pytest
 
-from nominatim.db.connection import connect
+from nominatim.db.connection import connect, get_pg_env
 
 @pytest.fixture
 def db(temp_db):
@@ -48,3 +48,24 @@ def test_cursor_scalar_many_rows(db):
     with db.cursor() as cur:
         with pytest.raises(RuntimeError):
             cur.scalar('SELECT * FROM pg_tables')
+
+
+def test_get_pg_env_add_variable(monkeypatch):
+    monkeypatch.delenv('PGPASSWORD', raising=False)
+    env = get_pg_env('user=fooF')
+
+    assert env['PGUSER'] == 'fooF'
+    assert 'PGPASSWORD' not in env
+
+
+def test_get_pg_env_overwrite_variable(monkeypatch):
+    monkeypatch.setenv('PGUSER', 'some default')
+    env = get_pg_env('user=overwriter')
+
+    assert env['PGUSER'] == 'overwriter'
+
+
+def test_get_pg_env_ignore_unknown():
+    env = get_pg_env('tty=stuff', base_env={})
+
+    assert env == {}