# SPDX-License-Identifier: GPL-2.0-only
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2025 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Helper functions to compare expected values.
"""
import json
import re
import math

from psycopg import sql as pysql
from psycopg.rows import dict_row
from .geometry_alias import ALIASES


COMPARATOR_TERMS = {
    'exactly': lambda exp, act: exp == act,
    'more than': lambda exp, act: act > exp,
    'less than': lambda exp, act: act < exp,
}


def _pretty(obj):
    return json.dumps(obj, sort_keys=True, indent=2)


def _pt_close(p1, p2):
    return math.isclose(p1[0], p2[0], abs_tol=1e-07) \
           and math.isclose(p1[1], p2[1], abs_tol=1e-07)


def within_box(value, expect):
    coord = [float(x) for x in expect.split(',')]

    if isinstance(value, str):
        if value.startswith('POINT'):
            value = value[6:-1].split(' ')
        else:
            value = value.split(',')
    value = list(map(float, value))

    if len(value) == 2:
        return coord[0] <= value[0] <= coord[2] \
               and coord[1] <= value[1] <= coord[3]

    if len(value) == 4:
        return value[0] >= coord[0] and value[1] <= coord[1] \
               and value[2] >= coord[2] and value[3] <= coord[3]

    raise ValueError("Not a coordinate or bbox.")


COMPARISON_FUNCS = {
    None: lambda val, exp: str(val) == exp,
    'i': lambda val, exp: str(val).lower() == exp.lower(),
    'fm': lambda val, exp: re.fullmatch(exp, val) is not None,
    'dict': lambda val, exp: val is None if exp == '-' else (val == eval('{' + exp + '}')),
    'in_box': within_box
}

OSM_TYPE = {'node': 'n', 'way': 'w', 'relation': 'r',
            'N': 'n', 'W': 'w', 'R': 'r'}


class ResultAttr:
    """ Returns the given attribute as a string.

        The key parameter determines how the value is formatted before
        returning. To refer to sub attributes, use '+' to add more keys
        (e.g. 'name+ref' will access obj['name']['ref']). A '!' introduces
        a formatting suffix. If no suffix is given, the value will be
        converted using the str() function.

        Available formatters:

        !:...   - use a formatting expression according to Python Mini Format Spec
        !i      - make case-insensitive comparison
        !fm     - consider comparison string a regular expression and match full value
        !wkt    - convert the expected value to a WKT string before comparing
        !in_box - the expected value is a comma-separated bbox description
    """

    def __init__(self, obj, key, grid=None):
        self.grid = grid
        self.obj = obj
        if '!' in key:
            self.key, self.fmt = key.rsplit('!', 1)
        else:
            self.key = key
            self.fmt = None

        if self.key == 'object':
            assert 'osm_id' in obj
            assert 'osm_type' in obj
            self.subobj = OSM_TYPE[obj['osm_type']] + str(obj['osm_id'])
            self.fmt = 'i'
        else:
            done = ''
            self.subobj = self.obj
            for sub in self.key.split('+'):
                done += f"[{sub}]"
                assert sub in self.subobj, \
                    f"Missing attribute {done}. Full object:\n{_pretty(self.obj)}"
                self.subobj = self.subobj[sub]

    def __eq__(self, other):
        # work around bad quoting by pytest-bdd
        if not isinstance(other, str):
            return self.subobj == other

        other = other.replace(r'\\', '\\')

        if self.fmt in COMPARISON_FUNCS:
            return COMPARISON_FUNCS[self.fmt](self.subobj, other)

        if self.fmt.startswith(':'):
            return other == f"{{{self.fmt}}}".format(self.subobj)

        if self.fmt == 'wkt':
            return self.compare_wkt(self.subobj, other)

        raise RuntimeError(f"Unknown format string '{self.fmt}'.")

    def __repr__(self):
        k = self.key.replace('+', '][')
        if self.fmt:
            k += '!' + self.fmt
        return f"result[{k}]({self.subobj})"

    def compare_wkt(self, value, expected):
        """ Compare a WKT value against a compact geometry format.
            The function understands the following formats:

              country:<country code>
                 Point geometry guaranteed to be in the given country
              <P>
                 Point geometry
              <P>,...,<P>
                 Line geometry
              (<P>,...,<P>)
                 Polygon geometry

           <P> may either be a coordinate of the form '<x> <y>' or a single
           number. In the latter case it must refer to a point in
           a previously defined grid.
        """
        m = re.fullmatch(r'(POINT)\(([0-9. -]*)\)', value) \
            or re.fullmatch(r'(LINESTRING)\(([0-9,. -]*)\)', value) \
            or re.fullmatch(r'(POLYGON)\(\(([0-9,. -]*)\)\)', value)
        if not m:
            return False

        converted = [list(map(float, pt.split(' ', 1)))
                     for pt in map(str.strip, m[2].split(','))]

        if expected.startswith('country:'):
            ccode = expected[8:].upper()
            assert ccode in ALIASES, f"Geometry error: unknown country {ccode}"
            return m[1] == 'POINT' and _pt_close(converted[0], ALIASES[ccode])

        if ',' not in expected:
            return m[1] == 'POINT' and _pt_close(converted[0], self.get_point(expected))

        if '(' not in expected:
            return m[1] == 'LINESTRING' and \
                all(_pt_close(p1, p2) for p1, p2 in
                    zip(converted, (self.get_point(p) for p in expected.split(','))))

        if m[1] != 'POLYGON':
            return False

        # Polygon comparison is tricky because the polygons don't necessarily
        # end at the same point or have the same winding order.
        # Brute force all possible variants of the expected polygon
        exp_coords = [self.get_point(p) for p in expected[1:-1].split(',')]
        if exp_coords[0] != exp_coords[-1]:
            raise RuntimeError(f"Invalid polygon {expected}. "
                               "First and last point need to be the same")
        for line in (exp_coords[:-1], exp_coords[-1:0:-1]):
            for i in range(len(line)):
                if all(_pt_close(p1, p2) for p1, p2 in
                       zip(converted, line[i:] + line[:i])):
                    return True

        return False

    def get_point(self, pt):
        pt = pt.strip()
        if ' ' in pt:
            return list(map(float, pt.split(' ', 1)))

        assert self.grid

        return self.grid.get(pt)


def check_table_content(conn, tablename, data, grid=None, exact=False):
    lines = set(range(1, len(data)))

    cols = []
    for col in data[0]:
        if col == 'object':
            cols.extend(('osm_id', 'osm_type'))
        elif '!' in col:
            name, fmt = col.rsplit('!', 1)
            if fmt in ('wkt', 'in_box'):
                cols.append(f"ST_AsText({name}) as {name}")
            else:
                cols.append(name.split('+')[0])
        else:
            cols.append(col.split('+')[0])

    with conn.cursor(row_factory=dict_row) as cur:
        cur.execute(pysql.SQL(f"SELECT {','.join(cols)} FROM")
                    + pysql.Identifier(tablename))

        table_content = ''
        for row in cur:
            table_content += '\n' + str(row)
            for i in lines:
                for col, value in zip(data[0], data[i]):
                    if ResultAttr(row, col, grid=grid) != (None if value == '-' else value):
                        break
                else:
                    lines.remove(i)
                    break
            else:
                assert not exact, f"Unexpected row in table {tablename}: {row}"

        assert not lines, \
               "Rows not found:\n" \
               + '\n'.join(str(data[i]) for i in lines) \
               + "\nTable content:\n" \
               + table_content
