1 # SPDX-License-Identifier: GPL-3.0-or-later
 
   3 # This file is part of Nominatim. (https://nominatim.org)
 
   5 # Copyright (C) 2025 by the Nominatim developer community.
 
   6 # For a full list of authors see the git log.
 
   8 Server implementation using the starlette webserver framework.
 
  10 from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, \
 
  11                    Awaitable, AsyncIterator
 
  12 from pathlib import Path
 
  17 from starlette.applications import Starlette
 
  18 from starlette.routing import Route
 
  19 from starlette.exceptions import HTTPException
 
  20 from starlette.responses import Response, PlainTextResponse, HTMLResponse
 
  21 from starlette.requests import Request
 
  22 from starlette.middleware import Middleware
 
  23 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
 
  24 from starlette.middleware.cors import CORSMiddleware
 
  26 from ...config import Configuration
 
  27 from ...core import NominatimAPIAsync
 
  28 from ...types import QueryStatistics
 
  29 from ... import v1 as api_impl
 
  30 from ...result_formatting import FormatDispatcher, load_format_dispatcher
 
  31 from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
 
  32 from ... import logging as loglib
 
  35 class ParamWrapper(ASGIAdaptor):
 
  36     """ Adaptor class for server glue to Starlette framework.
 
  39     def __init__(self, request: Request) -> None:
 
  40         self.request = request
 
  42     def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
 
  43         return self.request.query_params.get(name, default=default)
 
  45     def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
 
  46         return self.request.headers.get(name, default)
 
  48     def error(self, msg: str, status: int = 400) -> HTTPException:
 
  49         return HTTPException(status, detail=msg,
 
  50                              headers={'content-type': self.content_type})
 
  52     def create_response(self, status: int, output: str, num_results: int) -> Response:
 
  53         self.request.state.num_results = num_results
 
  54         return Response(output, status_code=status, media_type=self.content_type)
 
  56     def base_uri(self) -> str:
 
  57         scheme = self.request.url.scheme
 
  58         host = self.request.url.hostname
 
  59         port = self.request.url.port
 
  60         root = self.request.scope['root_path']
 
  61         if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
 
  64             return f"{scheme}://{host}:{port}{root}"
 
  66         return f"{scheme}://{host}{root}"
 
  68     def config(self) -> Configuration:
 
  69         return cast(Configuration, self.request.app.state.API.config)
 
  71     def formatting(self) -> FormatDispatcher:
 
  72         return cast(FormatDispatcher, self.request.app.state.formatter)
 
  74     def query_stats(self) -> Optional[QueryStatistics]:
 
  75         return cast(Optional[QueryStatistics], getattr(self.request.state, 'query_stats', None))
 
  78 def _wrap_endpoint(func: EndpointFunc)\
 
  79         -> Callable[[Request], Coroutine[Any, Any, Response]]:
 
  80     async def _callback(request: Request) -> Response:
 
  81         return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
 
  86 class FileLoggingMiddleware(BaseHTTPMiddleware):
 
  87     """ Middleware to log selected requests into a file.
 
  90     def __init__(self, app: Starlette, file_name: str = '', logstr: str = ''):
 
  92         self.fd = open(file_name, 'a', buffering=1, encoding='utf8')
 
  93         self.logstr = logstr + '\n'
 
  95     async def dispatch(self, request: Request,
 
  96                        call_next: RequestResponseEndpoint) -> Response:
 
  97         qs = QueryStatistics()
 
  98         request.state.query_stats = qs
 
  99         response = await call_next(request)
 
 101         if response.status_code != 200 or 'start' not in qs:
 
 104         for endpoint in ('reverse', 'search', 'lookup', 'details'):
 
 105             if request.url.path.startswith('/' + endpoint):
 
 106                 qs['endpoint'] = endpoint
 
 111         qs['query_string'] = request.scope['query_string'].decode('utf8')
 
 112         qs['results_total'] = getattr(request.state, 'num_results', 0)
 
 113         for param in ('start', 'end', 'start_query'):
 
 114             if isinstance(qs.get(param), dt.datetime):
 
 115                 qs[param] = qs[param].replace(tzinfo=None)\
 
 116                                      .isoformat(sep=' ', timespec='milliseconds')
 
 118         self.fd.write(self.logstr.format_map(qs))
 
 123 async def timeout_error(request: Request,
 
 124                         _: Exception) -> Response:
 
 125     """ Error handler for query timeouts.
 
 127     loglib.log().comment('Aborted: Query took too long to process.')
 
 128     logdata = loglib.get_and_disable()
 
 131         return HTMLResponse(logdata)
 
 133     return PlainTextResponse("Query took too long to process.", status_code=503)
 
 136 def get_application(project_dir: Path,
 
 137                     environ: Optional[Mapping[str, str]] = None,
 
 138                     debug: bool = True) -> Starlette:
 
 139     """ Create a Nominatim falcon ASGI application.
 
 141     config = Configuration(project_dir, environ)
 
 144     if config.get_bool('CORS_NOACCESSCONTROL'):
 
 145         middleware.append(Middleware(CORSMiddleware,
 
 147                                      allow_methods=['GET', 'OPTIONS'],
 
 150     log_file = config.LOG_FILE
 
 152         middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file,  # type: ignore
 
 153                                      logstr=config.LOG_FORMAT))
 
 155     exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
 
 156         TimeoutError: timeout_error,
 
 157         asyncio.TimeoutError: timeout_error
 
 160     @contextlib.asynccontextmanager
 
 161     async def lifespan(app: Starlette) -> AsyncIterator[Any]:
 
 162         app.state.API = NominatimAPIAsync(project_dir, environ)
 
 163         config = app.state.API.config
 
 165         legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
 
 166         for name, func in await api_impl.get_routes(app.state.API):
 
 167             endpoint = _wrap_endpoint(func)
 
 168             app.routes.append(Route(f"/{name}", endpoint=endpoint))
 
 170                 app.routes.append(Route(f"/{name}.php", endpoint=endpoint))
 
 174         await app.state.API.close()
 
 176     app = Starlette(debug=debug, middleware=middleware,
 
 177                     exception_handlers=exceptions,
 
 180     app.state.formatter = load_format_dispatcher('v1', project_dir)
 
 185 def run_wsgi() -> Starlette:
 
 186     """ Entry point for uvicorn.
 
 188     return get_application(Path('.'), debug=False)