From ffd5c32f1773484d41926ce1822228417d946605 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Thu, 4 Dec 2025 18:28:04 +0100 Subject: [PATCH] fix comparision between countr tokens and country restriction --- src/nominatim_api/search/db_search_builder.py | 2 +- .../api/search/test_db_search_builder.py | 46 +++++++++++++------ 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/src/nominatim_api/search/db_search_builder.py b/src/nominatim_api/search/db_search_builder.py index 591d32ca..ef3e3195 100644 --- a/src/nominatim_api/search/db_search_builder.py +++ b/src/nominatim_api/search/db_search_builder.py @@ -413,7 +413,7 @@ class SearchBuilder: """ tokens = self.query.get_tokens(trange, qmod.TOKEN_COUNTRY) if self.details.countries: - tokens = [t for t in tokens if t.lookup_word in self.details.countries] + tokens = [t for t in tokens if t.get_country() in self.details.countries] return tokens diff --git a/test/python/api/search/test_db_search_builder.py b/test/python/api/search/test_db_search_builder.py index 18beb6f2..0380d041 100644 --- a/test/python/api/search/test_db_search_builder.py +++ b/test/python/api/search/test_db_search_builder.py @@ -2,12 +2,14 @@ # # This file is part of Nominatim. (https://nominatim.org) # -# Copyright (C) 2023 by the Nominatim developer community. +# Copyright (C) 2025 by the Nominatim developer community. # For a full list of authors see the git log. """ Tests for creating abstract searches from token assignments. """ +from typing import Optional import pytest +import dataclasses from nominatim_api.search.query import Token, TokenRange, QueryStruct, Phrase import nominatim_api.search.query as qmod @@ -17,12 +19,15 @@ from nominatim_api.types import SearchDetails import nominatim_api.search.db_searches as dbs +@dataclasses.dataclass class MyToken(Token): + cc: Optional[str] = None + def get_category(self): return 'this', 'that' def get_country(self): - return self.lookup_word + return self.cc def make_query(*args): @@ -33,18 +38,24 @@ def make_query(*args): q.add_node(qmod.BREAK_END, qmod.PHRASE_ANY) for start, tlist in enumerate(args): - for end, ttype, tinfo in tlist: - for tid, word in tinfo: - q.add_token(TokenRange(start, end), ttype, - MyToken(penalty=0.5 if ttype == qmod.TOKEN_PARTIAL else 0.0, - token=tid, count=1, addr_count=1, - lookup_word=word)) + for end, ttype, tinfos in tlist: + for tinfo in tinfos: + if isinstance(tinfo, tuple): + q.add_token(TokenRange(start, end), ttype, + MyToken(penalty=0.5 if ttype == qmod.TOKEN_PARTIAL else 0.0, + token=tinfo[0], count=1, addr_count=1, + lookup_word=tinfo[1])) + else: + q.add_token(TokenRange(start, end), ttype, tinfo) return q def test_country_search(): - q = make_query([(1, qmod.TOKEN_COUNTRY, [(2, 'de'), (3, 'en')])]) + q = make_query([(1, qmod.TOKEN_COUNTRY, [ + MyToken(penalty=0.0, token=2, count=1, addr_count=1, lookup_word='Germany', cc='de'), + MyToken(penalty=0.0, token=3, count=1, addr_count=1, lookup_word='UK', cc='en'), + ])]) builder = SearchBuilder(q, SearchDetails()) searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1)))) @@ -58,7 +69,10 @@ def test_country_search(): def test_country_search_with_country_restriction(): - q = make_query([(1, qmod.TOKEN_COUNTRY, [(2, 'de'), (3, 'en')])]) + q = make_query([(1, qmod.TOKEN_COUNTRY, [ + MyToken(penalty=0.0, token=2, count=1, addr_count=1, lookup_word='Germany', cc='de'), + MyToken(penalty=0.0, token=3, count=1, addr_count=1, lookup_word='UK', cc='en'), + ])]) builder = SearchBuilder(q, SearchDetails.from_kwargs({'countries': 'en,fr'})) searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1)))) @@ -72,7 +86,10 @@ def test_country_search_with_country_restriction(): def test_country_search_with_conflicting_country_restriction(): - q = make_query([(1, qmod.TOKEN_COUNTRY, [(2, 'de'), (3, 'en')])]) + q = make_query([(1, qmod.TOKEN_COUNTRY, [ + MyToken(penalty=0.0, token=2, count=1, addr_count=1, lookup_word='Germany', cc='de'), + MyToken(penalty=0.0, token=3, count=1, addr_count=1, lookup_word='UK', cc='en'), + ])]) builder = SearchBuilder(q, SearchDetails.from_kwargs({'countries': 'fr'})) searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1)))) @@ -97,8 +114,11 @@ def test_postcode_search_simple(): def test_postcode_with_country(): - q = make_query([(1, qmod.TOKEN_POSTCODE, [(34, '2367')])], - [(2, qmod.TOKEN_COUNTRY, [(1, 'xx')])]) + q = make_query( + [(1, qmod.TOKEN_POSTCODE, [(34, '2367')])], + [(2, qmod.TOKEN_COUNTRY, [ + MyToken(penalty=0.0, token=1, count=1, addr_count=1, lookup_word='none', cc='xx'), + ])]) builder = SearchBuilder(q, SearchDetails()) searches = list(builder.build(TokenAssignment(postcode=TokenRange(0, 1), -- 2.39.5