]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/indexer/indexer.py
make index() function private
[nominatim.git] / nominatim / indexer / indexer.py
index fa40334b7851b617f6069c0932c3f8c6b0a310d8..7b826d96182eb69100b84339a6c4df75079bb5fe 100644 (file)
@@ -13,12 +13,6 @@ from nominatim.db.async_connection import DBConnection
 LOG = logging.getLogger()
 
 
-def _analyse_db_if(conn, condition):
-    if condition:
-        with conn.cursor() as cur:
-            cur.execute('ANALYSE')
-
-
 class Indexer:
     """ Main indexing routine.
     """
@@ -51,26 +45,31 @@ class Indexer:
             database will be analysed at the appropriate places to
             ensure that database statistics are updated.
         """
-        conn = psycopg2.connect(self.dsn)
-        conn.autocommit = True
+        with psycopg2.connect(self.dsn) as conn:
+            conn.autocommit = True
+
+            if analyse:
+                def _analyse():
+                    with conn.cursor() as cur:
+                        cur.execute('ANALYSE')
+            else:
+                def _analyse():
+                    pass
 
-        try:
             self.index_by_rank(0, 4)
-            _analyse_db_if(conn, analyse)
+            _analyse()
 
             self.index_boundaries(0, 30)
-            _analyse_db_if(conn, analyse)
+            _analyse()
 
             self.index_by_rank(5, 25)
-            _analyse_db_if(conn, analyse)
+            _analyse()
 
             self.index_by_rank(26, 30)
-            _analyse_db_if(conn, analyse)
+            _analyse()
 
             self.index_postcodes()
-            _analyse_db_if(conn, analyse)
-        finally:
-            conn.close()
+            _analyse()
 
 
     def index_boundaries(self, minrank, maxrank):
@@ -83,7 +82,7 @@ class Indexer:
 
         try:
             for rank in range(max(minrank, 4), min(maxrank, 26)):
-                self.index(runners.BoundaryRunner(rank))
+                self._index(runners.BoundaryRunner(rank))
         finally:
             self._close_connections()
 
@@ -102,14 +101,14 @@ class Indexer:
 
         try:
             for rank in range(max(1, minrank), maxrank):
-                self.index(runners.RankRunner(rank))
+                self._index(runners.RankRunner(rank))
 
             if maxrank == 30:
-                self.index(runners.RankRunner(0))
-                self.index(runners.InterpolationRunner(), 20)
-                self.index(runners.RankRunner(30), 20)
+                self._index(runners.RankRunner(0))
+                self._index(runners.InterpolationRunner(), 20)
+                self._index(runners.RankRunner(30), 20)
             else:
-                self.index(runners.RankRunner(maxrank))
+                self._index(runners.RankRunner(maxrank))
         finally:
             self._close_connections()
 
@@ -122,7 +121,7 @@ class Indexer:
         self._setup_connections()
 
         try:
-            self.index(runners.PostcodeRunner(), 20)
+            self._index(runners.PostcodeRunner(), 20)
         finally:
             self._close_connections()
 
@@ -139,26 +138,26 @@ class Indexer:
         finally:
             conn.close()
 
-    def index(self, obj, batch=1):
-        """ Index a single rank or table. `obj` describes the SQL to use
+    def _index(self, runner, batch=1):
+        """ Index a single rank or table. `runner` describes the SQL to use
             for indexing. `batch` describes the number of objects that
             should be processed with a single SQL statement
         """
-        LOG.warning("Starting %s (using batch size %s)", obj.name(), batch)
+        LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
 
         cur = self.conn.cursor()
-        cur.execute(obj.sql_count_objects())
+        cur.execute(runner.sql_count_objects())
 
         total_tuples = cur.fetchone()[0]
         LOG.debug("Total number of rows: %i", total_tuples)
 
         cur.close()
 
-        progress = ProgressLogger(obj.name(), total_tuples)
+        progress = ProgressLogger(runner.name(), total_tuples)
 
         if total_tuples > 0:
             cur = self.conn.cursor(name='places')
-            cur.execute(obj.sql_get_objects())
+            cur.execute(runner.sql_get_objects())
 
             next_thread = self.find_free_thread()
             while True:
@@ -169,7 +168,7 @@ class Indexer:
                 LOG.debug("Processing places: %s", str(places))
                 thread = next(next_thread)
 
-                thread.perform(obj.sql_index_place(places))
+                thread.perform(runner.sql_index_place(places))
                 progress.add(len(places))
 
             cur.close()