]> git.openstreetmap.org Git - nominatim.git/blob - nominatim/nominatim.py
use generator for thread choice
[nominatim.git] / nominatim / nominatim.py
1 #! /usr/bin/env python
2 #-----------------------------------------------------------------------------
3 # nominatim - [description]
4 #-----------------------------------------------------------------------------
5 #
6 # Indexing tool for the Nominatim database.
7 #
8 # Based on C version by Brian Quinion
9 #
10 # This program is free software; you can redistribute it and/or
11 # modify it under the terms of the GNU General Public License
12 # as published by the Free Software Foundation; either version 2
13 # of the License, or (at your option) any later version.
14 #
15 # This program is distributed in the hope that it will be useful,
16 # but WITHOUT ANY WARRANTY; without even the implied warranty of
17 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 # GNU General Public License for more details.
19 #
20 # You should have received a copy of the GNU General Public License
21 # along with this program; if not, write to the Free Software
22 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
23 #-----------------------------------------------------------------------------
24
25 from argparse import ArgumentParser, RawDescriptionHelpFormatter, ArgumentTypeError
26 import logging
27 import sys
28 import re
29 import getpass
30 from datetime import datetime
31 import psycopg2
32 from psycopg2.extras import wait_select
33 import select
34
35 log = logging.getLogger()
36
37 def make_connection(options, asynchronous=False):
38     return psycopg2.connect(dbname=options.dbname, user=options.user,
39                             password=options.password, host=options.host,
40                             port=options.port, async_=asynchronous)
41
42 class IndexingThread(object):
43
44     def __init__(self, thread_num, options):
45         log.debug("Creating thread {}".format(thread_num))
46         self.thread_num = thread_num
47         self.conn = make_connection(options, asynchronous=True)
48         self.wait()
49
50         self.cursor = self.conn.cursor()
51         self.perform("SET lc_messages TO 'C'")
52         self.wait()
53         self.perform(InterpolationRunner.prepare())
54         self.wait()
55         self.perform(RankRunner.prepare())
56         self.wait()
57
58         self.current_query = None
59         self.current_params = None
60
61     def wait(self):
62         wait_select(self.conn)
63         self.current_query = None
64
65     def perform(self, sql, args=None):
66         self.current_query = sql
67         self.current_params = args
68         self.cursor.execute(sql, args)
69
70     def fileno(self):
71         return self.conn.fileno()
72
73     def is_done(self):
74         if self.current_query is None:
75             return True
76
77         try:
78             if self.conn.poll() == psycopg2.extensions.POLL_OK:
79                 self.current_query = None
80                 return True
81         except psycopg2.extensions.TransactionRollbackError as e:
82             if e.pgcode is None:
83                 raise RuntimeError("Postgres exception has no error code")
84             if e.pgcode == '40P01':
85                 log.info("Deadlock detected, retry.")
86                 self.cursor.execute(self.current_query, self.current_params)
87             else:
88                 raise
89
90
91
92 class Indexer(object):
93
94     def __init__(self, options):
95         self.options = options
96         self.conn = make_connection(options)
97
98         self.threads = []
99         self.poll = select.poll()
100         for i in range(options.threads):
101             t = IndexingThread(i, options)
102             self.threads.append(t)
103             self.poll.register(t, select.EPOLLIN)
104
105     def run(self):
106         log.info("Starting indexing rank ({} to {}) using {} threads".format(
107                  self.options.minrank, self.options.maxrank,
108                  self.options.threads))
109
110         for rank in range(self.options.minrank, 30):
111             self.index(RankRunner(rank))
112
113         if self.options.maxrank >= 30:
114             self.index(InterpolationRunner())
115             self.index(RankRunner(30))
116
117     def index(self, obj):
118         log.info("Starting {}".format(obj.name()))
119
120         cur = self.conn.cursor(name="main")
121         cur.execute(obj.sql_index_sectors())
122
123         total_tuples = 0
124         for r in cur:
125             total_tuples += r[1]
126         log.debug("Total number of rows; {}".format(total_tuples))
127
128         cur.scroll(0, mode='absolute')
129
130         next_thread = self.find_free_thread()
131         done_tuples = 0
132         rank_start_time = datetime.now()
133         for r in cur:
134             sector = r[0]
135
136             # Should we do the remaining ones together?
137             do_all = total_tuples - done_tuples < len(self.threads) * 1000
138
139             pcur = self.conn.cursor(name='places')
140
141             if do_all:
142                 pcur.execute(obj.sql_nosector_places())
143             else:
144                 pcur.execute(obj.sql_sector_places(), (sector, ))
145
146             for place in pcur:
147                 place_id = place[0]
148                 log.debug("Processing place {}".format(place_id))
149                 thread = next(next_thread)
150
151                 thread.perform(obj.sql_index_place(), (place_id,))
152                 done_tuples += 1
153
154             pcur.close()
155
156             if do_all:
157                 break
158
159         cur.close()
160
161         for t in self.threads:
162             t.wait()
163
164         rank_end_time = datetime.now()
165         diff_seconds = (rank_end_time-rank_start_time).total_seconds()
166
167         log.info("Done {} in {} @ {} per second - FINISHED {}\n".format(
168                  done_tuples, int(diff_seconds),
169                  done_tuples/diff_seconds, obj.name()))
170
171     def find_free_thread(self):
172         thread_lookup = { t.fileno() : t for t in self.threads}
173
174         done_fids = [ t.fileno() for t in self.threads ]
175
176         while True:
177             for fid in done_fids:
178                 thread = thread_lookup[fid]
179                 if thread.is_done():
180                     yield thread
181                 else:
182                     print("not good", fid)
183
184             done_fids = [ x[0] for x in self.poll.poll()]
185
186         assert(False, "Unreachable code")
187
188 class RankRunner(object):
189
190     def __init__(self, rank):
191         self.rank = rank
192
193     def name(self):
194         return "rank {}".format(self.rank)
195
196     @classmethod
197     def prepare(cls):
198         return """PREPARE rnk_index AS
199                   UPDATE placex
200                   SET indexed_status = 0 WHERE place_id = $1"""
201
202     def sql_index_sectors(self):
203         return """SELECT geometry_sector, count(*) FROM placex
204                   WHERE rank_search = {} and indexed_status > 0
205                   GROUP BY geometry_sector
206                   ORDER BY geometry_sector""".format(self.rank)
207
208     def sql_nosector_places(self):
209         return """SELECT place_id FROM placex
210                   WHERE indexed_status > 0 and rank_search = {}
211                   ORDER BY geometry_sector""".format(self.rank)
212
213     def sql_sector_places(self):
214         return """SELECT place_id FROM placex
215                   WHERE indexed_status > 0 and geometry_sector = %s
216                   ORDER BY geometry_sector"""
217
218     def sql_index_place(self):
219         return "EXECUTE rnk_index(%s)"
220
221
222 class InterpolationRunner(object):
223
224     def name(self):
225         return "interpolation lines (location_property_osmline)"
226
227     @classmethod
228     def prepare(cls):
229         return """PREPARE ipl_index AS
230                   UPDATE location_property_osmline
231                   SET indexed_status = 0 WHERE place_id = $1"""
232
233     def sql_index_sectors(self):
234         return """SELECT geometry_sector, count(*) FROM location_property_osmline
235                   WHERE indexed_status > 0
236                   GROUP BY geometry_sector
237                   ORDER BY geometry_sector"""
238
239     def sql_nosector_places(self):
240         return """SELECT place_id FROM location_property_osmline
241                   WHERE indexed_status > 0
242                   ORDER BY geometry_sector"""
243
244     def sql_sector_places(self):
245         return """SELECT place_id FROM location_property_osmline
246                   WHERE indexed_status > 0 and geometry_sector = %s
247                   ORDER BY geometry_sector"""
248
249     def sql_index_place(self):
250         return "EXECUTE ipl_index(%s)"
251
252
253 def nominatim_arg_parser():
254     """ Setup the command-line parser for the tool.
255     """
256     def h(s):
257         return re.sub("\s\s+" , " ", s)
258
259     p = ArgumentParser(description=__doc__,
260                        formatter_class=RawDescriptionHelpFormatter)
261
262     p.add_argument('-d', '--database',
263                    dest='dbname', action='store', default='nominatim',
264                    help='Name of the PostgreSQL database to connect to.')
265     p.add_argument('-U', '--username',
266                    dest='user', action='store',
267                    help='PostgreSQL user name.')
268     p.add_argument('-W', '--password',
269                    dest='password_prompt', action='store_true',
270                    help='Force password prompt.')
271     p.add_argument('-H', '--host',
272                    dest='host', action='store',
273                    help='PostgreSQL server hostname or socket location.')
274     p.add_argument('-P', '--port',
275                    dest='port', action='store',
276                    help='PostgreSQL server port')
277     p.add_argument('-r', '--minrank',
278                    dest='minrank', type=int, metavar='RANK', default=0,
279                    help='Minimum/starting rank.')
280     p.add_argument('-R', '--maxrank',
281                    dest='maxrank', type=int, metavar='RANK', default=30,
282                    help='Maximum/finishing rank.')
283     p.add_argument('-t', '--threads',
284                    dest='threads', type=int, metavar='NUM', default=1,
285                    help='Number of threads to create for indexing.')
286     p.add_argument('-v', '--verbose',
287                    dest='loglevel', action='count', default=0,
288                    help='Increase verbosity')
289
290     return p
291
292 if __name__ == '__main__':
293     logging.basicConfig(stream=sys.stderr, format='%(levelname)s: %(message)s')
294
295     options = nominatim_arg_parser().parse_args(sys.argv[1:])
296
297     log.setLevel(max(3 - options.loglevel, 0) * 10)
298
299     options.password = None
300     if options.password_prompt:
301         password = getpass.getpass("Database password: ")
302         options.password = password
303
304     Indexer(options).run()