]> git.openstreetmap.org Git - nominatim.git/commitdiff
Merge pull request #2129 from lonvia/cleanup-bdd-tests
authorSarah Hoffmann <lonvia@denofr.de>
Thu, 7 Jan 2021 08:10:40 +0000 (09:10 +0100)
committerGitHub <noreply@github.com>
Thu, 7 Jan 2021 08:10:40 +0000 (09:10 +0100)
Clean up Python support code for BDD tests

test/bdd/steps/db_ops.py [deleted file]
test/bdd/steps/geometry_factory.py [new file with mode: 0644]
test/bdd/steps/http_responses.py [new file with mode: 0644]
test/bdd/steps/nominatim_environment.py [new file with mode: 0644]
test/bdd/steps/place_inserter.py [new file with mode: 0644]
test/bdd/steps/steps_api_queries.py [moved from test/bdd/steps/queries.py with 56% similarity]
test/bdd/steps/steps_db_ops.py [new file with mode: 0644]
test/bdd/steps/steps_osm_data.py [moved from test/bdd/steps/osm_data.py with 53% similarity]
test/bdd/steps/table_compare.py [new file with mode: 0644]
test/bdd/steps/utils.py [new file with mode: 0644]

index 0ac92104cd5792975bbca2678bc7b8eb8b86afcf..aeee2301bb6c063d8efcc7b73d01d698c99a541f 100644 (file)
@@ -1,16 +1,11 @@
 from behave import *
