1 # SPDX-License-Identifier: GPL-3.0-or-later
 
   3 # This file is part of Nominatim. (https://nominatim.org)
 
   5 # Copyright (C) 2024 by the Nominatim developer community.
 
   6 # For a full list of authors see the git log.
 
   8 Custom functions for SQLite.
 
  10 from typing import cast, Optional, Set, Any
 
  14 def weigh_search(search_vector: Optional[str], rankings: str, default: float) -> float:
 
  15     """ Custom weight function for search results.
 
  17     if search_vector is not None:
 
  18         svec = [int(x) for x in search_vector.split(',')]
 
  19         for rank in json.loads(rankings):
 
  20             if all(r in svec for r in rank[1]):
 
  21                 return cast(float, rank[0])
 
  26 class ArrayIntersectFuzzy:
 
  27     """ Compute the array of common elements of all input integer arrays.
 
  28         Very large input parameters may be ignored to speed up
 
  29         computation. Therefore, the result is a superset of common elements.
 
  31         Input and output arrays are given as comma-separated lists.
 
  33     def __init__(self) -> None:
 
  35         self.values: Optional[Set[int]] = None
 
  37     def step(self, value: Optional[str]) -> None:
 
  38         """ Add the next array to the intersection.
 
  43             elif len(value) < 10000000:
 
  44                 if self.values is None:
 
  45                     self.values = {int(x) for x in self.first.split(',')}
 
  46                 self.values.intersection_update((int(x) for x in value.split(',')))
 
  48     def finalize(self) -> str:
 
  49         """ Return the final result.
 
  51         if self.values is not None:
 
  52             return ','.join(map(str, self.values))
 
  58     """ Compute the set of all elements of the input integer arrays.
 
  60         Input and output arrays are given as strings of comma-separated lists.
 
  62     def __init__(self) -> None:
 
  63         self.values: Optional[Set[str]] = None
 
  65     def step(self, value: Optional[str]) -> None:
 
  66         """ Add the next array to the union.
 
  69             if self.values is None:
 
  70                 self.values = set(value.split(','))
 
  72                 self.values.update(value.split(','))
 
  74     def finalize(self) -> str:
 
  75         """ Return the final result.
 
  77         return '' if self.values is None else ','.join(self.values)
 
  80 def array_contains(container: Optional[str], containee: Optional[str]) -> Optional[bool]:
 
  81     """ Is the array 'containee' completely contained in array 'container'.
 
  83     if container is None or containee is None:
 
  86     vset = container.split(',')
 
  87     return all(v in vset for v in containee.split(','))
 
  90 def array_pair_contains(container1: Optional[str], container2: Optional[str],
 
  91                         containee: Optional[str]) -> Optional[bool]:
 
  92     """ Is the array 'containee' completely contained in the union of
 
  93         array 'container1' and array 'container2'.
 
  95     if container1 is None or container2 is None or containee is None:
 
  98     vset = container1.split(',') + container2.split(',')
 
  99     return all(v in vset for v in containee.split(','))
 
 102 def install_custom_functions(conn: Any) -> None:
 
 103     """ Install helper functions for Nominatim into the given SQLite
 
 106     conn.create_function('weigh_search', 3, weigh_search, deterministic=True)
 
 107     conn.create_function('array_contains', 2, array_contains, deterministic=True)
 
 108     conn.create_function('array_pair_contains', 3, array_pair_contains, deterministic=True)
 
 109     _create_aggregate(conn, 'array_intersect_fuzzy', 1, ArrayIntersectFuzzy)
 
 110     _create_aggregate(conn, 'array_union', 1, ArrayUnion)
 
 113 async def _make_aggregate(aioconn: Any, *args: Any) -> None:
 
 114     await aioconn._execute(aioconn._conn.create_aggregate, *args)
 
 117 def _create_aggregate(conn: Any, name: str, nargs: int, aggregate: Any) -> None:
 
 119         conn.await_(_make_aggregate(conn._connection, name, nargs, aggregate))
 
 120     except Exception as error:
 
 121         conn._handle_exception(error)