]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/server/starlette/server.py
allow OPTIONS method in starlette CORS middleware
[nominatim.git] / nominatim / server / starlette / server.py
index 2bf569edd2b9c21ba778b8aa3de3a62d2fefe85e..f89e52a151dac89cf205ce25964f4483cbd5e272 100644 (file)
@@ -9,15 +9,20 @@ Server implementation using the starlette webserver framework.
 """
 from typing import Any, Optional, Mapping, Callable, cast, Coroutine
 from pathlib import Path
 """
 from typing import Any, Optional, Mapping, Callable, cast, Coroutine
 from pathlib import Path
+import datetime as dt
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 from starlette.exceptions import HTTPException
 from starlette.responses import Response
 from starlette.requests import Request
 
 from starlette.applications import Starlette
 from starlette.routing import Route
 from starlette.exceptions import HTTPException
 from starlette.responses import Response
 from starlette.requests import Request
+from starlette.middleware import Middleware
+from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
+from starlette.middleware.cors import CORSMiddleware
 
 from nominatim.api import NominatimAPIAsync
 import nominatim.api.v1 as api_impl
 
 from nominatim.api import NominatimAPIAsync
 import nominatim.api.v1 as api_impl
+from nominatim.config import Configuration
 
 class ParamWrapper(api_impl.ASGIAdaptor):
     """ Adaptor class for server glue to Starlette framework.
 
 class ParamWrapper(api_impl.ASGIAdaptor):
     """ Adaptor class for server glue to Starlette framework.
@@ -35,12 +40,18 @@ class ParamWrapper(api_impl.ASGIAdaptor):
         return self.request.headers.get(name, default)
 
 
         return self.request.headers.get(name, default)
 
 
-    def error(self, msg: str) -> HTTPException:
-        return HTTPException(400, detail=msg)
+    def error(self, msg: str, status: int = 400) -> HTTPException:
+        return HTTPException(status, detail=msg,
+                             headers={'content-type': self.content_type})
 
 
 
 
-    def create_response(self, status: int, output: str, content_type: str) -> Response:
-        return Response(output, status_code=status, media_type=content_type)
+    def create_response(self, status: int, output: str, num_results: int) -> Response:
+        self.request.state.num_results = num_results
+        return Response(output, status_code=status, media_type=self.content_type)
+
+
+    def config(self) -> Configuration:
+        return cast(Configuration, self.request.app.state.API.config)
 
 
 def _wrap_endpoint(func: api_impl.EndpointFunc)\
 
 
 def _wrap_endpoint(func: api_impl.EndpointFunc)\
@@ -51,16 +62,79 @@ def _wrap_endpoint(func: api_impl.EndpointFunc)\
     return _callback
 
 
     return _callback
 
 
+class FileLoggingMiddleware(BaseHTTPMiddleware):
+    """ Middleware to log selected requests into a file.
+    """
+
+    def __init__(self, app: Starlette, file_name: str = ''):
+        super().__init__(app)
+        self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732
+
+    async def dispatch(self, request: Request,
+                       call_next: RequestResponseEndpoint) -> Response:
+        start = dt.datetime.now(tz=dt.timezone.utc)
+        response = await call_next(request)
+
+        if response.status_code != 200:
+            return response
+
+        finish = dt.datetime.now(tz=dt.timezone.utc)
+
+        for endpoint in ('reverse', 'search', 'lookup'):
+            if request.url.path.startswith('/' + endpoint):
+                qtype = endpoint
+                break
+        else:
+            return response
+
+        duration = (finish - start).total_seconds()
+        params = request.scope['query_string'].decode('utf8')
+
+        self.fd.write(f"[{start.replace(tzinfo=None).isoformat(sep=' ', timespec='milliseconds')}] "
+                      f"{duration:.4f} {getattr(request.state, 'num_results', 0)} "
+                      f'{qtype} "{params}"\n')
+
+        return response
+
+
 def get_application(project_dir: Path,
 def get_application(project_dir: Path,
-                    environ: Optional[Mapping[str, str]] = None) -> Starlette:
+                    environ: Optional[Mapping[str, str]] = None,
+                    debug: bool = True) -> Starlette:
     """ Create a Nominatim falcon ASGI application.
     """
     """ Create a Nominatim falcon ASGI application.
     """
+    config = Configuration(project_dir, environ)
+
     routes = []
     routes = []
+    legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
     for name, func in api_impl.ROUTES:
     for name, func in api_impl.ROUTES:
-        routes.append(Route(f"/{name}", endpoint=_wrap_endpoint(func)))
+        endpoint = _wrap_endpoint(func)
+        routes.append(Route(f"/{name}", endpoint=endpoint))
+        if legacy_urls:
+            routes.append(Route(f"/{name}.php", endpoint=endpoint))
+
+    middleware = []
+    if config.get_bool('CORS_NOACCESSCONTROL'):
+        middleware.append(Middleware(CORSMiddleware,
+                                     allow_origins=['*'],
+                                     allow_methods=['GET', 'OPTIONS'],
+                                     max_age=86400))
 
 
-    app = Starlette(debug=True, routes=routes)
+    log_file = config.LOG_FILE
+    if log_file:
+        middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
+
+    async def _shutdown() -> None:
+        await app.state.API.close()
+
+    app = Starlette(debug=debug, routes=routes, middleware=middleware,
+                    on_shutdown=[_shutdown])
 
     app.state.API = NominatimAPIAsync(project_dir, environ)
 
     return app
 
     app.state.API = NominatimAPIAsync(project_dir, environ)
 
     return app
+
+
+def run_wsgi() -> Starlette:
+    """ Entry point for uvicorn.
+    """
+    return get_application(Path('.'), debug=False)