1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2025 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Server implementation using the starlette webserver framework.
10 from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, \
11 Awaitable, AsyncIterator
12 from pathlib import Path
17 from starlette.applications import Starlette
18 from starlette.routing import Route
19 from starlette.exceptions import HTTPException
20 from starlette.responses import Response, PlainTextResponse, HTMLResponse
21 from starlette.requests import Request
22 from starlette.middleware import Middleware
23 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
24 from starlette.middleware.cors import CORSMiddleware
26 from ...config import Configuration
27 from ...core import NominatimAPIAsync
28 from ...types import QueryStatistics
29 from ... import v1 as api_impl
30 from ...result_formatting import FormatDispatcher, load_format_dispatcher
31 from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
32 from ... import logging as loglib
35 class ParamWrapper(ASGIAdaptor):
36 """ Adaptor class for server glue to Starlette framework.
39 def __init__(self, request: Request) -> None:
40 self.request = request
42 def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
43 return self.request.query_params.get(name, default=default)
45 def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
46 return self.request.headers.get(name, default)
48 def error(self, msg: str, status: int = 400) -> HTTPException:
49 return HTTPException(status, detail=msg,
50 headers={'content-type': self.content_type})
52 def create_response(self, status: int, output: str, num_results: int) -> Response:
53 self.request.state.num_results = num_results
54 return Response(output, status_code=status, media_type=self.content_type)
56 def base_uri(self) -> str:
57 scheme = self.request.url.scheme
58 host = self.request.url.hostname
59 port = self.request.url.port
60 root = self.request.scope['root_path']
61 if (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443):
64 return f"{scheme}://{host}:{port}{root}"
66 return f"{scheme}://{host}{root}"
68 def config(self) -> Configuration:
69 return cast(Configuration, self.request.app.state.API.config)
71 def formatting(self) -> FormatDispatcher:
72 return cast(FormatDispatcher, self.request.app.state.formatter)
74 def query_stats(self) -> Optional[QueryStatistics]:
75 return cast(Optional[QueryStatistics], getattr(self.request.state, 'query_stats', None))
78 def _wrap_endpoint(func: EndpointFunc)\
79 -> Callable[[Request], Coroutine[Any, Any, Response]]:
80 async def _callback(request: Request) -> Response:
81 return cast(Response, await func(request.app.state.API, ParamWrapper(request)))
86 class FileLoggingMiddleware(BaseHTTPMiddleware):
87 """ Middleware to log selected requests into a file.
90 def __init__(self, app: Starlette, file_name: str = '', logstr: str = ''):
92 self.fd = open(file_name, 'a', buffering=1, encoding='utf8')
93 self.logstr = logstr + '\n'
95 async def dispatch(self, request: Request,
96 call_next: RequestResponseEndpoint) -> Response:
97 qs = QueryStatistics()
98 request.state.query_stats = qs
99 response = await call_next(request)
101 if response.status_code != 200 or 'start' not in qs:
104 for endpoint in ('reverse', 'search', 'lookup', 'details'):
105 if request.url.path.startswith('/' + endpoint):
106 qs['endpoint'] = endpoint
111 qs['query_string'] = request.scope['query_string'].decode('utf8')
112 qs['results_total'] = getattr(request.state, 'num_results', 0)
113 for param in ('start', 'end', 'start_query'):
114 if isinstance(qs.get(param), dt.datetime):
115 qs[param] = qs[param].replace(tzinfo=None)\
116 .isoformat(sep=' ', timespec='milliseconds')
118 self.fd.write(self.logstr.format_map(qs))
123 async def timeout_error(request: Request,
124 _: Exception) -> Response:
125 """ Error handler for query timeouts.
127 loglib.log().comment('Aborted: Query took too long to process.')
128 logdata = loglib.get_and_disable()
131 return HTMLResponse(logdata)
133 return PlainTextResponse("Query took too long to process.", status_code=503)
136 def get_application(project_dir: Path,
137 environ: Optional[Mapping[str, str]] = None,
138 debug: bool = True) -> Starlette:
139 """ Create a Nominatim falcon ASGI application.
141 config = Configuration(project_dir, environ)
144 if config.get_bool('CORS_NOACCESSCONTROL'):
145 middleware.append(Middleware(CORSMiddleware,
147 allow_methods=['GET', 'OPTIONS'],
150 log_file = config.LOG_FILE
152 middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file, # type: ignore
153 logstr=config.LOG_FORMAT))
155 exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
156 TimeoutError: timeout_error,
157 asyncio.TimeoutError: timeout_error
160 @contextlib.asynccontextmanager
161 async def lifespan(app: Starlette) -> AsyncIterator[Any]:
162 app.state.API = NominatimAPIAsync(project_dir, environ)
163 config = app.state.API.config
165 legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
166 for name, func in await api_impl.get_routes(app.state.API):
167 endpoint = _wrap_endpoint(func)
168 app.routes.append(Route(f"/{name}", endpoint=endpoint))
170 app.routes.append(Route(f"/{name}.php", endpoint=endpoint))
174 await app.state.API.close()
176 app = Starlette(debug=debug, middleware=middleware,
177 exception_handlers=exceptions,
180 app.state.formatter = load_format_dispatcher('v1', project_dir)
185 def run_wsgi() -> Starlette:
186 """ Entry point for uvicorn.
188 return get_application(Path('.'), debug=False)