]> git.openstreetmap.org Git - nominatim.git/commitdiff
correctly quote strings when copying in data
authorSarah Hoffmann <lonvia@denofr.de>
Thu, 10 Jun 2021 07:36:43 +0000 (09:36 +0200)
committerSarah Hoffmann <lonvia@denofr.de>
Sun, 4 Jul 2021 08:28:20 +0000 (10:28 +0200)
Encapsulate the copy string in a class that ensures that
copy lines are written with correct quoting.

nominatim/db/utils.py
nominatim/tokenizer/icu_rule_loader.py
nominatim/tokenizer/legacy_icu_tokenizer.py
test/python/test_db_utils.py
test/python/test_tokenizer_icu_rule_loader.py

index b376940d804af364c049864a07649c897b515f0f..4d4305e7d67ff74c93119bbc67ef4acfa7036e2c 100644 (file)
@@ -4,6 +4,7 @@ Helper functions for handling DB accesses.
 import subprocess
 import logging
 import gzip
 import subprocess
 import logging
 import gzip
+import io
 
 from nominatim.db.connection import get_pg_env
 from nominatim.errors import UsageError
 
 from nominatim.db.connection import get_pg_env
 from nominatim.errors import UsageError
@@ -57,3 +58,49 @@ def execute_file(dsn, fname, ignore_errors=False, pre_code=None, post_code=None)
 
     if ret != 0 or remain > 0:
         raise UsageError("Failed to execute SQL file.")
 
     if ret != 0 or remain > 0:
         raise UsageError("Failed to execute SQL file.")
+
+
+# List of characters that need to be quoted for the copy command.
+_SQL_TRANSLATION = {ord(u'\\') : u'\\\\',
+                    ord(u'\t') : u'\\t',
+                    ord(u'\n') : u'\\n'}
+
+class CopyBuffer:
+    """ Data collector for the copy_from command.
+    """
+
+    def __init__(self):
+        self.buffer = io.StringIO()
+
+
+    def __enter__(self):
+        return self
+
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        if self.buffer is not None:
+            self.buffer.close()
+
+
+    def add(self, *data):
+        """ Add another row of data to the copy buffer.
+        """
+        first = True
+        for column in data:
+            if first:
+                first = False
+            else:
+                self.buffer.write('\t')
+            if column is None:
+                self.buffer.write('\\N')
+            else:
+                self.buffer.write(str(column).translate(_SQL_TRANSLATION))
+        self.buffer.write('\n')
+
+
+    def copy_out(self, cur, table, columns=None):
+        """ Copy all collected data into the given table.
+        """
+        if self.buffer.tell() > 0:
+            self.buffer.seek(0)
+            cur.copy_from(self.buffer, table, columns=columns)
index a11b9bd86e0140b97d4fd189e2d44b6b79adc13b..269faed981abbbb9ffc530bd32d6b38ae0c30df4 100644 (file)
@@ -93,7 +93,7 @@ class ICURuleLoader:
 
 
     def _load_from_yaml(self):
 
 
     def _load_from_yaml(self):
-        rules = yaml.load(self.configfile.read_text())
+        rules = yaml.safe_load(self.configfile.read_text())
 
         self.normalization_rules = self._cfg_to_icu_rules(rules, 'normalization')
         self.transliteration_rules = self._cfg_to_icu_rules(rules, 'transliteration')
 
         self.normalization_rules = self._cfg_to_icu_rules(rules, 'normalization')
         self.transliteration_rules = self._cfg_to_icu_rules(rules, 'transliteration')
@@ -122,6 +122,9 @@ class ICURuleLoader:
         """
         content = self._get_section(rules, section)
 
         """
         content = self._get_section(rules, section)
 
+        if content is None:
+            return ''
+
         if isinstance(content, str):
             return (self.configfile.parent / content).read_text().replace('\n', ' ')
 
         if isinstance(content, str):
             return (self.configfile.parent / content).read_text().replace('\n', ' ')
 
@@ -160,4 +163,5 @@ class ICURuleLoader:
             abbrterms = (norm.transliterate(t.strip()) for t in parts[1].split(','))
 
             for full, abbr in itertools.product(fullterms, abbrterms):
             abbrterms = (norm.transliterate(t.strip()) for t in parts[1].split(','))
 
             for full, abbr in itertools.product(fullterms, abbrterms):
-                self.abbreviations[full].append(abbr)
+                if full and abbr:
+                    self.abbreviations[full].append(abbr)
index f3eb7b4ef4fd9fae8bfcf6f7ef538ded91cfa08e..af53e825426bc1c1fcaed5fb1b4895e5a250719e 100644 (file)
@@ -14,6 +14,7 @@ import psycopg2.extras
 
 from nominatim.db.connection import connect
 from nominatim.db.properties import set_property, get_property
 
 from nominatim.db.connection import connect
 from nominatim.db.properties import set_property, get_property
