]> git.openstreetmap.org Git - nominatim.git/commitdiff
Revert "switch to threading"
authorSarah Hoffmann <lonvia@denofr.de>
Sun, 19 Jan 2020 20:56:37 +0000 (21:56 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Fri, 24 Jan 2020 21:06:30 +0000 (22:06 +0100)
This reverts commit 8b1c2181be5aa5335c68d36a49cab9c4e2cd8bef.

nominatim/nominatim.py

index 619070604e1be63af9e9e881bbb910117bf1d6f4..6b25cf5c0623217c7cddebf3de468f2b0ec69658 100644 (file)
@@ -30,8 +30,7 @@ import getpass
 from datetime import datetime
 import psycopg2
 from psycopg2.extras import wait_select
-import threading
-from queue import Queue
+import select
 
 log = logging.getLogger()
 
@@ -40,44 +39,53 @@ def make_connection(options, asynchronous=False):
                             password=options.password, host=options.host,
                             port=options.port, async_=asynchronous)
 
-class IndexingThread(threading.Thread):
+class IndexingThread(object):
 
-    def __init__(self, queue, barrier, options):
-        super().__init__()
-        self.conn = make_connection(options)
-        self.conn.autocommit = True
+    def __init__(self, thread_num, options):
+        log.debug("Creating thread {}".format(thread_num))
+        self.thread_num = thread_num
+        self.conn = make_connection(options, asynchronous=True)
+        self.wait()
 
         self.cursor = self.conn.cursor()
         self.perform("SET lc_messages TO 'C'")
+        self.wait()
         self.perform(InterpolationRunner.prepare())
+        self.wait()
         self.perform(RankRunner.prepare())
-        self.queue = queue
-        self.barrier = barrier
+        self.wait()
 
-    def run(self):
-        sql = None
-        while True:
-            item = self.queue.get()
-            if item is None:
-                break
-            elif isinstance(item, str):
-                sql = item
-                self.barrier.wait()
-            else:
-                self.perform(sql, (item,))
+        self.current_query = None
+        self.current_params = None
+
+    def wait(self):
+        wait_select(self.conn)
+        self.current_query = None
 
     def perform(self, sql, args=None):
-        while True:
-            try:
-                self.cursor.execute(sql, args)
-                return
-            except psycopg2.extensions.TransactionRollbackError as e:
-                if e.pgcode is None:
-                    raise RuntimeError("Postgres exception has no error code")
-                if e.pgcode == '40P01':
-                    log.info("Deadlock detected, retry.")
-                else:
-                    raise
+        self.current_query = sql
+        self.current_params = args
+        self.cursor.execute(sql, args)
+
+    def fileno(self):
+        return self.conn.fileno()
+
+    def is_done(self):
+        if self.current_query is None:
+            return True
+
+        try:
+            if self.conn.poll() == psycopg2.extensions.POLL_OK:
+                self.current_query = None
+                return True
+        except psycopg2.extensions.TransactionRollbackError as e:
+            if e.pgcode is None:
+                raise RuntimeError("Postgres exception has no error code")
+            if e.pgcode == '40P01':
+                log.info("Deadlock detected, retry.")
+                self.cursor.execute(self.current_query, self.current_params)
+            else:
+                raise
 
 
 
@@ -88,12 +96,11 @@ class Indexer(object):
         self.conn = make_connection(options)
 
         self.threads = []
-        self.queue = Queue(maxsize=1000)
-        self.barrier = threading.Barrier(options.threads + 1)
+        self.poll = select.poll()
         for i in range(options.threads):
-            t = IndexingThread(self.queue, self.barrier, options)
+            t = IndexingThread(i, options)
             self.threads.append(t)
-            t.start()
+            self.poll.register(t, select.EPOLLIN)
 
     def run(self):
         log.info("Starting indexing rank ({} to {}) using {} threads".format(
@@ -107,20 +114,9 @@ class Indexer(object):
             self.index(InterpolationRunner())
             self.index(RankRunner(30))
 
-        self.queue_all(None)
-        for t in self.threads:
-            t.join()
-
-    def queue_all(self, item):
-        for t in self.threads:
-            self.queue.put(item)
-
     def index(self, obj):
         log.info("Starting {}".format(obj.name()))
 
-        self.queue_all(obj.sql_index_place())
-        self.barrier.wait()
-
         cur = self.conn.cursor(name="main")
         cur.execute(obj.sql_index_sectors())
 
@@ -131,6 +127,7 @@ class Indexer(object):
 
         cur.scroll(0, mode='absolute')
 
+        next_thread = self.find_free_thread()
         done_tuples = 0
         rank_start_time = datetime.now()
         for r in cur:
@@ -149,8 +146,9 @@ class Indexer(object):
             for place in pcur:
                 place_id = place[0]
                 log.debug("Processing place {}".format(place_id))
+                thread = next(next_thread)
 
-                self.queue.put(place_id)
+                thread.perform(obj.sql_index_place(), (place_id,))
                 done_tuples += 1
 
             pcur.close()
@@ -160,8 +158,8 @@ class Indexer(object):
 
         cur.close()
 
-        self.queue_all("")
-        self.barrier.wait()
+        for t in self.threads:
+            t.wait()
 
         rank_end_time = datetime.now()
         diff_seconds = (rank_end_time-rank_start_time).total_seconds()
@@ -170,6 +168,22 @@ class Indexer(object):
                  done_tuples, int(diff_seconds),
                  done_tuples/diff_seconds, obj.name()))
 
+    def find_free_thread(self):
+        thread_lookup = { t.fileno() : t for t in self.threads}
+
+        done_fids = [ t.fileno() for t in self.threads ]
+
+        while True:
+            for fid in done_fids:
+                thread = thread_lookup[fid]
+                if thread.is_done():
+                    yield thread
+                else:
+                    print("not good", fid)
+
+            done_fids = [ x[0] for x in self.poll.poll()]
+
+        assert(False, "Unreachable code")
 
 class RankRunner(object):