]> git.openstreetmap.org Git - nominatim.git/commitdiff
correctly quote strings when copying in data
authorSarah Hoffmann <lonvia@denofr.de>
Thu, 10 Jun 2021 07:36:43 +0000 (09:36 +0200)
committerSarah Hoffmann <lonvia@denofr.de>
Sun, 4 Jul 2021 08:28:20 +0000 (10:28 +0200)
Encapsulate the copy string in a class that ensures that
copy lines are written with correct quoting.

nominatim/db/utils.py
nominatim/tokenizer/icu_rule_loader.py
nominatim/tokenizer/legacy_icu_tokenizer.py
test/python/test_db_utils.py
test/python/test_tokenizer_icu_rule_loader.py

index b376940d804af364c049864a07649c897b515f0f..4d4305e7d67ff74c93119bbc67ef4acfa7036e2c 100644 (file)
@@ -4,6 +4,7 @@ Helper functions for handling DB accesses.
 import subprocess
 import logging
 import gzip
+import io
 
 from nominatim.db.connection import get_pg_env
 from nominatim.errors import UsageError
@@ -57,3 +58,49 @@ def execute_file(dsn, fname, ignore_errors=False, pre_code=None, post_code=None)
 
     if ret != 0 or remain > 0:
         raise UsageError("Failed to execute SQL file.")
+
+
+# List of characters that need to be quoted for the copy command.
+_SQL_TRANSLATION = {ord(u'\\') : u'\\\\',
+                    ord(u'\t') : u'\\t',
+                    ord(u'\n') : u'\\n'}
+
+class CopyBuffer:
+    """ Data collector for the copy_from command.
+    """
+
+    def __init__(self):
+        self.buffer = io.StringIO()
+
+
+    def __enter__(self):
+        return self
+
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        if self.buffer is not None:
+            self.buffer.close()
+
+
+    def add(self, *data):
+        """ Add another row of data to the copy buffer.
+        """
+        first = True
+        for column in data:
+            if first:
+                first = False
+            else:
+                self.buffer.write('\t')
+            if column is None:
+                self.buffer.write('\\N')
+            else:
+                self.buffer.write(str(column).translate(_SQL_TRANSLATION))
+        self.buffer.write('\n')
+
+
+    def copy_out(self, cur, table, columns=None):
+        """ Copy all collected data into the given table.
+        """
+        if self.buffer.tell() > 0:
+            self.buffer.seek(0)
+            cur.copy_from(self.buffer, table, columns=columns)
index a11b9bd86e0140b97d4fd189e2d44b6b79adc13b..269faed981abbbb9ffc530bd32d6b38ae0c30df4 100644 (file)
@@ -93,7 +93,7 @@ class ICURuleLoader:
 
 
     def _load_from_yaml(self):
-        rules = yaml.load(self.configfile.read_text())
+        rules = yaml.safe_load(self.configfile.read_text())
 
         self.normalization_rules = self._cfg_to_icu_rules(rules, 'normalization')
         self.transliteration_rules = self._cfg_to_icu_rules(rules, 'transliteration')
@@ -122,6 +122,9 @@ class ICURuleLoader:
         """
         content = self._get_section(rules, section)
 
+        if content is None:
+            return ''
+
         if isinstance(content, str):
             return (self.configfile.parent / content).read_text().replace('\n', ' ')
 
@@ -160,4 +163,5 @@ class ICURuleLoader:
             abbrterms = (norm.transliterate(t.strip()) for t in parts[1].split(','))
 
             for full, abbr in itertools.product(fullterms, abbrterms):
-                self.abbreviations[full].append(abbr)
+                if full and abbr:
+                    self.abbreviations[full].append(abbr)
index f3eb7b4ef4fd9fae8bfcf6f7ef538ded91cfa08e..af53e825426bc1c1fcaed5fb1b4895e5a250719e 100644 (file)
@@ -14,6 +14,7 @@ import psycopg2.extras
 
 from nominatim.db.connection import connect
 from nominatim.db.properties import set_property, get_property
+from nominatim.db.utils import CopyBuffer
 from nominatim.db.sql_preprocessor import SQLPreprocessor
 from nominatim.tokenizer.icu_rule_loader import ICURuleLoader
 from nominatim.tokenizer.icu_name_processor import ICUNameProcessor, ICUNameProcessorRules
@@ -134,7 +135,7 @@ class LegacyICUTokenizer:
             @define('CONST_Term_Normalization_Rules', "{0.term_normalization}");
             @define('CONST_Transliteration', "{0.naming_rules.search_rules}");
             require_once('{1}/tokenizer/legacy_icu_tokenizer.php');
-            """.format(self, phpdir)))
+            """.format(self, phpdir))) # pylint: disable=missing-format-attribute
 
 
     def _save_config(self, config):
@@ -171,14 +172,15 @@ class LegacyICUTokenizer:
                             words[term] += cnt
 
             # copy them back into the word table
-            copystr = io.StringIO(''.join(('{}\t{}\n'.format(*args) for args in words.items())))
+            with CopyBuffer() as copystr:
+                for args in words.items():
+                    copystr.add(*args)
 
-
-            with conn.cursor() as cur:
-                copystr.seek(0)
-                cur.copy_from(copystr, 'word', columns=['word_token', 'search_name_count'])
-                cur.execute("""UPDATE word SET word_id = nextval('seq_word')
-                               WHERE word_id is null""")
+                with conn.cursor() as cur:
+                    copystr.copy_out(cur, 'word',
+                                      columns=['word_token', 'search_name_count'])
+                    cur.execute("""UPDATE word SET word_id = nextval('seq_word')
+                                   WHERE word_id is null""")
 
             conn.commit()
 
@@ -265,7 +267,6 @@ class LegacyICUNameAnalyzer:
             table.
         """
         to_delete = []
