From 1952290359cd1cffccc1af503ec2edcad1f401d2 Mon Sep 17 00:00:00 2001 From: anqixxx Date: Fri, 11 Apr 2025 12:03:57 -0700 Subject: [PATCH] Removed magic mocking, using monkeypatch instead, and using a placex table to simulate a 'real database' --- .../tools/special_phrases/sp_importer.py | 53 +++--- test/python/tools/test_sp_importer.py | 180 +++++++----------- 2 files changed, 99 insertions(+), 134 deletions(-) diff --git a/src/nominatim_db/tools/special_phrases/sp_importer.py b/src/nominatim_db/tools/special_phrases/sp_importer.py index d9bf3165..a4d0eaf6 100644 --- a/src/nominatim_db/tools/special_phrases/sp_importer.py +++ b/src/nominatim_db/tools/special_phrases/sp_importer.py @@ -16,7 +16,7 @@ from typing import Iterable, Tuple, Mapping, Sequence, Optional, Set import logging import re -import json +import json from psycopg.sql import Identifier, SQL from ...typing import Protocol @@ -65,37 +65,37 @@ class SPImporter(): # special phrases class/type on the wiki. self.table_phrases_to_delete: Set[str] = set() - def get_classtype_pairs_style(self) -> Set[Tuple[str, str]]: + def get_classtype_pairs_style(self) -> Set[Tuple[str, str]]: """ - Returns list of allowed special phrases from the the style file, - restricting to a list of combinations of classes and types + Returns list of allowed special phrases from the the style file, + restricting to a list of combinations of classes and types which have a 'main' property Note: This requirement was from 2021 and I am a bit unsure if it is still relevant """ - style_file = self.config.get_import_style_file() # this gives the path, so i will import it as a json + style_file = self.config.get_import_style_file() # import style file as json with open(style_file, 'r') as file: style_data = json.loads(f'[{file.read()}]') style_combinations = set() - for _map in style_data: # following ../settings/import-extratags.style + for _map in style_data: # following ../settings/import-extratags.style classes = _map.get("keys", []) values = _map.get("values", {}) for _type, properties in values.items(): - if "main" in properties and _type: # make sure the tag is not an empty string. since type is the value of the main tag + if "main" in properties and _type: # make sure the tag is a non-empty string for _class in classes: - style_combinations.add((_class, _type)) + style_combinations.add((_class, _type)) # type is the value of the main tag return style_combinations - - def get_classtype_pairs(self) -> Set[Tuple[str, str]]: + + def get_classtype_pairs(self) -> Set[Tuple[str, str]]: """ - Returns list of allowed special phrases from the database, - restricting to a list of combinations of classes and types - whic occur more than 100 times + Returns list of allowed special phrases from the database, + restricting to a list of combinations of classes and types + which occur more than 100 times """ - db_combinations = set() + db_combinations = set() query = """ SELECT class AS CLS, type AS typ FROM placex @@ -104,13 +104,12 @@ class SPImporter(): """ with self.db_connection.cursor() as db_cursor: - db_cursor.execute(SQL(query)) + db_cursor.execute(SQL(query)) for row in db_cursor.fetchall(): db_combinations.add((row[0], row[1])) - return db_combinations + return db_combinations - def import_phrases(self, tokenizer: AbstractTokenizer, should_replace: bool) -> None: """ Iterate through all SpecialPhrases extracted from the @@ -131,11 +130,10 @@ class SPImporter(): if result: class_type_pairs.add(result) - self._create_classtype_table_and_indexes(class_type_pairs) + self._create_classtype_table_and_indexes(class_type_pairs) if should_replace: self._remove_non_existent_tables_from_db() - self.db_connection.commit() with tokenizer.name_analyzer() as analyzer: @@ -235,7 +233,7 @@ class SPImporter(): LOG.warning("Skipping phrase %s=%s: not in allowed special phrases", phrase_class, phrase_type) continue - + table_name = _classtype_table(phrase_class, phrase_type) if table_name in self.table_phrases_to_delete: @@ -268,13 +266,13 @@ class SPImporter(): """ table_name = _classtype_table(phrase_class, phrase_type) with self.db_connection.cursor() as db_cursor: - db_cursor.execute(SQL("""CREATE TABLE IF NOT EXISTS {} {} AS - SELECT place_id AS place_id, - st_centroid(geometry) AS centroid - FROM placex - WHERE class = %s AND type = %s - """).format(Identifier(table_name), SQL(sql_tablespace)), - (phrase_class, phrase_type)) + db_cursor.execute(SQL( + """CREATE TABLE IF NOT EXISTS {} {} AS + SELECT place_id AS place_id, + st_centroid(geometry) AS centroid + FROM placex WHERE class = %s AND type = %s + """).format(Identifier(table_name), SQL(sql_tablespace)), + (phrase_class, phrase_type)) def _create_place_classtype_indexes(self, sql_tablespace: str, phrase_class: str, phrase_type: str) -> None: @@ -321,4 +319,3 @@ class SPImporter(): drop_tables(self.db_connection, *self.table_phrases_to_delete) for _ in self.table_phrases_to_delete: self.statistics_handler.notify_one_table_deleted() - diff --git a/test/python/tools/test_sp_importer.py b/test/python/tools/test_sp_importer.py index b49d2ea1..4d2dd8d4 100644 --- a/test/python/tools/test_sp_importer.py +++ b/test/python/tools/test_sp_importer.py @@ -2,13 +2,10 @@ import pytest import tempfile import json import os -from unittest.mock import MagicMock -from nominatim_db.errors import UsageError -from nominatim_db.tools.special_phrases.sp_csv_loader import SPCsvLoader -from nominatim_db.tools.special_phrases.special_phrase import SpecialPhrase from nominatim_db.tools.special_phrases.sp_importer import SPImporter +# Testing Style Class Pair Retrival @pytest.fixture def sample_style_file(): sample_data = [ @@ -59,135 +56,106 @@ def sample_style_file(): yield tmp_path os.remove(tmp_path) +def test_get_classtype_style(sample_style_file): + class Config: + def get_import_style_file(self): + return sample_style_file + + def load_sub_configuration(self, name): + return {'blackList': {}, 'whiteList': {}} -def test_get_sp_style(sample_style_file): - mock_config = MagicMock() - mock_config.get_import_style_file.return_value = sample_style_file + config = Config() + importer = SPImporter(config=config, conn=None, sp_loader=None) - importer = SPImporter(config=mock_config, conn=None, sp_loader=None) - result = importer.get_sp_style() + result = importer.get_classtype_pairs_style() expected = { ("highway", "motorway"), } - assert result == expected - -@pytest.fixture -def mock_phrase(): - return SpecialPhrase( - p_label="test", - p_class="highway", - p_type="motorway", - p_operator="eq" - ) - -def test_create_classtype_table_and_indexes(): - mock_config = MagicMock() - mock_config.TABLESPACE_AUX_DATA = '' - mock_config.DATABASE_WEBUSER = 'www-data' + assert expected.issubset(result) - mock_cursor = MagicMock() - mock_conn = MagicMock() - mock_conn.cursor.return_value.__enter__.return_value = mock_cursor +# Testing Database Class Pair Retrival using Mock Database +def test_get_classtype_pairs(monkeypatch): + class Config: + def load_sub_configuration(self, path, section=None): + return {"blackList": {}, "whiteList": {}} - importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=None) + class Cursor: + def execute(self, query): pass + def fetchall(self): + return [ + ("highway", "motorway"), + ("historic", "castle") + ] + def __enter__(self): return self + def __exit__(self, exc_type, exc_val, exc_tb): pass - importer._create_place_classtype_table = MagicMock() - importer._create_place_classtype_indexes = MagicMock() - importer._grant_access_to_webuser = MagicMock() - importer.statistics_handler.notify_one_table_created = lambda: print("✓ Created table") - importer.statistics_handler.notify_one_table_ignored = lambda: print("⨉ Ignored table") + class Connection: + def cursor(self): return Cursor() - importer.table_phrases_to_delete = {"place_classtype_highway_motorway"} + config = Config() + conn = Connection() + importer = SPImporter(config=config, conn=conn, sp_loader=None) - test_pairs = [("highway", "motorway"), ("natural", "peak")] - importer._create_classtype_table_and_indexes(test_pairs) + result = importer.get_classtype_pairs() - print("create_place_classtype_table calls:") - for call in importer._create_place_classtype_table.call_args_list: - print(call) - - print("\ncreate_place_classtype_indexes calls:") - for call in importer._create_place_classtype_indexes.call_args_list: - print(call) - - print("\ngrant_access_to_webuser calls:") - for call in importer._grant_access_to_webuser.call_args_list: - print(call) - -@pytest.fixture -def mock_config(): - config = MagicMock() - config.TABLESPACE_AUX_DATA = '' - config.DATABASE_WEBUSER = 'www-data' - config.load_sub_configuration.return_value = {'blackList': {}, 'whiteList': {}} - return config - - -def test_import_phrases_original(mock_config): - phrase = SpecialPhrase("roundabout", "highway", "motorway", "eq") - - mock_conn = MagicMock() - mock_cursor = MagicMock() - mock_conn.cursor.return_value.__enter__.return_value = mock_cursor - mock_loader = MagicMock() - mock_loader.generate_phrases.return_value = [phrase] - - mock_analyzer = MagicMock() - mock_tokenizer = MagicMock() - mock_tokenizer.name_analyzer.return_value.__enter__.return_value = mock_analyzer - - importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=mock_loader) - importer._fetch_existing_place_classtype_tables = MagicMock() - importer._create_classtype_table_and_indexes = MagicMock() - importer._remove_non_existent_tables_from_db = MagicMock() + expected = { + ("highway", "motorway"), + ("historic", "castle") + } - importer.import_phrases(tokenizer=mock_tokenizer, should_replace=True) + assert result == expected - assert importer.word_phrases == {("roundabout", "highway", "motorway", "-")} +# Testing Database Class Pair Retrival using Conftest.py and placex +def test_get_classtype_pair_data(placex_table, temp_db_conn): + class Config: + def load_sub_configuration(self, *_): + return {'blackList': {}, 'whiteList': {}} + + for _ in range(101): + placex_table.add(cls='highway', typ='motorway') # edge case 101 - mock_analyzer.update_special_phrases.assert_called_once_with( - importer.word_phrases, True - ) + for _ in range(99): + placex_table.add(cls='amenity', typ='prison') # edge case 99 + for _ in range(150): + placex_table.add(cls='tourism', typ='hotel') -def test_get_sp_filters_correctly(sample_style_file): - mock_config = MagicMock() - mock_config.get_import_style_file.return_value = sample_style_file - mock_config.load_sub_configuration.return_value = {"blackList": {}, "whiteList": {}} + config = Config() + importer = SPImporter(config=config, conn=temp_db_conn, sp_loader=None) - importer = SPImporter(config=mock_config, conn=MagicMock(), sp_loader=None) + result = importer.get_classtype_pairs() - allowed_from_db = {("highway", "motorway"), ("historic", "castle")} - importer.get_sp_db = lambda: allowed_from_db + expected = { + ("highway", "motorway"), + ("tourism", "hotel") + } - result = importer.get_sp() + assert result == expected, f"Expected {expected}, got {result}" - expected = {("highway", "motorway")} +def test_get_classtype_pair_data_more(placex_table, temp_db_conn): + class Config: + def load_sub_configuration(self, *_): + return {'blackList': {}, 'whiteList': {}} + + for _ in range(100): + placex_table.add(cls='emergency', typ='firehydrant') # edge case 100, not included - assert result == expected, f"Expected {expected}, got {result}" + for _ in range(199): + placex_table.add(cls='amenity', typ='prison') -def test_get_sp_db_filters_by_count_threshold(mock_config): - mock_cursor = MagicMock() - - # Simulate only results above the threshold being returned (as SQL would) - # These tuples simulate real SELECT class, type FROM placex GROUP BY ... HAVING COUNT(*) > 100 - mock_cursor.fetchall.return_value = [ - ("highway", "motorway"), - ("historic", "castle") - ] + for _ in range(3478): + placex_table.add(cls='tourism', typ='hotel') - mock_conn = MagicMock() - mock_conn.cursor.return_value.__enter__.return_value = mock_cursor - importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=None) + config = Config() + importer = SPImporter(config=config, conn=temp_db_conn, sp_loader=None) - result = importer.get_sp_db() + result = importer.get_classtype_pairs() expected = { - ("highway", "motorway"), - ("historic", "castle") + ("amenity", "prison"), + ("tourism", "hotel") } - assert result == expected - mock_cursor.execute.assert_called_once() + assert result == expected, f"Expected {expected}, got {result}" -- 2.39.5