1 # SPDX-License-Identifier: GPL-3.0-or-later
3 # This file is part of Nominatim. (https://nominatim.org)
5 # Copyright (C) 2026 by the Nominatim developer community.
6 # For a full list of authors see the git log.
8 Specialised psycopg cursor with shortcut functions useful for testing.
11 from psycopg import sql as pysql
14 class CursorForTesting(psycopg.Cursor):
15 """ Extension to the DictCursor class that provides execution
16 short-cuts that simplify writing assertions.
19 def scalar(self, sql, params=None):
20 """ Execute a query with a single return value and return this value.
21 Raises an assertion when not exactly one row is returned.
23 self.execute(sql, params)
24 assert self.rowcount == 1
25 return self.fetchone()[0]
27 def row_set(self, sql, params=None):
28 """ Execute a query and return the result as a set of tuples.
29 Fails when the SQL command returns duplicate rows.
31 self.execute(sql, params)
33 result = set((tuple(row) for row in self))
34 assert len(result) == self.rowcount
38 def table_exists(self, table):
39 """ Check that a table with the given name exists in the database.
41 num = self.scalar("""SELECT count(*) FROM pg_tables
42 WHERE tablename = %s""", (table, ))
45 def index_exists(self, table, index):
46 """ Check that an indexwith the given name exists on the given table.
48 num = self.scalar("""SELECT count(*) FROM pg_indexes
49 WHERE tablename = %s and indexname = %s""",
53 def table_rows(self, table, where=None):
54 """ Return the number of rows in the given table.
56 sql = pysql.SQL('SELECT count(*) FROM') + pysql.Identifier(table)
58 sql += pysql.SQL('WHERE') + pysql.SQL(where)
60 return self.scalar(sql)
62 def insert_row(self, table, **data):
63 """ Insert a row into the given table.
65 'data' is a dictionary of column names and associated values.
66 When the value is a pysql.Literal or pysql.SQL, then the expression
67 will be inserted as is instead of loading the value. When the
68 value is a tuple, then the first element will be added as an
69 SQL expression for the value and the second element is treated
70 as the actual value to insert. The SQL expression must contain
71 a %s placeholder in that case.
73 If data contains a 'place_id' column, then the value of the
74 place_id column after insert is returned. Otherwise the function
80 for k, v in data.items():
81 columns.append(pysql.Identifier(k))
82 if isinstance(v, tuple):
83 placeholders.append(pysql.SQL(v[0]))
85 elif isinstance(v, (pysql.Literal, pysql.SQL)):
86 placeholders.append(v)
88 placeholders.append(pysql.Placeholder())
91 sql = pysql.SQL("INSERT INTO {table} ({columns}) VALUES({values})")\
92 .format(table=pysql.Identifier(table),
93 columns=pysql.SQL(',').join(columns),
94 values=pysql.SQL(',').join(placeholders))
96 if 'place_id' in data:
97 sql += pysql.SQL('RETURNING place_id')
99 self.execute(sql, values)
101 return self.fetchone()[0] if 'place_id' in data else None