]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/indexer/indexer.py
fetch place info asynchronously
[nominatim.git] / nominatim / indexer / indexer.py
1 """
2 Main work horse for indexing (computing addresses) the database.
3 """
4 import logging
5 import select
6 import time
7
8 import psycopg2.extras
9
10 from nominatim.indexer.progress import ProgressLogger
11 from nominatim.indexer import runners
12 from nominatim.db.async_connection import DBConnection
13 from nominatim.db.connection import connect
14
15 LOG = logging.getLogger()
16
17 class WorkerPool:
18     """ A pool of asynchronous database connections.
19
20         The pool may be used as a context manager.
21     """
22     REOPEN_CONNECTIONS_AFTER = 100000
23
24     def __init__(self, dsn, pool_size):
25         self.threads = [DBConnection(dsn) for _ in range(pool_size)]
26         self.free_workers = self._yield_free_worker()
27
28
29     def finish_all(self):
30         """ Wait for all connection to finish.
31         """
32         for thread in self.threads:
33             while not thread.is_done():
34                 thread.wait()
35
36         self.free_workers = self._yield_free_worker()
37
38     def close(self):
39         """ Close all connections and clear the pool.
40         """
41         for thread in self.threads:
42             thread.close()
43         self.threads = []
44         self.free_workers = None
45
46
47     def next_free_worker(self):
48         """ Get the next free connection.
49         """
50         return next(self.free_workers)
51
52
53     def _yield_free_worker(self):
54         ready = self.threads
55         command_stat = 0
56         while True:
57             for thread in ready:
58                 if thread.is_done():
59                     command_stat += 1
60                     yield thread
61
62             if command_stat > self.REOPEN_CONNECTIONS_AFTER:
63                 for thread in self.threads:
64                     while not thread.is_done():
65                         thread.wait()
66                     thread.connect()
67                 ready = self.threads
68                 command_stat = 0
69             else:
70                 _, ready, _ = select.select([], self.threads, [])
71
72
73     def __enter__(self):
74         return self
75
76
77     def __exit__(self, exc_type, exc_value, traceback):
78         self.close()
79
80
81 class Indexer:
82     """ Main indexing routine.
83     """
84
85     def __init__(self, dsn, tokenizer, num_threads):
86         self.dsn = dsn
87         self.tokenizer = tokenizer
88         self.num_threads = num_threads
89
90
91     def index_full(self, analyse=True):
92         """ Index the complete database. This will first index boudnaries
93             followed by all other objects. When `analyse` is True, then the
94             database will be analysed at the appropriate places to
95             ensure that database statistics are updated.
96         """
97         with connect(self.dsn) as conn:
98             conn.autocommit = True
99
100             if analyse:
101                 def _analyze():
102                     with conn.cursor() as cur:
103                         cur.execute('ANALYZE')
104             else:
105                 def _analyze():
106                     pass
107
108             self.index_by_rank(0, 4)
109             _analyze()
110
111             self.index_boundaries(0, 30)
112             _analyze()
113
114             self.index_by_rank(5, 25)
115             _analyze()
116
117             self.index_by_rank(26, 30)
118             _analyze()
119
120             self.index_postcodes()
121             _analyze()
122
123
124     def index_boundaries(self, minrank, maxrank):
125         """ Index only administrative boundaries within the given rank range.
126         """
127         LOG.warning("Starting indexing boundaries using %s threads",
128                     self.num_threads)
129
130         with self.tokenizer.name_analyzer() as analyzer:
131             for rank in range(max(minrank, 4), min(maxrank, 26)):
132                 self._index(runners.BoundaryRunner(rank, analyzer))
133
134     def index_by_rank(self, minrank, maxrank):
135         """ Index all entries of placex in the given rank range (inclusive)
136             in order of their address rank.
137
138             When rank 30 is requested then also interpolations and
139             places with address rank 0 will be indexed.
140         """
141         maxrank = min(maxrank, 30)
142         LOG.warning("Starting indexing rank (%i to %i) using %i threads",
143                     minrank, maxrank, self.num_threads)
144
145         with self.tokenizer.name_analyzer() as analyzer:
146             for rank in range(max(1, minrank), maxrank):
147                 self._index(runners.RankRunner(rank, analyzer))
148
149             if maxrank == 30:
150                 self._index(runners.RankRunner(0, analyzer))
151                 self._index(runners.InterpolationRunner(analyzer), 20)
152                 self._index(runners.RankRunner(30, analyzer), 20)
153             else:
154                 self._index(runners.RankRunner(maxrank, analyzer))
155
156
157     def index_postcodes(self):
158         """Index the entries ofthe location_postcode table.
159         """
160         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
161
162         self._index(runners.PostcodeRunner(), 20)
163
164
165     def update_status_table(self):
166         """ Update the status in the status table to 'indexed'.
167         """
168         with connect(self.dsn) as conn:
169             with conn.cursor() as cur:
170                 cur.execute('UPDATE import_status SET indexed = true')
171
172             conn.commit()
173
174     def _index(self, runner, batch=1):
175         """ Index a single rank or table. `runner` describes the SQL to use
176             for indexing. `batch` describes the number of objects that
177             should be processed with a single SQL statement
178         """
179         LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
180
181         with connect(self.dsn) as conn:
182             psycopg2.extras.register_hstore(conn)
183             with conn.cursor() as cur:
184                 total_tuples = cur.scalar(runner.sql_count_objects())
185                 LOG.debug("Total number of rows: %i", total_tuples)
186
187                 # need to fetch those manually because register_hstore cannot
188                 # fetch them on an asynchronous connection below.
189                 hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid")
190                 hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid")
191
192             conn.commit()
193
194             progress = ProgressLogger(runner.name(), total_tuples)
195
196             fetcher_wait = 0
197             pool_wait = 0
198
199             if total_tuples > 0:
200                 with conn.cursor(name='places') as cur:
201                     cur.execute(runner.sql_get_objects())
202
203                     fetcher = DBConnection(self.dsn, cursor_factory=psycopg2.extras.DictCursor)
204                     psycopg2.extras.register_hstore(fetcher.conn,
205                                                     oid=hstore_oid,
206                                                     array_oid=hstore_array_oid)
207
208                     with WorkerPool(self.dsn, self.num_threads) as pool:
209                         places = self._fetch_next_batch(cur, fetcher, runner)
210                         while places is not None:
211                             if not places:
212                                 t0 = time.time()
213                                 fetcher.wait()
214                                 fetcher_wait += time.time() - t0
215                                 places = fetcher.cursor.fetchall()
216
217                             # asynchronously get the next batch
218                             next_places = self._fetch_next_batch(cur, fetcher, runner)
219
220                             # And insert the curent batch
221                             for idx in range(0, len(places), batch):
222                                 t0 = time.time()
223                                 worker = pool.next_free_worker()
224                                 pool_wait += time.time() - t0
225                                 part = places[idx:idx+batch]
226                                 LOG.debug("Processing places: %s", str(part))
227                                 runner.index_places(worker, part)
228                                 progress.add(len(part))
229
230                             places = next_places
231
232                         pool.finish_all()
233
234                     fetcher.wait()
235                     fetcher.close()
236
237                 conn.commit()
238
239         progress.done()
240         LOG.warning("Wait time: fetcher: {}s,  pool: {}s".format(fetcher_wait, pool_wait))
241
242
243     def _fetch_next_batch(self, cur, fetcher, runner):
244         ids = cur.fetchmany(100)
245
246         if not ids:
247             return None
248
249         if not hasattr(runner, 'get_place_details'):
250             return ids
251
252         runner.get_place_details(fetcher, ids)
253         return []