]> git.openstreetmap.org Git - nominatim.git/commitdiff
convert connect() into a context manager
authorSarah Hoffmann <lonvia@denofr.de>
Tue, 23 Feb 2021 09:11:21 +0000 (10:11 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Thu, 25 Feb 2021 17:42:54 +0000 (18:42 +0100)
12 files changed:
nominatim/clicmd/admin.py
nominatim/clicmd/freeze.py
nominatim/clicmd/index.py
nominatim/clicmd/refresh.py
nominatim/clicmd/replication.py
nominatim/db/connection.py
nominatim/tools/check_database.py
test/python/conftest.py
test/python/test_db_connection.py
test/python/test_tools_admin.py
test/python/test_tools_check_database.py
test/python/test_tools_refresh_create_functions.py

index e58635756b2de0e2649b84877e69f4d558a37b68..fd9382ebb4ed5704f1d538c4c9527ba673d103d8 100644 (file)
@@ -54,9 +54,8 @@ class AdminFuncs:
         if args.analyse_indexing:
             LOG.warning('Analysing performance of indexing function')
             from ..tools import admin
-            conn = connect(args.config.get_libpq_dsn())
-            admin.analyse_indexing(conn, osm_id=args.osm_id, place_id=args.place_id)
-            conn.close()
+            with connect(args.config.get_libpq_dsn()) as conn:
+                admin.analyse_indexing(conn, osm_id=args.osm_id, place_id=args.place_id)
 
         return 0
 
index 8bca04b995f1240782ab27c92e4326db2f1ad9f3..1b311e97f2f7ee1081c4d1706501b9591c08704c 100644 (file)
@@ -29,9 +29,8 @@ class SetupFreeze:
     def run(args):
         from ..tools import freeze
 
-        conn = connect(args.config.get_libpq_dsn())
-        freeze.drop_update_tables(conn)
+        with connect(args.config.get_libpq_dsn()) as conn:
+            freeze.drop_update_tables(conn)
         freeze.drop_flatnode_file(args.config.FLATNODE_FILE)
-        conn.close()
 
         return 0
index ca3f9deedf4e0546e211679336474a1615ef524d..96a69396e42027afcc54be56793c6a3aaf7b2725 100644 (file)
@@ -51,8 +51,7 @@ class UpdateIndex:
 
         if not args.no_boundaries and not args.boundaries_only \
            and args.minrank == 0 and args.maxrank == 30:
-            conn = connect(args.config.get_libpq_dsn())
-            status.set_indexed(conn, True)
-            conn.close()
+            with connect(args.config.get_libpq_dsn()) as conn:
+                status.set_indexed(conn, True)
 
         return 0
index ffbe628b8ff0ec6cbf16d2c16d871ff14392cb0e..f68e185ac2281ae98bf9053211c7b7efb3cd7d2b 100644 (file)
@@ -50,29 +50,25 @@ class UpdateRefresh:
 
         if args.postcodes:
             LOG.warning("Update postcodes centroid")
-            conn = connect(args.config.get_libpq_dsn())
-            refresh.update_postcodes(conn, args.sqllib_dir)
-            conn.close()
+            with connect(args.config.get_libpq_dsn()) as conn:
+                refresh.update_postcodes(conn, args.sqllib_dir)
 
         if args.word_counts:
             LOG.warning('Recompute frequency of full-word search terms')
-            conn = connect(args.config.get_libpq_dsn())
-            refresh.recompute_word_counts(conn, args.sqllib_dir)
-            conn.close()
+            with connect(args.config.get_libpq_dsn()) as conn:
+                refresh.recompute_word_counts(conn, args.sqllib_dir)
 
         if args.address_levels:
             cfg = Path(args.config.ADDRESS_LEVEL_CONFIG)
             LOG.warning('Updating address levels from %s', cfg)
-            conn = connect(args.config.get_libpq_dsn())
-            refresh.load_address_levels_from_file(conn, cfg)
-            conn.close()
+            with connect(args.config.get_libpq_dsn()) as conn:
+                refresh.load_address_levels_from_file(conn, cfg)
 
         if args.functions:
             LOG.warning('Create functions')
-            conn = connect(args.config.get_libpq_dsn())
-            refresh.create_functions(conn, args.config, args.sqllib_dir,
-                                     args.diffs, args.enable_debug_statements)
-            conn.close()
+            with connect(args.config.get_libpq_dsn()) as conn:
+                refresh.create_functions(conn, args.config, args.sqllib_dir,
+                                         args.diffs, args.enable_debug_statements)
 
         if args.wiki_data:
             run_legacy_script('setup.php', '--import-wikipedia-articles',
index 2a19e6cdad8c74c594be404bb9507c119187eaaa..e766be2be7848ed9e34c6f095942cdfdd3b33932 100644 (file)
@@ -62,13 +62,12 @@ class UpdateReplication:
         from ..tools import replication, refresh
 
         LOG.warning("Initialising replication updates")
-        conn = connect(args.config.get_libpq_dsn())
-        replication.init_replication(conn, base_url=args.config.REPLICATION_URL)
-        if args.update_functions:
-            LOG.warning("Create functions")
-            refresh.create_functions(conn, args.config, args.sqllib_dir,
-                                     True, False)
-        conn.close()
+        with connect(args.config.get_libpq_dsn()) as conn:
+            replication.init_replication(conn, base_url=args.config.REPLICATION_URL)
+            if args.update_functions:
+                LOG.warning("Create functions")
+                refresh.create_functions(conn, args.config, args.sqllib_dir,
+                                         True, False)
         return 0
 
 
@@ -76,10 +75,8 @@ class UpdateReplication:
     def _check_for_updates(args):
         from ..tools import replication
 
-        conn = connect(args.config.get_libpq_dsn())
-        ret = replication.check_for_updates(conn, base_url=args.config.REPLICATION_URL)
-        conn.close()
-        return ret
+        with connect(args.config.get_libpq_dsn()) as conn:
+            return replication.check_for_updates(conn, base_url=args.config.REPLICATION_URL)
 
     @staticmethod
     def _report_update(batchdate, start_import, start_index):
@@ -122,13 +119,12 @@ class UpdateReplication:
             recheck_interval = args.config.get_int('REPLICATION_RECHECK_INTERVAL')
 
         while True:
-            conn = connect(args.config.get_libpq_dsn())
-            start = dt.datetime.now(dt.timezone.utc)
-            state = replication.update(conn, params)
-            if state is not replication.UpdateState.NO_CHANGES:
-                status.log_status(conn, start, 'import')
-            batchdate, _, _ = status.get_status(conn)
-            conn.close()
+            with connect(args.config.get_libpq_dsn()) as conn:
+                start = dt.datetime.now(dt.timezone.utc)
+                state = replication.update(conn, params)
+                if state is not replication.UpdateState.NO_CHANGES:
+                    status.log_status(conn, start, 'import')
+                batchdate, _, _ = status.get_status(conn)
 
             if state is not replication.UpdateState.NO_CHANGES and args.do_index:
                 index_start = dt.datetime.now(dt.timezone.utc)
@@ -137,10 +133,9 @@ class UpdateReplication:
                 indexer.index_boundaries(0, 30)
                 indexer.index_by_rank(0, 30)
 
-                conn = connect(args.config.get_libpq_dsn())
-                status.set_indexed(conn, True)
-                status.log_status(conn, index_start, 'index')
-                conn.close()
+                with connect(args.config.get_libpq_dsn()) as conn:
+                    status.set_indexed(conn, True)
+                    status.log_status(conn, index_start, 'index')
             else:
                 index_start = None
 
index b941f46f56c63c74444506dec457bf09a8999c07..6bd81a2ff53d2025e164169a12ddd9bbaa88edab 100644 (file)
@@ -1,6 +1,7 @@
 """
 Specialised connection and cursor functions.
 """
+import contextlib
 import logging
 
 import psycopg2
@@ -84,9 +85,14 @@ class _Connection(psycopg2.extensions.connection):
 
 def connect(dsn):
     """ Open a connection to the database using the specialised connection
-        factory.
+        factory. The returned object may be used in conjunction with 'with'.
+        When used outside a context manager, use the `connection` attribute
+        to get the connection.
     """
     try:
-        return psycopg2.connect(dsn, connection_factory=_Connection)
+        conn = psycopg2.connect(dsn, connection_factory=_Connection)
+        ctxmgr = contextlib.closing(conn)
+        ctxmgr.connection = conn
+        return ctxmgr
     except psycopg2.OperationalError as err:
         raise UsageError("Cannot connect to database: {}".format(err)) from err
index 7b8da200b5a8598416bb9ef78f76dc3eda8347a5..d8ab08ccd7a8593bdf149235a9e91c36482bebc7 100644 (file)
@@ -60,7 +60,7 @@ def check_database(config):
     """ Run a number of checks on the database and return the status.
     """
     try:
-        conn = connect(config.get_libpq_dsn())
+        conn = connect(config.get_libpq_dsn()).connection
     except UsageError as err:
         conn = _BadConnection(str(err))
 
index 72a56dcff581bb123ee29855589352cf3eeee47b..0e0e808cb51f953e871e199b83e7732fb1caba92 100644 (file)
@@ -85,9 +85,8 @@ def temp_db_with_extensions(temp_db):
 def temp_db_conn(temp_db):
     """ Connection to the test database.
     """
-    conn = connection.connect('dbname=' + temp_db)
-    yield conn
-    conn.close()
+    with connection.connect('dbname=' + temp_db) as conn:
+        yield conn
 
 
 @pytest.fixture
index 11ad691aa64e3ab7bef6b182936a2ad57529d670..846ef864db853f102d3c02667ba28272eb5e2dc7 100644 (file)
@@ -7,9 +7,8 @@ from nominatim.db.connection import connect
 
 @pytest.fixture
 def db(temp_db):
-    conn = connect('dbname=' + temp_db)
-    yield conn
-    conn.close()
+    with connect('dbname=' + temp_db) as conn:
+        yield conn
 
 
 def test_connection_table_exists(db, temp_db_cursor):
index a40a17dbb950f2f9fbf1f82d19bbaa164cfd81d8..36c7d6ff0365ecc9e43a4c0a5ecd3efe6e994276 100644 (file)
@@ -9,9 +9,8 @@ from nominatim.tools import admin
 
 @pytest.fixture
 def db(temp_db, placex_table):
-    conn = connect('dbname=' + temp_db)
-    yield conn
-    conn.close()
+    with connect('dbname=' + temp_db) as conn:
+        yield conn
 
 def test_analyse_indexing_no_objects(db):
     with pytest.raises(UsageError):
index 0b5c23a6d05e0e00d1492b76019122686abb0d56..3787c3be16e9351fd4ccedb5a7d98e698a3b358b 100644 (file)
@@ -10,6 +10,10 @@ def test_check_database_unknown_db(def_config, monkeypatch):
     assert 1 == chkdb.check_database(def_config)
 
 
+def test_check_database_fatal_test(def_config, temp_db):
+    assert 1 == chkdb.check_database(def_config)
+
+
 def test_check_conection_good(temp_db_conn, def_config):
     assert chkdb.check_connection(temp_db_conn, def_config) == chkdb.CheckState.OK
 
index d219d74864f4af628e40c9a66738213ecd854ad4..ac2f221126f79c4da85249a8e5facb9c7bcd6f68 100644 (file)
@@ -11,9 +11,8 @@ SQL_DIR = (Path(__file__) / '..' / '..' / '..' / 'lib-sql').resolve()
 
 @pytest.fixture
 def db(temp_db):
-    conn = connect('dbname=' + temp_db)
-    yield conn
-    conn.close()
+    with connect('dbname=' + temp_db) as conn:
+        yield conn
 
 @pytest.fixture
 def db_with_tables(db):