]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/connection.py
Merge pull request #3773 from lonvia/small-countries
[nominatim.git] / src / nominatim_api / connection.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2024 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Extended SQLAlchemy connection class that also includes access to the schema.
9 """
10 from typing import cast, Any, Mapping, Sequence, Union, Dict, Optional, Set, \
11                    Awaitable, Callable, TypeVar
12 import asyncio
13
14 import sqlalchemy as sa
15 from sqlalchemy.ext.asyncio import AsyncConnection
16
17 from .typing import SaFromClause
18 from .sql.sqlalchemy_schema import SearchTables
19 from .sql.sqlalchemy_types import Geometry
20 from .logging import log
21 from .config import Configuration
22
23 T = TypeVar('T')
24
25
26 class SearchConnection:
27     """ An extended SQLAlchemy connection class, that also contains
28         the table definitions. The underlying asynchronous SQLAlchemy
29         connection can be accessed with the 'connection' property.
30         The 't' property is the collection of Nominatim tables.
31     """
32
33     def __init__(self, conn: AsyncConnection,
34                  tables: SearchTables,
35                  properties: Dict[str, Any],
36                  config: Configuration) -> None:
37         self.connection = conn
38         self.t = tables
39         self.config = config
40         self._property_cache = properties
41         self._classtables: Optional[Set[str]] = None
42         self.query_timeout: Optional[int] = None
43
44     def set_query_timeout(self, timeout: Optional[int]) -> None:
45         """ Set the timeout after which a query over this connection
46             is cancelled.
47         """
48         self.query_timeout = timeout
49
50     async def scalar(self, sql: sa.sql.base.Executable,
51                      params: Union[Mapping[str, Any], None] = None) -> Any:
52         """ Execute a 'scalar()' query on the connection.
53         """
54         log().sql(self.connection, sql, params)
55         return await asyncio.wait_for(self.connection.scalar(sql, params), self.query_timeout)
56
57     async def execute(self, sql: 'sa.Executable',
58                       params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None] = None
59                       ) -> 'sa.Result[Any]':
60         """ Execute a 'execute()' query on the connection.
61         """
62         log().sql(self.connection, sql, params)
63         return await asyncio.wait_for(self.connection.execute(sql, params), self.query_timeout)
64
65     async def get_property(self, name: str, cached: bool = True) -> str:
66         """ Get a property from Nominatim's property table.
67
68             Property values are normally cached so that they are only
69             retrieved from the database when they are queried for the
70             first time with this function. Set 'cached' to False to force
71             reading the property from the database.
72
73             Raises a ValueError if the property does not exist.
74         """
75         lookup_name = f'DBPROP:{name}'
76
77         if cached and lookup_name in self._property_cache:
78             return cast(str, self._property_cache[lookup_name])
79
80         sql = sa.select(self.t.properties.c.value)\
81             .where(self.t.properties.c.property == name)
82         value = await self.connection.scalar(sql)
83
84         if value is None:
85             raise ValueError(f"Property '{name}' not found in database.")
86
87         self._property_cache[lookup_name] = cast(str, value)
88
89         return cast(str, value)
90
91     async def get_db_property(self, name: str) -> Any:
92         """ Get a setting from the database. At the moment, only
93             'server_version', the version of the database software, can
94             be retrieved with this function.
95
96             Raises a ValueError if the property does not exist.
97         """
98         if name != 'server_version':
99             raise ValueError(f"DB setting '{name}' not found in database.")
100
101         return self._property_cache['DB:server_version']
102
103     async def get_cached_value(self, group: str, name: str,
104                                factory: Callable[[], Awaitable[T]]) -> T:
105         """ Access the cache for this Nominatim instance.
106             Each cache value needs to belong to a group and have a name.
107             This function is for internal API use only.
108
109             `factory` is an async callback function that produces
110             the value if it is not already cached.
111
112             Returns the cached value or the result of factory (also caching
113             the result).
114         """
115         full_name = f'{group}:{name}'
116
117         if full_name in self._property_cache:
118             return cast(T, self._property_cache[full_name])
119
120         value = await factory()
121         self._property_cache[full_name] = value
122
123         return value
124
125     async def get_class_table(self, cls: str, typ: str) -> Optional[SaFromClause]:
126         """ Lookup up if there is a classtype table for the given category
127             and return a SQLAlchemy table for it, if it exists.
128         """
129         if self._classtables is None:
130             res = await self.execute(sa.text("""SELECT tablename FROM pg_tables
131                                                 WHERE tablename LIKE 'place_classtype_%'
132                                              """))
133             self._classtables = {r[0] for r in res}
134
135         tablename = f"place_classtype_{cls}_{typ}"
136
137         if tablename not in self._classtables:
138             return None
139
140         if tablename in self.t.meta.tables:
141             return self.t.meta.tables[tablename]
142
143         return sa.Table(tablename, self.t.meta,
144                         sa.Column('place_id', sa.BigInteger),
145                         sa.Column('centroid', Geometry))