]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/indexer/indexer.py
factor out async connection handling into separate class
[nominatim.git] / nominatim / indexer / indexer.py
index ebc9803870f3748cec0e28fa8a8f3943ec62964c..a064b28580168cca6cb26a32bc1e398ea3f43442 100644 (file)
@@ -4,39 +4,83 @@ Main work horse for indexing (computing addresses) the database.
 import logging
 import select
 
 import logging
 import select
 
-import psycopg2
-
 from nominatim.indexer.progress import ProgressLogger
 from nominatim.indexer import runners
 from nominatim.db.async_connection import DBConnection
 from nominatim.indexer.progress import ProgressLogger
 from nominatim.indexer import runners
 from nominatim.db.async_connection import DBConnection
+from nominatim.db.connection import connect
 
 LOG = logging.getLogger()
 
 
 LOG = logging.getLogger()
 
+class WorkerPool:
+    """ A pool of asynchronous database connections.
 
 
-class Indexer:
-    """ Main indexing routine.
+        The pool may be used as a context manager.
     """
     """
+    REOPEN_CONNECTIONS_AFTER = 100000
 
 
-    def __init__(self, dsn, num_threads):
-        self.dsn = dsn
-        self.num_threads = num_threads
-        self.conn = None
-        self.threads = []
+    def __init__(self, dsn, pool_size):
+        self.threads = [DBConnection(dsn) for _ in range(pool_size)]
+        self.free_workers = self._yield_free_worker()
 
 
 
 
-    def _setup_connections(self):
-        self.conn = psycopg2.connect(self.dsn)
-        self.threads = [DBConnection(self.dsn) for _ in range(self.num_threads)]
-
+    def finish_all(self):
+        """ Wait for all connection to finish.
+        """
+        for thread in self.threads:
+            while not thread.is_done():
+                thread.wait()
 
 
-    def _close_connections(self):
-        if self.conn:
-            self.conn.close()
-            self.conn = None
+        self.free_workers = self._yield_free_worker()
 
 
+    def close(self):
+        """ Close all connections and clear the pool.
+        """
         for thread in self.threads:
             thread.close()
         self.threads = []
         for thread in self.threads:
             thread.close()
         self.threads = []
+        self.free_workers = None
+
+
+    def next_free_worker(self):
+        """ Get the next free connection.
+        """
+        return next(self.free_workers)
+
+
+    def _yield_free_worker(self):
+        ready = self.threads
+        command_stat = 0
+        while True:
+            for thread in ready:
+                if thread.is_done():
+                    command_stat += 1
+                    yield thread
+
+            if command_stat > self.REOPEN_CONNECTIONS_AFTER:
+                for thread in self.threads:
+                    while not thread.is_done():
+                        thread.wait()
+                    thread.connect()
+                ready = self.threads
+            else:
+                _, ready, _ = select.select([], self.threads, [])
+
+
+    def __enter__(self):
+        return self
+
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.close()
+
+
+class Indexer:
+    """ Main indexing routine.
+    """
+
+    def __init__(self, dsn, num_threads):
+        self.dsn = dsn
+        self.num_threads = num_threads
 
 
     def index_full(self, analyse=True):
 
 
     def index_full(self, analyse=True):
@@ -45,31 +89,31 @@ class Indexer:
             database will be analysed at the appropriate places to
             ensure that database statistics are updated.
         """
             database will be analysed at the appropriate places to
             ensure that database statistics are updated.
         """
-        with psycopg2.connect(self.dsn) as conn:
+        with connect(self.dsn) as conn:
             conn.autocommit = True
 
             if analyse:
             conn.autocommit = True
 
             if analyse:
-                def _analyse():
+                def _analyze():
                     with conn.cursor() as cur:
                     with conn.cursor() as cur:
-                        cur.execute('ANALYSE')
+                        cur.execute('ANALYZE')
             else:
             else:
-                def _analyse():
+                def _analyze():
                     pass
 
             self.index_by_rank(0, 4)
                     pass
 
             self.index_by_rank(0, 4)
-            _analyse()
+            _analyze()
 
             self.index_boundaries(0, 30)
 
             self.index_boundaries(0, 30)
-            _analyse()
+            _analyze()
 
             self.index_by_rank(5, 25)
 
             self.index_by_rank(5, 25)
-            _analyse()
+            _analyze()
 
             self.index_by_rank(26, 30)
 
             self.index_by_rank(26, 30)
-            _analyse()
+            _analyze()
 
             self.index_postcodes()
 
             self.index_postcodes()
-            _analyse()
+            _analyze()
 
 
     def index_boundaries(self, minrank, maxrank):
 
 
     def index_boundaries(self, minrank, maxrank):
@@ -78,13 +122,8 @@ class Indexer:
         LOG.warning("Starting indexing boundaries using %s threads",
                     self.num_threads)
 
         LOG.warning("Starting indexing boundaries using %s threads",
                     self.num_threads)
 
-        self._setup_connections()
-
-        try:
-            for rank in range(max(minrank, 4), min(maxrank, 26)):
-                self.index(runners.BoundaryRunner(rank))
-        finally:
-            self._close_connections()
+        for rank in range(max(minrank, 4), min(maxrank, 26)):
+            self._index(runners.BoundaryRunner(rank))
 
     def index_by_rank(self, minrank, maxrank):
         """ Index all entries of placex in the given rank range (inclusive)
 
     def index_by_rank(self, minrank, maxrank):
         """ Index all entries of placex in the given rank range (inclusive)
