]> git.openstreetmap.org Git - nominatim.git/commitdiff
switch to threading
authorSarah Hoffmann <lonvia@denofr.de>
Sun, 19 Jan 2020 20:56:20 +0000 (21:56 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Fri, 24 Jan 2020 21:06:30 +0000 (22:06 +0100)
nominatim/nominatim.py

index 6b25cf5c0623217c7cddebf3de468f2b0ec69658..619070604e1be63af9e9e881bbb910117bf1d6f4 100644 (file)
@@ -30,7 +30,8 @@ import getpass
 from datetime import datetime
 import psycopg2
 from psycopg2.extras import wait_select
-import select
+import threading
+from queue import Queue
 
 log = logging.getLogger()
 
@@ -39,53 +40,44 @@ def make_connection(options, asynchronous=False):
                             password=options.password, host=options.host,
                             port=options.port, async_=asynchronous)
 
-class IndexingThread(object):
+class IndexingThread(threading.Thread):
 
-    def __init__(self, thread_num, options):
-        log.debug("Creating thread {}".format(thread_num))
-        self.thread_num = thread_num
-        self.conn = make_connection(options, asynchronous=True)
-        self.wait()
+    def __init__(self, queue, barrier, options):
+        super().__init__()
+        self.conn = make_connection(options)
+        self.conn.autocommit = True
 
         self.cursor = self.conn.cursor()
         self.perform("SET lc_messages TO 'C'")
-        self.wait()
         self.perform(InterpolationRunner.prepare())
-        self.wait()
         self.perform(RankRunner.prepare())
-        self.wait()
-
-        self.current_query = None
-        self.current_params = None
+        self.queue = queue
+        self.barrier = barrier
 
-    def wait(self):
-        wait_select(self.conn)
-        self.current_query = None
+    def run(self):
+        sql = None
+        while True:
+            item = self.queue.get()
+            if item is None:
+                break
+            elif isinstance(item, str):
+                sql = item
+                self.barrier.wait()
+            else:
+                self.perform(sql, (item,))
 
     def perform(self, sql, args=None):
-        self.current_query = sql
-        self.current_params = args
-        self.cursor.execute(sql, args)
-
-    def fileno(self):
-        return self.conn.fileno()
-
-    def is_done(self):
-        if self.current_query is None:
-            return True
-
-        try:
-            if self.conn.poll() == psycopg2.extensions.POLL_OK:
-                self.current_query = None
-                return True
-        except psycopg2.extensions.TransactionRollbackError as e:
-            if e.pgcode is None:
-                raise RuntimeError("Postgres exception has no error code")
-            if e.pgcode == '40P01':
-                log.info("Deadlock detected, retry.")
-                self.cursor.execute(self.current_query, self.current_params)
-            else:
-                raise
+        while True:
+            try:
+                self.cursor.execute(sql, args)
+                return
+            except psycopg2.extensions.TransactionRollbackError as e:
+                if e.pgcode is None:
+                    raise RuntimeError("Postgres exception has no error code")
+                if e.pgcode == '40P01':
+                    log.info("Deadlock detected, retry.")
+                else:
+                    raise
 
 
 
@@ -96,11 +88,12 @@ class Indexer(object):
         self.conn = make_connection(options)
 
         self.threads = []
-        self.poll = select.poll()
+        self.queue = Queue(maxsize=1000)
+        self.barrier = threading.Barrier(options.threads + 1)
         for i in range(options.threads):
-            t = IndexingThread(i, options)
+            t = IndexingThread(self.queue, self.barrier, options)
             self.threads.append(t)
-            self.poll.register(t, select.EPOLLIN)
+            t.start()
 
     def run(self):
         log.info("Starting indexing rank ({} to {}) using {} threads".format(
@@ -114,9 +107,20 @@ class Indexer(object):
             self.index(InterpolationRunner())
             self.index(RankRunner(30))
 
+        self.queue_all(None)
+        for t in self.threads:
+            t.join()
+
+    def queue_all(self, item):
+        for t in self.threads:
+            self.queue.put(item)
+
     def index(self, obj):
         log.info("Starting {}".format(obj.name()))
 
+        self.queue_all(obj.sql_index_place())
+        self.barrier.wait()
+
         cur = self.conn.cursor(name="main")
         cur.execute(obj.sql_index_sectors())
 
@@ -127,7 +131,6 @@ class Indexer(object):
 
         cur.scroll(0, mode='absolute')
 
-        next_thread = self.find_free_thread()
         done_tuples = 0
         rank_start_time = datetime.now()
         for r in cur:
@@ -146,9 +149,8 @@ class Indexer(object):
             for place in pcur:
                 place_id = place[0]
                 log.debug("Processing place {}".format(place_id))
-                thread = next(next_thread)
 
-                thread.perform(obj.sql_index_place(), (place_id,))
+                self.queue.put(place_id)
                 done_tuples += 1
 
             pcur.close()
@@ -158,8 +160,8 @@ class Indexer(object):
 
         cur.close()
 
-        for t in self.threads:
-            t.wait()
+        self.queue_all("")
+        self.barrier.wait()
 
         rank_end_time = datetime.now()
         diff_seconds = (rank_end_time-rank_start_time).total_seconds()
@@ -168,22 +170,6 @@ class Indexer(object):
                  done_tuples, int(diff_seconds),
                  done_tuples/diff_seconds, obj.name()))
 
-    def find_free_thread(self):
-        thread_lookup = { t.fileno() : t for t in self.threads}
-
-        done_fids = [ t.fileno() for t in self.threads ]
-
-        while True:
-            for fid in done_fids:
-                thread = thread_lookup[fid]
-                if thread.is_done():
-                    yield thread
-                else:
-                    print("not good", fid)
-
-            done_fids = [ x[0] for x in self.poll.poll()]
-
-        assert(False, "Unreachable code")
 
 class RankRunner(object):