+from nominatim.db.utils import CopyBuffer
 from nominatim.db.sql_preprocessor import SQLPreprocessor
 from nominatim.tokenizer.icu_rule_loader import ICURuleLoader
 from nominatim.tokenizer.icu_name_processor import ICUNameProcessor, ICUNameProcessorRules
 from nominatim.db.sql_preprocessor import SQLPreprocessor
 from nominatim.tokenizer.icu_rule_loader import ICURuleLoader
 from nominatim.tokenizer.icu_name_processor import ICUNameProcessor, ICUNameProcessorRules
@@ -134,7 +135,7 @@ class LegacyICUTokenizer:
             @define('CONST_Term_Normalization_Rules', "{0.term_normalization}");
             @define('CONST_Transliteration', "{0.naming_rules.search_rules}");
             require_once('{1}/tokenizer/legacy_icu_tokenizer.php');
             @define('CONST_Term_Normalization_Rules', "{0.term_normalization}");
             @define('CONST_Transliteration', "{0.naming_rules.search_rules}");
             require_once('{1}/tokenizer/legacy_icu_tokenizer.php');
-            """.format(self, phpdir)))
+            """.format(self, phpdir))) # pylint: disable=missing-format-attribute
 
 
     def _save_config(self, config):
 
 
     def _save_config(self, config):
@@ -171,14 +172,15 @@ class LegacyICUTokenizer:
                             words[term] += cnt
 
             # copy them back into the word table
                             words[term] += cnt
 
             # copy them back into the word table
-            copystr = io.StringIO(''.join(('{}\t{}\n'.format(*args) for args in words.items())))
+            with CopyBuffer() as copystr:
+                for args in words.items():
+                    copystr.add(*args)
 
 
-
-            with conn.cursor() as cur:
-                copystr.seek(0)
-                cur.copy_from(copystr, 'word', columns=['word_token', 'search_name_count'])
-                cur.execute("""UPDATE word SET word_id = nextval('seq_word')
-                               WHERE word_id is null""")
+                with conn.cursor() as cur:
+                    copystr.copy_out(cur, 'word',
+                                      columns=['word_token', 'search_name_count'])
+                    cur.execute("""UPDATE word SET word_id = nextval('seq_word')
+                                   WHERE word_id is null""")
 
             conn.commit()
 
 
             conn.commit()
 
@@ -265,7 +267,6 @@ class LegacyICUNameAnalyzer:
             table.
         """
         to_delete = []
             table.
         """
         to_delete = []
-        copystr = io.StringIO()
         with self.conn.cursor() as cur:
             # This finds us the rows in location_postcode and word that are
             # missing in the other table.
         with self.conn.cursor() as cur:
             # This finds us the rows in location_postcode and word that are
             # missing in the other table.
@@ -278,26 +279,25 @@ class LegacyICUNameAnalyzer:
                               ON pc = word) x
                            WHERE pc is null or word is null""")
 
                               ON pc = word) x
                            WHERE pc is null or word is null""")
 
-            for postcode, word in cur:
-                if postcode is None:
-                    to_delete.append(word)
-                else:
-                    copystr.write(postcode)
-                    copystr.write('\t ')
-                    copystr.write(self.name_processor.get_search_normalized(postcode))
-                    copystr.write('\tplace\tpostcode\t0\n')
+            with CopyBuffer() as copystr:
+                for postcode, word in cur:
+                    if postcode is None:
+                        to_delete.append(word)
+                    else:
+                        copystr.add(
+                            postcode,
+                            ' ' + self.name_processor.get_search_normalized(postcode),
+                            'place', 'postcode', 0)
 
 
-            if to_delete:
-                cur.execute("""DELETE FROM WORD
-                               WHERE class ='place' and type = 'postcode'
-                                     and word = any(%s)
-                            """, (to_delete, ))
+                if to_delete:
+                    cur.execute("""DELETE FROM WORD
+                                   WHERE class ='place' and type = 'postcode'
+                                         and word = any(%s)
+                                """, (to_delete, ))
 
 
-            if copystr.getvalue():
-                copystr.seek(0)
-                cur.copy_from(copystr, 'word',
-                              columns=['word', 'word_token', 'class', 'type',
-                                       'search_name_count'])
+                copystr.copy_out(cur, 'word',
+                                 columns=['word', 'word_token', 'class', 'type',
+                                          'search_name_count'])
 
 
     def update_special_phrases(self, phrases, should_replace):
 
 
     def update_special_phrases(self, phrases, should_replace):