-import logging
-import os
-import psycopg2
-import psycopg2.extras
-import subprocess
-import tempfile
-from sys import version_info as python_version
+from pathlib import Path
-logger = logging.getLogger(__name__)
+from steps.geometry_factory import GeometryFactory
+from steps.nominatim_environment import NominatimEnvironment
 userconfig = {
-    'BUILDDIR' : os.path.join(os.path.split(__file__)[0], "../../build"),
+    'BUILDDIR' : (Path(__file__) / '..' / '..' / '..' / 'build').resolve(),
     'REMOVE_TEMPLATE' : False,
     'KEEP_TEST_DB' : False,
     'DB_HOST' : None,
@@ -26,290 +21,24 @@ userconfig = {
-class NominatimEnvironment(object):
-    """ Collects all functions for the execution of Nominatim functions.
-    """
-    def __init__(self, config):
-        self.build_dir = os.path.abspath(config['BUILDDIR'])
-        self.src_dir = os.path.abspath(os.path.join(os.path.split(__file__)[0], "../.."))
-        self.db_host = config['DB_HOST']
-        self.db_port = config['DB_PORT']
-        self.db_user = config['DB_USER']
-        self.db_pass = config['DB_PASS']
-        self.template_db = config['TEMPLATE_DB']
-        self.test_db = config['TEST_DB']
-        self.api_test_db = config['API_TEST_DB']
-        self.server_module_path = config['SERVER_MODULE_PATH']
-        self.reuse_template = not config['REMOVE_TEMPLATE']
-        self.keep_scenario_db = config['KEEP_TEST_DB']
-        self.code_coverage_path = config['PHPCOV']
-        self.code_coverage_id = 1
-        self.test_env = None
-        self.template_db_done = False
-        self.website_dir = None
-    def connect_database(self, dbname):
-        dbargs = {'database': dbname}
-        if self.db_host:
-            dbargs['host'] = self.db_host
-        if self.db_port:
-            dbargs['port'] = self.db_port
-        if self.db_user:
-            dbargs['user'] = self.db_user
-        if self.db_pass:
-            dbargs['password'] = self.db_pass
-        conn = psycopg2.connect(**dbargs)
-        return conn
-    def next_code_coverage_file(self):
-        fn = os.path.join(self.code_coverage_path, "%06d.cov" % self.code_coverage_id)
-        self.code_coverage_id += 1
-        return fn
-    def write_nominatim_config(self, dbname):
-        dsn = 'pgsql:dbname={}{}{}{}{}'.format(
-                dbname,
-                 (';host=' + self.db_host) if self.db_host else '',
-                 (';port=' + self.db_port) if self.db_port else '',
-                 (';user=' + self.db_user) if self.db_user else '',
-                 (';password=' + self.db_pass) if self.db_pass else ''
-                 )
-        if self.website_dir is not None \
-           and self.test_env is not None \
-           and dsn == self.test_env['NOMINATIM_DATABASE_DSN']:
-            return # environment already set uo
-        self.test_env = os.environ
-        self.test_env['NOMINATIM_DATABASE_DSN'] = dsn
-        self.test_env['NOMINATIM_FLATNODE_FILE'] = ''
-        self.test_env['NOMINATIM_IMPORT_STYLE'] = 'full'
-        self.test_env['NOMINATIM_USE_US_TIGER_DATA'] = 'yes'
-        if self.website_dir is not None:
-            self.website_dir.cleanup()
-        self.website_dir = tempfile.TemporaryDirectory()
-        self.run_setup_script('setup-website')
-    def db_drop_database(self, name):
-        conn = self.connect_database('postgres')
-        conn.set_isolation_level(0)
-        cur = conn.cursor()
-        cur.execute('DROP DATABASE IF EXISTS %s' % (name, ))
-        conn.close()
-    def setup_template_db(self):
-        if self.template_db_done:
-            return
-        self.template_db_done = True
-        if self.reuse_template:
-            # check that the template is there
-            conn = self.connect_database('postgres')
-            cur = conn.cursor()
-            cur.execute('select count(*) from pg_database where datname = %s',
-                        (self.template_db,))
-            if cur.fetchone()[0] == 1:
-                return
-            conn.close()
-        else:
-            # just in case... make sure a previous table has been dropped
-            self.db_drop_database(self.template_db)
-        try:
-            # call the first part of database setup
-            self.write_nominatim_config(self.template_db)
-            self.run_setup_script('create-db', 'setup-db')
-            # remove external data to speed up indexing for tests
-            conn = self.connect_database(self.template_db)
-            cur = conn.cursor()
-            cur.execute("""select tablename from pg_tables
-                           where tablename in ('gb_postcode', 'us_postcode')""")
-            for t in cur:
-                conn.cursor().execute('TRUNCATE TABLE %s' % (t[0],))
-            conn.commit()
-            conn.close()
-            # execute osm2pgsql import on an empty file to get the right tables
-            with tempfile.NamedTemporaryFile(dir='/tmp', suffix='.xml') as fd:
-                fd.write(b'<osm version="0.6"></osm>')
-                fd.flush()
-                self.run_setup_script('import-data',
-                                      'ignore-errors',
-                                      'create-functions',
-                                      'create-tables',
-                                      'create-partition-tables',
-                                      'create-partition-functions',
-                                      'load-data',
-                                      'create-search-indices',
-                                      osm_file=fd.name,
-                                      osm2pgsql_cache='200')
-        except:
-            self.db_drop_database(self.template_db)
-            raise
-    def setup_api_db(self, context):
-        self.write_nominatim_config(self.api_test_db)
-    def setup_unknown_db(self, context):
-        self.write_nominatim_config('UNKNOWN_DATABASE_NAME')
-    def setup_db(self, context):
-        self.setup_template_db()
-        self.write_nominatim_config(self.test_db)
-        conn = self.connect_database(self.template_db)
-        conn.set_isolation_level(0)
-        cur = conn.cursor()
-        cur.execute('DROP DATABASE IF EXISTS %s' % (self.test_db, ))
-        cur.execute('CREATE DATABASE %s TEMPLATE = %s' % (self.test_db, self.template_db))
-        conn.close()
-        context.db = self.connect_database(self.test_db)
-        if python_version[0] < 3:
-            psycopg2.extras.register_hstore(context.db, globally=False, unicode=True)
-        else:
-            psycopg2.extras.register_hstore(context.db, globally=False)
-    def teardown_db(self, context):
-        if 'db' in context:
-            context.db.close()
-        if not self.keep_scenario_db:
-            self.db_drop_database(self.test_db)
-    def run_setup_script(self, *args, **kwargs):
-        if self.server_module_path:
-            kwargs = dict(kwargs)
-            kwargs['module_path'] = self.server_module_path
-        self.run_nominatim_script('setup', *args, **kwargs)
-    def run_update_script(self, *args, **kwargs):
-        self.run_nominatim_script('update', *args, **kwargs)
-    def run_nominatim_script(self, script, *args, **kwargs):
-        cmd = ['/usr/bin/env', 'php', '-Cq']
-        cmd.append(os.path.join(self.build_dir, 'utils', '%s.php' % script))
-        cmd.extend(['--%s' % x for x in args])
-        for k, v in kwargs.items():
-            cmd.extend(('--' + k.replace('_', '-'), str(v)))
-        if self.website_dir is not None:
-            cwd = self.website_dir.name
-        else:
-            cwd = self.build_dir
-        proc = subprocess.Popen(cmd, cwd=cwd, env=self.test_env,
-                                stdout=subprocess.PIPE, stderr=subprocess.PIPE)
-        (outp, outerr) = proc.communicate()
-        outerr = outerr.decode('utf-8').replace('\\n', '\n')
-        logger.debug("run_nominatim_script: %s\n%s\n%s" % (cmd, outp, outerr))
-        assert (proc.returncode == 0), "Script '%s' failed:\n%s\n%s\n" % (script, outp, outerr)
-class OSMDataFactory(object):
-    def __init__(self):
-        scriptpath = os.path.dirname(os.path.abspath(__file__))
-        self.scene_path = os.environ.get('SCENE_PATH',
-                           os.path.join(scriptpath, '..', 'scenes', 'data'))
-        self.scene_cache = {}
-        self.clear_grid()
-    def parse_geometry(self, geom, scene):
-        if geom.find(':') >= 0:
-            return "ST_SetSRID(%s, 4326)" % self.get_scene_geometry(scene, geom)
-        if geom.find(',') < 0:
-            out = "POINT(%s)" % self.mk_wkt_point(geom)
-        elif geom.find('(') < 0:
-            line = ','.join([self.mk_wkt_point(x) for x in geom.split(',')])
-            out = "LINESTRING(%s)" % line
-        else:
-            inner = geom.strip('() ')
-            line = ','.join([self.mk_wkt_point(x) for x in inner.split(',')])
-            out = "POLYGON((%s))" % line
-        return "ST_SetSRID('%s'::geometry, 4326)" % out
-    def mk_wkt_point(self, point):
-        geom = point.strip()
-        if geom.find(' ') >= 0:
-            return geom
-        else:
-            pt = self.grid_node(int(geom))
-            assert pt is not None, "Bad scenario: Point '{}' not found in grid".format(geom)
-            return "%f %f" % pt
-    def get_scene_geometry(self, default_scene, name):
-        geoms = []
-        for obj in name.split('+'):
-            oname = obj.strip()
-            if oname.startswith(':'):
-                assert default_scene is not None, "Bad scenario: You need to set a scene"
-                defscene = self.load_scene(default_scene)
-                wkt = defscene[oname[1:]]
-            else:
-                scene, obj = oname.split(':', 2)
-                scene_geoms = self.load_scene(scene)
-                wkt = scene_geoms[obj]
-            geoms.append("'%s'::geometry" % wkt)
-        if len(geoms) == 1:
-            return geoms[0]
-        else:
-            return 'ST_LineMerge(ST_Collect(ARRAY[%s]))' % ','.join(geoms)
-    def load_scene(self, name):
-        if name in self.scene_cache:
-            return self.scene_cache[name]
-        scene = {}
-        with open(os.path.join(self.scene_path, "%s.wkt" % name), 'r') as fd:
-            for line in fd:
-                if line.strip():
-                    obj, wkt = line.split('|', 2)
-                    scene[obj.strip()] = wkt.strip()
-            self.scene_cache[name] = scene
-        return scene
-    def clear_grid(self):
-        self.grid = {}
-    def add_grid_node(self, nodeid, x, y):
-        self.grid[nodeid] = (x, y)
-    def grid_node(self, nodeid):
-        return self.grid.get(nodeid)
 def before_all(context):
     # logging setup
     # set up -D options
     for k,v in userconfig.items():
         context.config.userdata.setdefault(k, v)
-    logging.debug('User config: %s' %(str(context.config.userdata)))
     # Nominatim test setup
     context.nominatim = NominatimEnvironment(context.config.userdata)
-    context.osm = OSMDataFactory()
+    context.osm = GeometryFactory()
 def before_scenario(context, scenario):
     if 'DB' in context.tags:
     elif 'APIDB' in context.tags:
-        context.nominatim.setup_api_db(context)
+        context.nominatim.setup_api_db()
     elif 'UNKNOWNDB' in context.tags:
-        context.nominatim.setup_unknown_db(context)
+        context.nominatim.setup_unknown_db()
     context.scene = None
 def after_scenario(context, scenario):
diff --git a/test/bdd/steps/db_ops.py b/test/bdd/steps/db_ops.py
deleted file mode 100644 (file)
index 078e29f..0000000
+++ /dev/null
@@ -1,623 +0,0 @@
-import base64
-import random
-import string
-import re
-import psycopg2.extras
-from check_functions import Almost
-class PlaceColumn:
-    def __init__(self, context, force_name):
-        self.columns = { 'admin_level' : 15}
-        self.force_name = force_name
-        self.context = context
-        self.geometry = None
-    def add(self, key, value):
-        if hasattr(self, 'set_key_' + key):
-            getattr(self, 'set_key_' + key)(value)
-        elif key.startswith('name+'):
-            self.add_hstore('name', key[5:], value)
-        elif key.startswith('extra+'):
-            self.add_hstore('extratags', key[6:], value)
-        elif key.startswith('addr+'):
-            self.add_hstore('address', key[5:], value)
-        elif key in ('name', 'address', 'extratags'):
-            self.columns[key] = eval('{' + value + '}')
-        else:
-            assert key in ('class', 'type')
-            self.columns[key] = None if value == '' else value
-    def set_key_name(self, value):
-        self.add_hstore('name', 'name', value)
-    def set_key_osm(self, value):
-        assert value[0] in 'NRW'
-        assert value[1:].isdigit()
-        self.columns['osm_type'] = value[0]
-        self.columns['osm_id'] = int(value[1:])
-    def set_key_admin(self, value):
-        self.columns['admin_level'] = int(value)
-    def set_key_housenr(self, value):
-        if value:
-            self.add_hstore('address', 'housenumber', value)
-    def set_key_postcode(self, value):
-        if value:
-            self.add_hstore('address', 'postcode', value)
-    def set_key_street(self, value):
-        if value:
-            self.add_hstore('address', 'street', value)
-    def set_key_addr_place(self, value):
-        if value:
-            self.add_hstore('address', 'place', value)
-    def set_key_country(self, value):
-        if value:
-            self.add_hstore('address', 'country', value)
-    def set_key_geometry(self, value):
-        self.geometry = self.context.osm.parse_geometry(value, self.context.scene)
-        assert self.geometry is not None
-    def add_hstore(self, column, key, value):
-        if column in self.columns:
-            self.columns[column][key] = value
-        else:
-            self.columns[column] = { key : value }
-    def db_insert(self, cursor):
-        assert 'osm_type' in self.columns
-        if self.force_name and 'name' not in self.columns:
-            self.add_hstore('name', 'name', ''.join(random.choice(string.printable)
-                                           for _ in range(int(random.random()*30))))
-        if self.columns['osm_type'] == 'N' and self.geometry is None:
-            pt = self.context.osm.grid_node(self.columns['osm_id'])
-            if pt is None:
-                pt = (random.random()*360 - 180, random.random()*180 - 90)
-            self.geometry = "ST_SetSRID(ST_Point(%f, %f), 4326)" % pt
-        else:
-            assert self.geometry is not None, "Geometry missing"
-        query = 'INSERT INTO place (%s, geometry) values(%s, %s)' % (
-                     ','.join(self.columns.keys()),
-                     ','.join(['%s' for x in range(len(self.columns))]),
-                     self.geometry)
-        cursor.execute(query, list(self.columns.values()))
-class LazyFmt(object):
-    def __init__(self, fmtstr, *args):
-        self.fmt = fmtstr
-        self.args = args
-    def __str__(self):
-        return self.fmt % self.args
-class PlaceObjName(object):
-    def __init__(self, placeid, conn):
-        self.pid = placeid
-        self.conn = conn
-    def __str__(self):
-        if self.pid is None:
-            return "<null>"
-        if self.pid == 0:
-            return "place ID 0"
-        cur = self.conn.cursor()
-        cur.execute("""SELECT osm_type, osm_id, class
-                       FROM placex WHERE place_id = %s""",
-                    (self.pid, ))
-        assert cur.rowcount == 1, "No entry found for place id %s" % self.pid
-        return "%s%s:%s" % cur.fetchone()
-def compare_place_id(expected, result, column, context):
-    if expected == '0':
-        assert result == 0, \
-            LazyFmt("Bad place id in column %s. Expected: 0, got: %s.",
-                    column, PlaceObjName(result, context.db))
-    elif expected == '-':
-        assert result is None, \
-                LazyFmt("bad place id in column %s: %s.",
-                        column, PlaceObjName(result, context.db))
-    else:
-        assert NominatimID(expected).get_place_id(context.db.cursor()) == result, \
-            LazyFmt("Bad place id in column %s. Expected: %s, got: %s.",
-                    column, expected, PlaceObjName(result, context.db))
-def check_database_integrity(context):
-    """ Check some generic constraints on the tables.
-    """
-    # place_addressline should not have duplicate (place_id, address_place_id)
-    cur = context.db.cursor()
-    cur.execute("""SELECT count(*) FROM
-                    (SELECT place_id, address_place_id, count(*) as c
-                     FROM place_addressline GROUP BY place_id, address_place_id) x
-                   WHERE c > 1""")
-    assert cur.fetchone()[0] == 0, "Duplicates found in place_addressline"
-class NominatimID:
-    """ Splits a unique identifier for places into its components.
-        As place_ids cannot be used for testing, we use a unique
-        identifier instead that is of the form <osmtype><osmid>[:<class>].
-    """
-    id_regex = re.compile(r"(?P<tp>[NRW])(?P<id>\d+)(:(?P<cls>\w+))?")
-    def __init__(self, oid):
-        self.typ = self.oid = self.cls = None
-        if oid is not None:
-            m = self.id_regex.fullmatch(oid)
-            assert m is not None, "ID '%s' not of form <osmtype><osmid>[:<class>]" % oid
-            self.typ = m.group('tp')
-            self.oid = m.group('id')
-            self.cls = m.group('cls')
-    def __str__(self):
-        if self.cls is None:
-            return self.typ + self.oid
-        return '%s%d:%s' % (self.typ, self.oid, self.cls)
-    def table_select(self):
-        """ Return where clause and parameter list to select the object
-            from a Nominatim table.
-        """
-        where = 'osm_type = %s and osm_id = %s'
-        params = [self.typ, self. oid]
-        if self.cls is not None:
-            where += ' and class = %s'
-            params.append(self.cls)
-        return where, params
-    def get_place_id(self, cur):
-        where, params = self.table_select()
-        cur.execute("SELECT place_id FROM placex WHERE %s" % where, params)
-        assert cur.rowcount == 1, \
-            "Expected exactly 1 entry in placex for %s found %s" % (str(self), cur.rowcount)
-        return cur.fetchone()[0]
-def assert_db_column(row, column, value, context):
-    if column == 'object':
-        return
-    if column.startswith('centroid'):
-        if value == 'in geometry':
-            query = """SELECT ST_Within(ST_SetSRID(ST_Point({}, {}), 4326),
-                                        ST_SetSRID('{}'::geometry, 4326))""".format(
-                      row['cx'], row['cy'], row['geomtxt'])
-            cur = context.db.cursor()
-            cur.execute(query)
-            assert cur.fetchone()[0], "(Row %s failed: %s)" % (column, query)
-        else:
-            fac = float(column[9:]) if column.startswith('centroid*') else 1.0
-            x, y = value.split(' ')
-            assert Almost(float(x) * fac) == row['cx'], "Bad x coordinate"
-            assert Almost(float(y) * fac) == row['cy'], "Bad y coordinate"
-    elif column == 'geometry':
-        geom = context.osm.parse_geometry(value, context.scene)
-        cur = context.db.cursor()
-        query = "SELECT ST_Equals(ST_SnapToGrid(%s, 0.00001, 0.00001), ST_SnapToGrid(ST_SetSRID('%s'::geometry, 4326), 0.00001, 0.00001))" % (
-                 geom, row['geomtxt'],)
-        cur.execute(query)
-        assert cur.fetchone()[0], "(Row %s failed: %s)" % (column, query)
-    elif value == '-':
-        assert row[column] is None, "Row %s" % column
-    else:
-        assert value == str(row[column]), \
-            "Row '%s': expected: %s, got: %s" % (column, value, str(row[column]))
-################################ STEPS ##################################
-@given(u'the scene (?P<scene>.+)')
-def set_default_scene(context, scene):
-    context.scene = scene
-@given("the (?P<named>named )?places")
-def add_data_to_place_table(context, named):
-    cur = context.db.cursor()
-    cur.execute('ALTER TABLE place DISABLE TRIGGER place_before_insert')
-    for r in context.table:
-        col = PlaceColumn(context, named is not None)
-        for h in r.headings:
-            col.add(h, r[h])
-        col.db_insert(cur)
-    cur.execute('ALTER TABLE place ENABLE TRIGGER place_before_insert')
-    cur.close()
-    context.db.commit()
-@given("the relations")
-def add_data_to_planet_relations(context):
-    cur = context.db.cursor()
-    for r in context.table:
-        last_node = 0
-        last_way = 0
-        parts = []
-        if r['members']:
-            members = []
-            for m in r['members'].split(','):
-                mid = NominatimID(m)
-                if mid.typ == 'N':
-                    parts.insert(last_node, int(mid.oid))
-                    last_node += 1
-                    last_way += 1
-                elif mid.typ == 'W':
-                    parts.insert(last_way, int(mid.oid))
-                    last_way += 1
-                else:
-                    parts.append(int(mid.oid))
-                members.extend((mid.typ.lower() + mid.oid, mid.cls or ''))
-        else:
-            members = None
-        tags = []
-        for h in r.headings:
-            if h.startswith("tags+"):
-                tags.extend((h[5:], r[h]))
-        cur.execute("""INSERT INTO planet_osm_rels (id, way_off, rel_off, parts, members, tags)
-                       VALUES (%s, %s, %s, %s, %s, %s)""",
-                    (r['id'], last_node, last_way, parts, members, tags))
-    context.db.commit()
-@given("the ways")
-def add_data_to_planet_ways(context):
-    cur = context.db.cursor()
-    for r in context.table:
-        tags = []
-        for h in r.headings:
-            if h.startswith("tags+"):
-                tags.extend((h[5:], r[h]))
-        nodes = [ int(x.strip()) for x in r['nodes'].split(',') ]
-        cur.execute("INSERT INTO planet_osm_ways (id, nodes, tags) VALUES (%s, %s, %s)",
-                    (r['id'], nodes, tags))
-    context.db.commit()
-def import_and_index_data_from_place_table(context):
-    context.nominatim.run_setup_script('create-functions', 'create-partition-functions')
-    cur = context.db.cursor()
-    cur.execute(
-        """insert into placex (osm_type, osm_id, class, type, name, admin_level, address, extratags, geometry)
-           select              osm_type, osm_id, class, type, name, admin_level, address, extratags, geometry
-           from place where not (class='place' and type='houses' and osm_type='W')""")
-    cur.execute(
-            """insert into location_property_osmline (osm_id, address, linegeo)
-             SELECT osm_id, address, geometry from place
-              WHERE class='place' and type='houses' and osm_type='W'
-                    and ST_GeometryType(geometry) = 'ST_LineString'""")
-    context.db.commit()
-    context.nominatim.run_setup_script('calculate-postcodes', 'index', 'index-noanalyse')
-    check_database_integrity(context)
-@when("updating places")
-def update_place_table(context):
-    context.nominatim.run_setup_script(
-        'create-functions', 'create-partition-functions', 'enable-diff-updates')
-    cur = context.db.cursor()
-    for r in context.table:
-        col = PlaceColumn(context, False)
-        for h in r.headings:
-            col.add(h, r[h])
-        col.db_insert(cur)
-    context.db.commit()
-    while True:
-        context.nominatim.run_update_script('index')
-        cur = context.db.cursor()
-        cur.execute("SELECT 'a' FROM placex WHERE indexed_status != 0 LIMIT 1")
-        if cur.rowcount == 0:
-            break
-    check_database_integrity(context)
-@when("updating postcodes")
-def update_postcodes(context):
-    context.nominatim.run_update_script('calculate-postcodes')
-@when("marking for delete (?P<oids>.*)")
-def delete_places(context, oids):
-    context.nominatim.run_setup_script(
-        'create-functions', 'create-partition-functions', 'enable-diff-updates')
-    cur = context.db.cursor()
-    for oid in oids.split(','):
-        where, params = NominatimID(oid).table_select()
-        cur.execute("DELETE FROM place WHERE " + where, params)
-    context.db.commit()
-    while True:
-        context.nominatim.run_update_script('index')
-        cur = context.db.cursor()
-        cur.execute("SELECT 'a' FROM placex WHERE indexed_status != 0 LIMIT 1")
-        if cur.rowcount == 0:
-            break
-@then("placex contains(?P<exact> exactly)?")
-def check_placex_contents(context, exact):
-    cur = context.db.cursor(cursor_factory=psycopg2.extras.DictCursor)
-    expected_content = set()
-    for row in context.table:
-        nid = NominatimID(row['object'])
-        where, params = nid.table_select()
-        cur.execute("""SELECT *, ST_AsText(geometry) as geomtxt,
-                       ST_X(centroid) as cx, ST_Y(centroid) as cy
-                       FROM placex where %s""" % where,
-                    params)
-        assert cur.rowcount > 0, "No rows found for " + row['object']
-        for res in cur:
-            if exact:
-                expected_content.add((res['osm_type'], res['osm_id'], res['class']))
-            for h in row.headings:
-                if h in ('extratags', 'address'):
-                    if row[h] == '-':
-                        assert res[h] is None
-                    else:
-                        vdict = eval('{' + row[h] + '}')
-                        assert vdict == res[h]
-                elif h.startswith('name'):
-                    name = h[5:] if h.startswith('name+') else 'name'
-                    assert name in res['name']
-                    assert res['name'][name] == row[h]
-                elif h.startswith('extratags+'):
-                    assert res['extratags'][h[10:]] == row[h]
-                elif h.startswith('addr+'):
-                    if row[h] == '-':
-                        if res['address'] is not None:
-                            assert h[5:] not in res['address']
-                    else:
-                        assert h[5:] in res['address'], "column " + h
-                        assert res['address'][h[5:]] == row[h], "column %s" % h
-                elif h in ('linked_place_id', 'parent_place_id'):
-                    compare_place_id(row[h], res[h], h, context)
-                else:
-                    assert_db_column(res, h, row[h], context)
-    if exact:
-        cur.execute('SELECT osm_type, osm_id, class from placex')
-        assert expected_content == set([(r[0], r[1], r[2]) for r in cur])
-    context.db.commit()
-@then("place contains(?P<exact> exactly)?")
-def check_placex_contents(context, exact):
-    cur = context.db.cursor(cursor_factory=psycopg2.extras.DictCursor)
-    expected_content = set()
-    for row in context.table:
-        nid = NominatimID(row['object'])
-        where, params = nid.table_select()
-        cur.execute("""SELECT *, ST_AsText(geometry) as geomtxt,
-                       ST_GeometryType(geometry) as geometrytype
-                       FROM place where %s""" % where,
-                    params)
-        assert cur.rowcount > 0, "No rows found for " + row['object']
-        for res in cur:
-            if exact:
-                expected_content.add((res['osm_type'], res['osm_id'], res['class']))
-            for h in row.headings:
-                msg = "%s: %s" % (row['object'], h)
-                if h in ('name', 'extratags', 'address'):
-                    if row[h] == '-':
-                        assert res[h] is None, msg
-                    else:
-                        vdict = eval('{' + row[h] + '}')
-                        assert vdict == res[h], msg
-                elif h.startswith('name+'):
-                    assert res['name'][h[5:]] == row[h], msg
-                elif h.startswith('extratags+'):
-                    assert res['extratags'][h[10:]] == row[h], msg
-                elif h.startswith('addr+'):
-                    if row[h] == '-':
-                        if res['address']  is not None:
-                            assert h[5:] not in res['address']
-                    else:
-                        assert res['address'][h[5:]] == row[h], msg
-                elif h in ('linked_place_id', 'parent_place_id'):
-                    compare_place_id(row[h], res[h], h, context)
-                else:
-                    assert_db_column(res, h, row[h], context)
-    if exact:
-        cur.execute('SELECT osm_type, osm_id, class from place')
-        assert expected_content, set([(r[0], r[1], r[2]) for r in cur])
-    context.db.commit()
-@then("search_name contains(?P<exclude> not)?")
-def check_search_name_contents(context, exclude):
-    cur = context.db.cursor(cursor_factory=psycopg2.extras.DictCursor)
-    for row in context.table:
-        pid = NominatimID(row['object']).get_place_id(cur)
-        cur.execute("""SELECT *, ST_X(centroid) as cx, ST_Y(centroid) as cy
-                       FROM search_name WHERE place_id = %s""", (pid, ))
-        assert cur.rowcount > 0, "No rows found for " + row['object']
-        for res in cur:
-            for h in row.headings:
-                if h in ('name_vector', 'nameaddress_vector'):
-                    terms = [x.strip() for x in row[h].split(',') if not x.strip().startswith('#')]
-                    words = [x.strip()[1:] for x in row[h].split(',') if x.strip().startswith('#')]
-                    subcur = context.db.cursor()
-                    subcur.execute(""" SELECT word_id, word_token
-                                       FROM word, (SELECT unnest(%s::TEXT[]) as term) t
-                                       WHERE word_token = make_standard_name(t.term)
-                                             and class is null and country_code is null
-                                             and operator is null
-                                      UNION
-                                       SELECT word_id, word_token
-                                       FROM word, (SELECT unnest(%s::TEXT[]) as term) t
-                                       WHERE word_token = ' ' || make_standard_name(t.term)
-                                             and class is null and country_code is null
-                                             and operator is null
-                                   """,
-                                   (terms, words))
-                    if not exclude:
-                        assert subcur.rowcount >= len(terms) + len(words), \
-                            "No word entry found for " + row[h] + ". Entries found: " + str(subcur.rowcount)
-                    for wid in subcur:
-                        if exclude:
-                            assert wid[0] not in res[h], "Found term for %s/%s: %s" % (pid, h, wid[1])
-                        else:
-                            assert wid[0] in res[h], "Missing term for %s/%s: %s" % (pid, h, wid[1])
-                else:
-                    assert_db_column(res, h, row[h], context)
-    context.db.commit()
-@then("location_postcode contains exactly")
-def check_location_postcode(context):
-    cur = context.db.cursor(cursor_factory=psycopg2.extras.DictCursor)
-    cur.execute("SELECT *, ST_AsText(geometry) as geomtxt FROM location_postcode")
-    assert cur.rowcount == len(list(context.table)), \
-        "Postcode table has %d rows, expected %d rows." % (cur.rowcount, len(list(context.table)))
-    table = list(cur)
-    for row in context.table:
-        for i in range(len(table)):
-            if table[i]['country_code'] != row['country'] \
-                    or table[i]['postcode'] != row['postcode']:
-                continue
-            for h in row.headings:
-                if h not in ('country', 'postcode'):
-                    assert_db_column(table[i], h, row[h], context)
-@then("word contains(?P<exclude> not)?")
-def check_word_table(context, exclude):
-    cur = context.db.cursor(cursor_factory=psycopg2.extras.DictCursor)
-    for row in context.table:
-        wheres = []
-        values = []
-        for h in row.headings:
-            wheres.append("%s = %%s" % h)
-            values.append(row[h])
-        cur.execute("SELECT * from word WHERE %s" % ' AND '.join(wheres), values)
-        if exclude:
-            assert cur.rowcount == 0, "Row still in word table: %s" % '/'.join(values)
-        else:
-            assert cur.rowcount > 0, "Row not in word table: %s" % '/'.join(values)
-@then("place_addressline contains")
-def check_place_addressline(context):
-    cur = context.db.cursor(cursor_factory=psycopg2.extras.DictCursor)
-    for row in context.table:
-        pid = NominatimID(row['object']).get_place_id(cur)
-        apid = NominatimID(row['address']).get_place_id(cur)
-        cur.execute(""" SELECT * FROM place_addressline
-                        WHERE place_id = %s AND address_place_id = %s""",
-                    (pid, apid))
-        assert cur.rowcount > 0, \
-                    "No rows found for place %s and address %s" % (row['object'], row['address'])
-        for res in cur:
-            for h in row.headings:
-                if h not in ('address', 'object'):
-                    assert_db_column(res, h, row[h], context)
-    context.db.commit()
-@then("place_addressline doesn't contain")
-def check_place_addressline_exclude(context):
-    cur = context.db.cursor(cursor_factory=psycopg2.extras.DictCursor)
-    for row in context.table:
-        pid = NominatimID(row['object']).get_place_id(cur)
-        apid = NominatimID(row['address']).get_place_id(cur)
-        cur.execute(""" SELECT * FROM place_addressline
-                        WHERE place_id = %s AND address_place_id = %s""",
-                    (pid, apid))
-        assert cur.rowcount == 0, \
-            "Row found for place %s and address %s" % (row['object'], row['address'])
-    context.db.commit()
-@then("(?P<oid>\w+) expands to(?P<neg> no)? interpolation")
-def check_location_property_osmline(context, oid, neg):
-    cur = context.db.cursor(cursor_factory=psycopg2.extras.DictCursor)
-    nid = NominatimID(oid)
-    assert 'W' == nid.typ, "interpolation must be a way"
-    cur.execute("""SELECT *, ST_AsText(linegeo) as geomtxt
-                   FROM location_property_osmline
-                   WHERE osm_id = %s AND startnumber IS NOT NULL""",
-                (nid.oid, ))
-    if neg:
-        assert cur.rowcount == 0
-        return
-    todo = list(range(len(list(context.table))))
-    for res in cur:
-        for i in todo:
-            row = context.table[i]
-            if (int(row['start']) == res['startnumber']
-                and int(row['end']) == res['endnumber']):
-                todo.remove(i)
-                break
-        else:
-            assert False, "Unexpected row %s" % (str(res))
-        for h in row.headings:
-            if h in ('start', 'end'):
-                continue
-            elif h == 'parent_place_id':
-                compare_place_id(row[h], res[h], h, context)
-            else:
-                assert_db_column(res, h, row[h], context)
-    assert not todo
-@then("(?P<table>placex|place) has no entry for (?P<oid>.*)")
-def check_placex_has_entry(context, table, oid):
-    cur = context.db.cursor(cursor_factory=psycopg2.extras.DictCursor)
-    nid = NominatimID(oid)
-    where, params = nid.table_select()
-    cur.execute("SELECT * FROM %s where %s" % (table, where), params)
-    assert cur.rowcount == 0
-    context.db.commit()
-@then("search_name has no entry for (?P<oid>.*)")
-def check_search_name_has_entry(context, oid):
-    cur = context.db.cursor(cursor_factory=psycopg2.extras.DictCursor)
-    pid = NominatimID(oid).get_place_id(cur)
-    cur.execute("SELECT * FROM search_name WHERE place_id = %s", (pid, ))
-    assert cur.rowcount == 0
-    context.db.commit()
diff --git a/test/bdd/steps/geometry_factory.py b/test/bdd/steps/geometry_factory.py
new file mode 100644 (file)
index 0000000..0a40383
--- /dev/null
@@ -0,0 +1,122 @@
+from pathlib import Path
+import os
+class GeometryFactory:
+    """ Provides functions to create geometries from scenes and data grids.
+    """
+    def __init__(self):
+        defpath = Path(__file__) / '..' / '..' / '..' / 'scenes' / 'data'
+        self.scene_path = os.environ.get('SCENE_PATH', defpath.resolve())
+        self.scene_cache = {}
+        self.grid = {}
+    def parse_geometry(self, geom, scene):
+        """ Create a WKT SQL term for the given geometry.
+            The function understands the following formats:
+              [<scene>]:<name>
+                 Geometry from a scene. If the scene is omitted, use the
+                 default scene.
+              <P>
+                 Point geometry
+              <P>,...,<P>
+                 Line geometry
+              (<P>,...,<P>)
+                 Polygon geometry
+           <P> may either be a coordinate of the form '<x> <y>' or a single
+           number. In the latter case it must refer to a point in
+           a previously defined grid.
+        """
+        if geom.find(':') >= 0:
+            return "ST_SetSRID({}, 4326)".format(self.get_scene_geometry(scene, geom))
+        if geom.find(',') < 0:
+            out = "POINT({})".format(self.mk_wkt_point(geom))
+        elif geom.find('(') < 0:
+            out = "LINESTRING({})".format(self.mk_wkt_points(geom))
+        else:
+            out = "POLYGON(({}))".format(self.mk_wkt_points(geom.strip('() ')))
+        return "ST_SetSRID('{}'::geometry, 4326)".format(out)
+    def mk_wkt_point(self, point):
+        """ Parse a point description.
+            The point may either consist of 'x y' cooordinates or a number
+            that refers to a grid setup.
+        """
+        geom = point.strip()
+        if geom.find(' ') >= 0:
+            return geom
+        try:
+            pt = self.grid_node(int(geom))
+        except ValueError:
+            assert False, "Scenario error: Point '{}' is not a number".format(geom)
+        assert pt is not None, "Scenario error: Point '{}' not found in grid".format(geom)
+        return "{} {}".format(*pt)
+    def mk_wkt_points(self, geom):
+        """ Parse a list of points.
+            The list must be a comma-separated list of points. Points
+            in coordinate and grid format may be mixed.
+        """
+        return ','.join([self.mk_wkt_point(x) for x in geom.split(',')])
+    def get_scene_geometry(self, default_scene, name):
+        """ Load the geometry from a scene.
+        """
+        geoms = []
+        for obj in name.split('+'):
+            oname = obj.strip()
+            if oname.startswith(':'):
+                assert default_scene is not None, "Scenario error: You need to set a scene"
+                defscene = self.load_scene(default_scene)
+                wkt = defscene[oname[1:]]
+            else:
+                scene, obj = oname.split(':', 2)
+                scene_geoms = self.load_scene(scene)
+                wkt = scene_geoms[obj]
+            geoms.append("'{}'::geometry".format(wkt))
+        if len(geoms) == 1:
+            return geoms[0]
+        return 'ST_LineMerge(ST_Collect(ARRAY[{}]))'.format(','.join(geoms))
+    def load_scene(self, name):
+        """ Load a scene from a file.
+        """
+        if name in self.scene_cache:
+            return self.scene_cache[name]
+        scene = {}
+        with open(Path(self.scene_path) / "{}.wkt".format(name), 'r') as fd:
+            for line in fd:
+                if line.strip():
+                    obj, wkt = line.split('|', 2)
+                    scene[obj.strip()] = wkt.strip()
+            self.scene_cache[name] = scene
+        return scene
+    def set_grid(self, lines, grid_step):
+        """ Replace the grid with one from the given lines.
+        """
+        self.grid = {}
+        y = 0
+        for line in lines:
+            x = 0
+            for pt_id in line:
+                if pt_id.isdigit():
+                    self.grid[int(pt_id)] = (x, y)
+                x += grid_step
+            y += grid_step
+    def grid_node(self, nodeid):
+        """ Get the coordinates for the given grid node.
+        """
+        return self.grid.get(nodeid)
diff --git a/test/bdd/steps/http_responses.py b/test/bdd/steps/http_responses.py
new file mode 100644 (file)
index 0000000..161e29f
--- /dev/null
@@ -0,0 +1,198 @@
+Classes wrapping HTTP responses from the Nominatim API.
+from collections import OrderedDict
+import re
+import json
+import xml.etree.ElementTree as ET
+from check_functions import Almost
+def _geojson_result_to_json_result(geojson_result):
+    result = geojson_result['properties']
+    result['geojson'] = geojson_result['geometry']
+    if 'bbox' in geojson_result:
+        # bbox is  minlon, minlat, maxlon, maxlat
+        # boundingbox is minlat, maxlat, minlon, maxlon
+        result['boundingbox'] = [geojson_result['bbox'][1],
+                                 geojson_result['bbox'][3],
+                                 geojson_result['bbox'][0],
+                                 geojson_result['bbox'][2]]
+    return result
+class BadRowValueAssert:
+    """ Lazily formatted message for failures to find a field content.
+    """
+    def __init__(self, response, idx, field, value):
+        self.idx = idx
+        self.field = field
+        self.value = value
+        self.row = response.result[idx]
+    def __str__(self):
+        return "\nBad value for row {} field '{}'. Expected: {}, got: {}.\nFull row: {}"""\
+                   .format(self.idx, self.field, self.value,
+                           self.row[self.field], json.dumps(self.row, indent=4))
+class GenericResponse:
+    """ Common base class for all API responses.
+    """
+    def __init__(self, page, fmt, errorcode=200):
+        fmt = fmt.strip()
+        if fmt == 'jsonv2':
+            fmt = 'json'
+        self.page = page
+        self.format = fmt
+        self.errorcode = errorcode
+        self.result = []
+        self.header = dict()
+        if errorcode == 200:
+            getattr(self, '_parse_' + fmt)()
+    def _parse_json(self):
+        m = re.fullmatch(r'([\w$][^(]*)\((.*)\)', self.page)
+        if m is None:
+            code = self.page
+        else:
+            code = m.group(2)
+            self.header['json_func'] = m.group(1)
+        self.result = json.JSONDecoder(object_pairs_hook=OrderedDict).decode(code)
+        if isinstance(self.result, OrderedDict):
+            self.result = [self.result]
+    def _parse_geojson(self):
+        self._parse_json()
+        if 'error' in self.result[0]:
+            self.result = []
+        else:
+            self.result = list(map(_geojson_result_to_json_result, self.result[0]['features']))
+    def _parse_geocodejson(self):
+        self._parse_geojson()
+        if self.result is not None:
+            self.result = [r['geocoding'] for r in self.result]
+    def assert_field(self, idx, field, value):
+        """ Check that result row `idx` has a field `field` with value `value`.
+            Float numbers are matched approximately. When the expected value
+            starts with a carat, regular expression matching is used.
+        """
+        assert field in self.result[idx], \
+               "Result row {} has no field '{}'.\nFull row: {}"\
+                   .format(idx, field, json.dumps(self.result[idx], indent=4))
+        if isinstance(value, float):
+            assert Almost(value) == float(self.result[idx][field]), \
+                   BadRowValueAssert(self, idx, field, value)
+        elif value.startswith("^"):
+            assert re.fullmatch(value, self.result[idx][field]), \
+                   BadRowValueAssert(self, idx, field, value)
+        else:
+            assert str(self.result[idx][field]) == str(value), \
+                   BadRowValueAssert(self, idx, field, value)
+    def match_row(self, row):
+        """ Match the result fields against the given behave table row.
+        """
+        if 'ID' in row.headings:
+            todo = [int(row['ID'])]
+        else:
+            todo = range(len(self.result))
+        for i in todo:
+            for name, value in zip(row.headings, row.cells):
+                if name == 'ID':
+                    pass
+                elif name == 'osm':
+                    self.assert_field(i, 'osm_type', value[0])
+                    self.assert_field(i, 'osm_id', value[1:])
+                elif name == 'centroid':
+                    lon, lat = value.split(' ')
+                    self.assert_field(i, 'lat', float(lat))
+                    self.assert_field(i, 'lon', float(lon))
+                else:
+                    self.assert_field(i, name, value)
+    def property_list(self, prop):
+        return [x[prop] for x in self.result]
+class SearchResponse(GenericResponse):
+    """ Specialised class for search and lookup responses.
+        Transforms the xml response in a format similar to json.
+    """
+    def _parse_xml(self):
+        xml_tree = ET.fromstring(self.page)
+        self.header = dict(xml_tree.attrib)
+        for child in xml_tree:
+            assert child.tag == "place"
+            self.result.append(dict(child.attrib))
+            address = {}
+            for sub in child:
+                if sub.tag == 'extratags':
+                    self.result[-1]['extratags'] = {}
+                    for tag in sub:
+                        self.result[-1]['extratags'][tag.attrib['key']] = tag.attrib['value']
+                elif sub.tag == 'namedetails':
+                    self.result[-1]['namedetails'] = {}
+                    for tag in sub:
+                        self.result[-1]['namedetails'][tag.attrib['desc']] = tag.text
+                elif sub.tag == 'geokml':
+                    self.result[-1][sub.tag] = True
+                else:
+                    address[sub.tag] = sub.text
+            if address:
+                self.result[-1]['address'] = address
+class ReverseResponse(GenericResponse):
+    """ Specialised class for reverse responses.
+        Transforms the xml response in a format similar to json.
+    """
+    def _parse_xml(self):
+        xml_tree = ET.fromstring(self.page)
+        self.header = dict(xml_tree.attrib)
+        self.result = []
+        for child in xml_tree:
+            if child.tag == 'result':
+                assert not self.result, "More than one result in reverse result"
+                self.result.append(dict(child.attrib))
+            elif child.tag == 'addressparts':
+                address = {}
+                for sub in child:
+                    address[sub.tag] = sub.text
+                self.result[0]['address'] = address
+            elif child.tag == 'extratags':
+                self.result[0]['extratags'] = {}
+                for tag in child:
+                    self.result[0]['extratags'][tag.attrib['key']] = tag.attrib['value']
+            elif child.tag == 'namedetails':
+                self.result[0]['namedetails'] = {}
+                for tag in child:
+                    self.result[0]['namedetails'][tag.attrib['desc']] = tag.text
+            elif child.tag == 'geokml':
+                self.result[0][child.tag] = True
+            else:
+                assert child.tag == 'error', \
+                       "Unknown XML tag {} on page: {}".format(child.tag, self.page)
+class StatusResponse(GenericResponse):
+    """ Specialised class for status responses.
+        Can also parse text responses.
+    """
+    def _parse_text(self):
+        pass
diff --git a/test/bdd/steps/nominatim_environment.py b/test/bdd/steps/nominatim_environment.py
new file mode 100644 (file)
index 0000000..7013a20
--- /dev/null
@@ -0,0 +1,254 @@
+import os
+from pathlib import Path
+import tempfile
+import psycopg2
+import psycopg2.extras
+from steps.utils import run_script
+class NominatimEnvironment:
+    """ Collects all functions for the execution of Nominatim functions.
+    """
+    def __init__(self, config):
+        self.build_dir = Path(config['BUILDDIR']).resolve()
+        self.src_dir = (Path(__file__) / '..' / '..' / '..' / '..').resolve()
+        self.db_host = config['DB_HOST']
+        self.db_port = config['DB_PORT']
+        self.db_user = config['DB_USER']
+        self.db_pass = config['DB_PASS']
+        self.template_db = config['TEMPLATE_DB']
+        self.test_db = config['TEST_DB']
+        self.api_test_db = config['API_TEST_DB']
+        self.server_module_path = config['SERVER_MODULE_PATH']
+        self.reuse_template = not config['REMOVE_TEMPLATE']
+        self.keep_scenario_db = config['KEEP_TEST_DB']
+        self.code_coverage_path = config['PHPCOV']
+        self.code_coverage_id = 1
+        self.test_env = None
+        self.template_db_done = False
+        self.website_dir = None
+    def connect_database(self, dbname):
+        """ Return a connection to the database with the given name.
+            Uses configured host, user and port.
+        """
+        dbargs = {'database': dbname}
+        if self.db_host:
+            dbargs['host'] = self.db_host
+        if self.db_port:
+            dbargs['port'] = self.db_port
+        if self.db_user:
+            dbargs['user'] = self.db_user
+        if self.db_pass:
+            dbargs['password'] = self.db_pass
+        conn = psycopg2.connect(**dbargs)
+        return conn
+    def next_code_coverage_file(self):
+        """ Generate the next name for a coverage file.
+        """
+        fn = Path(self.code_coverage_path) / "{:06d}.cov".format(self.code_coverage_id)
+        self.code_coverage_id += 1
+        return fn.resolve()
+    def write_nominatim_config(self, dbname):
+        """ Set up a custom test configuration that connects to the given
+            database. This sets up the environment variables so that they can
+            be picked up by dotenv and creates a project directory with the
+            appropriate website scripts.
+        """
+        dsn = 'pgsql:dbname={}'.format(dbname)
+        if self.db_host:
+            dsn += ';host=' + self.db_host
+        if self.db_port:
+            dsn += ';port=' + self.db_port
+        if self.db_user:
+            dsn += ';user=' + self.db_user
+        if self.db_pass:
+            dsn += ';password=' + self.db_pass
+        if self.website_dir is not None \
+           and self.test_env is not None \
+           and dsn == self.test_env['NOMINATIM_DATABASE_DSN']:
+            return # environment already set uo
+        self.test_env = os.environ
+        self.test_env['NOMINATIM_DATABASE_DSN'] = dsn
+        self.test_env['NOMINATIM_FLATNODE_FILE'] = ''
+        self.test_env['NOMINATIM_IMPORT_STYLE'] = 'full'
+        self.test_env['NOMINATIM_USE_US_TIGER_DATA'] = 'yes'
+        if self.server_module_path:
+            self.test_env['NOMINATIM_DATABASE_MODULE_PATH'] = self.server_module_path
+        if self.website_dir is not None:
+            self.website_dir.cleanup()
+        self.website_dir = tempfile.TemporaryDirectory()
+        self.run_setup_script('setup-website')
+    def db_drop_database(self, name):
+        """ Drop the database with the given name.
+        """
+        conn = self.connect_database('postgres')
+        conn.set_isolation_level(0)
+        cur = conn.cursor()
+        cur.execute('DROP DATABASE IF EXISTS {}'.format(name))
+        conn.close()
+    def setup_template_db(self):
+        """ Setup a template database that already contains common test data.
+            Having a template database speeds up tests considerably but at
+            the price that the tests sometimes run with stale data.
+        """
+        if self.template_db_done:
+            return
+        self.template_db_done = True
+        if self.reuse_template:
+            # check that the template is there
+            conn = self.connect_database('postgres')
+            cur = conn.cursor()
+            cur.execute('select count(*) from pg_database where datname = %s',
+                        (self.template_db,))
+            if cur.fetchone()[0] == 1:
+                return
+            conn.close()
+        else:
+            # just in case... make sure a previous table has been dropped
+            self.db_drop_database(self.template_db)
+        try:
+            # call the first part of database setup
+            self.write_nominatim_config(self.template_db)
+            self.run_setup_script('create-db', 'setup-db')
+            # remove external data to speed up indexing for tests
+            conn = self.connect_database(self.template_db)
+            cur = conn.cursor()
+            cur.execute("""select tablename from pg_tables
+                           where tablename in ('gb_postcode', 'us_postcode')""")
+            for t in cur:
+                conn.cursor().execute('TRUNCATE TABLE {}'.format(t[0]))
+            conn.commit()
+            conn.close()
+            # execute osm2pgsql import on an empty file to get the right tables
+            with tempfile.NamedTemporaryFile(dir='/tmp', suffix='.xml') as fd:
+                fd.write(b'<osm version="0.6"></osm>')
+                fd.flush()
+                self.run_setup_script('import-data',
+                                      'ignore-errors',
+                                      'create-functions',
+                                      'create-tables',
+                                      'create-partition-tables',
+                                      'create-partition-functions',
+                                      'load-data',
+                                      'create-search-indices',
+                                      osm_file=fd.name,
+                                      osm2pgsql_cache='200')
+        except:
+            self.db_drop_database(self.template_db)
+            raise
+    def setup_api_db(self):
+        """ Setup a test against the API test database.
+        """
+        self.write_nominatim_config(self.api_test_db)
+    def setup_unknown_db(self):
+        """ Setup a test against a non-existing database.
+        """
+        self.write_nominatim_config('UNKNOWN_DATABASE_NAME')
+    def setup_db(self, context):
+        """ Setup a test against a fresh, empty test database.
+        """
+        self.setup_template_db()
+        self.write_nominatim_config(self.test_db)
+        conn = self.connect_database(self.template_db)
+        conn.set_isolation_level(0)
+        cur = conn.cursor()
+        cur.execute('DROP DATABASE IF EXISTS {}'.format(self.test_db))
+        cur.execute('CREATE DATABASE {} TEMPLATE = {}'.format(self.test_db, self.template_db))
+        conn.close()
+        context.db = self.connect_database(self.test_db)
+        context.db.autocommit = True
+        psycopg2.extras.register_hstore(context.db, globally=False)
+    def teardown_db(self, context):
+        """ Remove the test database, if it exists.
+        """
+        if 'db' in context:
+            context.db.close()
+        if not self.keep_scenario_db:
+            self.db_drop_database(self.test_db)
+    def reindex_placex(self, db):
+        """ Run the indexing step until all data in the placex has
+            been processed. Indexing during updates can produce more data
+            to index under some circumstances. That is why indexing may have
+            to be run multiple times.
+        """
+        with db.cursor() as cur:
+            while True:
+                self.run_update_script('index')
+                cur.execute("SELECT 'a' FROM placex WHERE indexed_status != 0 LIMIT 1")
+                if cur.rowcount == 0:
+                    return
+    def run_setup_script(self, *args, **kwargs):
+        """ Run the Nominatim setup script with the given arguments.
+        """
+        self.run_nominatim_script('setup', *args, **kwargs)
+    def run_update_script(self, *args, **kwargs):
+        """ Run the Nominatim update script with the given arguments.
+        """
+        self.run_nominatim_script('update', *args, **kwargs)
+    def run_nominatim_script(self, script, *args, **kwargs):
+        """ Run one of the Nominatim utility scripts with the given arguments.
+        """
+        cmd = ['/usr/bin/env', 'php', '-Cq']
+        cmd.append((Path(self.build_dir) / 'utils' / '{}.php'.format(script)).resolve())
+        cmd.extend(['--' + x for x in args])
+        for k, v in kwargs.items():
+            cmd.extend(('--' + k.replace('_', '-'), str(v)))
+        if self.website_dir is not None:
+            cwd = self.website_dir.name
+        else:
+            cwd = self.build_dir
+        run_script(cmd, cwd=cwd, env=self.test_env)
+    def copy_from_place(self, db):
+        """ Copy data from place to the placex and location_property_osmline
+            tables invoking the appropriate triggers.
+        """
+        self.run_setup_script('create-functions', 'create-partition-functions')
+        with db.cursor() as cur:
+            cur.execute("""INSERT INTO placex (osm_type, osm_id, class, type,
+                                               name, admin_level, address,
+                                               extratags, geometry)
+                             SELECT osm_type, osm_id, class, type,
+                                    name, admin_level, address,
+                                    extratags, geometry
+                               FROM place
+                               WHERE not (class='place' and type='houses' and osm_type='W')""")
+            cur.execute("""INSERT INTO location_property_osmline (osm_id, address, linegeo)
+                             SELECT osm_id, address, geometry
+                               FROM place
+                              WHERE class='place' and type='houses'
+                                    and osm_type='W'
+                                    and ST_GeometryType(geometry) = 'ST_LineString'""")
diff --git a/test/bdd/steps/place_inserter.py b/test/bdd/steps/place_inserter.py
new file mode 100644 (file)
index 0000000..90f071b
--- /dev/null
@@ -0,0 +1,105 @@
+Helper classes for filling the place table.
+import random
+import string
+class PlaceColumn:
+    """ Helper class to collect contents from a behave table row and
+        insert it into the place table.
+    """
+    def __init__(self, context):
+        self.columns = {'admin_level' : 15}
+        self.context = context
+        self.geometry = None
+    def add_row(self, row, force_name):
+        """ Parse the content from the given behave row as place column data.
+        """
+        for name, value in zip(row.headings, row.cells):
+            self._add(name, value)
+        assert 'osm_type' in self.columns, "osm column missing"
+        if force_name and 'name' not in self.columns:
+            self._add_hstore('name', 'name',
+                             ''.join(random.choice(string.printable)
+                                     for _ in range(int(random.random()*30))))
+        return self
+    def _add(self, key, value):
+        if hasattr(self, '_set_key_' + key):
+            getattr(self, '_set_key_' + key)(value)
+        elif key.startswith('name+'):
+            self._add_hstore('name', key[5:], value)
+        elif key.startswith('extra+'):
+            self._add_hstore('extratags', key[6:], value)
+        elif key.startswith('addr+'):
+            self._add_hstore('address', key[5:], value)
+        elif key in ('name', 'address', 'extratags'):
+            self.columns[key] = eval('{' + value + '}')
+        else:
+            assert key in ('class', 'type'), "Unknown column '{}'.".format(key)
+            self.columns[key] = None if value == '' else value
+    def _set_key_name(self, value):
+        self._add_hstore('name', 'name', value)
+    def _set_key_osm(self, value):
+        assert value[0] in 'NRW' and value[1:].isdigit(), \
+               "OSM id needs to be of format <NRW><id>."
+        self.columns['osm_type'] = value[0]
+        self.columns['osm_id'] = int(value[1:])
+    def _set_key_admin(self, value):
+        self.columns['admin_level'] = int(value)
+    def _set_key_housenr(self, value):
+        if value:
+            self._add_hstore('address', 'housenumber', value)
+    def _set_key_postcode(self, value):
+        if value:
+            self._add_hstore('address', 'postcode', value)
+    def _set_key_street(self, value):
+        if value:
+            self._add_hstore('address', 'street', value)
+    def _set_key_addr_place(self, value):
+        if value:
+            self._add_hstore('address', 'place', value)
+    def _set_key_country(self, value):
+        if value:
+            self._add_hstore('address', 'country', value)
+    def _set_key_geometry(self, value):
+        self.geometry = self.context.osm.parse_geometry(value, self.context.scene)
+        assert self.geometry is not None, "Bad geometry: {}".format(value)
+    def _add_hstore(self, column, key, value):
+        if column in self.columns:
+            self.columns[column][key] = value
+        else:
+            self.columns[column] = {key: value}
+    def db_insert(self, cursor):
+        """ Insert the collected data into the database.
+        """
+        if self.columns['osm_type'] == 'N' and self.geometry is None:
+            pt = self.context.osm.grid_node(self.columns['osm_id'])
+            if pt is None:
+                pt = (random.random()*360 - 180, random.random()*180 - 90)
+            self.geometry = "ST_SetSRID(ST_Point(%f, %f), 4326)" % pt
+        else:
+            assert self.geometry is not None, "Geometry missing"
+        query = 'INSERT INTO place ({}, geometry) values({}, {})'.format(
+            ','.join(self.columns.keys()),
+            ','.join(['%s' for x in range(len(self.columns))]),
+            self.geometry)
+        cursor.execute(query, list(self.columns.values()))
similarity index 56%
rename from test/bdd/steps/queries.py
rename to test/bdd/steps/steps_api_queries.py
index 0ea4685b3a45d9fa42cd2a26b9cabf472e285133..47dc8ac3fa2889cc2f3529b55c09858272afc215 100644 (file)
@@ -1,22 +1,18 @@
-""" Steps that run search queries.
+""" Steps that run queries against the API.
     Queries may either be run directly via PHP using the query script
-    or via the HTTP interface.
+    or via the HTTP interface using php-cgi.
 import json
 import os
-import io
 import re
 import logging
-import xml.etree.ElementTree as ET
-import subprocess
 from urllib.parse import urlencode
-from collections import OrderedDict
-from check_functions import Almost
+from utils import run_script
+from http_responses import GenericResponse, SearchResponse, ReverseResponse, StatusResponse
-logger = logging.getLogger(__name__)
+LOG = logging.getLogger(__name__)
     'HTTP_HOST' : 'localhost',
@@ -56,208 +52,6 @@ def compare(operator, op1, op2):
         raise Exception("unknown operator '%s'" % operator)
-class GenericResponse(object):
-    def match_row(self, row):
-        if 'ID' in row.headings:
-            todo = [int(row['ID'])]
-        else:
-            todo = range(len(self.result))
-        for i in todo:
-            res = self.result[i]
-            for h in row.headings:
-                if h == 'ID':
-                    pass
-                elif h == 'osm':
-                    assert res['osm_type'] == row[h][0]
-                    assert res['osm_id'] == int(row[h][1:])
-                elif h == 'centroid':
-                    x, y = row[h].split(' ')
-                    assert Almost(float(y)) == float(res['lat'])
-                    assert Almost(float(x)) == float(res['lon'])
-                elif row[h].startswith("^"):
-                    assert h in res
-                    assert re.fullmatch(row[h], res[h]) is not None, \
-                           "attribute '%s': expected: '%s', got '%s'" % (h, row[h], res[h])
-                else:
-                    assert h in res
-                    assert str(res[h]) == str(row[h])
-    def property_list(self, prop):
-        return [ x[prop] for x in self.result ]
-class SearchResponse(GenericResponse):
-    def __init__(self, page, fmt='json', errorcode=200):
-        self.page = page
-        self.format = fmt
-        self.errorcode = errorcode
-        self.result = []
-        self.header = dict()
-        if errorcode == 200:
-            getattr(self, 'parse_' + fmt)()
-    def parse_json(self):
-        m = re.fullmatch(r'([\w$][^(]*)\((.*)\)', self.page)
-        if m is None:
-            code = self.page
-        else:
-            code = m.group(2)
-            self.header['json_func'] = m.group(1)
-        self.result = json.JSONDecoder(object_pairs_hook=OrderedDict).decode(code)
-    def parse_geojson(self):
-        self.parse_json()
-        self.result = geojson_results_to_json_results(self.result)
-    def parse_geocodejson(self):
-        self.parse_geojson()
-        if self.result is not None:
-            self.result = [r['geocoding'] for r in self.result]
-    def parse_xml(self):
-        et = ET.fromstring(self.page)
-        self.header = dict(et.attrib)
-        for child in et:
-            assert child.tag == "place"
-            self.result.append(dict(child.attrib))
-            address = {}
-            for sub in child:
-                if sub.tag == 'extratags':
-                    self.result[-1]['extratags'] = {}
-                    for tag in sub:
-                        self.result[-1]['extratags'][tag.attrib['key']] = tag.attrib['value']
-                elif sub.tag == 'namedetails':
-                    self.result[-1]['namedetails'] = {}
-                    for tag in sub:
-                        self.result[-1]['namedetails'][tag.attrib['desc']] = tag.text
-                elif sub.tag in ('geokml'):
-                    self.result[-1][sub.tag] = True
-                else:
-                    address[sub.tag] = sub.text
-            if len(address) > 0:
-                self.result[-1]['address'] = address
-class ReverseResponse(GenericResponse):
-    def __init__(self, page, fmt='json', errorcode=200):
-        self.page = page
-        self.format = fmt
-        self.errorcode = errorcode
-        self.result = []
-        self.header = dict()
-        if errorcode == 200:
-            getattr(self, 'parse_' + fmt)()
-    def parse_json(self):
-        m = re.fullmatch(r'([\w$][^(]*)\((.*)\)', self.page)
-        if m is None:
-            code = self.page
-        else:
-            code = m.group(2)
-            self.header['json_func'] = m.group(1)
-        self.result = [json.JSONDecoder(object_pairs_hook=OrderedDict).decode(code)]
-    def parse_geojson(self):
-        self.parse_json()
-        if 'error' in self.result:
-            return
-        self.result = geojson_results_to_json_results(self.result[0])
-    def parse_geocodejson(self):
-        self.parse_geojson()
-        if self.result is not None:
-            self.result = [r['geocoding'] for r in self.result]
-    def parse_xml(self):
-        et = ET.fromstring(self.page)
-        self.header = dict(et.attrib)
-        self.result = []
-        for child in et:
-            if child.tag == 'result':
-                assert len(self.result) == 0, "More than one result in reverse result"
-                self.result.append(dict(child.attrib))
-            elif child.tag == 'addressparts':
-                address = {}
-                for sub in child:
-                    address[sub.tag] = sub.text
-                self.result[0]['address'] = address
-            elif child.tag == 'extratags':
-                self.result[0]['extratags'] = {}
-                for tag in child:
-                    self.result[0]['extratags'][tag.attrib['key']] = tag.attrib['value']
-            elif child.tag == 'namedetails':
-                self.result[0]['namedetails'] = {}
-                for tag in child:
-                    self.result[0]['namedetails'][tag.attrib['desc']] = tag.text
-            elif child.tag in ('geokml'):
-                self.result[0][child.tag] = True
-            else:
-                assert child.tag == 'error', \
-                        "Unknown XML tag %s on page: %s" % (child.tag, self.page)
-class DetailsResponse(GenericResponse):
-    def __init__(self, page, fmt='json', errorcode=200):
-        self.page = page
-        self.format = fmt
-        self.errorcode = errorcode
-        self.result = []
-        self.header = dict()
-        if errorcode == 200:
-            getattr(self, 'parse_' + fmt)()
-    def parse_json(self):
-        self.result = [json.JSONDecoder(object_pairs_hook=OrderedDict).decode(self.page)]
-class StatusResponse(GenericResponse):
-    def __init__(self, page, fmt='text', errorcode=200):
-        self.page = page
-        self.format = fmt
-        self.errorcode = errorcode
-        if errorcode == 200 and fmt != 'text':
-            getattr(self, 'parse_' + fmt)()
-    def parse_json(self):
-        self.result = [json.JSONDecoder(object_pairs_hook=OrderedDict).decode(self.page)]
-def geojson_result_to_json_result(geojson_result):
-    result = geojson_result['properties']
-    result['geojson'] = geojson_result['geometry']
-    if 'bbox' in geojson_result:
-        # bbox is  minlon, minlat, maxlon, maxlat
-        # boundingbox is minlat, maxlat, minlon, maxlon
-        result['boundingbox'] = [
-                                    geojson_result['bbox'][1],
-                                    geojson_result['bbox'][3],
-                                    geojson_result['bbox'][0],
-                                    geojson_result['bbox'][2]
-                                ]
-    return result
-def geojson_results_to_json_results(geojson_results):
-    if 'error' in geojson_results:
-        return
-    return list(map(geojson_result_to_json_result, geojson_results['features']))
 @when(u'searching for "(?P<query>.*)"(?P<dups> with dups)?')
 def query_cmd(context, query, dups):
@@ -277,14 +71,9 @@ def query_cmd(context, query, dups):
     if dups:
         cmd.extend(('--dedupe', '0'))
-    proc = subprocess.Popen(cmd, cwd=context.nominatim.build_dir,
-                            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
-    (outp, err) = proc.communicate()
+    outp, err = run_script(cmd, cwd=context.nominatim.build_dir)
-    assert proc.returncode == 0, "query.php failed with message: %s\noutput: %s" % (err, outp)
-    logger.debug("run_nominatim_script: %s\n%s\n" % (cmd, outp.decode('utf-8').replace('\\n', '\n')))
-    context.response = SearchResponse(outp.decode('utf-8'), 'json')
+    context.response = SearchResponse(outp, 'json')
 def send_api_query(endpoint, params, fmt, context):
     if fmt is not None:
@@ -306,7 +95,7 @@ def send_api_query(endpoint, params, fmt, context):
     env['SCRIPT_FILENAME'] = os.path.join(env['CONTEXT_DOCUMENT_ROOT'],
                                           '%s.php' % endpoint)
-    logger.debug("Environment:" + json.dumps(env, sort_keys=True, indent=2))
+    LOG.debug("Environment:" + json.dumps(env, sort_keys=True, indent=2))
     if hasattr(context, 'http_headers'):
@@ -326,19 +115,7 @@ def send_api_query(endpoint, params, fmt, context):
     for k,v in params.items():
         cmd.append("%s=%s" % (k, v))
-    proc = subprocess.Popen(cmd, cwd=context.nominatim.website_dir.name, env=env,
-                            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
-    (outp, err) = proc.communicate()
-    outp = outp.decode('utf-8')
-    err = err.decode("utf-8")
-    logger.debug("Result: \n===============================\n"
-                 + outp + "\n===============================\n")
-    assert proc.returncode == 0, \
-                  "%s failed with message: %s" % (
-                      os.path.basename(env['SCRIPT_FILENAME']), err)
+    outp, err = run_script(cmd, cwd=context.nominatim.website_dir.name, env=env)
     assert len(err) == 0, "Unexpected PHP error: %s" % (err)
@@ -371,12 +148,7 @@ def website_search_request(context, fmt, query, addr):
     outp, status = send_api_query('search', params, fmt, context)
-    if fmt is None or fmt == 'jsonv2 ':
-        outfmt = 'json'
-    else:
-        outfmt = fmt.strip()
-    context.response = SearchResponse(outp, outfmt, status)
+    context.response = SearchResponse(outp, fmt or 'json', status)
 @when(u'sending (?P<fmt>\S+ )?reverse coordinates (?P<lat>.+)?,(?P<lon>.+)?')
 def website_reverse_request(context, fmt, lat, lon):
@@ -388,14 +160,7 @@ def website_reverse_request(context, fmt, lat, lon):
     outp, status = send_api_query('reverse', params, fmt, context)
-    if fmt is None:
-        outfmt = 'xml'
-    elif fmt == 'jsonv2 ':
-        outfmt = 'json'
-    else:
-        outfmt = fmt.strip()
-    context.response = ReverseResponse(outp, outfmt, status)
+    context.response = ReverseResponse(outp, fmt or 'xml', status)
 @when(u'sending (?P<fmt>\S+ )?details query for (?P<query>.*)')
 def website_details_request(context, fmt, query):
@@ -407,42 +172,21 @@ def website_details_request(context, fmt, query):
         params['place_id'] = query
     outp, status = send_api_query('details', params, fmt, context)
-    if fmt is None:
-        outfmt = 'json'
-    else:
-        outfmt = fmt.strip()
-    context.response = DetailsResponse(outp, outfmt, status)
+    context.response = GenericResponse(outp, fmt or 'json', status)
 @when(u'sending (?P<fmt>\S+ )?lookup query for (?P<query>.*)')
 def website_lookup_request(context, fmt, query):
     params = { 'osm_ids' : query }
     outp, status = send_api_query('lookup', params, fmt, context)
-    if fmt == 'json ':
-        outfmt = 'json'
-    elif fmt == 'jsonv2 ':
-        outfmt = 'json'
-    elif fmt == 'geojson ':
-        outfmt = 'geojson'
-    elif fmt == 'geocodejson ':
-        outfmt = 'geocodejson'
-    else:
-        outfmt = 'xml'
-    context.response = SearchResponse(outp, outfmt, status)
+    context.response = SearchResponse(outp, fmt or 'xml', status)
 @when(u'sending (?P<fmt>\S+ )?status query')
 def website_status_request(context, fmt):
     params = {}
     outp, status = send_api_query('status', params, fmt, context)
-    if fmt is None:
-        outfmt = 'text'
-    else:
-        outfmt = fmt.strip()
-    context.response = StatusResponse(outp, outfmt, status)
+    context.response = StatusResponse(outp, fmt or 'text', status)
 @step(u'(?P<operator>less than|more than|exactly|at least|at most) (?P<number>\d+) results? (?:is|are) returned')
 def validate_result_number(context, operator, number):
diff --git a/test/bdd/steps/steps_db_ops.py b/test/bdd/steps/steps_db_ops.py
new file mode 100644 (file)
index 0000000..c549f3e
--- /dev/null
@@ -0,0 +1,336 @@
+from itertools import chain
+import psycopg2.extras
+from place_inserter import PlaceColumn
+from table_compare import NominatimID, DBRow
+def check_database_integrity(context):
+    """ Check some generic constraints on the tables.
+    """
+    # place_addressline should not have duplicate (place_id, address_place_id)
+    cur = context.db.cursor()
+    cur.execute("""SELECT count(*) FROM
+                    (SELECT place_id, address_place_id, count(*) as c
+                     FROM place_addressline GROUP BY place_id, address_place_id) x
+                   WHERE c > 1""")
+    assert cur.fetchone()[0] == 0, "Duplicates found in place_addressline"
+################################ GIVEN ##################################
+@given("the (?P<named>named )?places")
+def add_data_to_place_table(context, named):
+    """ Add entries into the place table. 'named places' makes sure that
+        the entries get a random name when none is explicitly given.
+    """
+    with context.db.cursor() as cur:
+        cur.execute('ALTER TABLE place DISABLE TRIGGER place_before_insert')
+        for row in context.table:
+            PlaceColumn(context).add_row(row, named is not None).db_insert(cur)
+        cur.execute('ALTER TABLE place ENABLE TRIGGER place_before_insert')
+@given("the relations")
+def add_data_to_planet_relations(context):
+    """ Add entries into the osm2pgsql relation middle table. This is needed
+        for tests on data that looks up members.
+    """
+    with context.db.cursor() as cur:
+        for r in context.table:
+            last_node = 0
+            last_way = 0
+            parts = []
+            if r['members']:
+                members = []
+                for m in r['members'].split(','):
+                    mid = NominatimID(m)
+                    if mid.typ == 'N':
+                        parts.insert(last_node, int(mid.oid))
+                        last_node += 1
+                        last_way += 1
+                    elif mid.typ == 'W':
+                        parts.insert(last_way, int(mid.oid))
+                        last_way += 1
+                    else:
+                        parts.append(int(mid.oid))
+                    members.extend((mid.typ.lower() + mid.oid, mid.cls or ''))
+            else:
+                members = None
+            tags = chain.from_iterable([(h[5:], r[h]) for h in r.headings if h.startswith("tags+")])
+            cur.execute("""INSERT INTO planet_osm_rels (id, way_off, rel_off, parts, members, tags)
+                           VALUES (%s, %s, %s, %s, %s, %s)""",
+                        (r['id'], last_node, last_way, parts, members, list(tags)))
+@given("the ways")
+def add_data_to_planet_ways(context):
+    """ Add entries into the osm2pgsql way middle table. This is necessary for
+        tests on that that looks up node ids in this table.
+    """
+    with context.db.cursor() as cur:
+        for r in context.table:
+            tags = chain.from_iterable([(h[5:], r[h]) for h in r.headings if h.startswith("tags+")])
+            nodes = [ int(x.strip()) for x in r['nodes'].split(',') ]
+            cur.execute("INSERT INTO planet_osm_ways (id, nodes, tags) VALUES (%s, %s, %s)",
+                        (r['id'], nodes, list(tags)))
+################################ WHEN ##################################
+def import_and_index_data_from_place_table(context):
+    """ Import data previously set up in the place table.
+    """
+    context.nominatim.copy_from_place(context.db)
+    context.nominatim.run_setup_script('calculate-postcodes', 'index', 'index-noanalyse')
+    check_database_integrity(context)
+@when("updating places")
+def update_place_table(context):
+    """ Update the place table with the given data. Also runs all triggers
+        related to updates and reindexes the new data.
+    """
+    context.nominatim.run_setup_script(
+        'create-functions', 'create-partition-functions', 'enable-diff-updates')
+    with context.db.cursor() as cur:
+        for row in context.table:
+            PlaceColumn(context).add_row(row, False).db_insert(cur)
+    context.nominatim.reindex_placex(context.db)
+    check_database_integrity(context)
+@when("updating postcodes")
+def update_postcodes(context):
+    """ Rerun the calculation of postcodes.
+    """
+    context.nominatim.run_update_script('calculate-postcodes')
+@when("marking for delete (?P<oids>.*)")
+def delete_places(context, oids):
+    """ Remove entries from the place table. Multiple ids may be given
+        separated by commas. Also runs all triggers
+        related to updates and reindexes the new data.
+    """
+    context.nominatim.run_setup_script(
+        'create-functions', 'create-partition-functions', 'enable-diff-updates')
+    with context.db.cursor() as cur:
+        for oid in oids.split(','):
+            NominatimID(oid).query_osm_id(cur, 'DELETE FROM place WHERE {}')
+    context.nominatim.reindex_placex(context.db)
+################################ THEN ##################################
+@then("(?P<table>placex|place) contains(?P<exact> exactly)?")
+def check_place_contents(context, table, exact):
+    """ Check contents of place/placex tables. Each row represents a table row
+        and all data must match. Data not present in the expected table, may
+        be arbitry. The rows are identified via the 'object' column which must
+        have an identifier of the form '<NRW><osm id>[:<class>]'. When multiple
+        rows match (for example because 'class' was left out and there are
+        multiple entries for the given OSM object) then all must match. All
+        expected rows are expected to be present with at least one database row.
+        When 'exactly' is given, there must not be additional rows in the database.
+    """
+    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+        expected_content = set()
+        for row in context.table:
+            nid = NominatimID(row['object'])
+            query = 'SELECT *, ST_AsText(geometry) as geomtxt, ST_GeometryType(geometry) as geometrytype'
+            if table == 'placex':
+                query += ' ,ST_X(centroid) as cx, ST_Y(centroid) as cy'
+            query += " FROM %s WHERE {}" % (table, )
+            nid.query_osm_id(cur, query)
+            assert cur.rowcount > 0, "No rows found for " + row['object']
+            for res in cur:
+                if exact:
+                    expected_content.add((res['osm_type'], res['osm_id'], res['class']))
+                DBRow(nid, res, context).assert_row(row, ['object'])
+        if exact:
+            cur.execute('SELECT osm_type, osm_id, class from {}'.format(table))
+            assert expected_content == set([(r[0], r[1], r[2]) for r in cur])
+@then("(?P<table>placex|place) has no entry for (?P<oid>.*)")
+def check_place_has_entry(context, table, oid):
+    """ Ensure that no database row for the given object exists. The ID
+        must be of the form '<NRW><osm id>[:<class>]'.
+    """
+    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+        NominatimID(oid).query_osm_id(cur, "SELECT * FROM %s where {}" % table)
+        assert cur.rowcount == 0, \
+               "Found {} entries for ID {}".format(cur.rowcount, oid)
+@then("search_name contains(?P<exclude> not)?")
+def check_search_name_contents(context, exclude):
+    """ Check contents of place/placex tables. Each row represents a table row
+        and all data must match. Data not present in the expected table, may
+        be arbitry. The rows are identified via the 'object' column which must
+        have an identifier of the form '<NRW><osm id>[:<class>]'. All
+        expected rows are expected to be present with at least one database row.
+    """
+    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+        for row in context.table:
+            nid = NominatimID(row['object'])
+            nid.row_by_place_id(cur, 'search_name',
+                                ['ST_X(centroid) as cx', 'ST_Y(centroid) as cy'])
+            assert cur.rowcount > 0, "No rows found for " + row['object']
+            for res in cur:
+                db_row = DBRow(nid, res, context)
+                for name, value in zip(row.headings, row.cells):
+                    if name in ('name_vector', 'nameaddress_vector'):
+                        items = [x.strip() for x in value.split(',')]
+                        with context.db.cursor() as subcur:
+                            subcur.execute(""" SELECT word_id, word_token
+                                               FROM word, (SELECT unnest(%s::TEXT[]) as term) t
+                                               WHERE word_token = make_standard_name(t.term)
+                                                     and class is null and country_code is null
+                                                     and operator is null
+                                              UNION
+                                               SELECT word_id, word_token
+                                               FROM word, (SELECT unnest(%s::TEXT[]) as term) t
+                                               WHERE word_token = ' ' || make_standard_name(t.term)
+                                                     and class is null and country_code is null
+                                                     and operator is null
+                                           """,
+                                           (list(filter(lambda x: not x.startswith('#'), items)),
+                                            list(filter(lambda x: x.startswith('#'), items))))
+                            if not exclude:
+                                assert subcur.rowcount >= len(items), \
+                                    "No word entry found for {}. Entries found: {!s}".format(value, subcur.rowcount)
+                            for wid in subcur:
+                                present = wid[0] in res[name]
+                                if exclude:
+                                    assert not present, "Found term for {}/{}: {}".format(row['object'], name, wid[1])
+                                else:
+                                    assert present, "Missing term for {}/{}: {}".fromat(row['object'], name, wid[1])
+                    elif name != 'object':
+                        assert db_row.contains(name, value), db_row.assert_msg(name, value)
+@then("search_name has no entry for (?P<oid>.*)")
+def check_search_name_has_entry(context, oid):
+    """ Check that there is noentry in the search_name table for the given
+        objects. IDs are in format '<NRW><osm id>[:<class>]'.
+    """
+    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+        NominatimID(oid).row_by_place_id(cur, 'search_name')
+        assert cur.rowcount == 0, \
+               "Found {} entries for ID {}".format(cur.rowcount, oid)
+@then("location_postcode contains exactly")
+def check_location_postcode(context):
+    """ Check full contents for location_postcode table. Each row represents a table row
+        and all data must match. Data not present in the expected table, may
+        be arbitry. The rows are identified via 'country' and 'postcode' columns.
+        All rows must be present as excepted and there must not be additional
+        rows.
+    """
+    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+        cur.execute("SELECT *, ST_AsText(geometry) as geomtxt FROM location_postcode")
+        assert cur.rowcount == len(list(context.table)), \
+            "Postcode table has {} rows, expected {}.".foramt(cur.rowcount, len(list(context.table)))
+        results = {}
+        for row in cur:
+            key = (row['country_code'], row['postcode'])
+            assert key not in results, "Postcode table has duplicate entry: {}".format(row)
+            results[key] = DBRow((row['country_code'],row['postcode']), row, context)
+        for row in context.table:
+            db_row = results.get((row['country'],row['postcode']))
+            assert db_row is not None, \
+                "Missing row for country '{r['country']}' postcode '{r['postcode']}'.".format(r=row)
+            db_row.assert_row(row, ('country', 'postcode'))
+@then("word contains(?P<exclude> not)?")
+def check_word_table(context, exclude):
+    """ Check the contents of the word table. Each row represents a table row
+        and all data must match. Data not present in the expected table, may
+        be arbitry. The rows are identified via all given columns.
+    """
+    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+        for row in context.table:
+            wheres = ' AND '.join(["{} = %s".format(h) for h in row.headings])
+            cur.execute("SELECT * from word WHERE " + wheres, list(row.cells))
+            if exclude:
+                assert cur.rowcount == 0, "Row still in word table: %s" % '/'.join(values)
+            else:
+                assert cur.rowcount > 0, "Row not in word table: %s" % '/'.join(values)
+@then("place_addressline contains")
+def check_place_addressline(context):
+    """ Check the contents of the place_addressline table. Each row represents
+        a table row and all data must match. Data not present in the expected
+        table, may be arbitry. The rows are identified via the 'object' column,
+        representing the addressee and the 'address' column, representing the
+        address item.
+    """
+    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+        for row in context.table:
+            nid = NominatimID(row['object'])
+            pid = nid.get_place_id(cur)
+            apid = NominatimID(row['address']).get_place_id(cur)
+            cur.execute(""" SELECT * FROM place_addressline
+                            WHERE place_id = %s AND address_place_id = %s""",
+                        (pid, apid))
+            assert cur.rowcount > 0, \
+                        "No rows found for place %s and address %s" % (row['object'], row['address'])
+            for res in cur:
+                DBRow(nid, res, context).assert_row(row, ('address', 'object'))
+@then("place_addressline doesn't contain")
+def check_place_addressline_exclude(context):
+    """ Check that the place_addressline doesn't contain any entries for the
+        given addressee/address item pairs.
+    """
+    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+        for row in context.table:
+            pid = NominatimID(row['object']).get_place_id(cur)
+            apid = NominatimID(row['address']).get_place_id(cur)
+            cur.execute(""" SELECT * FROM place_addressline
+                            WHERE place_id = %s AND address_place_id = %s""",
+                        (pid, apid))
+            assert cur.rowcount == 0, \
+                "Row found for place %s and address %s" % (row['object'], row['address'])
+@then("W(?P<oid>\d+) expands to(?P<neg> no)? interpolation")
+def check_location_property_osmline(context, oid, neg):
+    """ Check that the given way is present in the interpolation table.
+    """
+    with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
+        cur.execute("""SELECT *, ST_AsText(linegeo) as geomtxt
+                       FROM location_property_osmline
+                       WHERE osm_id = %s AND startnumber IS NOT NULL""",
+                    (oid, ))
+        if neg:
+            assert cur.rowcount == 0, "Interpolation found for way {}.".format(oid)
+            return
+        todo = list(range(len(list(context.table))))
+        for res in cur:
+            for i in todo:
+                row = context.table[i]
+                if (int(row['start']) == res['startnumber']
+                    and int(row['end']) == res['endnumber']):
+                    todo.remove(i)
+                    break
+            else:
+                assert False, "Unexpected row " + str(res)
+            DBRow(oid, res, context).assert_row(row, ('start', 'end'))
+        assert not todo
similarity index 53%
rename from test/bdd/steps/osm_data.py
rename to test/bdd/steps/steps_osm_data.py
index 0f8b11886576948c26a23378d1c20194b0b63a77..3858198b680112017d785c522a62ca4f1b6c2243 100644 (file)
@@ -1,34 +1,45 @@
-import subprocess
 import tempfile
 import random
 import os
+def write_opl_file(opl, grid):
+    """ Create a temporary OSM file from OPL and return the file name. It is
+        the responsibility of the caller to delete the file again.
+        Node with missing coordinates, can retrieve their coordinates from
+        a supplied grid. Failing that a random coordinate is assigned.
+    """
+    with tempfile.NamedTemporaryFile(suffix='.opl', delete=False) as fd:
+        for line in opl.splitlines():
+            if line.startswith('n') and line.find(' x') < 0:
+                coord = grid.grid_node(int(line[1:].split(' ')[0]))
+                if coord is None:
+                    coord = (random.random() * 360 - 180,
+                             random.random() * 180 - 90)
+                line += " x%f y%f" % coord
+            fd.write(line.encode('utf-8'))
+            fd.write(b'\n')
+        return fd.name
+@given(u'the scene (?P<scene>.+)')
+def set_default_scene(context, scene):
+    context.scene = scene
 @given(u'the ([0-9.]+ )?grid')
 def define_node_grid(context, grid_step):
     Define a grid of node positions.
+    Use a table to define the grid. The nodes must be integer ids. Optionally
+    you can give the grid distance. The default is 0.00001 degrees.
     if grid_step is not None:
         grid_step = float(grid_step.strip())
         grid_step = 0.00001
-    context.osm.clear_grid()
-    i = 0
-    for h in context.table.headings:
-        if h.isdigit():
-            context.osm.add_grid_node(int(h), 0, i)
-        i += grid_step
-    x = grid_step
-    for r in context.table:
-        y = 0
-        for h in r:
-            if h.isdigit():
-                context.osm.add_grid_node(int(h), x, y)
-            y += grid_step
-        x += grid_step
+    context.osm.set_grid([context.table.headings] + [list(h) for h in context.table],
+                         grid_step)
 @when(u'loading osm data')
@@ -39,21 +50,11 @@ def load_osm_file(context):
     The data is expected as attached text in OPL format.
-    # create a OSM file in /tmp and import it
-    with tempfile.NamedTemporaryFile(dir='/tmp', suffix='.opl', delete=False) as fd:
-        fname = fd.name
-        for line in context.text.splitlines():
-            if line.startswith('n') and line.find(' x') < 0:
-                coord = context.osm.grid_node(int(line[1:].split(' ')[0]))
-                if coord is None:
-                    coord = (random.random() * 360 - 180,
-                             random.random() * 180 - 90)
-                line += " x%f y%f" % coord
-            fd.write(line.encode('utf-8'))
-            fd.write(b'\n')
+    # create an OSM file and import it
+    fname = write_opl_file(context.text, context.osm)
     context.nominatim.run_setup_script('import-data', osm_file=fname,
+    os.remove(fname)
     ### reintroduce the triggers/indexes we've lost by having osm2pgsql set up place again
     cur = context.db.cursor()
@@ -64,7 +65,6 @@ def load_osm_file(context):
     cur.execute("""CREATE UNIQUE INDEX idx_place_osm_unique on place using btree(osm_id,osm_type,class,type)""")
-    os.remove(fname)
 @when(u'updating osm data')
 def update_from_osm_file(context):
@@ -74,30 +74,12 @@ def update_from_osm_file(context):
     The data is expected as attached text in OPL format.
-    context.nominatim.run_setup_script('create-functions', 'create-partition-functions')
-    cur = context.db.cursor()
-    cur.execute("""insert into placex (osm_type, osm_id, class, type, name, admin_level, address, extratags, geometry)
-                     select            osm_type, osm_id, class, type, name, admin_level, address, extratags, geometry from place""")
-    cur.execute(
-        """insert into location_property_osmline (osm_id, address, linegeo)
-             SELECT osm_id, address, geometry from place
-              WHERE class='place' and type='houses' and osm_type='W'
-                    and ST_GeometryType(geometry) = 'ST_LineString'""")
-    context.db.commit()
+    context.nominatim.copy_from_place(context.db)
     context.nominatim.run_setup_script('index', 'index-noanalyse')
     context.nominatim.run_setup_script('create-functions', 'create-partition-functions',
-    # create a OSM file in /tmp and import it
-    with tempfile.NamedTemporaryFile(dir='/tmp', suffix='.opl', delete=False) as fd:
-        fname = fd.name
-        for line in context.text.splitlines():
-            if line.startswith('n') and line.find(' x') < 0:
-                    line += " x%d y%d" % (random.random() * 360 - 180,
-                                          random.random() * 180 - 90)
-            fd.write(line.encode('utf-8'))
-            fd.write(b'\n')
+    # create an OSM file and import it
+    fname = write_opl_file(context.text, context.osm)
diff --git a/test/bdd/steps/table_compare.py b/test/bdd/steps/table_compare.py
new file mode 100644 (file)
index 0000000..2e71d94
--- /dev/null
@@ -0,0 +1,209 @@
+Functions to facilitate accessing and comparing the content of DB tables.
+import re
+import json
+from steps.check_functions import Almost
+ID_REGEX = re.compile(r"(?P<typ>[NRW])(?P<oid>\d+)(:(?P<cls>\w+))?")
+class NominatimID:
+    """ Splits a unique identifier for places into its components.
+        As place_ids cannot be used for testing, we use a unique
+        identifier instead that is of the form <osmtype><osmid>[:<class>].
+    """
+    def __init__(self, oid):
+        self.typ = self.oid = self.cls = None
+        if oid is not None:
+            m = ID_REGEX.fullmatch(oid)
+            assert m is not None, \
+                   "ID '{}' not of form <osmtype><osmid>[:<class>]".format(oid)
+            self.typ = m.group('typ')
+            self.oid = m.group('oid')
+            self.cls = m.group('cls')
+    def __str__(self):
+        if self.cls is None:
+            return self.typ + self.oid
+        return '{self.typ}{self.oid}:{self.cls}'.format(self=self)
+    def query_osm_id(self, cur, query):
+        """ Run a query on cursor `cur` using osm ID, type and class. The
+            `query` string must contain exactly one placeholder '{}' where
+            the 'where' query should go.
+        """
+        where = 'osm_type = %s and osm_id = %s'
+        params = [self.typ, self. oid]
+        if self.cls is not None:
+            where += ' and class = %s'
+            params.append(self.cls)
+        cur.execute(query.format(where), params)
+    def row_by_place_id(self, cur, table, extra_columns=None):
+        """ Get a row by place_id from the given table using cursor `cur`.
+            extra_columns may contain a list additional elements for the select
+            part of the query.
+        """
+        pid = self.get_place_id(cur)
+        query = "SELECT {} FROM {} WHERE place_id = %s".format(
+                    ','.join(['*'] + (extra_columns or [])), table)
+        cur.execute(query, (pid, ))
+    def get_place_id(self, cur):
+        """ Look up the place id for the ID. Throws an assertion if the ID
+            is not unique.
+        """
+        self.query_osm_id(cur, "SELECT place_id FROM placex WHERE {}")
+        assert cur.rowcount == 1, \
+               "Place ID {!s} not unique. Found {} entries.".format(self, cur.rowcount)
+        return cur.fetchone()[0]
+class DBRow:
+    """ Represents a row from a database and offers comparison functions.
+    """
+    def __init__(self, nid, db_row, context):
+        self.nid = nid
+        self.db_row = db_row
+        self.context = context
+    def assert_row(self, row, exclude_columns):
+        """ Check that all columns of the given behave row are contained
+            in the database row. Exclude behave rows with the names given
+            in the `exclude_columns` list.
+        """
+        for name, value in zip(row.headings, row.cells):
+            if name not in exclude_columns:
+                assert self.contains(name, value), self.assert_msg(name, value)
+    def contains(self, name, expected):
+        """ Check that the DB row contains a column `name` with the given value.
+        """
+        if '+' in name:
+            column, field = name.split('+', 1)
+            return self._contains_hstore_value(column, field, expected)
+        if name == 'geometry':
+            return self._has_geometry(expected)
+        if name not in self.db_row:
+            return False
+        actual = self.db_row[name]
+        if expected == '-':
+            return actual is None
+        if name == 'name' and ':' not in expected:
+            return self._compare_column(actual[name], expected)
+        if 'place_id' in name:
+            return self._compare_place_id(actual, expected)
+        if name == 'centroid':
+            return self._has_centroid(expected)
+        return self._compare_column(actual, expected)
+    def _contains_hstore_value(self, column, field, expected):
+        if column == 'addr':
+            column = 'address'
+        if column not in self.db_row:
+            return False
+        if expected == '-':
+            return self.db_row[column] is None or field not in self.db_row[column]
+        if self.db_row[column] is None:
+            return False
+        return self._compare_column(self.db_row[column].get(field), expected)
+    def _compare_column(self, actual, expected):
+        if isinstance(actual, dict):
+            return actual == eval('{' + expected + '}')
+        return str(actual) == expected
+    def _compare_place_id(self, actual, expected):
+       if expected == '0':
+            return actual == 0
+       with self.context.db.cursor() as cur:
+            return NominatimID(expected).get_place_id(cur) == actual
+    def _has_centroid(self, expected):
+        if expected == 'in geometry':
+            with self.context.db.cursor() as cur:
+                cur.execute("""SELECT ST_Within(ST_SetSRID(ST_Point({cx}, {cy}), 4326),
+                                        ST_SetSRID('{geomtxt}'::geometry, 4326))""".format(**self.db_row))
+                return cur.fetchone()[0]
+        x, y = expected.split(' ')
+        return Almost(float(x)) == self.db_row['cx'] and Almost(float(y)) == self.db_row['cy']
+    def _has_geometry(self, expected):
+        geom = self.context.osm.parse_geometry(expected, self.context.scene)
+        with self.context.db.cursor() as cur:
+            cur.execute("""SELECT ST_Equals(ST_SnapToGrid({}, 0.00001, 0.00001),
+                                   ST_SnapToGrid(ST_SetSRID('{}'::geometry, 4326), 0.00001, 0.00001))""".format(
+                            geom, self.db_row['geomtxt']))
+            return cur.fetchone()[0]
+    def assert_msg(self, name, value):
+        """ Return a string with an informative message for a failed compare.
+        """
+        msg = "\nBad column '{}' in row '{!s}'.".format(name, self.nid)
+        actual = self._get_actual(name)
+        if actual is not None:
+            msg += " Expected: {}, got: {}.".format(value, actual)
+        else:
+            msg += " No such column."
+        return msg + "\nFull DB row: {}".format(json.dumps(dict(self.db_row), indent=4, default=str))
+    def _get_actual(self, name):
+        if '+' in name:
+            column, field = name.split('+', 1)
+            if column == 'addr':
+                column = 'address'
+            return (self.db_row.get(column) or {}).get(field)
+        if name == 'geometry':
+            return self.db_row['geomtxt']
+        if name not in self.db_row:
+            return None
+        if name == 'centroid':
+            return "POINT({cx} {cy})".format(**self.db_row)
+        actual = self.db_row[name]
+        if 'place_id' in name:
+            if actual is None:
+                return '<null>'
+            if actual == 0:
+                return "place ID 0"
+            with self.context.db.cursor() as cur:
+                cur.execute("""SELECT osm_type, osm_id, class
+                               FROM placex WHERE place_id = %s""",
+                            (actual, ))
+                if cur.rowcount == 1:
+                    return "{0[0]}{0[1]}:{0[2]}".format(cur.fetchone())
+                return "[place ID {} not found]".format(actual)
+        return actual
diff --git a/test/bdd/steps/utils.py b/test/bdd/steps/utils.py
new file mode 100644 (file)
index 0000000..64d020d
--- /dev/null
@@ -0,0 +1,22 @@
+Various smaller helps for step execution.
+import logging
+import subprocess
+LOG = logging.getLogger(__name__)
+def run_script(cmd, **kwargs):
+    """ Run the given command, check that it is successful and output
+        when necessary.
+    """
+    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
+                            **kwargs)
+    (outp, outerr) = proc.communicate()
+    outp = outp.decode('utf-8')
+    outerr = outerr.decode('utf-8').replace('\\n', '\n')
+    LOG.debug("Run command: %s\n%s\n%s", cmd, outp, outerr)
+    assert proc.returncode == 0, "Script '{}' failed:\n{}\n{}\n".format(cmd[0], outp, outerr)
+    return outp, outerr