@@ -97,20 +136,15 @@ class Indexer:
         LOG.warning("Starting indexing rank (%i to %i) using %i threads",
                     minrank, maxrank, self.num_threads)
 
         LOG.warning("Starting indexing rank (%i to %i) using %i threads",
                     minrank, maxrank, self.num_threads)
 
-        self._setup_connections()
-
-        try:
-            for rank in range(max(1, minrank), maxrank):
-                self.index(runners.RankRunner(rank))
+        for rank in range(max(1, minrank), maxrank):
+            self._index(runners.RankRunner(rank))
 
 
-            if maxrank == 30:
-                self.index(runners.RankRunner(0))
-                self.index(runners.InterpolationRunner(), 20)
-                self.index(runners.RankRunner(30), 20)
-            else:
-                self.index(runners.RankRunner(maxrank))
-        finally:
-            self._close_connections()
+        if maxrank == 30:
+            self._index(runners.RankRunner(0))
+            self._index(runners.InterpolationRunner(), 20)
+            self._index(runners.RankRunner(30), 20)
+        else:
+            self._index(runners.RankRunner(maxrank))
 
 
     def index_postcodes(self):
 
 
     def index_postcodes(self):
@@ -118,89 +152,52 @@ class Indexer:
         """
         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
 
         """
         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
 
-        self._setup_connections()
+        self._index(runners.PostcodeRunner(), 20)
 
 
-        try:
-            self.index(runners.PostcodeRunner(), 20)
-        finally:
-            self._close_connections()
 
     def update_status_table(self):
         """ Update the status in the status table to 'indexed'.
         """
 
     def update_status_table(self):
         """ Update the status in the status table to 'indexed'.
         """
-        conn = psycopg2.connect(self.dsn)
-
-        try:
+        with connect(self.dsn) as conn:
             with conn.cursor() as cur:
                 cur.execute('UPDATE import_status SET indexed = true')
 
             conn.commit()
             with conn.cursor() as cur:
                 cur.execute('UPDATE import_status SET indexed = true')
 
             conn.commit()
-        finally:
-            conn.close()
 
 
-    def index(self, obj, batch=1):
-        """ Index a single rank or table. `obj` describes the SQL to use
+    def _index(self, runner, batch=1):
+        """ Index a single rank or table. `runner` describes the SQL to use
             for indexing. `batch` describes the number of objects that
             should be processed with a single SQL statement
         """
             for indexing. `batch` describes the number of objects that
             should be processed with a single SQL statement
         """
-        LOG.warning("Starting %s (using batch size %s)", obj.name(), batch)
-
-        cur = self.conn.cursor()
-        cur.execute(obj.sql_count_objects())
+        LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
 
 
-        total_tuples = cur.fetchone()[0]
-        LOG.debug("Total number of rows: %i", total_tuples)
+        with connect(self.dsn) as conn:
+            with conn.cursor() as cur:
+                total_tuples = cur.scalar(runner.sql_count_objects())
+                LOG.debug("Total number of rows: %i", total_tuples)
 
 
-        cur.close()
+            conn.commit()
 
 
-        progress = ProgressLogger(obj.name(), total_tuples)
+            progress = ProgressLogger(runner.name(), total_tuples)
 
 
-        if total_tuples > 0:
-            cur = self.conn.cursor(name='places')
-            cur.execute(obj.sql_get_objects())
+            if total_tuples > 0:
+                with conn.cursor(name='places') as cur:
+                    cur.execute(runner.sql_get_objects())
 
 
-            next_thread = self.find_free_thread()
-            while True:
-                places = [p[0] for p in cur.fetchmany(batch)]
-                if not places:
-                    break
+                    with WorkerPool(self.dsn, self.num_threads) as pool:
+                        while True:
+                            places = [p[0] for p in cur.fetchmany(batch)]
+                            if not places:
+                                break
 
 
-                LOG.debug("Processing places: %s", str(places))
-                thread = next(next_thread)
+                            LOG.debug("Processing places: %s", str(places))
+                            worker = pool.next_free_worker()
 
 
-                thread.perform(obj.sql_index_place(places))
-                progress.add(len(places))
+                            worker.perform(runner.sql_index_place(places))
+                            progress.add(len(places))
 
 
-            cur.close()
+                        pool.finish_all()
 
 
-            for thread in self.threads:
-                thread.wait()
+                conn.commit()
 
         progress.done()
 
         progress.done()
-
-    def find_free_thread(self):
-        """ Generator that returns the next connection that is free for
-            sending a query.
-        """
-        ready = self.threads
-        command_stat = 0
-
-        while True:
-            for thread in ready:
-                if thread.is_done():
-                    command_stat += 1
-                    yield thread
-
-            # refresh the connections occasionaly to avoid potential
-            # memory leaks in Postgresql.
-            if command_stat > 100000:
-                for thread in self.threads:
-                    while not thread.is_done():
-                        thread.wait()
-                    thread.connect()
-                command_stat = 0
-                ready = self.threads
-            else:
-                ready, _, _ = select.select(self.threads, [], [])
-
-        assert False, "Unreachable code"