1 # SPDX-License-Identifier: GPL-3.0-or-later
 
   3 # This file is part of Nominatim. (https://nominatim.org)
 
   5 # Copyright (C) 2024 by the Nominatim developer community.
 
   6 # For a full list of authors see the git log.
 
   8 Server implementation using the starlette webserver framework.
 
  10 from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
 
  11 from pathlib import Path
 
  16 from starlette.applications import Starlette
 
  17 from starlette.routing import Route
 
  18 from starlette.exceptions import HTTPException
 
  19 from starlette.responses import Response, PlainTextResponse, HTMLResponse
 
  20 from starlette.requests import Request
 
  21 from starlette.middleware import Middleware
 
  22 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
 
  23 from starlette.middleware.cors import CORSMiddleware
 
  25 from ...config import Configuration
 
  26 from ...core import NominatimAPIAsync
 
  27 from ... import v1 as api_impl
 
  28 from ...result_formatting import FormatDispatcher, load_format_dispatcher
 
  29 from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
 
  30 from ... import logging as loglib
 
  33 class ParamWrapper(ASGIAdaptor):
 
  34     """ Adaptor class for server glue to Starlette framework.
 
  37     def __init__(self, request: Request) -> None:
 
  38         self.request = request
 
  40     def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
 
  41         return self.request.query_params.get(name, default=default)
 
  43     def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
 
  44         return self.request.headers.get(name, default)
 
  46     def error(self, msg: str, status: int = 400) -> HTTPException:
 
  47         return HTTPException(status, detail=msg,
 
  48                              headers={'content-type': self.content_type})
 
  50     def create_response(self, status: int, output: str, num_results: int) -> Response:
 
  51         self.request.state.num_results = num_results
 
  52         return Response(output, status_code=status, media_type=self.content_type)
 
  54     def base_uri(self) -> str:
 
  55         scheme = self.request.url.scheme
 
  56         host = self.request.url.hostname
 
  57         port = self.request.url.port
 
  58         root = self.request.scope['root_path']
 
  59         if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
 
  62             return f"{scheme}://{host}:{port}{root}"
 
  64         return f"{scheme}://{host}{root}"
 
  66     def config(self) -> Configuration:
 
  67         return cast(Configuration, self.request.app.state.API.config)
 
  69     def formatting(self) -> FormatDispatcher:
 
  70         return cast(FormatDispatcher, self.request.app.state.formatter)
 
  73 def _wrap_endpoint(func: EndpointFunc)\
 
  74         -> Callable[[Request], Coroutine[Any, Any, Response]]:
 
  75     async def _callback(request: Request) -> Response:
 
  76         return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
 
  81 class FileLoggingMiddleware(BaseHTTPMiddleware):
 
  82     """ Middleware to log selected requests into a file.
 
  85     def __init__(self, app: Starlette, file_name: str = ''):
 
  87         self.fd = open(file_name, 'a', buffering=1, encoding='utf8')
 
  89     async def dispatch(self, request: Request,
 
  90                        call_next: RequestResponseEndpoint) -> Response:
 
  91         start = dt.datetime.now(tz=dt.timezone.utc)
 
  92         response = await call_next(request)
 
  94         if response.status_code != 200:
 
  97         finish = dt.datetime.now(tz=dt.timezone.utc)
 
  99         for endpoint in ('reverse', 'search', 'lookup', 'details'):
 
 100             if request.url.path.startswith('/' + endpoint):
 
 106         duration = (finish - start).total_seconds()
 
 107         params = request.scope['query_string'].decode('utf8')
 
 109         self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
 
 110                       f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
 
 111                       f'{qtype} "{params}"\n')
 
 116 async def timeout_error(request: Request,
 
 117                         _: Exception) -> Response:
 
 118     """ Error handler for query timeouts.
 
 120     loglib.log().comment('Aborted: Query took too long to process.')
 
 121     logdata = loglib.get_and_disable()
 
 124         return HTMLResponse(logdata)
 
 126     return PlainTextResponse("Query took too long to process.", status_code=503)
 
 129 def get_application(project_dir: Path,
 
 130                     environ: Optional[Mapping[str, str]] = None,
 
 131                     debug: bool = True) -> Starlette:
 
 132     """ Create a Nominatim falcon ASGI application.
 
 134     config = Configuration(project_dir, environ)
 
 137     if config.get_bool('CORS_NOACCESSCONTROL'):
 
 138         middleware.append(Middleware(CORSMiddleware,
 
 140                                      allow_methods=['GET', 'OPTIONS'],
 
 143     log_file = config.LOG_FILE
 
 145         middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
 
 147     exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
 
 148         TimeoutError: timeout_error,
 
 149         asyncio.TimeoutError: timeout_error
 
 152     @contextlib.asynccontextmanager
 
 153     async def lifespan(app: Starlette) -> None:
 
 154         app.state.API = NominatimAPIAsync(project_dir, environ)
 
 155         config = app.state.API.config
 
 157         legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
 
 158         for name, func in api_impl.ROUTES:
 
 159             endpoint = _wrap_endpoint(func)
 
 160             app.routes.append(Route(f"/{name}", endpoint=endpoint))
 
 162                 app.routes.append(Route(f"/{name}.php", endpoint=endpoint))
 
 166         await app.state.API.close()
 
 168     app = Starlette(debug=debug, middleware=middleware,
 
 169                     exception_handlers=exceptions,
 
 172     app.state.formatter = load_format_dispatcher('v1', project_dir)
 
 177 def run_wsgi() -> Starlette:
 
 178     """ Entry point for uvicorn.
 
 180     return get_application(Path('.'), debug=False)