]> git.openstreetmap.org Git - nominatim.git/commitdiff
api: delay setup of initial database connection
authorSarah Hoffmann <lonvia@denofr.de>
Tue, 24 Jan 2023 09:56:22 +0000 (10:56 +0100)
committerSarah Hoffmann <lonvia@denofr.de>
Tue, 24 Jan 2023 09:56:22 +0000 (10:56 +0100)
Defer database setup until the first call to a function. Needs an
additional lock because the setup still needs to be done sequentially.

nominatim/api.py
nominatim/apicmd/status.py

index 10cca5330d6ae11fa1699933c27e61ecac330e21..4ce89595994637dc6d74fc0984b31100948cea5e 100644 (file)
@@ -2,18 +2,18 @@
 #
 # This file is part of Nominatim. (https://nominatim.org)
 #
-# Copyright (C) 2022 by the Nominatim developer community.
+# Copyright (C) 2023 by the Nominatim developer community.
 # For a full list of authors see the git log.
 """
 Implementation of classes for API access via libraries.
 """
-from typing import Mapping, Optional, cast, Any
+from typing import Mapping, Optional, Any, AsyncIterator
 import asyncio
+import contextlib
 from pathlib import Path
 
-from sqlalchemy import text, event
-from sqlalchemy.engine.url import URL
-from sqlalchemy.ext.asyncio import create_async_engine
+import sqlalchemy as sa
+import sqlalchemy.ext.asyncio as sa_asyncio
 import asyncpg
 
 from nominatim.config import Configuration
@@ -25,52 +25,93 @@ class NominatimAPIAsync:
     def __init__(self, project_dir: Path,
                  environ: Optional[Mapping[str, str]] = None) -> None:
         self.config = Configuration(project_dir, environ)
+        self.server_version = 0
+
+        self._engine_lock = asyncio.Lock()
+        self._engine: Optional[sa_asyncio.AsyncEngine] = None
+
+
+    async def setup_database(self) -> None:
+        """ Set up the engine and connection parameters.
+
+            This function will be implicitly called when the database is
+            accessed for the first time. You may also call it explicitly to
+            avoid that the first call is delayed by the setup.
+        """
+        async with self._engine_lock:
+            if self._engine:
+                return
+
+            dsn = self.config.get_database_params()
+
+            dburl = sa.engine.URL.create(
+                       'postgresql+asyncpg',
+                       database=dsn.get('dbname'),
+                       username=dsn.get('user'), password=dsn.get('password'),
+                       host=dsn.get('host'), port=int(dsn['port']) if 'port' in dsn else None,
+                       query={k: v for k, v in dsn.items()
+                              if k not in ('user', 'password', 'dbname', 'host', 'port')})
+            engine = sa_asyncio.create_async_engine(
+                             dburl, future=True,
+                             connect_args={'server_settings': {
+                                'DateStyle': 'sql,european',
+                                'max_parallel_workers_per_gather': '0'
+                             }})
+
+            try:
+                async with engine.begin() as conn:
+                    result = await conn.scalar(sa.text('SHOW server_version_num'))
+                    self.server_version = int(result)
+            except asyncpg.PostgresError:
+                self.server_version = 0
+
+            if self.server_version >= 110000:
+                @sa.event.listens_for(engine.sync_engine, "connect") # type: ignore[misc]
+                def _on_connect(dbapi_con: Any, _: Any) -> None:
+                    cursor = dbapi_con.cursor()
+                    cursor.execute("SET jit_above_cost TO '-1'")
+                # Make sure that all connections get the new settings
+                await self.close()
+
+            self._engine = engine
 
-        dsn = self.config.get_database_params()
-
-        dburl = URL.create(
-                   'postgresql+asyncpg',
-                   database=dsn.get('dbname'),
-                   username=dsn.get('user'), password=dsn.get('password'),
-                   host=dsn.get('host'), port=int(dsn['port']) if 'port' in dsn else None,
-                   query={k: v for k, v in dsn.items()
-                          if k not in ('user', 'password', 'dbname', 'host', 'port')})
-        self.engine = create_async_engine(
-                         dburl, future=True,
-                         connect_args={'server_settings': {
-                            'DateStyle': 'sql,european',
-                            'max_parallel_workers_per_gather': '0'
-                         }})
-        asyncio.get_event_loop().run_until_complete(self._query_server_version())
-        asyncio.get_event_loop().run_until_complete(self.close())
-
-        if self.server_version >= 110000:
-            @event.listens_for(self.engine.sync_engine, "connect") # type: ignore[misc]
-            def _on_connect(dbapi_con: Any, _: Any) -> None:
-                cursor = dbapi_con.cursor()
-                cursor.execute("SET jit_above_cost TO '-1'")
-
-
-    async def _query_server_version(self) -> None:
-        try:
-            async with self.engine.begin() as conn:
-                result = await conn.scalar(text('SHOW server_version_num'))
-                self.server_version = int(cast(str, result))
-        except asyncpg.PostgresError:
-            self.server_version = 0
 
     async def close(self) -> None:
         """ Close all active connections to the database. The NominatimAPIAsync
             object remains usable after closing. If a new API functions is
             called, new connections are created.
         """
