]> git.openstreetmap.org Git - nominatim.git/blobdiff - tests/steps/db_results.py
properly close connection in test
[nominatim.git] / tests / steps / db_results.py
index f65e992462366373143519d03e1adeb7bf2a11ec..2566418e39fdcd0c0fc4b8880dfe5a568bd90855 100644 (file)
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
 
 @step(u'table placex contains as names for (N|R|W)(\d+)')
 def check_placex_names(step, osmtyp, osmid):
-    """ Check for the exact content of the name hstaore in placex.
+    """ Check for the exact content of the name hstore in placex.
     """
     cur = world.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
     cur.execute('SELECT name FROM placex where osm_type = %s and osm_id =%s', (osmtyp, int(osmid)))
@@ -43,47 +43,55 @@ def check_placex_content(step, tablename):
         given columns are tested. If there is more than one
         line for an OSM object, they must match in these columns.
     """
-    cur = world.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
-    for line in step.hashes:
-        osmtype, osmid, cls = world.split_id(line['object'])
-        q = 'SELECT *'
-        if tablename == 'placex':
-            q = q + ", ST_X(centroid) as clat, ST_Y(centroid) as clon"
-        q = q + ", ST_GeometryType(geometry) as geometrytype"
-        q = q + ' FROM %s where osm_type = %%s and osm_id = %%s' % (tablename,)
-        if cls is None:
-            params = (osmtype, osmid)
-        else:
-            q = q + ' and class = %s'
-            params = (osmtype, osmid, cls)
-        cur.execute(q, params)
-        assert(cur.rowcount > 0)
-        for res in cur:
-            for k,v in line.iteritems():
-                if not k == 'object':
-                    assert_in(k, res)
-                    if type(res[k]) is dict:
-                        val = world.make_hash(v)
-                        assert_equals(res[k], val)
-                    elif k in ('parent_place_id', 'linked_place_id'):
-                        pid = world.get_placeid(v)
-                        assert_equals(pid, res[k], "Results for '%s'/'%s' differ: '%s' != '%s'" % (line['object'], k, pid, res[k]))
-                    elif k == 'centroid':
-                        world.match_geometry((res['clat'], res['clon']), v)
-                    else:
-                        assert_equals(str(res[k]), v, "Results for '%s'/'%s' differ: '%s' != '%s'" % (line['object'], k, str(res[k]), v))
+    try:
+        cur = world.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
+        for line in step.hashes:
+            osmtype, osmid, cls = world.split_id(line['object'])
+            q = 'SELECT *'
+            if tablename == 'placex':
+                q = q + ", ST_X(centroid) as clat, ST_Y(centroid) as clon"
+            q = q + ", ST_GeometryType(geometry) as geometrytype"
+            q = q + ' FROM %s where osm_type = %%s and osm_id = %%s' % (tablename,)
+            if cls is None:
+                params = (osmtype, osmid)
+            else:
+                q = q + ' and class = %s'
+                params = (osmtype, osmid, cls)
+            cur.execute(q, params)
+            assert(cur.rowcount > 0)
+            for res in cur:
+                for k,v in line.iteritems():
+                    if not k == 'object':
+                        assert_in(k, res)
+                        if type(res[k]) is dict:
+                            val = world.make_hash(v)
+                            assert_equals(res[k], val)
+                        elif k in ('parent_place_id', 'linked_place_id'):
+                            pid = world.get_placeid(v)
+                            assert_equals(pid, res[k], "Results for '%s'/'%s' differ: '%s' != '%s'" % (line['object'], k, pid, res[k]))
+                        elif k == 'centroid':
+                            world.match_geometry((res['clat'], res['clon']), v)
+                        else:
+                            assert_equals(str(res[k]), v, "Results for '%s'/'%s' differ: '%s' != '%s'" % (line['object'], k, str(res[k]), v))
+    finally:
+        cur.close()
+        world.conn.commit()
 
 @step(u'table (placex?) has no entry for (N|R|W)(\d+)(:\w+)?')
 def check_placex_missing(step, tablename, osmtyp, osmid, placeclass):
     cur = world.conn.cursor()
-    q = 'SELECT count(*) FROM %s where osm_type = %%s and osm_id = %%s' % (tablename, )
-    args = [osmtyp, int(osmid)]
-    if placeclass is not None:
-        q = q + ' and class = %s'
-        args.append(placeclass[1:])
-    cur.execute(q, args)
-    numres = cur.fetchone()[0]
-    assert_equals (numres, 0)
+    try:
+        q = 'SELECT count(*) FROM %s where osm_type = %%s and osm_id = %%s' % (tablename, )
+        args = [osmtyp, int(osmid)]
+        if placeclass is not None:
+            q = q + ' and class = %s'
+            args.append(placeclass[1:])
+        cur.execute(q, args)
+        numres = cur.fetchone()[0]
+        assert_equals (numres, 0)
+    finally:
+        cur.close()
+        world.conn.commit()
 
 @step(u'search_name table contains$')
 def check_search_name_content(step):