1 # SPDX-License-Identifier: GPL-3.0-or-later
 
   3 # This file is part of Nominatim. (https://nominatim.org)
 
   5 # Copyright (C) 2025 by the Nominatim developer community.
 
   6 # For a full list of authors see the git log.
 
   8 Specialised psycopg cursor with shortcut functions useful for testing.
 
  13 class CursorForTesting(psycopg.Cursor):
 
  14     """ Extension to the DictCursor class that provides execution
 
  15         short-cuts that simplify writing assertions.
 
  18     def scalar(self, sql, params=None):
 
  19         """ Execute a query with a single return value and return this value.
 
  20             Raises an assertion when not exactly one row is returned.
 
  22         self.execute(sql, params)
 
  23         assert self.rowcount == 1
 
  24         return self.fetchone()[0]
 
  26     def row_set(self, sql, params=None):
 
  27         """ Execute a query and return the result as a set of tuples.
 
  28             Fails when the SQL command returns duplicate rows.
 
  30         self.execute(sql, params)
 
  32         result = set((tuple(row) for row in self))
 
  33         assert len(result) == self.rowcount
 
  37     def table_exists(self, table):
 
  38         """ Check that a table with the given name exists in the database.
 
  40         num = self.scalar("""SELECT count(*) FROM pg_tables
 
  41                              WHERE tablename = %s""", (table, ))
 
  44     def index_exists(self, table, index):
 
  45         """ Check that an indexwith the given name exists on the given table.
 
  47         num = self.scalar("""SELECT count(*) FROM pg_indexes
 
  48                              WHERE tablename = %s and indexname = %s""",
 
  52     def table_rows(self, table, where=None):
 
  53         """ Return the number of rows in the given table.
 
  56             return self.scalar('SELECT count(*) FROM ' + table)
 
  58         return self.scalar('SELECT count(*) FROM {} WHERE {}'.format(table, where))