]> git.openstreetmap.org Git - nominatim.git/blob - src/nominatim_api/server/starlette/server.py
Merge pull request #3833 from lonvia/rework-logging
[nominatim.git] / src / nominatim_api / server / starlette / server.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2025 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Server implementation using the starlette webserver framework.
9 """
10 from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, \
11                    Awaitable, AsyncIterator
12 from pathlib import Path
13 import asyncio
14 import contextlib
15 import datetime as dt
16
17 from starlette.applications import Starlette
18 from starlette.routing import Route
19 from starlette.exceptions import HTTPException
20 from starlette.responses import Response, PlainTextResponse, HTMLResponse
21 from starlette.requests import Request
22 from starlette.middleware import Middleware
23 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
24 from starlette.middleware.cors import CORSMiddleware
25
26 from ...config import Configuration
27 from ...core import NominatimAPIAsync
28 from ...types import QueryStatistics
29 from ... import v1 as api_impl
30 from ...result_formatting import FormatDispatcher, load_format_dispatcher
31 from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
32 from ... import logging as loglib
33
34
35 class ParamWrapper(ASGIAdaptor):
36     """ Adaptor class for server glue to Starlette framework.
37     """
38
39     def __init__(self, request: Request) -> None:
40         self.request = request
41
42     def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
43         return self.request.query_params.get(name, default=default)
44
45     def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
46         return self.request.headers.get(name, default)
47
48     def error(self, msg: str, status: int = 400) -> HTTPException:
49         return HTTPException(status, detail=msg,
50                              headers={'content-type': self.content_type})
51
52     def create_response(self, status: int, output: str, num_results: int) -> Response:
53         self.request.state.num_results = num_results
54         return Response(output, status_code=status, media_type=self.content_type)
55
56     def base_uri(self) -> str:
57         scheme = self.request.url.scheme
58         host = self.request.url.hostname
59         port = self.request.url.port
60         root = self.request.scope['root_path']
61         if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
62             port = None
63         if port is not None:
64             return f"{scheme}://{host}:{port}{root}"
65
66         return f"{scheme}://{host}{root}"
67
68     def config(self) -> Configuration:
69         return cast(Configuration, self.request.app.state.API.config)
70
71     def formatting(self) -> FormatDispatcher:
72         return cast(FormatDispatcher, self.request.app.state.formatter)
73
74     def query_stats(self) -> Optional[QueryStatistics]:
75         return cast(Optional[QueryStatistics], getattr(self.request.state, 'query_stats', None))
76
77
78 def _wrap_endpoint(func: EndpointFunc)\
79         -> Callable[[Request], Coroutine[Any, Any, Response]]:
80     async def _callback(request: Request) -> Response:
81         return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
82
83     return _callback
84
85
86 class FileLoggingMiddleware(BaseHTTPMiddleware):
87     """ Middleware to log selected requests into a file.
88     """
89
90     def __init__(self, app: Starlette, file_name: str = '', logstr: str = ''):
91         super().__init__(app)
92         self.fd = open(file_name, 'a', buffering=1, encoding='utf8')
93         self.logstr = logstr + '\n'
94
95     async def dispatch(self, request: Request,
96                        call_next: RequestResponseEndpoint) -> Response:
97         qs = QueryStatistics()
98         request.state.query_stats = qs
99         response = await call_next(request)
100
101         if response.status_code != 200 or 'start' not in qs:
102             return response
103
104         for endpoint in ('reverse', 'search', 'lookup', 'details'):
105             if request.url.path.startswith('/' + endpoint):
106                 qs['endpoint'] = endpoint
107                 break
108         else:
109             return response
110
111         qs['query_string'] = request.scope['query_string'].decode('utf8')
112         qs['results_total'] = getattr(request.state, 'num_results', 0)
113         for param in ('start', 'end', 'start_query'):
114             if isinstance(qs.get(param), dt.datetime):
115                 qs[param] = qs[param].replace(tzinfo=None)\
116                                      .isoformat(sep=' ', timespec='milliseconds')
117
118         self.fd.write(self.logstr.format_map(qs))
119
120         return response
121
122
123 async def timeout_error(request: Request,
124                         _: Exception) -> Response:
125     """ Error handler for query timeouts.
126     """
127     loglib.log().comment('Aborted: Query took too long to process.')
128     logdata = loglib.get_and_disable()
129
130     if logdata:
131         return HTMLResponse(logdata)
132
133     return PlainTextResponse("Query took too long to process.", status_code=503)
134
135
136 def get_application(project_dir: Path,
137                     environ: Optional[Mapping[str, str]] = None,
138                     debug: bool = True) -> Starlette:
139     """ Create a Nominatim falcon ASGI application.
140     """
141     config = Configuration(project_dir, environ)
142
143     middleware = []
144     if config.get_bool('CORS_NOACCESSCONTROL'):
145         middleware.append(Middleware(CORSMiddleware,
146                                      allow_origins=['*'],
147                                      allow_methods=['GET', 'OPTIONS'],
148                                      max_age=86400))
149
150     log_file = config.LOG_FILE
151     if log_file:
152         middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file,  # type: ignore
153                                      logstr=config.LOG_FORMAT))
154
155     exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
156         TimeoutError: timeout_error,
157         asyncio.TimeoutError: timeout_error
158     }
159
160     @contextlib.asynccontextmanager
161     async def lifespan(app: Starlette) -> AsyncIterator[Any]:
162         app.state.API = NominatimAPIAsync(project_dir, environ)
163         config = app.state.API.config
164
165         legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
166         for name, func in await api_impl.get_routes(app.state.API):
167             endpoint = _wrap_endpoint(func)
168             app.routes.append(Route(f"/{name}", endpoint=endpoint))
169             if legacy_urls:
170                 app.routes.append(Route(f"/{name}.php", endpoint=endpoint))
171
172         yield
173
174         await app.state.API.close()
175
176     app = Starlette(debug=debug, middleware=middleware,
177                     exception_handlers=exceptions,
178                     lifespan=lifespan)
179
180     app.state.formatter = load_format_dispatcher('v1', project_dir)
181
182     return app
183
184
185 def run_wsgi() -> Starlette:
186     """ Entry point for uvicorn.
187     """
188     return get_application(Path('.'), debug=False)