]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/tokenizer/icu_rule_loader.py
use yaml tag syntax to mark include files
[nominatim.git] / nominatim / tokenizer / icu_rule_loader.py
index 6bf23201cf953545bc182cf0c3d12b9333827e93..ddb17ae76698025dd9c9fd82619137d0a5b1e22e 100644 (file)
@@ -5,6 +5,7 @@ import io
 import logging
 from collections import defaultdict
 import itertools
+from pathlib import Path
 
 import yaml
 from icu import Transliterator
@@ -13,6 +14,22 @@ from nominatim.errors import UsageError
 
 LOG = logging.getLogger()
 
+def _flatten_yaml_list(content):
+    if not content:
+        return []
+
+    if not isinstance(content, list):
+        raise UsageError("List expected in ICU yaml configuration.")
+
+    output = []
+    for ele in content:
+        if isinstance(ele, list):
+            output.extend(_flatten_yaml_list(ele))
+        else:
+            output.append(ele)
+
+    return output
+
 
 class ICURuleLoader:
     """ Compiler for ICU rules from a tokenizer configuration file.
@@ -87,8 +104,20 @@ class ICURuleLoader:
 
         return [(k, list(synonyms[k])) for k in sorted_keys]
 
+    def _yaml_include_representer(self, loader, node):
+        value = loader.construct_scalar(node)
+
+        if Path(value).is_absolute():
+            content = Path(value).read_text()
+        else:
+            content = (self.configfile.parent / value).read_text()
+
+        return yaml.safe_load(content)
+
 
     def _load_from_yaml(self):
+        yaml.add_constructor('!include', self._yaml_include_representer,
+                             Loader=yaml.SafeLoader)
         rules = yaml.safe_load(self.configfile.read_text())
 
         self.normalization_rules = self._cfg_to_icu_rules(rules, 'normalization')
@@ -121,10 +150,8 @@ class ICURuleLoader:
         if content is None:
             return ''
 
-        if isinstance(content, str):
-            return (self.configfile.parent / content).read_text().replace('\n', ' ')
+        return ';'.join(_flatten_yaml_list(content)) + ';'
 
-        return ';'.join(content) + ';'
 
 
     def _parse_compound_suffix_list(self, rules):