]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/indexer/indexer.py
indexer: fetch ids in batches
[nominatim.git] / nominatim / indexer / indexer.py
1 """
2 Main work horse for indexing (computing addresses) the database.
3 """
4 import logging
5 import select
6
7 import psycopg2.extras
8
9 from nominatim.indexer.progress import ProgressLogger
10 from nominatim.indexer import runners
11 from nominatim.db.async_connection import DBConnection
12 from nominatim.db.connection import connect
13
14 LOG = logging.getLogger()
15
16 class WorkerPool:
17     """ A pool of asynchronous database connections.
18
19         The pool may be used as a context manager.
20     """
21     REOPEN_CONNECTIONS_AFTER = 100000
22
23     def __init__(self, dsn, pool_size):
24         self.threads = [DBConnection(dsn) for _ in range(pool_size)]
25         self.free_workers = self._yield_free_worker()
26
27
28     def finish_all(self):
29         """ Wait for all connection to finish.
30         """
31         for thread in self.threads:
32             while not thread.is_done():
33                 thread.wait()
34
35         self.free_workers = self._yield_free_worker()
36
37     def close(self):
38         """ Close all connections and clear the pool.
39         """
40         for thread in self.threads:
41             thread.close()
42         self.threads = []
43         self.free_workers = None
44
45
46     def next_free_worker(self):
47         """ Get the next free connection.
48         """
49         return next(self.free_workers)
50
51
52     def _yield_free_worker(self):
53         ready = self.threads
54         command_stat = 0
55         while True:
56             for thread in ready:
57                 if thread.is_done():
58                     command_stat += 1
59                     yield thread
60
61             if command_stat > self.REOPEN_CONNECTIONS_AFTER:
62                 for thread in self.threads:
63                     while not thread.is_done():
64                         thread.wait()
65                     thread.connect()
66                 ready = self.threads
67                 command_stat = 0
68             else:
69                 _, ready, _ = select.select([], self.threads, [])
70
71
72     def __enter__(self):
73         return self
74
75
76     def __exit__(self, exc_type, exc_value, traceback):
77         self.close()
78
79
80 class Indexer:
81     """ Main indexing routine.
82     """
83
84     def __init__(self, dsn, tokenizer, num_threads):
85         self.dsn = dsn
86         self.tokenizer = tokenizer
87         self.num_threads = num_threads
88
89
90     def index_full(self, analyse=True):
91         """ Index the complete database. This will first index boudnaries
92             followed by all other objects. When `analyse` is True, then the
93             database will be analysed at the appropriate places to
94             ensure that database statistics are updated.
95         """
96         with connect(self.dsn) as conn:
97             conn.autocommit = True
98
99             if analyse:
100                 def _analyze():
101                     with conn.cursor() as cur:
102                         cur.execute('ANALYZE')
103             else:
104                 def _analyze():
105                     pass
106
107             self.index_by_rank(0, 4)
108             _analyze()
109
110             self.index_boundaries(0, 30)
111             _analyze()
112
113             self.index_by_rank(5, 25)
114             _analyze()
115
116             self.index_by_rank(26, 30)
117             _analyze()
118
119             self.index_postcodes()
120             _analyze()
121
122
123     def index_boundaries(self, minrank, maxrank):
124         """ Index only administrative boundaries within the given rank range.
125         """
126         LOG.warning("Starting indexing boundaries using %s threads",
127                     self.num_threads)
128
129         with self.tokenizer.name_analyzer() as analyzer:
130             for rank in range(max(minrank, 4), min(maxrank, 26)):
131                 self._index(runners.BoundaryRunner(rank, analyzer))
132
133     def index_by_rank(self, minrank, maxrank):
134         """ Index all entries of placex in the given rank range (inclusive)
135             in order of their address rank.
136
137             When rank 30 is requested then also interpolations and
138             places with address rank 0 will be indexed.
139         """
140         maxrank = min(maxrank, 30)
141         LOG.warning("Starting indexing rank (%i to %i) using %i threads",
142                     minrank, maxrank, self.num_threads)
143
144         with self.tokenizer.name_analyzer() as analyzer:
145             for rank in range(max(1, minrank), maxrank):
146                 self._index(runners.RankRunner(rank, analyzer))
147
148             if maxrank == 30:
149                 self._index(runners.RankRunner(0, analyzer))
150                 self._index(runners.InterpolationRunner(analyzer), 20)
151                 self._index(runners.RankRunner(30, analyzer), 20)
152             else:
153                 self._index(runners.RankRunner(maxrank, analyzer))
154
155
156     def index_postcodes(self):
157         """Index the entries ofthe location_postcode table.
158         """
159         LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
160
161         self._index(runners.PostcodeRunner(), 20)
162
163
164     def update_status_table(self):
165         """ Update the status in the status table to 'indexed'.
166         """
167         with connect(self.dsn) as conn:
168             with conn.cursor() as cur:
169                 cur.execute('UPDATE import_status SET indexed = true')
170
171             conn.commit()
172
173     def _index(self, runner, batch=1):
174         """ Index a single rank or table. `runner` describes the SQL to use
175             for indexing. `batch` describes the number of objects that
176             should be processed with a single SQL statement
177         """
178         LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
179
180         with connect(self.dsn) as conn:
181             psycopg2.extras.register_hstore(conn)
182             with conn.cursor() as cur:
183                 total_tuples = cur.scalar(runner.sql_count_objects())
184                 LOG.debug("Total number of rows: %i", total_tuples)
185
186                 hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid")
187                 hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid")
188
189             conn.commit()
190
191             progress = ProgressLogger(runner.name(), total_tuples)
192
193             if total_tuples > 0:
194                 with conn.cursor(name='places') as cur:
195                     cur.execute(runner.sql_get_objects())
196
197                     fetcher = DBConnection(self.dsn)
198                     psycopg2.extras.register_hstore(fetcher.conn,
199                                                     oid=hstore_oid,
200                                                     array_oid=hstore_array_oid)
201
202                     with WorkerPool(self.dsn, self.num_threads) as pool:
203                         places = self._fetch_next_batch(cur, fetcher, runner)
204                         while places is not None:
205                             if not places:
206                                 fetcher.wait()
207                                 places = fetcher.cursor.fetchall()
208
209                             # asynchronously get the next batch
210                             next_places = self._fetch_next_batch(cur, fetcher, runner)
211
212                             # And insert the curent batch
213                             for idx in range(0, len(places), batch):
214                                 worker = pool.next_free_worker()
215                                 part = places[idx:idx+batch]
216                                 LOG.debug("Processing places: %s", str(part))
217                                 runner.index_places(worker, part)
218                                 progress.add(len(part))
219
220                             places = next_places
221
222                         pool.finish_all()
223
224                     fetcher.wait()
225                     fetcher.close()
226
227                 conn.commit()
228
229         progress.done()
230
231
232     def _fetch_next_batch(self, cur, fetcher, runner):
233         ids = cur.fetchmany(1000)
234
235         if not ids:
236             return None
237
238         if not hasattr(runner, 'get_place_details'):
239             return ids
240
241         runner.get_place_details(fetcher, ids)
242         return []