]> git.openstreetmap.org Git - nominatim.git/blob - test/python/cursor.py
rewrite indexing tests to use standard table fixtures
[nominatim.git] / test / python / cursor.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2026 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Specialised psycopg cursor with shortcut functions useful for testing.
9 """
10 import psycopg
11 from psycopg import sql as pysql
12
13
14 class CursorForTesting(psycopg.Cursor):
15     """ Extension to the DictCursor class that provides execution
16         short-cuts that simplify writing assertions.
17     """
18
19     def scalar(self, sql, params=None):
20         """ Execute a query with a single return value and return this value.
21             Raises an assertion when not exactly one row is returned.
22         """
23         self.execute(sql, params)
24         assert self.rowcount == 1
25         return self.fetchone()[0]
26
27     def row_set(self, sql, params=None):
28         """ Execute a query and return the result as a set of tuples.
29             Fails when the SQL command returns duplicate rows.
30         """
31         self.execute(sql, params)
32
33         result = set((tuple(row) for row in self))
34         assert len(result) == self.rowcount
35
36         return result
37
38     def table_exists(self, table):
39         """ Check that a table with the given name exists in the database.
40         """
41         num = self.scalar("""SELECT count(*) FROM pg_tables
42                              WHERE tablename = %s""", (table, ))
43         return num == 1
44
45     def index_exists(self, table, index):
46         """ Check that an indexwith the given name exists on the given table.
47         """
48         num = self.scalar("""SELECT count(*) FROM pg_indexes
49                              WHERE tablename = %s and indexname = %s""",
50                           (table, index))
51         return num == 1
52
53     def table_rows(self, table, where=None):
54         """ Return the number of rows in the given table.
55         """
56         sql = pysql.SQL('SELECT count(*) FROM') + pysql.Identifier(table)
57         if where is not None:
58             sql += pysql.SQL('WHERE') + pysql.SQL(where)
59
60         return self.scalar(sql)
61
62     def insert_row(self, table, **data):
63         """ Insert a row into the given table.
64
65             'data' is a dictionary of column names and associated values.
66             When the value is a pysql.Literal or pysql.SQL, then the expression
67             will be inserted as is instead of loading the value. When the
68             value is a tuple, then the first element will be added as an
69             SQL expression for the value and the second element is treated
70             as the actual value to insert. The SQL expression must contain
71             a %s placeholder in that case.
72
73             If data contains a 'place_id' column, then the value of the
74             place_id column after insert is returned. Otherwise the function
75             returns nothing.
76         """
77         columns = []
78         placeholders = []
79         values = []
80         for k, v in data.items():
81             columns.append(pysql.Identifier(k))
82             if isinstance(v, tuple):
83                 placeholders.append(pysql.SQL(v[0]))
84                 values.append(v[1])
85             elif isinstance(v, (pysql.Literal, pysql.SQL)):
86                 placeholders.append(v)
87             else:
88                 placeholders.append(pysql.Placeholder())
89                 values.append(v)
90
91         sql = pysql.SQL("INSERT INTO {table} ({columns}) VALUES({values})")\
92                    .format(table=pysql.Identifier(table),
93                            columns=pysql.SQL(',').join(columns),
94                            values=pysql.SQL(',').join(placeholders))
95
96         if 'place_id' in data:
97             sql += pysql.SQL('RETURNING place_id')
98
99         self.execute(sql, values)
100
101         return self.fetchone()[0] if 'place_id' in data else None