]> git.openstreetmap.org Git - osqa.git/blob - forum_modules/exporter/importer.py
Several fixes in the OSQA exporter.
[osqa.git] / forum_modules / exporter / importer.py
1 import os, tarfile, datetime, ConfigParser, logging
2
3 from django.utils.translation import ugettext as _
4 from django.core.cache import cache
5
6 from south.db import db
7
8 from xml.sax import make_parser
9 from xml.sax.handler import ContentHandler, ErrorHandler
10
11 from forum.templatetags.extra_tags import diff_date
12
13 from exporter import TMP_FOLDER, DATETIME_FORMAT, DATE_FORMAT, META_INF_SECTION, CACHE_KEY
14 from orm import orm
15 import commands, settings
16
17 NO_DEFAULT = object()
18
19 class ContentElement():
20     def __init__(self, content):
21         self._content = content
22
23     def content(self):
24         return self._content.strip()
25
26     def as_bool(self):
27         return self.content() == "true"
28
29     def as_date(self, default=NO_DEFAULT):
30         try:
31             return datetime.datetime.strptime(self.content(), DATE_FORMAT)
32         except:
33             if default == NO_DEFAULT:
34                 return datetime.date.fromtimestamp(0)
35             else:
36                 return default
37             
38
39     def as_datetime(self, default=NO_DEFAULT):
40         try:
41             return datetime.datetime.strptime(self.content(), DATETIME_FORMAT)
42         except:
43             if default == NO_DEFAULT:
44                 return datetime.datetime.fromtimestamp(0)
45             else:
46                 return default
47
48     def as_int(self, default=0):
49         try:
50             return int(self.content())
51         except:
52             return default
53
54     def __str__(self):
55         return self.content()
56
57
58 class RowElement(ContentElement):
59     def __init__(self, name, attrs, parent=None):
60         self.name = name.lower()
61         self.parent = parent
62         self.attrs = dict([(k.lower(), ContentElement(v)) for k, v in attrs.items()])
63         self._content = u''
64         self.sub_elements = {}
65
66         if parent:
67             parent.add(self)
68
69     def add_to_content(self, ch):
70         self._content += unicode(ch)
71
72     def add(self, sub):
73         curr = self.sub_elements.get(sub.name, None)
74
75         if not curr:
76             curr = []
77             self.sub_elements[sub.name] = curr
78
79         curr.append(sub)
80
81     def get(self, name, default=None):
82         return self.sub_elements.get(name.lower(), [default])[-1]
83
84     def get_list(self, name):
85         return self.sub_elements.get(name.lower(), [])
86
87     def get_listc(self, name):
88         return [r.content() for r in self.get_list(name)]
89
90     def getc(self, name, default=""):
91         el = self.get(name, None)
92
93         if el:
94             return el.content()
95         else:
96             return default
97
98     def get_attr(self, name, default=""):
99         return self.attrs.get(name.lower(), default)
100
101     def as_pickled(self, default=None):
102         value_el = self.get('value')
103
104         if value_el:
105             return value_el._as_pickled(default)
106         else:
107             return default
108
109     TYPES_MAP = dict([(c.__name__, c) for c in (int, long, str, unicode, float)])
110
111     def _as_pickled(self, default=None):
112         type = self.get_attr('type').content()
113
114         try:
115             if type == 'dict':
116                 return dict([ (item.get_attr('key'), item.as_pickled()) for item in self.get_list('item') ])
117             elif type == 'list':
118                 return [item.as_pickled() for item in self.get_list('item')]
119             elif type == 'bool':
120                 return self.content().lower() == 'true'
121             elif type in RowElement.TYPES_MAP:
122                 return RowElement.TYPES_MAP[type](self.content())
123             else:
124                 return self.content()
125         except:
126             return default
127
128
129
130
131 class TableHandler(ContentHandler):
132     def __init__(self, root_name, row_name, callback, callback_args = [], ping = None):
133         self.root_name = root_name.lower()
134         self.row_name = row_name.lower()
135         self.callback = callback
136         self.callback_args = callback_args
137         self.ping = ping
138
139         self._reset()
140
141     def _reset(self):
142         self.curr_element = None
143         self.in_tag = None
144
145     def startElement(self, name, attrs):
146         name = name.lower()
147
148         if name == self.root_name.lower():
149             pass
150         elif name == self.row_name:
151             self.curr_element = RowElement(name, attrs)
152         else:
153             self.curr_element = RowElement(name, attrs, self.curr_element)
154
155     def characters(self, ch):
156         if self.curr_element:
157             self.curr_element.add_to_content(ch)
158
159     def endElement(self, name):
160         name = name.lower()
161
162         if name == self.root_name:
163             pass
164         elif name == self.row_name:
165             self.callback(self.curr_element, *self.callback_args)
166             if self.ping:
167                 self.ping()
168
169             self._reset()
170         else:
171             self.curr_element = self.curr_element.parent
172
173
174 class SaxErrorHandler(ErrorHandler):
175     def error(self, e):
176         raise e
177
178     def fatalError(self, e):
179         raise e
180
181     def warning(self, e):
182         raise e
183
184 def disable_triggers():
185     if db.backend_name == "postgres":
186         db.start_transaction()
187         db.execute_many(commands.PG_DISABLE_TRIGGERS)
188         db.commit_transaction()
189
190 def enable_triggers():
191     if db.backend_name == "postgres":
192         db.start_transaction()
193         db.execute_many(commands.PG_ENABLE_TRIGGERS)
194         db.commit_transaction()
195
196 def reset_sequences():
197     if db.backend_name == "postgres":
198         db.start_transaction()
199         db.execute_many(commands.PG_SEQUENCE_RESETS)
200         db.commit_transaction()
201
202 def reset_fts_indexes():
203     pass
204
205 FILE_HANDLERS = []
206
207 def start_import(fname, tag_merge, user):
208
209     start_time = datetime.datetime.now()
210     steps = [s for s in FILE_HANDLERS]
211
212     with open(os.path.join(TMP_FOLDER, 'backup.inf'), 'r') as inffile:
213         inf = ConfigParser.SafeConfigParser()
214         inf.readfp(inffile)
215
216         state = dict([(s['id'], {
217             'status': _('Queued'), 'count': int(inf.get(META_INF_SECTION, s['id'])), 'parsed': 0
218         }) for s in steps] + [
219             ('overall', {
220                 'status': _('Starting'), 'count': int(inf.get(META_INF_SECTION, 'overall')), 'parsed': 0
221             })
222         ])
223
224     full_state = dict(running=True, state=state, time_started="")
225
226     def set_state():
227         full_state['time_started'] = diff_date(start_time)
228         cache.set(CACHE_KEY, full_state)
229
230     set_state()
231
232     def ping_state(name):
233         state[name]['parsed'] += 1
234         state['overall']['parsed'] += 1
235         set_state()
236
237     data = {
238         'is_merge': True,
239         'tag_merge': tag_merge
240     }
241
242     def run(fn, name):
243         def ping():
244             ping_state(name)
245
246         state['overall']['status'] = _('Importing %s') % s['name']
247         state[name]['status'] = _('Importing')
248
249
250         fn(TMP_FOLDER, user, ping, data)
251
252         state[name]['status'] = _('Done')
253
254         set_state()
255
256         return fname
257
258     #dump = tarfile.open(fname, 'r')
259     #dump.extractall(TMP_FOLDER)
260
261     try:
262
263         disable_triggers()
264         db.start_transaction()
265
266         for h in FILE_HANDLERS:
267             run(h['fn'], h['id'])
268
269         db.commit_transaction()
270         enable_triggers()
271
272         settings.MERGE_MAPPINGS.set_value(dict(merged_nodes=data['nodes_map'], merged_users=data['users_map']))
273
274         reset_sequences()
275     except Exception, e:
276         full_state['running'] = False
277         full_state['errors'] = "%s: %s" % (e.__class__.__name__, unicode(e))
278         set_state()
279
280         import traceback
281         logging.error("Error executing xml import: \n %s" % (traceback.format_exc()))
282
283 def file_handler(file_name, root_tag, el_tag, name, args_handler=None, pre_callback=None, post_callback=None):
284     def decorator(fn):
285         def decorated(location, current_user, ping, data):
286             if pre_callback:
287                 pre_callback(current_user, data)
288
289             if (args_handler):
290                 args = args_handler(current_user, data)
291             else:
292                 args = []
293
294             parser = make_parser()
295             handler = TableHandler(root_tag, el_tag, fn, args, ping)
296             parser.setContentHandler(handler)
297             #parser.setErrorHandler(SaxErrorHandler())
298
299             parser.parse(os.path.join(location, file_name))
300
301             if post_callback:
302                 post_callback()
303
304         FILE_HANDLERS.append(dict(id=root_tag, name=name, fn=decorated))
305         return decorated
306     return decorator
307
308 def verify_existence(row):
309     try:
310         return orm.User.objects.get(email=row.getc('email'))
311     except:
312         for key in row.get('authKeys').get_list('key'):
313             key = key=key.getc('key')
314
315             if not ("google.com" in key or "yahoo.com" in key):
316                 try:
317                     return orm.AuthKeyUserAssociation.objects.get(key=key).user
318                 except:
319                     pass
320
321     return None
322
323 def user_import_pre_callback(user, data):
324     data['users_map'] = {}
325
326 @file_handler('users.xml', 'users', 'user', _('Users'), pre_callback=user_import_pre_callback, args_handler=lambda u, d: [u, d['is_merge'], d['users_map']])
327 def user_import(row, current_user, is_merge, users_map):
328     existent = is_merge and verify_existence(row) or None
329
330     roles = row.get('roles').get_listc('role')
331     valid_email = row.get('email').get_attr('validated').as_bool()
332     badges = row.get('badges')
333
334     if existent:
335         user = existent
336
337         user.reputation += row.get('reputation').as_int()
338         user.gold += badges.get_attr('gold').as_int()
339         user.silver += badges.get_attr('gold').as_int()
340         user.bronze += badges.get_attr('gold').as_int()
341
342     else:
343         username = row.getc('username')
344
345         if is_merge:
346             username_count = 0
347
348             while orm.User.objects.filter(username=username).count():
349                 username_count += 1
350                 username = "%s %s" % (row.getc('username'), username_count)
351
352         user = orm.User(
353                 id           = (not is_merge) and row.getc('id') or None,
354                 username     = username,
355                 password     = row.getc('password'),
356                 email        = row.getc('email'),
357                 email_isvalid= valid_email,
358                 is_superuser = (not is_merge) and 'superuser' in roles,
359                 is_staff     = ('moderator' in roles) or (is_merge and 'superuser' in roles),
360                 is_active    = row.get('active').as_bool(),
361                 date_joined  = row.get('joindate').as_datetime(),
362                 about         = row.getc('bio'),
363                 date_of_birth = row.get('birthdate').as_date(None),
364                 website       = row.getc('website'),
365                 reputation    = row.get('reputation').as_int(),
366                 gold          = badges.get_attr('gold').as_int(),
367                 silver        = badges.get_attr('silver').as_int(),
368                 bronze        = badges.get_attr('bronze').as_int(),
369                 real_name     = row.getc('realname'),
370                 location      = row.getc('location'),
371         )
372
373     user.save()
374
375     users_map[row.get('id').as_int()] = user.id
376
377     authKeys = row.get('authKeys')
378
379     for key in authKeys.get_list('key'):
380         if (not is_merge) or orm.AuthKeyUserAssociation.objects.filter(key=key.getc('key')).count() == 0:
381             orm.AuthKeyUserAssociation(user=user, key=key.getc('key'), provider=key.getc('provider')).save()
382
383     if not existent:
384         notifications = row.get('notifications')
385
386         attributes = dict([(str(k), v.as_bool() and 'i' or 'n') for k, v in notifications.get('notify').attrs.items()])
387         attributes.update(dict([(str(k), v.as_bool()) for k, v in notifications.get('autoSubscribe').attrs.items()]))
388         attributes.update(dict([(str("notify_%s" % k), v.as_bool()) for k, v in notifications.get('notifyOnSubscribed').attrs.items()]))
389
390         ss = orm.SubscriptionSettings(user=user, enable_notifications=notifications.get_attr('enabled').as_bool(), **attributes)
391
392         if current_user.id == row.get('id').as_int():
393             ss.id = current_user.subscription_settings.id
394
395         ss.save()
396         
397
398 def pre_tag_import(user, data):
399     data['tag_mappings'] = dict([ (t.name, t) for t in orm.Tag.objects.all() ])
400
401
402 @file_handler('tags.xml', 'tags', 'tag', _('Tags'), pre_callback=pre_tag_import, args_handler=lambda u, d: [d['is_merge'], d['tag_merge'], d['users_map'], d['tag_mappings']])
403 def tag_import(row, is_merge, tag_merge, users_map, tag_mappings):
404     created_by = row.get('used').as_int()
405     created_by = users_map.get(created_by, created_by)
406
407     tag_name = row.getc('name')
408     tag_name = tag_merge and tag_merge.get(tag_name, tag_name) or tag_name
409
410     if is_merge and tag_name in tag_mappings:
411         tag = tag_mappings[tag_name]
412         tag.used_count += row.get('used').as_int()
413     else:
414         tag = orm.Tag(name=tag_name, used_count=row.get('used').as_int(), created_by_id=created_by)
415         tag_mappings[tag.name] = tag
416
417     tag.save()
418
419 def pre_node_import(user, data):
420     data['nodes_map'] = {}
421
422 @file_handler('nodes.xml', 'nodes', 'node', _('Nodes'), pre_callback=pre_node_import,
423               args_handler=lambda u, d: [d['is_merge'], d['tag_merge'], d['tag_mappings'], d['nodes_map'], d['users_map']])
424 def node_import(row, is_merge, tag_merge, tags, nodes_map, users_map):
425
426     ntags = []
427
428     for t in row.get('tags').get_list('tag'):
429         t = t.content()
430         ntags.append(tags[tag_merge and tag_merge.get(t, t) or t])
431
432     author = row.get('author').as_int()
433
434     last_act = row.get('lastactivity')
435     last_act_user = last_act.get('by').as_int(None)
436
437     parent = row.get('parent').as_int(None)
438     abs_parent = row.get('absparent').as_int(None)
439
440     node = orm.Node(
441             id            = (not is_merge) and row.getc('id') or None,
442             node_type     = row.getc('type'),
443             author_id     = users_map.get(author, author),
444             added_at      = row.get('date').as_datetime(),
445             parent_id     = nodes_map.get(parent, parent),
446             abs_parent_id = nodes_map.get(abs_parent, abs_parent),
447             score         = row.get('score').as_int(0),
448
449             last_activity_by_id = last_act_user and users_map.get(last_act_user, last_act_user) or last_act_user,
450             last_activity_at    = last_act.get('at').as_datetime(None),
451
452             title         = row.getc('title'),
453             body          = row.getc('body'),
454             tagnames      = " ".join([t.name for t in ntags]),
455
456             marked        = row.get('marked').as_bool(),
457             extra_ref_id  = row.get('extraRef').as_int(None),
458             extra_count   = row.get('extraCount').as_int(0),
459             extra         = row.get('extraData').as_pickled()
460     )
461
462     node.save()
463
464     nodes_map[row.get('id').as_int()] = node.id
465
466     node.tags = ntags
467
468     revisions = row.get('revisions')
469     active = revisions.get_attr('active').as_int()
470
471     if active == 0:
472         active = orm.NodeRevision(
473             author_id = node.author_id,
474             body = row.getc('body'),
475             node = node,
476             revised_at = row.get('date').as_datetime(),
477             revision = 1,
478             summary = _('Initial revision'),
479             tagnames = " ".join([t.name for t in ntags]),
480             title = row.getc('title'),
481         )
482
483         active.save()
484     else:
485         for r in revisions.get_list('revision'):
486             author = row.get('author').as_int()
487
488             rev = orm.NodeRevision(
489                 author_id = users_map.get(author, author),
490                 body = r.getc('body'),
491                 node = node,
492                 revised_at = r.get('date').as_datetime(),
493                 revision = r.get('number').as_int(),
494                 summary = r.getc('summary'),
495                 tagnames = " ".join(r.getc('tags').split(',')),
496                 title = r.getc('title'),
497             )
498
499             rev.save()
500             if rev.revision == active:
501                 active = rev
502
503     node.active_revision = active
504     node.save()
505
506 POST_ACTION = {}
507
508 def post_action(*types):
509     def decorator(fn):
510         for t in types:
511             POST_ACTION[t] = fn
512         return fn
513     return decorator
514
515 def pre_action_import_callback(user, data):
516     data['actions_map'] = {}
517
518 def post_action_import_callback():
519     with_state = orm.Node.objects.filter(id__in=orm.NodeState.objects.values_list('node_id', flat=True).distinct())
520
521     for n in with_state:
522         n.state_string = "".join(["(%s)" % s for s in n.states.values_list('state_type')])
523         n.save()
524
525 @file_handler('actions.xml', 'actions', 'action', _('Actions'), post_callback=post_action_import_callback,
526               pre_callback=pre_action_import_callback, args_handler=lambda u, d: [d['nodes_map'], d['users_map'], d['actions_map']])
527 def actions_import(row, nodes, users, actions_map):
528     node = row.get('node').as_int(None)
529     user = row.get('user').as_int()
530     real_user = row.get('realUser').as_int(None)
531
532     action = orm.Action(
533         #id           = row.get('id').as_int(),
534         action_type  = row.getc('type'),
535         action_date  = row.get('date').as_datetime(),
536         node_id      = nodes.get(node, node),
537         user_id      = users.get(user, user),
538         real_user_id = users.get(real_user, real_user),
539         ip           = row.getc('ip'),
540         extra        = row.get('extraData').as_pickled(),
541     )
542
543     canceled = row.get('canceled')
544     if canceled.get_attr('state').as_bool():
545         by = canceled.get('user').as_int()
546         action.canceled = True
547         action.canceled_by_id = users.get(by, by)
548         action.canceled_at = canceled.getc('date') #.as_datetime(),
549         action.canceled_ip = canceled.getc('ip')
550
551     action.save()
552
553     actions_map[row.get('id').as_int()] = action.id
554
555     for r in row.get('reputes').get_list('repute'):
556         by_canceled = r.get_attr('byCanceled').as_bool()
557
558         orm.ActionRepute(
559             action = action,
560             user_id = users[r.get('user').as_int()],
561             value = r.get('value').as_int(),
562
563             date = by_canceled and action.canceled_at or action.action_date,
564             by_canceled = by_canceled
565         ).save()
566
567     if (not action.canceled) and (action.action_type in POST_ACTION):
568         POST_ACTION[action.action_type](row, action, users, nodes, actions_map)
569
570
571
572
573 @post_action('voteup', 'votedown', 'voteupcomment')
574 def vote_action(row, action, users, nodes, actions):
575     orm.Vote(user_id=action.user_id, node_id=action.node_id, action=action,
576              voted_at=action.action_date, value=(action.action_type != 'votedown') and 1 or -1).save()
577
578 def state_action(state):
579     def fn(row, action, users, nodes, actions):
580         if orm.NodeState.objects.filter(state_type = state, node = action.node_id).count():
581             return
582
583         orm.NodeState(
584             state_type = state,
585             node_id = action.node_id,
586             action = action
587         ).save()
588     return fn
589
590 post_action('wikify')(state_action('wiki'))
591 post_action('delete')(state_action('deleted'))
592 post_action('acceptanswer')(state_action('accepted'))
593 post_action('publish')(state_action('published'))
594
595
596 @post_action('flag')
597 def flag_action(row, action, users, nodes, actions):
598     orm.Flag(user_id=action.user_id, node_id=action.node_id, action=action, reason=action.extra or "").save()
599
600
601 def award_import_args(user, data):
602     return [ dict([ (b.cls, b) for b in orm.Badge.objects.all() ]) , data['nodes_map'], data['users_map'], data['actions_map']]
603
604
605 @file_handler('awards.xml', 'awards', 'award', _('Awards'), args_handler=award_import_args)
606 def awards_import(row, badges, nodes, users, actions):
607     badge_type = badges.get(row.getc('badge'), None)
608
609     if not badge_type:
610         return
611
612     action = row.get('action').as_int(None)
613     trigger = row.get('trigger').as_int(None)
614     node = row.get('node').as_int(None)
615     user = row.get('user').as_int()
616
617     if orm.Award.objects.filter(badge=badges[row.getc('badge')], user=users.get(user, user), node=nodes.get(node, node)).count():
618         return
619
620     award = orm.Award(
621         user_id = users.get(user, user),
622         badge = badge_type,
623         node_id = nodes.get(node, node),
624         action_id = actions.get(action, action),
625         trigger_id = actions.get(trigger, trigger)
626     ).save()
627
628
629 #@file_handler('settings.xml', 'settings', 'setting', _('Settings'))
630 def settings_import(row):
631     orm.KeyValue(key=row.getc('key'), value=row.get('value').as_pickled())
632
633
634
635
636
637
638
639
640