]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/server/starlette/server.py
allow OPTIONS method in starlette CORS middleware
[nominatim.git] / nominatim / server / starlette / server.py
index 38eac8dce24c66089b1dcd0c8c3a14ce5864de04..f89e52a151dac89cf205ce25964f4483cbd5e272 100644 (file)
 #
 # This file is part of Nominatim. (https://nominatim.org)
 #
 #
 # This file is part of Nominatim. (https://nominatim.org)
 #
-# Copyright (C) 2022 by the Nominatim developer community.
+# Copyright (C) 2023 by the Nominatim developer community.
 # For a full list of authors see the git log.
 """
 Server implementation using the starlette webserver framework.
 """
 # For a full list of authors see the git log.
 """
 Server implementation using the starlette webserver framework.
 """
-from typing import Any, Type
+from typing import Any, Optional, Mapping, Callable, cast, Coroutine
 from pathlib import Path
 from pathlib import Path
+import datetime as dt
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 from starlette.exceptions import HTTPException
 from starlette.responses import Response
 from starlette.requests import Request
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 from starlette.exceptions import HTTPException
 from starlette.responses import Response
 from starlette.requests import Request
+from starlette.middleware import Middleware
+from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
+from starlette.middleware.cors import CORSMiddleware
 
 from nominatim.api import NominatimAPIAsync
 
 from nominatim.api import NominatimAPIAsync
-from nominatim.apicmd.status import StatusResult
-import nominatim.result_formatter.v1 as formatting
+import nominatim.api.v1 as api_impl
+from nominatim.config import Configuration
 
 
-CONTENT_TYPE = {
-  'text': 'text/plain; charset=utf-8',
-  'xml': 'text/xml; charset=utf-8'
-}
+class ParamWrapper(api_impl.ASGIAdaptor):
+    """ Adaptor class for server glue to Starlette framework.
+    """
 
 
-FORMATTERS = {
-    StatusResult: formatting.create(StatusResult)
-}
+    def __init__(self, request: Request) -> None:
+        self.request = request
 
 
 
 
-def parse_format(request: Request, rtype: Type[Any], default: str) -> None:
-    """ Get and check the 'format' parameter and prepare the formatter.
-        `rtype` describes the expected return type and `default` the
-        format value to assume when no parameter is present.
-    """
-    fmt = request.query_params.get('format', default=default)
-    fmtter = FORMATTERS[rtype]
+    def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
+        return self.request.query_params.get(name, default=default)
 
 
-    if not fmtter.supports_format(fmt):
-        raise HTTPException(400, detail="Parameter 'format' must be one of: " +
-                                        ', '.join(fmtter.list_formats()))
 
 
-    request.state.format = fmt
-    request.state.formatter = fmtter
+    def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
+        return self.request.headers.get(name, default)
 
 
 
 
-def format_response(request: Request, result: Any) -> Response:
-    """ Render response into a string according to the formatter
-        set in `parse_format()`.
-    """
-    fmt = request.state.format
-    return Response(request.state.formatter.format(result, fmt),
-                    media_type=CONTENT_TYPE.get(fmt, 'application/json'))
+    def error(self, msg: str, status: int = 400) -> HTTPException:
+        return HTTPException(status, detail=msg,
+                             headers={'content-type': self.content_type})
+
+
+    def create_response(self, status: int, output: str, num_results: int) -> Response:
+        self.request.state.num_results = num_results
+        return Response(output, status_code=status, media_type=self.content_type)
+
+
+    def config(self) -> Configuration:
+        return cast(Configuration, self.request.app.state.API.config)
 
 
 
 
-async def on_status(request: Request) -> Response:
-    """ Implementation of status endpoint.
+def _wrap_endpoint(func: api_impl.EndpointFunc)\
+        -> Callable[[Request], Coroutine[Any, Any, Response]]:
+    async def _callback(request: Request) -> Response:
+        return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
+
+    return _callback
+
+
+class FileLoggingMiddleware(BaseHTTPMiddleware):
+    """ Middleware to log selected requests into a file.
     """
     """
-    parse_format(request, StatusResult, 'text')
-    result = await request.app.state.API.status()
-    return format_response(request, result)
 
 
+    def __init__(self, app: Starlette, file_name: str = ''):
+        super().__init__(app)
+        self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732
+
+    async def dispatch(self, request: Request,
+                       call_next: RequestResponseEndpoint) -> Response:
+        start = dt.datetime.now(tz=dt.timezone.utc)
+        response = await call_next(request)
+
+        if response.status_code != 200:
+            return response
+
+        finish = dt.datetime.now(tz=dt.timezone.utc)
 
 
-V1_ROUTES = [
-    Route('/status', endpoint=on_status)
-]
+        for endpoint in ('reverse', 'search', 'lookup'):
+            if request.url.path.startswith('/' + endpoint):
+                qtype = endpoint
+                break
+        else:
+            return response
 
 
-def get_application(project_dir: Path) -> Starlette:
+        duration = (finish - start).total_seconds()
+        params = request.scope['query_string'].decode('utf8')
+
+        self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
+                      f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
+                      f'{qtype} "{params}"\n')
+
+        return response
+
+
+def get_application(project_dir: Path,
+                    environ: Optional[Mapping[str, str]] = None,
+                    debug: bool = True) -> Starlette:
     """ Create a Nominatim falcon ASGI application.
     """
     """ Create a Nominatim falcon ASGI application.
     """
-    app = Starlette(debug=True, routes=V1_ROUTES)
+    config = Configuration(project_dir, environ)
+
+    routes = []
+    legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
+    for name, func in api_impl.ROUTES:
+        endpoint = _wrap_endpoint(func)
+        routes.append(Route(f"/{name}", endpoint=endpoint))
+        if legacy_urls:
+            routes.append(Route(f"/{name}.php", endpoint=endpoint))
+
+    middleware = []
+    if config.get_bool('CORS_NOACCESSCONTROL'):
+        middleware.append(Middleware(CORSMiddleware,
+                                     allow_origins=['*'],
+                                     allow_methods=['GET', 'OPTIONS'],
+                                     max_age=86400))
 
 
-    app.state.API = NominatimAPIAsync(project_dir)
+    log_file = config.LOG_FILE
+    if log_file:
+        middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
+
+    async def _shutdown() -> None:
+        await app.state.API.close()
+
+    app = Starlette(debug=debug, routes=routes, middleware=middleware,
+                    on_shutdown=[_shutdown])
+
+    app.state.API = NominatimAPIAsync(project_dir, environ)
 
     return app
 
     return app
+
+
+def run_wsgi() -> Starlette:
+    """ Entry point for uvicorn.
+    """
+    return get_application(Path('.'), debug=False)