]> git.openstreetmap.org Git - nominatim.git/blobdiff - test/python/api/search/test_search_near.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / test / python / api / search / test_search_near.py
index cfbdadb2a551f23d565096df80ebe70ff12bcd5e..2a0acb745969a777a75856f8cc002ea7e33da91f 100644 (file)
@@ -16,18 +16,21 @@ from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCateg
                                                   FieldLookup, FieldRanking, RankedTokens
 
 
-def run_search(apiobj, global_penalty, cat, cat_penalty=None,
+def run_search(apiobj, global_penalty, cat, cat_penalty=None, ccodes=[],
                details=SearchDetails()):
 
     class PlaceSearchData:
         penalty = 0.0
         postcodes = WeightedStrings([], [])
-        countries = WeightedStrings([], [])
+        countries = WeightedStrings(ccodes, [0.0] * len(ccodes))
         housenumbers = WeightedStrings([], [])
         qualifiers = WeightedStrings([], [])
         lookups = [FieldLookup('name_vector', [56], 'lookup_all')]
         rankings = []
 
+    if ccodes is not None:
+        details.countries = ccodes
+
     place_search = PlaceSearch(0.0, PlaceSearchData(), 2)
 
     if cat_penalty is None:
@@ -49,6 +52,18 @@ def test_no_results_inner_query(apiobj):
     assert not run_search(apiobj, 0.4, [('this', 'that')])
 
 
+def test_no_appropriate_results_inner_query(apiobj):
+    apiobj.add_placex(place_id=100, country_code='us',
+                      centroid=(5.6, 4.3),
+                      geometry='POLYGON((0.0 0.0, 10.0 0.0, 10.0 2.0, 0.0 2.0, 0.0 0.0))')
+    apiobj.add_search_name(100, names=[56], country_code='us',
+                           centroid=(5.6, 4.3))
+    apiobj.add_placex(place_id=22, class_='amenity', type='bank',
+                      centroid=(5.6001, 4.2994))
+
+    assert not run_search(apiobj, 0.4, [('amenity', 'bank')])
+
+
 class TestNearSearch:
 
     @pytest.fixture(autouse=True)
@@ -100,3 +115,51 @@ class TestNearSearch:
 
         assert [r.place_id for r in results] == [22]
 
+
+    @pytest.mark.parametrize('cc,rid', [('us', 22), ('mx', 23)])
+    def test_restrict_by_country(self, apiobj, cc, rid):
+        apiobj.add_placex(place_id=22, class_='amenity', type='bank',
+                          centroid=(5.6001, 4.2994),
+                          country_code='us')
+        apiobj.add_placex(place_id=122, class_='amenity', type='bank',
+                          centroid=(5.6001, 4.2994),
+                          country_code='mx')
+        apiobj.add_placex(place_id=23, class_='amenity', type='bank',
+                          centroid=(-10.3001, 56.9),
+                          country_code='mx')
+        apiobj.add_placex(place_id=123, class_='amenity', type='bank',
+                          centroid=(-10.3001, 56.9),
+                          country_code='us')
+
+        results = run_search(apiobj, 0.1, [('amenity', 'bank')], ccodes=[cc, 'fr'])
+
+        assert [r.place_id for r in results] == [rid]
+
+
+    @pytest.mark.parametrize('excluded,rid', [(22, 122), (122, 22)])
+    def test_exclude_place_by_id(self, apiobj, excluded, rid):
+        apiobj.add_placex(place_id=22, class_='amenity', type='bank',
+                          centroid=(5.6001, 4.2994),
+                          country_code='us')
+        apiobj.add_placex(place_id=122, class_='amenity', type='bank',
+                          centroid=(5.6001, 4.2994),
+                          country_code='us')
+
+
+        results = run_search(apiobj, 0.1, [('amenity', 'bank')],
+                             details=SearchDetails(excluded=[excluded]))
+
+        assert [r.place_id for r in results] == [rid]
+
+
+    @pytest.mark.parametrize('layer,rids', [(napi.DataLayer.POI, [22]),
+                                            (napi.DataLayer.MANMADE, [])])
+    def test_with_layer(self, apiobj, layer, rids):
+        apiobj.add_placex(place_id=22, class_='amenity', type='bank',
+                          centroid=(5.6001, 4.2994),
+                          country_code='us')
+
+        results = run_search(apiobj, 0.1, [('amenity', 'bank')],
+                             details=SearchDetails(layers=layer))
+
+        assert [r.place_id for r in results] == rids