@@ -331,34 +331,24 @@ class LegacyICUNameAnalyzer:
         """
         to_add = new_phrases - existing_phrases
 
         """
         to_add = new_phrases - existing_phrases
 
-        copystr = io.StringIO()
         added = 0
         added = 0
-        for word, cls, typ, oper in to_add:
-            term = self.name_processor.get_search_normalized(word)
-            if term:
-                copystr.write(word)
-                copystr.write('\t ')
-                copystr.write(term)
-                copystr.write('\t')
-                copystr.write(cls)
-                copystr.write('\t')
-                copystr.write(typ)
-                copystr.write('\t')
-                copystr.write(oper if oper in ('in', 'near')  else '\\N')
-                copystr.write('\t0\n')
-                added += 1
-
-
-        if copystr.tell() > 0:
-            copystr.seek(0)
-            cursor.copy_from(copystr, 'word',
+        with CopyBuffer() as copystr:
+            for word, cls, typ, oper in to_add:
+                term = self.name_processor.get_search_normalized(word)
+                if term:
+                    copystr.add(word, term, cls, typ,
+                                oper if oper in ('in', 'near')  else None, 0)
+                    added += 1
+
+            copystr.copy_out(cursor, 'word',
                              columns=['word', 'word_token', 'class', 'type',
                                       'operator', 'search_name_count'])
 
         return added
 
 
                              columns=['word', 'word_token', 'class', 'type',
                                       'operator', 'search_name_count'])
 
         return added
 
 
-    def _remove_special_phrases(self, cursor, new_phrases, existing_phrases):
+    @staticmethod
+    def _remove_special_phrases(cursor, new_phrases, existing_phrases):
         """ Remove all phrases from the databse that are no longer in the
             new phrase list.
         """
         """ Remove all phrases from the databse that are no longer in the
             new phrase list.
         """
index d549b70f803ec8f7e873e21e8c79f350e73f7af6..545cc58f633448096fbd2f212a19e69160ae01ff 100644 (file)
@@ -50,3 +50,68 @@ def test_execute_file_with_post_code(dsn, tmp_path, temp_db_cursor):
     db_utils.execute_file(dsn, tmpfile, post_code='INSERT INTO test VALUES(23)')
 
     assert temp_db_cursor.row_set('SELECT * FROM test') == {(23, )}
     db_utils.execute_file(dsn, tmpfile, post_code='INSERT INTO test VALUES(23)')
 
     assert temp_db_cursor.row_set('SELECT * FROM test') == {(23, )}
+
+
+class TestCopyBuffer:
+    TABLE_NAME = 'copytable'
+
+    @pytest.fixture(autouse=True)
+    def setup_test_table(self, table_factory):
+        table_factory(self.TABLE_NAME, 'colA INT, colB TEXT')
+
+
+    def table_rows(self, cursor):
+        return cursor.row_set('SELECT * FROM ' + self.TABLE_NAME)
+
+
+    def test_copybuffer_empty(self):
+        with db_utils.CopyBuffer() as buf:
+            buf.copy_out(None, "dummy")
+
+
+    def test_all_columns(self, temp_db_cursor):
+        with db_utils.CopyBuffer() as buf:
+            buf.add(3, 'hum')
+            buf.add(None, 'f\\t')
+
+            buf.copy_out(temp_db_cursor, self.TABLE_NAME)
+
+        assert self.table_rows(temp_db_cursor) == {(3, 'hum'), (None, 'f\\t')}
+
+
+    def test_selected_columns(self, temp_db_cursor):
+        with db_utils.CopyBuffer() as buf:
+            buf.add('foo')
+
+            buf.copy_out(temp_db_cursor, self.TABLE_NAME,
+                         columns=['colB'])
+
+        assert self.table_rows(temp_db_cursor) == {(None, 'foo')}
+
+
+    def test_reordered_columns(self, temp_db_cursor):
+        with db_utils.CopyBuffer() as buf:
+            buf.add('one', 1)
+            buf.add(' two ', 2)
+
+            buf.copy_out(temp_db_cursor, self.TABLE_NAME,
+                         columns=['colB', 'colA'])
+
+        assert self.table_rows(temp_db_cursor) == {(1, 'one'), (2, ' two ')}
+
+
+    def test_special_characters(self, temp_db_cursor):
+        with db_utils.CopyBuffer() as buf:
+            buf.add('foo\tbar')
+            buf.add('sun\nson')
+            buf.add('\\N')
+
+            buf.copy_out(temp_db_cursor, self.TABLE_NAME,
+                         columns=['colB'])
+
+        assert self.table_rows(temp_db_cursor) == {(None, 'foo\tbar'),
+                                                   (None, 'sun\nson'),
+                                                   (None, '\\N')}
+
+
+
index 20b127f39c7c622e2424e3b50fc0c38c7f3bb934..abbc92423f4d9b1f44941b90f16534491c9dd2b6 100644 (file)
@@ -21,6 +21,7 @@ def cfgfile(tmp_path, suffix='.yaml'):
             - ":: NFC ()"
         transliteration:
             - "::  Latin ()"
             - ":: NFC ()"
         transliteration:
             - "::  Latin ()"
