]> git.openstreetmap.org Git - osqa.git/commitdiff
Several improvements in the sx importer.
authorhernani <hernani@0cfe37f9-358a-4d5e-be75-b63607b5c754>
Wed, 9 Jun 2010 22:21:48 +0000 (22:21 +0000)
committerhernani <hernani@0cfe37f9-358a-4d5e-be75-b63607b5c754>
Wed, 9 Jun 2010 22:21:48 +0000 (22:21 +0000)
git-svn-id: http://svn.osqa.net/svnroot/osqa/trunk@401 0cfe37f9-358a-4d5e-be75-b63607b5c754

forum_modules/sximporter/importer.py
forum_modules/sximporter/views.py

index bc1cf9fa785fd7a4e2e1110c6c02c2b52330ac87..d63541c03c64eb86bf92c610d05d542533c66156 100644 (file)
@@ -4,6 +4,8 @@ from xml.dom import minidom
 from datetime import datetime, timedelta
 import time
 import re
+import os
+import gc
 from django.utils.translation import ugettext as _
 from django.template.defaultfilters import slugify
 from forum.models.utils import dbsafe_encode
@@ -20,6 +22,51 @@ from copy import deepcopy
 from base64 import b64encode, b64decode
 from zlib import compress, decompress
 
+from xml.sax import make_parser
+from xml.sax.handler import ContentHandler
+
+class SXTableHandler(ContentHandler):
+    def __init__(self, fname, callback):
+        self.in_row = False
+        self.el_data = {}
+        self.ch_data = ''
+
+        self.fname = fname.lower()
+        self.callback = callback
+
+    def startElement(self, name, attrs):
+        if name.lower() == self.fname:
+            pass
+        elif name.lower() == "row":
+            self.in_row = True
+
+    def characters(self, ch):
+        self.ch_data += ch
+
+    def endElement(self, name):
+        if name.lower() == self.fname:
+            pass
+        elif name.lower() == "row":
+            self.callback(self.el_data)
+
+            self.in_row = False
+            del self.el_data
+            self.el_data = {}
+        elif self.in_row:
+            self.el_data[name.lower()] = self.ch_data.strip()
+            del self.ch_data
+            self.ch_data = ''
+
+
+def readTable(path, name, callback):
+    parser = make_parser()
+    handler = SXTableHandler(name, callback)
+    parser.setContentHandler(handler)
+
+    f = os.path.join(path, "%s.xml" % name)
+    parser.parse(f)
+
+
 def dbsafe_encode(value):
     return force_unicode(b64encode(compress(dumps(deepcopy(value)))))
 
@@ -38,11 +85,13 @@ def readTime(ts):
 
     return datetime(*time.strptime(ts, '%Y-%m-%dT%H:%M:%S')[0:6])
 
-def readEl(el):
-    return dict([(n.tagName.lower(), getText(n)) for n in el.childNodes if n.nodeType == el.ELEMENT_NODE])
+#def readEl(el):
+#    return dict([(n.tagName.lower(), getText(n)) for n in el.childNodes if n.nodeType == el.ELEMENT_NODE])
 
-def readTable(dump, name):
-    return [readEl(e) for e in minidom.parseString(dump.read("%s.xml" % name)).getElementsByTagName('row')]
+#def readTable(dump, name):
+#    for e in minidom.parseString(dump.read("%s.xml" % name)).getElementsByTagName('row'):
+#        yield readEl(e)
+#return [readEl(e) for e in minidom.parseString(dump.read("%s.xml" % name)).getElementsByTagName('row')]
 
 google_accounts_lookup = re.compile(r'^https?://www.google.com/accounts/')
 yahoo_accounts_lookup = re.compile(r'^https?://me.yahoo.com/a/')
@@ -109,25 +158,32 @@ class IdMapper(dict):
     def __setitem__(self, key, value):
         super(IdMapper, self).__setitem__(int(key), int(value))
 
+class IdIncrementer():
+    def __init__(self, initial):
+        self.value = initial
+
+    def inc(self):
+        self.value += 1
+
 openidre = re.compile('^https?\:\/\/')
-def userimport(dump, options):
-    users = readTable(dump, "Users")
+def userimport(path, options):
+#users = readTable(dump, "Users")
 
     user_by_name = {}
     uidmapper = IdMapper()
-    merged_users = []
+    #merged_users = []
 
     owneruid = options.get('owneruid', None)
     #check for empty values
     if not owneruid:
         owneruid = None
 
-    for sxu in users:
+    def callback(sxu):
         create = True
 
         if sxu.get('id') == '-1':