-        await self.engine.dispose()
+        if self._engine is not None:
+            await self._engine.dispose()
+
+
+    @contextlib.asynccontextmanager
+    async def begin(self) -> AsyncIterator[sa_asyncio.AsyncConnection]:
+        """ Create a new connection with automatic transaction handling.
+
+            This function may be used to get low-level access to the database.
+            Refer to the documentation of SQLAlchemy for details how to use
+            the connection object.
+        """
+        if self._engine is None:
+            await self.setup_database()
+
+        assert self._engine is not None
+
+        async with self._engine.begin() as conn:
+            yield conn
 
 
     async def status(self) -> StatusResult:
         """ Return the status of the database.
         """
-        return await get_status(self.engine)
+        try:
+            async with self.begin() as conn:
+                status = await get_status(conn)
+        except asyncpg.PostgresError:
+            return StatusResult(700, 'Database connection failed')
+
+        return status
 
 
 class NominatimAPI:
@@ -79,7 +120,8 @@ class NominatimAPI:
 
     def __init__(self, project_dir: Path,
                  environ: Optional[Mapping[str, str]] = None) -> None:
-        self.async_api = NominatimAPIAsync(project_dir, environ)
+        self._loop = asyncio.new_event_loop()
+        self._async_api = NominatimAPIAsync(project_dir, environ)
 
 
     def close(self) -> None:
@@ -87,10 +129,11 @@ class NominatimAPI:
             object remains usable after closing. If a new API functions is
             called, new connections are created.
         """
-        asyncio.get_event_loop().run_until_complete(self.async_api.close())
+        self._loop.run_until_complete(self._async_api.close())
+        self._loop.close()
 
 
     def status(self) -> StatusResult:
         """ Return the status of the database.
         """
-        return asyncio.get_event_loop().run_until_complete(self.async_api.status())
+        return self._loop.run_until_complete(self._async_api.status())
index 85071db9397853c64eb789c620a466a2cd81c313..560953d36079bc0c2f5e03abdb3b24787e29c7a7 100644 (file)
@@ -2,7 +2,7 @@
 #
 # This file is part of Nominatim. (https://nominatim.org)
 #
-# Copyright (C) 2022 by the Nominatim developer community.
+# Copyright (C) 2023 by the Nominatim developer community.
 # For a full list of authors see the git log.
 """
 Classes and function releated to status call.
@@ -10,8 +10,8 @@ Classes and function releated to status call.
 from typing import Optional, cast
 import datetime as dt
 
-import sqlalchemy as sqla
-from sqlalchemy.ext.asyncio.engine import AsyncEngine, AsyncConnection
+import sqlalchemy as sa
+from sqlalchemy.ext.asyncio.engine import AsyncConnection
 import asyncpg
 
 from nominatim import version
@@ -31,7 +31,7 @@ class StatusResult:
 async def _get_database_date(conn: AsyncConnection) -> Optional[dt.datetime]:
     """ Query the database date.
     """
-    sql = sqla.text('SELECT lastimportdate FROM import_status LIMIT 1')
+    sql = sa.text('SELECT lastimportdate FROM import_status LIMIT 1')
     result = await conn.execute(sql)
 
     for row in result:
@@ -41,8 +41,8 @@ async def _get_database_date(conn: AsyncConnection) -> Optional[dt.datetime]:
 
 
 async def _get_database_version(conn: AsyncConnection) -> Optional[version.NominatimVersion]:
-    sql = sqla.text("""SELECT value FROM nominatim_properties
-                       WHERE property = 'database_version'""")
+    sql = sa.text("""SELECT value FROM nominatim_properties
+                     WHERE property = 'database_version'""")
     result = await conn.execute(sql)
 
     for row in result:
@@ -51,14 +51,13 @@ async def _get_database_version(conn: AsyncConnection) -> Optional[version.Nomin
     return None
 
 
-async def get_status(engine: AsyncEngine) -> StatusResult:
+async def get_status(conn: AsyncConnection) -> StatusResult:
     """ Execute a status API call.
     """
     status = StatusResult(0, 'OK')
     try:
-        async with engine.begin() as conn:
-            status.data_updated = await _get_database_date(conn)
-            status.database_version = await _get_database_version(conn)
+        status.data_updated = await _get_database_date(conn)
+        status.database_version = await _get_database_version(conn)
     except asyncpg.PostgresError:
         return StatusResult(700, 'Database connection failed')