]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/nominatim.py
regularly close connection while indexing
[nominatim.git] / nominatim / nominatim.py
1 #! /usr/bin/env python3
2 #-----------------------------------------------------------------------------
3 # nominatim - [description]
4 #-----------------------------------------------------------------------------
5 #
6 # Indexing tool for the Nominatim database.
7 #
8 # Based on C version by Brian Quinion
9 #
10 # This program is free software; you can redistribute it and/or
11 # modify it under the terms of the GNU General Public License
12 # as published by the Free Software Foundation; either version 2
13 # of the License, or (at your option) any later version.
14 #
15 # This program is distributed in the hope that it will be useful,
16 # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 # GNU General Public License for more details.
19 #
20 # You should have received a copy of the GNU General Public License
21 # along with this program; if not, write to the Free Software
22 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
23 #-----------------------------------------------------------------------------
24
25 from argparse import ArgumentParser, RawDescriptionHelpFormatter, ArgumentTypeError
26 import logging
27 import sys
28 import re
29 import getpass
30 from datetime import datetime
31 import psycopg2
32 from psycopg2.extras import wait_select
33 import select
34
35 log = logging.getLogger()
36
37 def make_connection(options, asynchronous=False):
38     return psycopg2.connect(dbname=options.dbname, user=options.user,
39                             password=options.password, host=options.host,
40                             port=options.port, async_=asynchronous)
41
42
43 class RankRunner(object):
44     """ Returns SQL commands for indexing one rank within the placex table.
45     """
46
47     def __init__(self, rank):
48         self.rank = rank
49
50     def name(self):
51         return "rank {}".format(self.rank)
52
53     def sql_index_sectors(self):
54         return """SELECT geometry_sector, count(*) FROM placex
55                   WHERE rank_search = {} and indexed_status > 0
56                   GROUP BY geometry_sector
57                   ORDER BY geometry_sector""".format(self.rank)
58
59     def sql_nosector_places(self):
60         return """SELECT place_id FROM placex
61                   WHERE indexed_status > 0 and rank_search = {}
62                   ORDER BY geometry_sector""".format(self.rank)
63
64     def sql_sector_places(self):
65         return """SELECT place_id FROM placex
66                   WHERE indexed_status > 0 and rank_search = {}
67                         and geometry_sector = %s""".format(self.rank)
68
69     def sql_index_place(self):
70         return "UPDATE placex SET indexed_status = 0 WHERE place_id = %s"
71
72
73 class InterpolationRunner(object):
74     """ Returns SQL commands for indexing the address interpolation table
75         location_property_osmline.
76     """
77
78     def name(self):
79         return "interpolation lines (location_property_osmline)"
80
81     def sql_index_sectors(self):
82         return """SELECT geometry_sector, count(*) FROM location_property_osmline
83                   WHERE indexed_status > 0
84                   GROUP BY geometry_sector
85                   ORDER BY geometry_sector"""
86
87     def sql_nosector_places(self):
88         return """SELECT place_id FROM location_property_osmline
89                   WHERE indexed_status > 0
90                   ORDER BY geometry_sector"""
91
92     def sql_sector_places(self):
93         return """SELECT place_id FROM location_property_osmline
94                   WHERE indexed_status > 0 and geometry_sector = %s
95                   ORDER BY geometry_sector"""
96
97     def sql_index_place(self):
98         return """UPDATE location_property_osmline
99                   SET indexed_status = 0 WHERE place_id = %s"""
100
101
102 class DBConnection(object):
103     """ A single non-blocking database connection.
104     """
105
106     def __init__(self, options):
107         self.current_query = None
108         self.current_params = None
109
110         self.conn = None
111         self.connect()
112
113     def connect(self):
114         if self.conn is not None:
115             self.cursor.close()
116             self.conn.close()
117
118         self.conn = make_connection(options, asynchronous=True)
119         self.wait()
120
121         self.cursor = self.conn.cursor()
122
123     def wait(self):
124         """ Block until any pending operation is done.
125         """
126         wait_select(self.conn)
127         self.current_query = None
128
129     def perform(self, sql, args=None):
130         """ Send SQL query to the server. Returns immediately without
131             blocking.
132         """
133         self.current_query = sql
134         self.current_params = args
135         self.cursor.execute(sql, args)
136
137     def fileno(self):
138         """ File descriptor to wait for. (Makes this class select()able.)
139         """
140         return self.conn.fileno()
141
142     def is_done(self):
143         """ Check if the connection is available for a new query.
144
145             Also checks if the previous query has run into a deadlock.
146             If so, then the previous query is repeated.
147         """
148         if self.current_query is None:
149             return True
150
151         try:
152             if self.conn.poll() == psycopg2.extensions.POLL_OK:
153                 self.current_query = None
154                 return True
155         except psycopg2.extensions.TransactionRollbackError as e:
156             if e.pgcode == '40P01':
157                 log.info("Deadlock detected (params = {}), retry.".format(self.current_params))
158                 self.cursor.execute(self.current_query, self.current_params)
159             else:
160                 raise
161
162         return False
163
164
165 class Indexer(object):
166     """ Main indexing routine.
167     """
168
169     def __init__(self, options):
170         self.minrank = max(0, options.minrank)
171         self.maxrank = min(30, options.maxrank)
172         self.conn = make_connection(options)
173         self.threads = [DBConnection(options) for i in range(options.threads)]
174
175     def run(self):
176         """ Run indexing over the entire database.
177         """
178         log.warning("Starting indexing rank ({} to {}) using {} threads".format(
179                  self.minrank, self.maxrank, len(self.threads)))
180
181         for rank in range(self.minrank, self.maxrank):
182             self.index(RankRunner(rank))
183
184         if self.maxrank == 30:
185             self.index(InterpolationRunner())
186
187         self.index(RankRunner(self.maxrank))
188
189     def index(self, obj):
190         """ Index a single rank or table. `obj` describes the SQL to use
191             for indexing.
192         """
193         log.warning("Starting {}".format(obj.name()))
194
195         cur = self.conn.cursor(name='main')
196         cur.execute(obj.sql_index_sectors())
197
198         total_tuples = 0
199         for r in cur:
200             total_tuples += r[1]
201         log.debug("Total number of rows; {}".format(total_tuples))
202
203         cur.scroll(0, mode='absolute')
204
205         next_thread = self.find_free_thread()
206         done_tuples = 0
207         rank_start_time = datetime.now()
208
209         sector_sql = obj.sql_sector_places()
210         index_sql = obj.sql_index_place()
211         min_grouped_tuples = total_tuples - len(self.threads) * 1000
212
213         next_info = 100 if log.isEnabledFor(logging.INFO) else total_tuples + 1
214
215         for r in cur:
216             sector = r[0]
217
218             # Should we do the remaining ones together?
219             do_all = done_tuples > min_grouped_tuples
220
221             pcur = self.conn.cursor(name='places')
222
223             if do_all:
224                 pcur.execute(obj.sql_nosector_places())
225             else:
226                 pcur.execute(sector_sql, (sector, ))
227
228             for place in pcur:
229                 place_id = place[0]
230                 log.debug("Processing place {}".format(place_id))
231                 thread = next(next_thread)
232
233                 thread.perform(index_sql, (place_id,))
234                 done_tuples += 1
235
236                 if done_tuples >= next_info:
237                     now = datetime.now()
238                     done_time = (now - rank_start_time).total_seconds()
239                     tuples_per_sec = done_tuples / done_time
240                     log.info("Done {} in {} @ {:.3f} per second - {} ETA (seconds): {:.2f}"
241                            .format(done_tuples, int(done_time),
242                                    tuples_per_sec, obj.name(),
243                                    (total_tuples - done_tuples)/tuples_per_sec))
244                     next_info += int(tuples_per_sec)
245
246             pcur.close()
247
248             if do_all:
249                 break
250
251         cur.close()
252
253         for t in self.threads:
254             t.wait()
255
256         rank_end_time = datetime.now()
257         diff_seconds = (rank_end_time-rank_start_time).total_seconds()
258
259         log.warning("Done {}/{} in {} @ {:.3f} per second - FINISHED {}\n".format(
260                  done_tuples, total_tuples, int(diff_seconds),
261                  done_tuples/diff_seconds, obj.name()))
262
263     def find_free_thread(self):
264         """ Generator that returns the next connection that is free for
265             sending a query.
266         """
267         ready = self.threads
268         command_stat = 0
269
270         while True:
271             for thread in ready:
272                 if thread.is_done():
273                     command_stat += 1
274                     yield thread
275
276             # refresh the connections occasionaly to avoid potential
277             # memory leaks in Postgresql.
278             if command_stat > 100000:
279                 for t in self.threads:
280                     while not t.is_done():
281                         wait_select(t.conn)
282                     t.connect()
283                 command_stat = 0
284                 ready = self.threads
285             else:
286                 ready, _, _ = select.select(self.threads, [], [])
287
288         assert(False, "Unreachable code")
289
290
291 def nominatim_arg_parser():
292     """ Setup the command-line parser for the tool.
293     """
294     def h(s):
295         return re.sub("\s\s+" , " ", s)
296
297     p = ArgumentParser(description="Indexing tool for Nominatim.",
298                        formatter_class=RawDescriptionHelpFormatter)
299
300     p.add_argument('-d', '--database',
301                    dest='dbname', action='store', default='nominatim',
302                    help='Name of the PostgreSQL database to connect to.')
303     p.add_argument('-U', '--username',
304                    dest='user', action='store',
305                    help='PostgreSQL user name.')
306     p.add_argument('-W', '--password',
307                    dest='password_prompt', action='store_true',
308                    help='Force password prompt.')
309     p.add_argument('-H', '--host',
310                    dest='host', action='store',
311                    help='PostgreSQL server hostname or socket location.')
312     p.add_argument('-P', '--port',
313                    dest='port', action='store',
314                    help='PostgreSQL server port')
315     p.add_argument('-r', '--minrank',
316                    dest='minrank', type=int, metavar='RANK', default=0,
317                    help='Minimum/starting rank.')
318     p.add_argument('-R', '--maxrank',
319                    dest='maxrank', type=int, metavar='RANK', default=30,
320                    help='Maximum/finishing rank.')
321     p.add_argument('-t', '--threads',
322                    dest='threads', type=int, metavar='NUM', default=1,
323                    help='Number of threads to create for indexing.')
324     p.add_argument('-v', '--verbose',
325                    dest='loglevel', action='count', default=0,
326                    help='Increase verbosity')
327
328     return p
329
330 if __name__ == '__main__':
331     logging.basicConfig(stream=sys.stderr, format='%(levelname)s: %(message)s')
332
333     options = nominatim_arg_parser().parse_args(sys.argv[1:])
334
335     log.setLevel(max(3 - options.loglevel, 0) * 10)
336
337     options.password = None
338     if options.password_prompt:
339         password = getpass.getpass("Database password: ")
340         options.password = password
341
342     Indexer(options).run()