+            - "[[:Punctuation:][:Space:]]+ > ' '"
         """)
         content += "compound_suffixes:\n"
         content += '\n'.join(("    - " + s for s in suffixes)) + '\n'
         """)
         content += "compound_suffixes:\n"
         content += '\n'.join(("    - " + s for s in suffixes)) + '\n'
@@ -32,13 +33,33 @@ def cfgfile(tmp_path, suffix='.yaml'):
 
     return _create_config
 
 
     return _create_config
 
-def test_missing_normalization(tmp_path):
+
+def test_empty_rule_file(tmp_path):
     fpath = tmp_path / ('test_config.yaml')
     fpath.write_text(dedent("""\
     fpath = tmp_path / ('test_config.yaml')
     fpath.write_text(dedent("""\
-        normalizatio:
-            - ":: NFD ()"
+        normalization:
+        transliteration:
+        compound_suffixes:
+        abbreviations:
         """))
 
         """))
 
+    rules = ICURuleLoader(fpath)
+    assert rules.get_search_rules() == ''
+    assert rules.get_normalization_rules() == ''
+    assert rules.get_transliteration_rules() == ''
+    assert rules.get_replacement_pairs() == []
+
+CONFIG_SECTIONS = ('normalization', 'transliteration',
+                   'compound_suffixes', 'abbreviations')
+
+@pytest.mark.parametrize("section", CONFIG_SECTIONS)
+def test_missing_normalization(tmp_path, section):
+    fpath = tmp_path / ('test_config.yaml')
+    with fpath.open('w') as fd:
+        for name in CONFIG_SECTIONS:
+            if name != section:
+                fd.write(name + ':\n')
+
     with pytest.raises(UsageError):
         ICURuleLoader(fpath)
 
     with pytest.raises(UsageError):
         ICURuleLoader(fpath)
 
@@ -53,6 +74,7 @@ def test_get_search_rules(cfgfile):
     rules = loader.get_search_rules()
     trans = Transliterator.createFromRules("test", rules)
 
     rules = loader.get_search_rules()
     trans = Transliterator.createFromRules("test", rules)
 
+    assert trans.transliterate(" Baum straße ") == " baum straße "
     assert trans.transliterate(" Baumstraße ") == " baum straße "
     assert trans.transliterate(" Baumstrasse ") == " baum strasse "
     assert trans.transliterate(" Baumstr ") == " baum str "
     assert trans.transliterate(" Baumstraße ") == " baum straße "
     assert trans.transliterate(" Baumstrasse ") == " baum strasse "
     assert trans.transliterate(" Baumstr ") == " baum str "
@@ -61,6 +83,28 @@ def test_get_search_rules(cfgfile):
     assert trans.transliterate(" проспект ") == " prospekt "
 
 
     assert trans.transliterate(" проспект ") == " prospekt "
 
 
+def test_get_normalization_rules(cfgfile):
+    fpath = cfgfile(['strasse', 'straße', 'weg'],
+                    ['strasse,straße => str'])
+
+    loader = ICURuleLoader(fpath)
+    rules = loader.get_normalization_rules()
+    trans = Transliterator.createFromRules("test", rules)
+
+    assert trans.transliterate(" проспект-Prospekt ") == " проспект prospekt "
+
+
+def test_get_transliteration_rules(cfgfile):
+    fpath = cfgfile(['strasse', 'straße', 'weg'],
+                    ['strasse,straße => str'])
+
+    loader = ICURuleLoader(fpath)
+    rules = loader.get_transliteration_rules()
+    trans = Transliterator.createFromRules("test", rules)
+
+    assert trans.transliterate(" проспект-Prospekt ") == " prospekt Prospekt "
+
+
 def test_get_synonym_pairs(cfgfile):
     fpath = cfgfile(['Weg', 'Strasse'],
                     ['Strasse => str,st'])
 def test_get_synonym_pairs(cfgfile):
     fpath = cfgfile(['Weg', 'Strasse'],
                     ['Strasse => str,st'])