1 # SPDX-License-Identifier: GPL-2.0-only
 
   3 # This file is part of Nominatim. (https://nominatim.org)
 
   5 # Copyright (C) 2023 by the Nominatim developer community.
 
   6 # For a full list of authors see the git log.
 
   8 Server implementation using the starlette webserver framework.
 
  10 from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
 
  11 from pathlib import Path
 
  15 from starlette.applications import Starlette
 
  16 from starlette.routing import Route
 
  17 from starlette.exceptions import HTTPException
 
  18 from starlette.responses import Response, PlainTextResponse, HTMLResponse
 
  19 from starlette.requests import Request
 
  20 from starlette.middleware import Middleware
 
  21 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
 
  22 from starlette.middleware.cors import CORSMiddleware
 
  24 from nominatim.api import NominatimAPIAsync
 
  25 import nominatim.api.v1 as api_impl
 
  26 import nominatim.api.logging as loglib
 
  27 from nominatim.config import Configuration
 
  29 class ParamWrapper(api_impl.ASGIAdaptor):
 
  30     """ Adaptor class for server glue to Starlette framework.
 
  33     def __init__(self, request: Request) -> None:
 
  34         self.request = request
 
  37     def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
 
  38         return self.request.query_params.get(name, default=default)
 
  41     def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
 
  42         return self.request.headers.get(name, default)
 
  45     def error(self, msg: str, status: int = 400) -> HTTPException:
 
  46         return HTTPException(status, detail=msg,
 
  47                              headers={'content-type': self.content_type})
 
  50     def create_response(self, status: int, output: str, num_results: int) -> Response:
 
  51         self.request.state.num_results = num_results
 
  52         return Response(output, status_code=status, media_type=self.content_type)
 
  55     def base_uri(self) -> str:
 
  56         scheme = self.request.url.scheme
 
  57         host = self.request.url.hostname
 
  58         port = self.request.url.port
 
  59         root = self.request.scope['root_path']
 
  60         if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
 
  63             return f"{scheme}://{host}:{port}{root}"
 
  65         return f"{scheme}://{host}{root}"
 
  68     def config(self) -> Configuration:
 
  69         return cast(Configuration, self.request.app.state.API.config)
 
  72 def _wrap_endpoint(func: api_impl.EndpointFunc)\
 
  73         -> Callable[[Request], Coroutine[Any, Any, Response]]:
 
  74     async def _callback(request: Request) -> Response:
 
  75         return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
 
  80 class FileLoggingMiddleware(BaseHTTPMiddleware):
 
  81     """ Middleware to log selected requests into a file.
 
  84     def __init__(self, app: Starlette, file_name: str = ''):
 
  86         self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732
 
  88     async def dispatch(self, request: Request,
 
  89                        call_next: RequestResponseEndpoint) -> Response:
 
  90         start = dt.datetime.now(tz=dt.timezone.utc)
 
  91         response = await call_next(request)
 
  93         if response.status_code != 200:
 
  96         finish = dt.datetime.now(tz=dt.timezone.utc)
 
  98         for endpoint in ('reverse', 'search', 'lookup', 'details'):
 
  99             if request.url.path.startswith('/' + endpoint):
 
 105         duration = (finish - start).total_seconds()
 
 106         params = request.scope['query_string'].decode('utf8')
 
 108         self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
 
 109                       f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
 
 110                       f'{qtype} "{params}"\n')
 
 115 async def timeout_error(request: Request, #pylint: disable=unused-argument
 
 116                         _: Exception) -> Response:
 
 117     """ Error handler for query timeouts.
 
 119     loglib.log().comment('Aborted: Query took too long to process.')
 
 120     logdata = loglib.get_and_disable()
 
 123         return HTMLResponse(logdata)
 
 125     return PlainTextResponse("Query took too long to process.", status_code=503)
 
 128 def get_application(project_dir: Path,
 
 129                     environ: Optional[Mapping[str, str]] = None,
 
 130                     debug: bool = True) -> Starlette:
 
 131     """ Create a Nominatim falcon ASGI application.
 
 133     config = Configuration(project_dir, environ)
 
 136     legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
 
 137     for name, func in api_impl.ROUTES:
 
 138         endpoint = _wrap_endpoint(func)
 
 139         routes.append(Route(f"/{name}", endpoint=endpoint))
 
 141             routes.append(Route(f"/{name}.php", endpoint=endpoint))
 
 144     if config.get_bool('CORS_NOACCESSCONTROL'):
 
 145         middleware.append(Middleware(CORSMiddleware,
 
 147                                      allow_methods=['GET', 'OPTIONS'],
 
 150     log_file = config.LOG_FILE
 
 152         middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
 
 154     exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
 
 155         TimeoutError: timeout_error,
 
 156         asyncio.TimeoutError: timeout_error
 
 159     async def _shutdown() -> None:
 
 160         await app.state.API.close()
 
 162     app = Starlette(debug=debug, routes=routes, middleware=middleware,
 
 163                     exception_handlers=exceptions,
 
 164                     on_shutdown=[_shutdown])
 
 166     app.state.API = NominatimAPIAsync(project_dir, environ)
 
 171 def run_wsgi() -> Starlette:
 
 172     """ Entry point for uvicorn.
 
 174     return get_application(Path('.'), debug=False)