]> git.openstreetmap.org Git - chef.git/blob - cookbooks/mysql/libraries/mysql.rb
Limit columns selected from mysql user table to those we need
[chef.git] / cookbooks / mysql / libraries / mysql.rb
1 require "chef/mixin/shell_out"
2 require "rexml/document"
3
4 module OpenStreetMap
5   module MySQL
6     include Chef::Mixin::ShellOut
7
8     USER_PRIVILEGES = [
9       :select, :insert, :update, :delete, :create, :drop, :reload,
10       :shutdown, :process, :file, :grant, :references, :index, :alter,
11       :show_db, :super, :create_tmp_table, :lock_tables, :execute,
12       :repl_slave, :repl_client, :create_view, :show_view, :create_routine,
13       :alter_routine, :create_user, :event, :trigger, :create_tablespace
14     ].freeze
15
16     DATABASE_PRIVILEGES = [
17       :select, :insert, :update, :delete, :create, :drop, :grant,
18       :references, :index, :alter, :create_tmp_table, :lock_tables,
19       :create_view, :show_view, :create_routine, :alter_routine,
20       :execute, :event, :trigger
21     ].freeze
22
23     def mysql_execute(options)
24       # Create argument array
25       args = []
26
27       # Work out how to authenticate
28       if options[:user]
29         args.push("--username")
30         args.push(options[:user])
31
32         if options[:password]
33           args.push("--password")
34           args.push(options[:password])
35         end
36       else
37         args.push("--defaults-file=/etc/mysql/debian.cnf")
38       end
39
40       # Set output format
41       args.push("--xml") if options[:xml]
42
43       # Add any SQL command to execute
44       if options[:command]
45         args.push("--execute")
46         args.push(options[:command])
47       end
48
49       # Add the database name
50       args.push(options[:database] || "mysql")
51
52       # Run the command
53       shell_out!("/usr/bin/mysql", *args, :user => "root", :group => "root")
54     end
55
56     def query(sql, options = {})
57       # Run the query
58       result = mysql_execute(options.merge(:command => sql, :xml => true))
59
60       # Parse the output
61       document = REXML::Document.new(result.stdout)
62
63       # Create
64       records = []
65
66       # Loop over the rows in the result set
67       document.root.each_element("/resultset/row") do |row|
68         # Create a record
69         record = {}
70
71         # Loop over the fields, adding them to the record
72         row.each_element("field") do |field|
73           name = field.attributes["name"].downcase
74           value = field.text
75
76           record[name.to_sym] = value
77         end
78
79         # Add the record to the record list
80         records << record
81       end
82
83       # Return the record list
84       records
85     end
86
87     def mysql_users
88       privilege_columns = USER_PRIVILEGES.collect { |privilege| "#{privilege}_priv" }.join(", ")
89
90       @mysql_users ||= query("SELECT user, host, #{privilege_columns} FROM user").each_with_object({}) do |user, users|
91         name = "'#{user[:user]}'@'#{user[:host]}'"
92
93         users[name] = USER_PRIVILEGES.each_with_object({}) do |privilege, privileges|
94           privileges[privilege] = user["#{privilege}_priv".to_sym] == "Y"
95         end
96       end
97     end
98
99     def mysql_databases
100       @mysql_databases ||= query("SHOW databases").each_with_object({}) do |database, databases|
101         databases[database[:database]] = {
102           :permissions => {}
103         }
104       end
105
106       query("SELECT * FROM db").each do |record|
107         database = @mysql_databases[record[:db]]
108
109         next unless database
110
111         user = "'#{record[:user]}'@'#{record[:host]}'"
112
113         database[:permissions][user] = DATABASE_PRIVILEGES.each_with_object([]) do |privilege, privileges|
114           privileges << privilege if record["#{privilege}_priv".to_sym] == "Y"
115         end
116       end
117
118       @mysql_databases
119     end
120
121     def mysql_canonicalise_user(user)
122       local, host = user.split("@")
123
124       host ||= "%"
125
126       local = "'#{local}'" unless local =~ /^'.*'$/
127       host = "'#{host}'" unless host =~ /^'.*'$/
128
129       "#{local}@#{host}"
130     end
131
132     def mysql_privilege_name(privilege)
133       case privilege
134       when :grant
135         "GRANT OPTION"
136       when :create_tmp_table
137         "CREATE TEMPORARY TABLES"
138       else
139         privilege.to_s.upcase.tr("_", " ")
140       end
141     end
142   end
143 end