1 # SPDX-License-Identifier: GPL-2.0-only
 
   3 # This file is part of Nominatim. (https://nominatim.org)
 
   5 # Copyright (C) 2025 by the Nominatim developer community.
 
   6 # For a full list of authors see the git log.
 
   8 Helper functions to compare expected values.
 
  10 import collections.abc
 
  15 from psycopg import sql as pysql
 
  16 from psycopg.rows import dict_row
 
  17 from .geometry_alias import ALIASES
 
  21     'exactly': lambda exp, act: exp == act,
 
  22     'more than': lambda exp, act: act > exp,
 
  23     'less than': lambda exp, act: act < exp,
 
  28     return json.dumps(obj, sort_keys=True, indent=2)
 
  31 def _pt_close(p1, p2):
 
  32     return math.isclose(p1[0], p2[0], abs_tol=1e-07) \
 
  33            and math.isclose(p1[1], p2[1], abs_tol=1e-07)
 
  36 def within_box(value, expect):
 
  37     coord = [float(x) for x in expect.split(',')]
 
  39     if isinstance(value, str):
 
  40         if value.startswith('POINT'):
 
  41             value = value[6:-1].split(' ')
 
  43             value = value.split(',')
 
  44     value = list(map(float, value))
 
  47         return coord[0] <= value[0] <= coord[2] \
 
  48                and coord[1] <= value[1] <= coord[3]
 
  51         return value[0] >= coord[0] and value[1] <= coord[1] \
 
  52                and value[2] >= coord[2] and value[3] <= coord[3]
 
  54     raise ValueError("Not a coordinate or bbox.")
 
  58     None: lambda val, exp: str(val) == exp,
 
  59     'i': lambda val, exp: str(val).lower() == exp.lower(),
 
  60     'fm': lambda val, exp: re.fullmatch(exp, val) is not None,
 
  61     'dict': lambda val, exp: val is None if exp == '-' else (val == eval('{' + exp + '}')),
 
  65 OSM_TYPE = {'node': 'n', 'way': 'w', 'relation': 'r',
 
  66             'N': 'n', 'W': 'w', 'R': 'r'}
 
  70     """ Returns the given attribute as a string.
 
  72         The key parameter determines how the value is formatted before
 
  73         returning. To refer to sub attributes, use '+' to add more keys
 
  74         (e.g. 'name+ref' will access obj['name']['ref']). A '!' introduces
 
  75         a formatting suffix. If no suffix is given, the value will be
 
  76         converted using the str() function.
 
  80         !:...   - use a formatting expression according to Python Mini Format Spec
 
  81         !i      - make case-insensitive comparison
 
  82         !fm     - consider comparison string a regular expression and match full value
 
  83         !wkt    - convert the expected value to a WKT string before comparing
 
  84         !in_box - the expected value is a comma-separated bbox description
 
  87     def __init__(self, obj, key, grid=None):
 
  91             self.key, self.fmt = key.rsplit('!', 1)
 
  96         if self.key == 'object':
 
  97             assert 'osm_id' in obj
 
  98             assert 'osm_type' in obj
 
  99             self.subobj = OSM_TYPE[obj['osm_type']] + str(obj['osm_id'])
 
 103             self.subobj = self.obj
 
 104             for sub in self.key.split('+'):
 
 106                 if isinstance(self.subobj, collections.abc.Sequence) and sub.isdigit():
 
 108                     assert sub < len(self.subobj), \
 
 109                         f"Out of bound index {done}. Full object:\n{_pretty(self.obj)}"
 
 111                     assert sub in self.subobj, \
 
 112                         f"Missing attribute {done}. Full object:\n{_pretty(self.obj)}"
 
 113                 self.subobj = self.subobj[sub]
 
 115     def __eq__(self, other):
 
 116         # work around bad quoting by pytest-bdd
 
 117         if not isinstance(other, str):
 
 118             return self.subobj == other
 
 120         other = other.replace(r'\\', '\\')
 
 122         if self.fmt in COMPARISON_FUNCS:
 
 123             return COMPARISON_FUNCS[self.fmt](self.subobj, other)
 
 125         if self.fmt.startswith(':'):
 
 126             return other == f"{{{self.fmt}}}".format(self.subobj)
 
 128         if self.fmt == 'wkt':
 
 129             return self.compare_wkt(self.subobj, other)
 
 131         raise RuntimeError(f"Unknown format string '{self.fmt}'.")
 
 134         k = self.key.replace('+', '][')
 
 137         return f"result[{k}]({self.subobj})"
 
 139     def compare_wkt(self, value, expected):
 
 140         """ Compare a WKT value against a compact geometry format.
 
 141             The function understands the following formats:
 
 143               country:<country code>
 
 144                  Point geometry guaranteed to be in the given country
 
 152            <P> may either be a coordinate of the form '<x> <y>' or a single
 
 153            number. In the latter case it must refer to a point in
 
 154            a previously defined grid.
 
 156         m = re.fullmatch(r'(POINT)\(([0-9. -]*)\)', value) \
 
 157             or re.fullmatch(r'(LINESTRING)\(([0-9,. -]*)\)', value) \
 
 158             or re.fullmatch(r'(POLYGON)\(\(([0-9,. -]*)\)\)', value)
 
 162         converted = [list(map(float, pt.split(' ', 1)))
 
 163                      for pt in map(str.strip, m[2].split(','))]
 
 165         if expected.startswith('country:'):
 
 166             ccode = expected[8:].upper()
 
 167             assert ccode in ALIASES, f"Geometry error: unknown country {ccode}"
 
 168             return m[1] == 'POINT' and _pt_close(converted[0], ALIASES[ccode])
 
 170         if ',' not in expected:
 
 171             return m[1] == 'POINT' and _pt_close(converted[0], self.get_point(expected))
 
 173         if '(' not in expected:
 
 174             return m[1] == 'LINESTRING' and \
 
 175                 all(_pt_close(p1, p2) for p1, p2 in
 
 176                     zip(converted, (self.get_point(p) for p in expected.split(','))))
 
 178         if m[1] != 'POLYGON':
 
 181         # Polygon comparison is tricky because the polygons don't necessarily
 
 182         # end at the same point or have the same winding order.
 
 183         # Brute force all possible variants of the expected polygon
 
 184         exp_coords = [self.get_point(p) for p in expected[1:-1].split(',')]
 
 185         if exp_coords[0] != exp_coords[-1]:
 
 186             raise RuntimeError(f"Invalid polygon {expected}. "
 
 187                                "First and last point need to be the same")
 
 188         for line in (exp_coords[:-1], exp_coords[-1:0:-1]):
 
 189             for i in range(len(line)):
 
 190                 if all(_pt_close(p1, p2) for p1, p2 in
 
 191                        zip(converted, line[i:] + line[:i])):
 
 196     def get_point(self, pt):
 
 199             return list(map(float, pt.split(' ', 1)))
 
 203         return self.grid.get(pt)
 
 206 def check_table_content(conn, tablename, data, grid=None, exact=False):
 
 207     lines = set(range(1, len(data)))
 
 212             cols.extend(('osm_id', 'osm_type'))
 
 214             name, fmt = col.rsplit('!', 1)
 
 215             if fmt in ('wkt', 'in_box'):
 
 216                 cols.append(f"ST_AsText({name}) as {name}")
 
 218                 cols.append(name.split('+')[0])
 
 220             cols.append(col.split('+')[0])
 
 222     with conn.cursor(row_factory=dict_row) as cur:
 
 223         cur.execute(pysql.SQL(f"SELECT {','.join(cols)} FROM")
 
 224                     + pysql.Identifier(tablename))
 
 228             table_content += '\n' + str(row)
 
 230                 for col, value in zip(data[0], data[i]):
 
 231                     if ResultAttr(row, col, grid=grid) != (None if value == '-' else value):
 
 237                 assert not exact, f"Unexpected row in table {tablename}: {row}"
 
 240                "Rows not found:\n" \
 
 241                + '\n'.join(str(data[i]) for i in lines) \
 
 242                + "\nTable content:\n" \