-        copystr = io.StringIO()
         with self.conn.cursor() as cur:
             # This finds us the rows in location_postcode and word that are
             # missing in the other table.
@@ -278,26 +279,25 @@ class LegacyICUNameAnalyzer:
                               ON pc = word) x
                            WHERE pc is null or word is null""")
 
-            for postcode, word in cur:
-                if postcode is None:
-                    to_delete.append(word)
-                else:
-                    copystr.write(postcode)
-                    copystr.write('\t ')
-                    copystr.write(self.name_processor.get_search_normalized(postcode))
-                    copystr.write('\tplace\tpostcode\t0\n')
+            with CopyBuffer() as copystr:
+                for postcode, word in cur:
+                    if postcode is None:
+                        to_delete.append(word)
+                    else:
+                        copystr.add(
+                            postcode,
+                            ' ' + self.name_processor.get_search_normalized(postcode),
+                            'place', 'postcode', 0)
 
-            if to_delete:
-                cur.execute("""DELETE FROM WORD
-                               WHERE class ='place' and type = 'postcode'
-                                     and word = any(%s)
-                            """, (to_delete, ))
+                if to_delete:
+                    cur.execute("""DELETE FROM WORD
+                                   WHERE class ='place' and type = 'postcode'
+                                         and word = any(%s)
+                                """, (to_delete, ))
 
-            if copystr.getvalue():
-                copystr.seek(0)
-                cur.copy_from(copystr, 'word',
-                              columns=['word', 'word_token', 'class', 'type',
-                                       'search_name_count'])
+                copystr.copy_out(cur, 'word',
+                                 columns=['word', 'word_token', 'class', 'type',
+                                          'search_name_count'])
 
 
     def update_special_phrases(self, phrases, should_replace):
@@ -331,34 +331,24 @@ class LegacyICUNameAnalyzer:
         """
         to_add = new_phrases - existing_phrases
 
-        copystr = io.StringIO()
         added = 0
