From 1a323165f93d26d8d519d7b9c6a160663a031e7f Mon Sep 17 00:00:00 2001 From: anqixxx Date: Mon, 7 Apr 2025 21:40:42 -0700 Subject: [PATCH] Filter special phrases by style and frequency to fix #235 --- .../tools/special_phrases/sp_importer.py | 64 +++++- test/python/tools/test_sp_importer.py | 193 ++++++++++++++++++ 2 files changed, 253 insertions(+), 4 deletions(-) create mode 100644 test/python/tools/test_sp_importer.py diff --git a/src/nominatim_db/tools/special_phrases/sp_importer.py b/src/nominatim_db/tools/special_phrases/sp_importer.py index 40b089a7..d9bf3165 100644 --- a/src/nominatim_db/tools/special_phrases/sp_importer.py +++ b/src/nominatim_db/tools/special_phrases/sp_importer.py @@ -16,7 +16,7 @@ from typing import Iterable, Tuple, Mapping, Sequence, Optional, Set import logging import re - +import json from psycopg.sql import Identifier, SQL from ...typing import Protocol @@ -65,6 +65,52 @@ class SPImporter(): # special phrases class/type on the wiki. self.table_phrases_to_delete: Set[str] = set() + def get_classtype_pairs_style(self) -> Set[Tuple[str, str]]: + """ + Returns list of allowed special phrases from the the style file, + restricting to a list of combinations of classes and types + which have a 'main' property + + Note: This requirement was from 2021 and I am a bit unsure if it is still relevant + """ + style_file = self.config.get_import_style_file() # this gives the path, so i will import it as a json + with open(style_file, 'r') as file: + style_data = json.loads(f'[{file.read()}]') + + style_combinations = set() + for _map in style_data: # following ../settings/import-extratags.style + classes = _map.get("keys", []) + values = _map.get("values", {}) + + for _type, properties in values.items(): + if "main" in properties and _type: # make sure the tag is not an empty string. since type is the value of the main tag + for _class in classes: + style_combinations.add((_class, _type)) + + return style_combinations + + def get_classtype_pairs(self) -> Set[Tuple[str, str]]: + """ + Returns list of allowed special phrases from the database, + restricting to a list of combinations of classes and types + whic occur more than 100 times + """ + db_combinations = set() + query = """ + SELECT class AS CLS, type AS typ + FROM placex + GROUP BY class, type + HAVING COUNT(*) > 100 + """ + + with self.db_connection.cursor() as db_cursor: + db_cursor.execute(SQL(query)) + for row in db_cursor.fetchall(): + db_combinations.add((row[0], row[1])) + + return db_combinations + + def import_phrases(self, tokenizer: AbstractTokenizer, should_replace: bool) -> None: """ Iterate through all SpecialPhrases extracted from the @@ -85,9 +131,11 @@ class SPImporter(): if result: class_type_pairs.add(result) - self._create_classtype_table_and_indexes(class_type_pairs) + self._create_classtype_table_and_indexes(class_type_pairs) if should_replace: self._remove_non_existent_tables_from_db() + + self.db_connection.commit() with tokenizer.name_analyzer() as analyzer: @@ -177,10 +225,17 @@ class SPImporter(): with self.db_connection.cursor() as db_cursor: db_cursor.execute("CREATE INDEX idx_placex_classtype ON placex (class, type)") + allowed_special_phrases = self.get_classtype_pairs() + for pair in class_type_pairs: phrase_class = pair[0] phrase_type = pair[1] + if (phrase_class, phrase_type) not in allowed_special_phrases: + LOG.warning("Skipping phrase %s=%s: not in allowed special phrases", + phrase_class, phrase_type) + continue + table_name = _classtype_table(phrase_class, phrase_type) if table_name in self.table_phrases_to_delete: @@ -212,8 +267,8 @@ class SPImporter(): if doesn't exit. """ table_name = _classtype_table(phrase_class, phrase_type) - with self.db_connection.cursor() as cur: - cur.execute(SQL("""CREATE TABLE IF NOT EXISTS {} {} AS + with self.db_connection.cursor() as db_cursor: + db_cursor.execute(SQL("""CREATE TABLE IF NOT EXISTS {} {} AS SELECT place_id AS place_id, st_centroid(geometry) AS centroid FROM placex @@ -266,3 +321,4 @@ class SPImporter(): drop_tables(self.db_connection, *self.table_phrases_to_delete) for _ in self.table_phrases_to_delete: self.statistics_handler.notify_one_table_deleted() + diff --git a/test/python/tools/test_sp_importer.py b/test/python/tools/test_sp_importer.py new file mode 100644 index 00000000..b49d2ea1 --- /dev/null +++ b/test/python/tools/test_sp_importer.py @@ -0,0 +1,193 @@ +import pytest +import tempfile +import json +import os +from unittest.mock import MagicMock + +from nominatim_db.errors import UsageError +from nominatim_db.tools.special_phrases.sp_csv_loader import SPCsvLoader +from nominatim_db.tools.special_phrases.special_phrase import SpecialPhrase +from nominatim_db.tools.special_phrases.sp_importer import SPImporter + +@pytest.fixture +def sample_style_file(): + sample_data = [ + { + "keys" : ["emergency"], + "values" : { + "fire_hydrant" : "skip", + "yes" : "skip", + "no" : "skip", + "" : "main" + } + }, + { + "keys" : ["historic", "military"], + "values" : { + "no" : "skip", + "yes" : "skip", + "" : "main" + } + }, + { + "keys" : ["name:prefix", "name:suffix", "name:prefix:*", "name:suffix:*", + "name:botanical", "wikidata", "*:wikidata"], + "values" : { + "" : "extra" + } + }, + { + "keys" : ["addr:housename"], + "values" : { + "" : "name,house" + } + }, + { + "keys": ["highway"], + "values": { + "motorway": "main", + "": "skip" + } + } + ] + content = ",".join(json.dumps(entry) for entry in sample_data) + + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + tmp.write(content) + tmp_path = tmp.name + + yield tmp_path + os.remove(tmp_path) + + +def test_get_sp_style(sample_style_file): + mock_config = MagicMock() + mock_config.get_import_style_file.return_value = sample_style_file + + importer = SPImporter(config=mock_config, conn=None, sp_loader=None) + result = importer.get_sp_style() + + expected = { + ("highway", "motorway"), + } + + assert result == expected + +@pytest.fixture +def mock_phrase(): + return SpecialPhrase( + p_label="test", + p_class="highway", + p_type="motorway", + p_operator="eq" + ) + +def test_create_classtype_table_and_indexes(): + mock_config = MagicMock() + mock_config.TABLESPACE_AUX_DATA = '' + mock_config.DATABASE_WEBUSER = 'www-data' + + mock_cursor = MagicMock() + mock_conn = MagicMock() + mock_conn.cursor.return_value.__enter__.return_value = mock_cursor + + importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=None) + + importer._create_place_classtype_table = MagicMock() + importer._create_place_classtype_indexes = MagicMock() + importer._grant_access_to_webuser = MagicMock() + importer.statistics_handler.notify_one_table_created = lambda: print("✓ Created table") + importer.statistics_handler.notify_one_table_ignored = lambda: print("⨉ Ignored table") + + importer.table_phrases_to_delete = {"place_classtype_highway_motorway"} + + test_pairs = [("highway", "motorway"), ("natural", "peak")] + importer._create_classtype_table_and_indexes(test_pairs) + + print("create_place_classtype_table calls:") + for call in importer._create_place_classtype_table.call_args_list: + print(call) + + print("\ncreate_place_classtype_indexes calls:") + for call in importer._create_place_classtype_indexes.call_args_list: + print(call) + + print("\ngrant_access_to_webuser calls:") + for call in importer._grant_access_to_webuser.call_args_list: + print(call) + +@pytest.fixture +def mock_config(): + config = MagicMock() + config.TABLESPACE_AUX_DATA = '' + config.DATABASE_WEBUSER = 'www-data' + config.load_sub_configuration.return_value = {'blackList': {}, 'whiteList': {}} + return config + + +def test_import_phrases_original(mock_config): + phrase = SpecialPhrase("roundabout", "highway", "motorway", "eq") + + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value.__enter__.return_value = mock_cursor + mock_loader = MagicMock() + mock_loader.generate_phrases.return_value = [phrase] + + mock_analyzer = MagicMock() + mock_tokenizer = MagicMock() + mock_tokenizer.name_analyzer.return_value.__enter__.return_value = mock_analyzer + + importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=mock_loader) + importer._fetch_existing_place_classtype_tables = MagicMock() + importer._create_classtype_table_and_indexes = MagicMock() + importer._remove_non_existent_tables_from_db = MagicMock() + + importer.import_phrases(tokenizer=mock_tokenizer, should_replace=True) + + assert importer.word_phrases == {("roundabout", "highway", "motorway", "-")} + + mock_analyzer.update_special_phrases.assert_called_once_with( + importer.word_phrases, True + ) + + +def test_get_sp_filters_correctly(sample_style_file): + mock_config = MagicMock() + mock_config.get_import_style_file.return_value = sample_style_file + mock_config.load_sub_configuration.return_value = {"blackList": {}, "whiteList": {}} + + importer = SPImporter(config=mock_config, conn=MagicMock(), sp_loader=None) + + allowed_from_db = {("highway", "motorway"), ("historic", "castle")} + importer.get_sp_db = lambda: allowed_from_db + + result = importer.get_sp() + + expected = {("highway", "motorway")} + + assert result == expected, f"Expected {expected}, got {result}" + +def test_get_sp_db_filters_by_count_threshold(mock_config): + mock_cursor = MagicMock() + + # Simulate only results above the threshold being returned (as SQL would) + # These tuples simulate real SELECT class, type FROM placex GROUP BY ... HAVING COUNT(*) > 100 + mock_cursor.fetchall.return_value = [ + ("highway", "motorway"), + ("historic", "castle") + ] + + mock_conn = MagicMock() + mock_conn.cursor.return_value.__enter__.return_value = mock_cursor + importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=None) + + result = importer.get_sp_db() + + expected = { + ("highway", "motorway"), + ("historic", "castle") + } + + assert result == expected + mock_cursor.execute.assert_called_once() -- 2.39.5