]> git.openstreetmap.org Git - nominatim.git/blob - test/python/cursor.py
use better SQL quoting in test cursor implementation
[nominatim.git] / test / python / cursor.py
1 # SPDX-License-Identifier: GPL-3.0-or-later
2 #
3 # This file is part of Nominatim. (https://nominatim.org)
4 #
5 # Copyright (C) 2026 by the Nominatim developer community.
6 # For a full list of authors see the git log.
7 """
8 Specialised psycopg cursor with shortcut functions useful for testing.
9 """
10 import psycopg
11 from psycopg import sql as pysql
12
13
14 class CursorForTesting(psycopg.Cursor):
15     """ Extension to the DictCursor class that provides execution
16         short-cuts that simplify writing assertions.
17     """
18
19     def scalar(self, sql, params=None):
20         """ Execute a query with a single return value and return this value.
21             Raises an assertion when not exactly one row is returned.
22         """
23         self.execute(sql, params)
24         assert self.rowcount == 1
25         return self.fetchone()[0]
26
27     def row_set(self, sql, params=None):
28         """ Execute a query and return the result as a set of tuples.
29             Fails when the SQL command returns duplicate rows.
30         """
31         self.execute(sql, params)
32
33         result = set((tuple(row) for row in self))
34         assert len(result) == self.rowcount
35
36         return result
37
38     def table_exists(self, table):
39         """ Check that a table with the given name exists in the database.
40         """
41         num = self.scalar("""SELECT count(*) FROM pg_tables
42                              WHERE tablename = %s""", (table, ))
43         return num == 1
44
45     def index_exists(self, table, index):
46         """ Check that an indexwith the given name exists on the given table.
47         """
48         num = self.scalar("""SELECT count(*) FROM pg_indexes
49                              WHERE tablename = %s and indexname = %s""",
50                           (table, index))
51         return num == 1
52
53     def table_rows(self, table, where=None):
54         """ Return the number of rows in the given table.
55         """
56         sql = pysql.SQL('SELECT count(*) FROM') + pysql.Identifier(table)
57         if where is not None:
58             sql += pysql.SQL('WHERE') + pysql.SQL(where)
59
60         return self.scalar(sql)