]> git.openstreetmap.org Git - nominatim.git/blob - test/python/tools/test_sp_importer.py
Filter special phrases by style and frequency to fix #235
[nominatim.git] / test / python / tools / test_sp_importer.py
1 import pytest
2 import tempfile
3 import json
4 import os
5 from unittest.mock import MagicMock
6
7 from nominatim_db.errors import UsageError
8 from nominatim_db.tools.special_phrases.sp_csv_loader import SPCsvLoader
9 from nominatim_db.tools.special_phrases.special_phrase import SpecialPhrase
10 from nominatim_db.tools.special_phrases.sp_importer import SPImporter
11
12 @pytest.fixture
13 def sample_style_file():
14     sample_data = [
15         {
16             "keys" : ["emergency"],
17             "values" : {
18                 "fire_hydrant" : "skip",
19                 "yes" : "skip",
20                 "no" : "skip",
21                 "" : "main"
22             }
23         },
24         {
25             "keys" : ["historic", "military"],
26             "values" : {
27                 "no" : "skip",
28                 "yes" : "skip",
29                 "" : "main"
30             }
31         },
32         {
33             "keys" : ["name:prefix", "name:suffix", "name:prefix:*", "name:suffix:*",
34                     "name:botanical", "wikidata", "*:wikidata"],
35             "values" : {
36                 "" : "extra"
37             }
38         },
39         {
40             "keys" : ["addr:housename"],
41             "values" : {
42                 "" : "name,house"
43             }
44         },
45         {
46             "keys": ["highway"],
47             "values": {
48                 "motorway": "main",
49                 "": "skip"
50             }
51         }
52     ]
53     content = ",".join(json.dumps(entry) for entry in sample_data)
54
55     with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp:
56         tmp.write(content)
57         tmp_path = tmp.name
58
59     yield tmp_path
60     os.remove(tmp_path)
61
62
63 def test_get_sp_style(sample_style_file):
64     mock_config = MagicMock()
65     mock_config.get_import_style_file.return_value = sample_style_file
66
67     importer = SPImporter(config=mock_config, conn=None, sp_loader=None)
68     result = importer.get_sp_style()
69
70     expected = {
71         ("highway", "motorway"),
72     }
73
74     assert result == expected
75
76 @pytest.fixture
77 def mock_phrase():
78     return SpecialPhrase(
79         p_label="test",
80         p_class="highway",
81         p_type="motorway",
82         p_operator="eq"
83     )
84
85 def test_create_classtype_table_and_indexes():
86     mock_config = MagicMock()
87     mock_config.TABLESPACE_AUX_DATA = ''
88     mock_config.DATABASE_WEBUSER = 'www-data'
89
90     mock_cursor = MagicMock()
91     mock_conn = MagicMock()
92     mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
93
94     importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=None)
95
96     importer._create_place_classtype_table = MagicMock()
97     importer._create_place_classtype_indexes = MagicMock()
98     importer._grant_access_to_webuser = MagicMock()
99     importer.statistics_handler.notify_one_table_created = lambda: print("✓ Created table")
100     importer.statistics_handler.notify_one_table_ignored = lambda: print("⨉ Ignored table")
101
102     importer.table_phrases_to_delete = {"place_classtype_highway_motorway"}
103
104     test_pairs = [("highway", "motorway"), ("natural", "peak")]
105     importer._create_classtype_table_and_indexes(test_pairs)
106
107     print("create_place_classtype_table calls:")
108     for call in importer._create_place_classtype_table.call_args_list:
109         print(call)
110
111     print("\ncreate_place_classtype_indexes calls:")
112     for call in importer._create_place_classtype_indexes.call_args_list:
113         print(call)
114
115     print("\ngrant_access_to_webuser calls:")
116     for call in importer._grant_access_to_webuser.call_args_list:
117         print(call)
118
119 @pytest.fixture
120 def mock_config():
121     config = MagicMock()
122     config.TABLESPACE_AUX_DATA = ''
123     config.DATABASE_WEBUSER = 'www-data'
124     config.load_sub_configuration.return_value = {'blackList': {}, 'whiteList': {}}
125     return config
126
127
128 def test_import_phrases_original(mock_config):
129     phrase = SpecialPhrase("roundabout", "highway", "motorway", "eq")
130
131     mock_conn = MagicMock()
132     mock_cursor = MagicMock()
133     mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
134     mock_loader = MagicMock()
135     mock_loader.generate_phrases.return_value = [phrase]
136
137     mock_analyzer = MagicMock()
138     mock_tokenizer = MagicMock()
139     mock_tokenizer.name_analyzer.return_value.__enter__.return_value = mock_analyzer
140
141     importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=mock_loader)
142     importer._fetch_existing_place_classtype_tables = MagicMock()
143     importer._create_classtype_table_and_indexes = MagicMock()
144     importer._remove_non_existent_tables_from_db = MagicMock()
145
146     importer.import_phrases(tokenizer=mock_tokenizer, should_replace=True)
147
148     assert importer.word_phrases == {("roundabout", "highway", "motorway", "-")}
149
150     mock_analyzer.update_special_phrases.assert_called_once_with(
151         importer.word_phrases, True
152     )
153
154
155 def test_get_sp_filters_correctly(sample_style_file):
156     mock_config = MagicMock()
157     mock_config.get_import_style_file.return_value = sample_style_file
158     mock_config.load_sub_configuration.return_value = {"blackList": {}, "whiteList": {}}
159
160     importer = SPImporter(config=mock_config, conn=MagicMock(), sp_loader=None)
161
162     allowed_from_db = {("highway", "motorway"), ("historic", "castle")}
163     importer.get_sp_db = lambda: allowed_from_db
164
165     result = importer.get_sp()
166
167     expected = {("highway", "motorway")}
168
169     assert result == expected, f"Expected {expected}, got {result}"
170
171 def test_get_sp_db_filters_by_count_threshold(mock_config):
172     mock_cursor = MagicMock()
173     
174     # Simulate only results above the threshold being returned (as SQL would)
175     # These tuples simulate real SELECT class, type FROM placex GROUP BY ... HAVING COUNT(*) > 100
176     mock_cursor.fetchall.return_value = [
177         ("highway", "motorway"),
178         ("historic", "castle")
179     ]
180
181     mock_conn = MagicMock()
182     mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
183     importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=None)
184
185     result = importer.get_sp_db()
186
187     expected = {
188         ("highway", "motorway"),
189         ("historic", "castle")
190     }
191
192     assert result == expected
193     mock_cursor.execute.assert_called_once()