-        for word, cls, typ, oper in to_add:
-            term = self.name_processor.get_search_normalized(word)
-            if term:
-                copystr.write(word)
-                copystr.write('\t ')
-                copystr.write(term)
-                copystr.write('\t')
-                copystr.write(cls)
-                copystr.write('\t')
-                copystr.write(typ)
-                copystr.write('\t')
-                copystr.write(oper if oper in ('in', 'near')  else '\\N')
-                copystr.write('\t0\n')
-                added += 1
-
-
-        if copystr.tell() > 0:
-            copystr.seek(0)
-            cursor.copy_from(copystr, 'word',
+        with CopyBuffer() as copystr:
+            for word, cls, typ, oper in to_add:
+                term = self.name_processor.get_search_normalized(word)
+                if term:
+                    copystr.add(word, term, cls, typ,
+                                oper if oper in ('in', 'near')  else None, 0)
+                    added += 1
+
+            copystr.copy_out(cursor, 'word',
                              columns=['word', 'word_token', 'class', 'type',
                                       'operator', 'search_name_count'])
 
         return added
 
 
-    def _remove_special_phrases(self, cursor, new_phrases, existing_phrases):
+    @staticmethod
+    def _remove_special_phrases(cursor, new_phrases, existing_phrases):
         """ Remove all phrases from the databse that are no longer in the
             new phrase list.
         """
index d549b70f803ec8f7e873e21e8c79f350e73f7af6..545cc58f633448096fbd2f212a19e69160ae01ff 100644 (file)
@@ -50,3 +50,68 @@ def test_execute_file_with_post_code(dsn, tmp_path, temp_db_cursor):
     db_utils.execute_file(dsn, tmpfile, post_code='INSERT INTO test VALUES(23)')
 
     assert temp_db_cursor.row_set('SELECT * FROM test') == {(23, )}
+
+
+class TestCopyBuffer:
+    TABLE_NAME = 'copytable'
+
+    @pytest.fixture(autouse=True)
+    def setup_test_table(self, table_factory):
+        table_factory(self.TABLE_NAME, 'colA INT, colB TEXT')
+
+
+    def table_rows(self, cursor):
+        return cursor.row_set('SELECT * FROM ' + self.TABLE_NAME)
+
+
+    def test_copybuffer_empty(self):
+        with db_utils.CopyBuffer() as buf:
+            buf.copy_out(None, "dummy")
+
+
+    def test_all_columns(self, temp_db_cursor):
+        with db_utils.CopyBuffer() as buf:
+            buf.add(3, 'hum')
+            buf.add(None, 'f\\t')
+
+            buf.copy_out(temp_db_cursor, self.TABLE_NAME)
+
+        assert self.table_rows(temp_db_cursor) == {(3, 'hum'), (None, 'f\\t')}
+
+
+    def test_selected_columns(self, temp_db_cursor):
+        with db_utils.CopyBuffer() as buf:
+            buf.add('foo')
+
+            buf.copy_out(temp_db_cursor, self.TABLE_NAME,
+                         columns=['colB'])
+
+        assert self.table_rows(temp_db_cursor) == {(None, 'foo')}
+
+
+    def test_reordered_columns(self, temp_db_cursor):
+        with db_utils.CopyBuffer() as buf:
+            buf.add('one', 1)
+            buf.add(' two ', 2)
+
+            buf.copy_out(temp_db_cursor, self.TABLE_NAME,
+                         columns=['colB', 'colA'])
+
+        assert self.table_rows(temp_db_cursor) == {(1, 'one'), (2, ' two ')}
+
+
+    def test_special_characters(self, temp_db_cursor):
+        with db_utils.CopyBuffer() as buf:
+            buf.add('foo\tbar')
+            buf.add('sun\nson')
+            buf.add('\\N')
+
+            buf.copy_out(temp_db_cursor, self.TABLE_NAME,
+                         columns=['colB'])
+
+        assert self.table_rows(temp_db_cursor) == {(None, 'foo\tbar'),
+                                                   (None, 'sun\nson'),
+                                                   (None, '\\N')}
+
+
+
index 20b127f39c7c622e2424e3b50fc0c38c7f3bb934..abbc92423f4d9b1f44941b90f16534491c9dd2b6 100644 (file)
@@ -21,6 +21,7 @@ def cfgfile(tmp_path, suffix='.yaml'):
             - ":: NFC ()"
         transliteration:
             - "::  Latin ()"
+            - "[[:Punctuation:][:Space:]]+ > ' '"
         """)
         content += "compound_suffixes:\n"
         content += '\n'.join(("    - " + s for s in suffixes)) + '\n'
@@ -32,13 +33,33 @@ def cfgfile(tmp_path, suffix='.yaml'):
 
     return _create_config
 
-def test_missing_normalization(tmp_path):
+
+def test_empty_rule_file(tmp_path):
     fpath = tmp_path / ('test_config.yaml')
     fpath.write_text(dedent("""\
-        normalizatio:
-            - ":: NFD ()"
+        normalization:
+        transliteration:
+        compound_suffixes:
+        abbreviations:
         """))
 
+    rules = ICURuleLoader(fpath)
+    assert rules.get_search_rules() == ''
+    assert rules.get_normalization_rules() == ''
+    assert rules.get_transliteration_rules() == ''
+    assert rules.get_replacement_pairs() == []
+
+CONFIG_SECTIONS = ('normalization', 'transliteration',
+                   'compound_suffixes', 'abbreviations')
+
+@pytest.mark.parametrize("section", CONFIG_SECTIONS)
+def test_missing_normalization(tmp_path, section):
+    fpath = tmp_path / ('test_config.yaml')
+    with fpath.open('w') as fd:
+        for name in CONFIG_SECTIONS:
+            if name != section:
+                fd.write(name + ':\n')
+
     with pytest.raises(UsageError):
         ICURuleLoader(fpath)
 
@@ -53,6 +74,7 @@ def test_get_search_rules(cfgfile):
     rules = loader.get_search_rules()
     trans = Transliterator.createFromRules("test", rules)
 
+    assert trans.transliterate(" Baum straße ") == " baum straße "
     assert trans.transliterate(" Baumstraße ") == " baum straße "
     assert trans.transliterate(" Baumstrasse ") == " baum strasse "
     assert trans.transliterate(" Baumstr ") == " baum str "
@@ -61,6 +83,28 @@ def test_get_search_rules(cfgfile):
     assert trans.transliterate(" проспект ") == " prospekt "
 
 
+def test_get_normalization_rules(cfgfile):
+    fpath = cfgfile(['strasse', 'straße', 'weg'],
+                    ['strasse,straße => str'])
+
+    loader = ICURuleLoader(fpath)
+    rules = loader.get_normalization_rules()
+    trans = Transliterator.createFromRules("test", rules)
+
+    assert trans.transliterate(" проспект-Prospekt ") == " проспект prospekt "
+
+
+def test_get_transliteration_rules(cfgfile):
+    fpath = cfgfile(['strasse', 'straße', 'weg'],
+                    ['strasse,straße => str'])
+
+    loader = ICURuleLoader(fpath)
+    rules = loader.get_transliteration_rules()
+    trans = Transliterator.createFromRules("test", rules)
+
+    assert trans.transliterate(" проспект-Prospekt ") == " prospekt Prospekt "
+
+
 def test_get_synonym_pairs(cfgfile):
     fpath = cfgfile(['Weg', 'Strasse'],
                     ['Strasse => str,st'])