-            continue
-
+            return
+        #print "\n".join(["%s : %s" % i for i in sxu.items()])
         if int(sxu.get('id')) == int(owneruid):
             osqau = orm.User.objects.get(id=1)
             uidmapper[owneruid] = 1
@@ -224,7 +280,7 @@ def userimport(dump, options):
             osqau.location = sxu.get('location', '')
             osqau.real_name = sxu.get('realname', '')
 
-            merged_users.append(osqau.id)
+            #merged_users.append(osqau.id)
             osqau.save()
 
         user_by_name[osqau.username] = osqau
@@ -234,17 +290,19 @@ def userimport(dump, options):
             assoc = orm.AuthKeyUserAssociation(user=osqau, key=openid, provider="openidurl")
             assoc.save()
 
+    readTable(path, "Users", callback)
+
     if uidmapper[-1] == -1:
         uidmapper[-1] = 1
 
-    return (uidmapper, merged_users)
+    return uidmapper
 
 def tagsimport(dump, uidmap):
-    tags = readTable(dump, "Tags")
+#tags = readTable(dump, "Tags")
 
     tagmap = {}
 
-    for sxtag in tags:
+    def callback(sxtag):
         otag = orm.Tag(
                 id = int(sxtag['id']),
                 name = sxtag['name'],
@@ -255,6 +313,8 @@ def tagsimport(dump, uidmap):
 
         tagmap[otag.name] = otag
 
+    readTable(dump, "Tags", callback)
+
     return tagmap
 
 def add_post_state(name, post, action):
@@ -280,19 +340,19 @@ def remove_post_state(name, post):
     post.state_string = "".join("(%s)" % s for s in re.findall('\w+', post.state_string) if s != name)
 
 def postimport(dump, uidmap, tagmap):
-    history = {}
-    accepted = {}
-    all = {}
+#history = {}
+#accepted = {}
+    all = []
 
-    for h in readTable(dump, "PostHistory"):
-        if not history.get(h.get('postid'), None):
-            history[h.get('postid')] = []
+    #for h in readTable(dump, "PostHistory"):
+    #    if not history.get(h.get('postid'), None):
+    #        history[h.get('postid')] = []
+    #
+    #    history[h.get('postid')].append(h)
 
-        history[h.get('postid')].append(h)
+    #posts = readTable(dump, "Posts")
 
-    posts = readTable(dump, "Posts")
-
-    for sxpost in posts:
+    def callback(sxpost):
         nodetype = (sxpost.get('posttypeid') == '1') and "nodetype" or "answer"
 
         post = orm.Node(
@@ -350,24 +410,30 @@ def postimport(dump, uidmap, tagmap):
 
             post.extra_count = sxpost.get('viewcount', 0)
 
+            add_tags_to_post(post, tagmap)
+
         else:
             post.parent_id = sxpost['parentid']
 
         post.save()
 
-        all[int(post.id)] = post
+        all.append(int(post.id))
+
+        del post
+
+    readTable(dump, "Posts", callback)
 
     return all
 
 def comment_import(dump, uidmap, posts):
-    comments = readTable(dump, "PostComments")
-    currid = max(posts.keys())
+#comments = readTable(dump, "PostComments")
+    currid = IdIncrementer(max(posts))
     mapping = {}
 
-    for sxc in comments:
-        currid += 1
+    def callback(sxc):
+        currid.inc()
         oc = orm.Node(
-                id = currid,
+                id = currid.value,
                 node_type = "comment",
                 added_at = readTime(sxc['creationdate']),
                 author_id = uidmap[sxc.get('userid', 1)],
@@ -403,20 +469,18 @@ def comment_import(dump, uidmap, posts):
         create_action.save()
         oc.save()
 
-        posts[oc.id] = oc
+        posts.append(int(oc.id))
         mapping[int(sxc['id'])] = int(oc.id)
 
+    readTable(dump, "PostComments", callback)
     return posts, mapping
 
 
-def add_tags_to_posts(posts, tagmap):
-    for post in posts.values():
-        if post.node_type == "question":
-            tags = [tag for tag in [tagmap.get(name.strip()) for name in post.tagnames.split(u' ') if name] if tag]
-            post.tagnames = " ".join([t.name for t in tags]).strip()
-            post.tags = tags
-
-        create_and_activate_revision(post)
+def add_tags_to_post(post, tagmap):
+    tags = [tag for tag in [tagmap.get(name.strip()) for name in post.tagnames.split(u' ') if name] if tag]
+    post.tagnames = " ".join([t.name for t in tags]).strip()
+    post.tags = tags
+    create_and_activate_revision(post)
 
 
 def create_and_activate_revision(post):
@@ -436,24 +500,29 @@ def create_and_activate_revision(post):
     post.save()
 
 def post_vote_import(dump, uidmap, posts):
-    votes = readTable(dump, "Posts2Votes")
-    close_reasons = dict([(r['id'], r['name']) for r in readTable(dump, "CloseReasons")])
+#votes = readTable(dump, "Posts2Votes")
+    close_reasons = {}
+
+    def close_callback(r):
+        close_reasons[r['id']] = r['name']
+
+    readTable(dump, "CloseReasons", close_callback)
 
     user2vote = []
 
-    for sxv in votes:
+    def callback(sxv):
         action = orm.Action(
                 user_id=uidmap[sxv['userid']],
                 action_date = readTime(sxv['creationdate']),
                 )
 
-        node = posts.get(int(sxv['postid']), None)
-        if not node: continue
+        if not int(sxv['postid']) in posts: return
+        node = orm.Node.objects.get(id=sxv['postid'])
         action.node = node
 
         if sxv['votetypeid'] == '1':
             answer = node
-            question = posts.get(int(answer.parent_id), None)
+            question = orm.Node.objects.get(id=answer.parent_id)
 
             action.action_type = "acceptanswer"
             action.save()
@@ -557,12 +626,15 @@ def post_vote_import(dump, uidmap, posts):
             state = {"acceptanswer": "accepted", "delete": "deleted", "close": "closed"}[action.action_type]
             add_post_state(state, node, action)
 
+    readTable(dump, "Posts2Votes", callback)
 
-def comment_vote_import(dump, uidmap, comments, posts):
-    votes = readTable(dump, "Comments2Votes")
+
+def comment_vote_import(dump, uidmap, comments):
+#votes = readTable(dump, "Comments2Votes")
     user2vote = []
+    comments2score = {}
 
-    for sxv in votes:
+    def callback(sxv):
         if sxv['votetypeid'] == "2":
             comment_id = comments[int(sxv['postcommentid'])]
             user_id = uidmap[sxv['userid']]
@@ -588,14 +660,28 @@ def comment_vote_import(dump, uidmap, comments, posts):
 
                 ov.save()
 
-                posts[int(action.node_id)].score += 1
-                posts[int(action.node_id)].save()
+                if not comment_id in comments2score:
+                    comments2score[comment_id] = 1
+                else:
+                    comments2score[comment_id] += 1
+
+    readTable(dump, "Comments2Votes", callback)
+
+    for cid, score in comments2score.items():
+        orm.Node.objects.filter(id=cid).update(score=score)
 
 
 def badges_import(dump, uidmap, post_list):
-    node_ctype = orm['contenttypes.contenttype'].objects.get(name='node')
+#node_ctype = orm['contenttypes.contenttype'].objects.get(name='node')
+
+    sxbadges = {}
+
+    def sxcallback(b):
+        sxbadges[int(b['id'])] = b
+
+    readTable(dump, "Badges", sxcallback)
+
     obadges = dict([(b.cls, b) for b in orm.Badge.objects.all()])
-    sxbadges = dict([(int(b['id']), b) for b in readTable(dump, "Badges")])
     user_badge_count = {}
 
     sx_to_osqa = {}
@@ -614,10 +700,9 @@ def badges_import(dump, uidmap, post_list):
             osqab.save()
             sx_to_osqa[id] = osqab
 
-    sxawards = readTable(dump, "Users2Badges")
     osqaawards = []
 
-    for sxa in sxawards:
+    def callback(sxa):
         badge = sx_to_osqa[int(sxa['badgeid'])]
 
         user_id = uidmap[sxa['userid']]
@@ -635,7 +720,7 @@ def badges_import(dump, uidmap, post_list):
         osqaa = orm.Award(
                 user_id = uidmap[sxa['userid']],
                 badge = badge,
-                node = post_list[user_badge_count[user_id]],
+                node_id = post_list[user_badge_count[user_id]],
                 awarded_at = action.action_date,
                 action = action
                 )
@@ -644,15 +729,20 @@ def badges_import(dump, uidmap, post_list):
         badge.awarded_count += 1
         user_badge_count[user_id] += 1
 
+    readTable(dump, "Users2Badges", callback)
+
     for badge in obadges.values():
         badge.save()
 
-def pages_import(dump):
+def pages_import(dump, currid):
+    currid = IdIncrementer(currid)
     registry = {}
-    sx_pages = readTable(dump, "FlatPages")
+    #sx_pages = readTable(dump, "FlatPages")
 
-    for sxp in sx_pages:
+    def callback(sxp):
+        currid.inc()
         page = orm.Node(
+                id = currid.value,
                 node_type = "page",
                 title = sxp['name'],
                 body = b64decode(sxp['value']),
@@ -690,6 +780,8 @@ def pages_import(dump):
             pub_action.save()
             add_post_state("published", page, pub_action)
 
+    readTable(dump, "FlatPages", callback)
+
     kv = orm.KeyValue(key='STATIC_PAGE_REGISTRY', value=dbsafe_encode(registry))
     kv.save()
 
@@ -721,10 +813,10 @@ def html_decode(html):
 
 
 def static_import(dump):
-    sx_sets = readTable(dump, "ThemeTextResources")
+#sx_sets = readTable(dump, "ThemeTextResources")
     sx_unknown = {}
 
-    for set in sx_sets:
+    def callback(set):
         if unicode(set['name']) in sx2osqa_set_map:
             kv = orm.KeyValue(
                     key = sx2osqa_set_map[set['name']],
@@ -735,9 +827,24 @@ def static_import(dump):
         else:
             sx_unknown[set['name']] = html_decode(set['value'])
 
+    readTable(dump, "ThemeTextResources", callback)
+
     unknown = orm.KeyValue(key='SXIMPORT_UNKNOWN_SETS', value=dbsafe_encode(sx_unknown))
     unknown.save()
 
+def disable_triggers():
+    from south.db import db
+    if db.backend_name == "postgres":
+        db.execute_many(PG_DISABLE_TRIGGERS)
+        db.commit_transaction()
+        db.start_transaction()
+
+def enable_triggers():
+    from south.db import db
+    if db.backend_name == "postgres":
+        db.start_transaction()
+        db.execute_many(PG_ENABLE_TRIGGERS)
+        db.commit_transaction()
 
 def reset_sequences():
     from south.db import db
@@ -746,24 +853,89 @@ def reset_sequences():
         db.execute_many(PG_SEQUENCE_RESETS)
         db.commit_transaction()
 
+
 def sximport(dump, options):
-    uidmap, merged_users = userimport(dump, options)
+    disable_triggers()
+    uidmap = userimport(dump, options)
     tagmap = tagsimport(dump, uidmap)
+    gc.collect()
+
     posts = postimport(dump, uidmap, tagmap)
+    gc.collect()
+
     posts, comments = comment_import(dump, uidmap, posts)
-    add_tags_to_posts(posts, tagmap)
+    gc.collect()
+
     post_vote_import(dump, uidmap, posts)
-    comment_vote_import(dump, uidmap, comments, posts)
-    badges_import(dump, uidmap, posts.values())
+    gc.collect()
 
-    pages_import(dump)
+    comment_vote_import(dump, uidmap, comments)
+    gc.collect()
+
+    badges_import(dump, uidmap, posts)
+
+    pages_import(dump, max(posts))
     static_import(dump)
+    gc.collect()
 
     from south.db import db
     db.commit_transaction()
 
     reset_sequences()
+    enable_triggers()
+
+
+PG_DISABLE_TRIGGERS = """
+ALTER table auth_user DISABLE TRIGGER ALL;
+ALTER table auth_user_groups DISABLE TRIGGER ALL;
+ALTER table auth_user_user_permissions DISABLE TRIGGER ALL;
+ALTER table forum_keyvalue DISABLE TRIGGER ALL;
+ALTER table forum_action DISABLE TRIGGER ALL;
+ALTER table forum_actionrepute DISABLE TRIGGER ALL;
+ALTER table forum_subscriptionsettings DISABLE TRIGGER ALL;
+ALTER table forum_validationhash DISABLE TRIGGER ALL;
+ALTER table forum_authkeyuserassociation DISABLE TRIGGER ALL;
+ALTER table forum_tag DISABLE TRIGGER ALL;
+ALTER table forum_markedtag DISABLE TRIGGER ALL;
+ALTER table forum_node DISABLE TRIGGER ALL;
+ALTER table forum_nodestate DISABLE TRIGGER ALL;
+ALTER table forum_node_tags DISABLE TRIGGER ALL;
+ALTER table forum_noderevision DISABLE TRIGGER ALL;
+ALTER table forum_node_tags DISABLE TRIGGER ALL;
+ALTER table forum_questionsubscription DISABLE TRIGGER ALL;
+ALTER table forum_vote DISABLE TRIGGER ALL;
+ALTER table forum_flag DISABLE TRIGGER ALL;
+ALTER table forum_badge DISABLE TRIGGER ALL;
+ALTER table forum_award DISABLE TRIGGER ALL;
+ALTER table forum_openidnonce DISABLE TRIGGER ALL;
+ALTER table forum_openidassociation DISABLE TRIGGER ALL;
+"""
 
+PG_ENABLE_TRIGGERS = """
+ALTER table auth_user ENABLE TRIGGER ALL;
+ALTER table auth_user_groups ENABLE TRIGGER ALL;
+ALTER table auth_user_user_permissions ENABLE TRIGGER ALL;
+ALTER table forum_keyvalue ENABLE TRIGGER ALL;
+ALTER table forum_action ENABLE TRIGGER ALL;
+ALTER table forum_actionrepute ENABLE TRIGGER ALL;
+ALTER table forum_subscriptionsettings ENABLE TRIGGER ALL;
+ALTER table forum_validationhash ENABLE TRIGGER ALL;
+ALTER table forum_authkeyuserassociation ENABLE TRIGGER ALL;
+ALTER table forum_tag ENABLE TRIGGER ALL;
+ALTER table forum_markedtag ENABLE TRIGGER ALL;
+ALTER table forum_node ENABLE TRIGGER ALL;
+ALTER table forum_nodestate ENABLE TRIGGER ALL;
+ALTER table forum_node_tags ENABLE TRIGGER ALL;
+ALTER table forum_noderevision ENABLE TRIGGER ALL;
+ALTER table forum_node_tags ENABLE TRIGGER ALL;
+ALTER table forum_questionsubscription ENABLE TRIGGER ALL;
+ALTER table forum_vote ENABLE TRIGGER ALL;
+ALTER table forum_flag ENABLE TRIGGER ALL;
+ALTER table forum_badge ENABLE TRIGGER ALL;
+ALTER table forum_award ENABLE TRIGGER ALL;
+ALTER table forum_openidnonce ENABLE TRIGGER ALL;
+ALTER table forum_openidassociation ENABLE TRIGGER ALL;
+"""
 
 PG_SEQUENCE_RESETS = """
 SELECT setval('"auth_user_id_seq"', coalesce(max("id"), 1) + 2, max("id") IS NOT null) FROM "auth_user";
@@ -783,8 +955,6 @@ SELECT setval('"forum_node_tags_id_seq"', coalesce(max("id"), 1) + 2, max("id")
 SELECT setval('"forum_noderevision_id_seq"', coalesce(max("id"), 1) + 2, max("id") IS NOT null) FROM "forum_noderevision";
 SELECT setval('"forum_node_tags_id_seq"', coalesce(max("id"), 1) + 2, max("id") IS NOT null) FROM "forum_node_tags";
 SELECT setval('"forum_questionsubscription_id_seq"', coalesce(max("id"), 1) + 2, max("id") IS NOT null) FROM "forum_questionsubscription";
-SELECT setval('"forum_node_tags_id_seq"', coalesce(max("id"), 1) + 2, max("id") IS NOT null) FROM "forum_node_tags";
-SELECT setval('"forum_node_tags_id_seq"', coalesce(max("id"), 1) + 2, max("id") IS NOT null) FROM "forum_node_tags";
 SELECT setval('"forum_vote_id_seq"', coalesce(max("id"), 1) + 2, max("id") IS NOT null) FROM "forum_vote";
 SELECT setval('"forum_flag_id_seq"', coalesce(max("id"), 1) + 2, max("id") IS NOT null) FROM "forum_flag";
 SELECT setval('"forum_badge_id_seq"', coalesce(max("id"), 1) + 2, max("id") IS NOT null) FROM "forum_badge";
index fb0bcd1bbf3a4e16bb8739a58e1b77eb3c085197..76bc8e4be99087b82c1f1f035e7ced07f6a06f83 100644 (file)
@@ -3,15 +3,19 @@ from django.template import RequestContext
 from forum.views.admin import super_user_required\r
 import importer\r
 from zipfile import ZipFile\r
+import os\r
 \r
 @super_user_required\r
 def sximporter(request):\r
     list = []\r
     if request.method == "POST" and "dump" in request.FILES:\r
         dump = ZipFile(request.FILES['dump'])\r
-        importer.sximport(dump, request.POST)\r
+        members = [f for f in dump.namelist() if f.endswith('.xml')]\r
+        extract_to = os.path.join(os.path.dirname(__file__), 'tmp')\r
+        dump.extractall(extract_to, members)\r
+        importer.sximport(extract_to, request.POST)\r
         dump.close()\r
 \r
     return render_to_response('modules/sximporter/page.html', {\r
-        'names': list\r
+    'names': list\r
     }, context_instance=RequestContext(request))
\ No newline at end of file