]> git.openstreetmap.org Git - osqa.git/blob - forum_modules/oauthauth/lib/oauth.py
initial import
[osqa.git] / forum_modules / oauthauth / lib / oauth.py
1 """
2 The MIT License
3
4 Copyright (c) 2007 Leah Culver
5
6 Permission is hereby granted, free of charge, to any person obtaining a copy
7 of this software and associated documentation files (the "Software"), to deal
8 in the Software without restriction, including without limitation the rights
9 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 copies of the Software, and to permit persons to whom the Software is
11 furnished to do so, subject to the following conditions:
12
13 The above copyright notice and this permission notice shall be included in
14 all copies or substantial portions of the Software.
15
16 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 THE SOFTWARE.
23 """
24
25 import cgi
26 import urllib
27 import time
28 import random
29 import urlparse
30 import hmac
31 import binascii
32
33
34 VERSION = '1.0' # Hi Blaine!
35 HTTP_METHOD = 'GET'
36 SIGNATURE_METHOD = 'PLAINTEXT'
37
38
39 class OAuthError(RuntimeError):
40     """Generic exception class."""
41     def __init__(self, message='OAuth error occured.'):
42         self.message = message
43
44 def build_authenticate_header(realm=''):
45     """Optional WWW-Authenticate header (401 error)"""
46     return {'WWW-Authenticate': 'OAuth realm="%s"' % realm}
47
48 def escape(s):
49     """Escape a URL including any /."""
50     return urllib.quote(s, safe='~')
51
52 def _utf8_str(s):
53     """Convert unicode to utf-8."""
54     if isinstance(s, unicode):
55         return s.encode("utf-8")
56     else:
57         return str(s)
58
59 def generate_timestamp():
60     """Get seconds since epoch (UTC)."""
61     return int(time.time())
62
63 def generate_nonce(length=8):
64     """Generate pseudorandom number."""
65     return ''.join([str(random.randint(0, 9)) for i in range(length)])
66
67
68 class OAuthConsumer(object):
69     """Consumer of OAuth authentication.
70
71     OAuthConsumer is a data type that represents the identity of the Consumer
72     via its shared secret with the Service Provider.
73
74     """
75     key = None
76     secret = None
77
78     def __init__(self, key, secret):
79         self.key = key
80         self.secret = secret
81
82
83 class OAuthToken(object):
84     """OAuthToken is a data type that represents an End User via either an access
85     or request token.
86
87     key -- the token
88     secret -- the token secret
89
90     """
91     key = None
92     secret = None
93
94     def __init__(self, key, secret):
95         self.key = key
96         self.secret = secret
97
98     def to_string(self):
99         return urllib.urlencode({'oauth_token': self.key,
100             'oauth_token_secret': self.secret})
101
102     def from_string(s):
103         """ Returns a token from something like:
104         oauth_token_secret=xxx&oauth_token=xxx
105         """
106         params = cgi.parse_qs(s, keep_blank_values=False)
107         key = params['oauth_token'][0]
108         secret = params['oauth_token_secret'][0]
109         return OAuthToken(key, secret)
110     from_string = staticmethod(from_string)
111
112     def __str__(self):
113         return self.to_string()
114
115
116 class OAuthRequest(object):
117     """OAuthRequest represents the request and can be serialized.
118
119     OAuth parameters:
120         - oauth_consumer_key
121         - oauth_token
122         - oauth_signature_method
123         - oauth_signature
124         - oauth_timestamp
125         - oauth_nonce
126         - oauth_version
127         ... any additional parameters, as defined by the Service Provider.
128     """
129     parameters = None # OAuth parameters.
130     http_method = HTTP_METHOD
131     http_url = None
132     version = VERSION
133
134     def __init__(self, http_method=HTTP_METHOD, http_url=None, parameters=None):
135         self.http_method = http_method
136         self.http_url = http_url
137         self.parameters = parameters or {}
138
139     def set_parameter(self, parameter, value):
140         self.parameters[parameter] = value
141
142     def get_parameter(self, parameter):
143         try:
144             return self.parameters[parameter]
145         except:
146             raise OAuthError('Parameter not found: %s' % parameter)
147
148     def _get_timestamp_nonce(self):
149         return self.get_parameter('oauth_timestamp'), self.get_parameter(
150             'oauth_nonce')
151
152     def get_nonoauth_parameters(self):
153         """Get any non-OAuth parameters."""
154         parameters = {}
155         for k, v in self.parameters.iteritems():
156             # Ignore oauth parameters.
157             if k.find('oauth_') < 0:
158                 parameters[k] = v
159         return parameters
160
161     def to_header(self, realm=''):
162         """Serialize as a header for an HTTPAuth request."""
163         auth_header = 'OAuth realm="%s"' % realm
164         # Add the oauth parameters.
165         if self.parameters:
166             for k, v in self.parameters.iteritems():
167                 if k[:6] == 'oauth_':
168                     auth_header += ', %s="%s"' % (k, escape(str(v)))
169         return {'Authorization': auth_header}
170
171     def to_postdata(self):
172         """Serialize as post data for a POST request."""
173         return '&'.join(['%s=%s' % (escape(str(k)), escape(str(v))) \
174             for k, v in self.parameters.iteritems()])
175
176     def to_url(self):
177         """Serialize as a URL for a GET request."""
178         return '%s?%s' % (self.get_normalized_http_url(), self.to_postdata())
179
180     def get_normalized_parameters(self):
181         """Return a string that contains the parameters that must be signed."""
182         params = self.parameters
183         try:
184             # Exclude the signature if it exists.
185             del params['oauth_signature']
186         except:
187             pass
188         # Escape key values before sorting.
189         key_values = [(escape(_utf8_str(k)), escape(_utf8_str(v))) \
190             for k,v in params.items()]
191         # Sort lexicographically, first after key, then after value.
192         key_values.sort()
193         # Combine key value pairs into a string.
194         return '&'.join(['%s=%s' % (k, v) for k, v in key_values])
195
196     def get_normalized_http_method(self):
197         """Uppercases the http method."""
198         return self.http_method.upper()
199
200     def get_normalized_http_url(self):
201         """Parses the URL and rebuilds it to be scheme://host/path."""
202         parts = urlparse.urlparse(self.http_url)
203         scheme, netloc, path = parts[:3]
204         # Exclude default port numbers.
205         if scheme == 'http' and netloc[-3:] == ':80':
206             netloc = netloc[:-3]
207         elif scheme == 'https' and netloc[-4:] == ':443':
208             netloc = netloc[:-4]
209         return '%s://%s%s' % (scheme, netloc, path)
210
211     def sign_request(self, signature_method, consumer, token):
212         """Set the signature parameter to the result of build_signature."""
213         # Set the signature method.
214         self.set_parameter('oauth_signature_method',
215             signature_method.get_name())
216         # Set the signature.
217         self.set_parameter('oauth_signature',
218             self.build_signature(signature_method, consumer, token))
219
220     def build_signature(self, signature_method, consumer, token):
221         """Calls the build signature method within the signature method."""
222         return signature_method.build_signature(self, consumer, token)
223
224     def from_request(http_method, http_url, headers=None, parameters=None,
225             query_string=None):
226         """Combines multiple parameter sources."""
227         if parameters is None:
228             parameters = {}
229
230         # Headers
231         if headers and 'Authorization' in headers:
232             auth_header = headers['Authorization']
233             # Check that the authorization header is OAuth.
234             if auth_header.index('OAuth') > -1:
235                 auth_header = auth_header.lstrip('OAuth ')
236                 try:
237                     # Get the parameters from the header.
238                     header_params = OAuthRequest._split_header(auth_header)
239                     parameters.update(header_params)
240                 except:
241                     raise OAuthError('Unable to parse OAuth parameters from '
242                         'Authorization header.')
243
244         # GET or POST query string.
245         if query_string:
246             query_params = OAuthRequest._split_url_string(query_string)
247             parameters.update(query_params)
248
249         # URL parameters.
250         param_str = urlparse.urlparse(http_url)[4] # query
251         url_params = OAuthRequest._split_url_string(param_str)
252         parameters.update(url_params)
253
254         if parameters:
255             return OAuthRequest(http_method, http_url, parameters)
256
257         return None
258     from_request = staticmethod(from_request)
259
260     def from_consumer_and_token(oauth_consumer, token=None,
261             http_method=HTTP_METHOD, http_url=None, parameters=None):
262         if not parameters:
263             parameters = {}
264
265         defaults = {
266             'oauth_consumer_key': oauth_consumer.key,
267             'oauth_timestamp': generate_timestamp(),
268             'oauth_nonce': generate_nonce(),
269             'oauth_version': OAuthRequest.version,
270         }
271
272         defaults.update(parameters)
273         parameters = defaults
274
275         if token:
276             parameters['oauth_token'] = token.key
277
278         return OAuthRequest(http_method, http_url, parameters)
279     from_consumer_and_token = staticmethod(from_consumer_and_token)
280
281     def from_token_and_callback(token, callback=None, http_method=HTTP_METHOD,
282             http_url=None, parameters=None):
283         if not parameters:
284             parameters = {}
285
286         parameters['oauth_token'] = token.key
287
288         if callback:
289             parameters['oauth_callback'] = callback
290
291         return OAuthRequest(http_method, http_url, parameters)
292     from_token_and_callback = staticmethod(from_token_and_callback)
293
294     def _split_header(header):
295         """Turn Authorization: header into parameters."""
296         params = {}
297         parts = header.split(',')
298         for param in parts:
299             # Ignore realm parameter.
300             if param.find('realm') > -1:
301                 continue
302             # Remove whitespace.
303             param = param.strip()
304             # Split key-value.
305             param_parts = param.split('=', 1)
306             # Remove quotes and unescape the value.
307             params[param_parts[0]] = urllib.unquote(param_parts[1].strip('\"'))
308         return params
309     _split_header = staticmethod(_split_header)
310
311     def _split_url_string(param_str):
312         """Turn URL string into parameters."""
313         parameters = cgi.parse_qs(param_str, keep_blank_values=False)
314         for k, v in parameters.iteritems():
315             parameters[k] = urllib.unquote(v[0])
316         return parameters
317     _split_url_string = staticmethod(_split_url_string)
318
319 class OAuthServer(object):
320     """A worker to check the validity of a request against a data store."""
321     timestamp_threshold = 300 # In seconds, five minutes.
322     version = VERSION
323     signature_methods = None
324     data_store = None
325
326     def __init__(self, data_store=None, signature_methods=None):
327         self.data_store = data_store
328         self.signature_methods = signature_methods or {}
329
330     def set_data_store(self, data_store):
331         self.data_store = data_store
332
333     def get_data_store(self):
334         return self.data_store
335
336     def add_signature_method(self, signature_method):
337         self.signature_methods[signature_method.get_name()] = signature_method
338         return self.signature_methods
339
340     def fetch_request_token(self, oauth_request):
341         """Processes a request_token request and returns the
342         request token on success.
343         """
344         try:
345             # Get the request token for authorization.
346             token = self._get_token(oauth_request, 'request')
347         except OAuthError:
348             # No token required for the initial token request.
349             version = self._get_version(oauth_request)
350             consumer = self._get_consumer(oauth_request)
351             self._check_signature(oauth_request, consumer, None)
352             # Fetch a new token.
353             token = self.data_store.fetch_request_token(consumer)
354         return token
355
356     def fetch_access_token(self, oauth_request):
357         """Processes an access_token request and returns the
358         access token on success.
359         """
360         version = self._get_version(oauth_request)
361         consumer = self._get_consumer(oauth_request)
362         # Get the request token.
363         token = self._get_token(oauth_request, 'request')
364         self._check_signature(oauth_request, consumer, token)
365         new_token = self.data_store.fetch_access_token(consumer, token)
366         return new_token
367
368     def verify_request(self, oauth_request):
369         """Verifies an api call and checks all the parameters."""
370         # -> consumer and token
371         version = self._get_version(oauth_request)
372         consumer = self._get_consumer(oauth_request)
373         # Get the access token.
374         token = self._get_token(oauth_request, 'access')
375         self._check_signature(oauth_request, consumer, token)
376         parameters = oauth_request.get_nonoauth_parameters()
377         return consumer, token, parameters
378
379     def authorize_token(self, token, user):
380         """Authorize a request token."""
381         return self.data_store.authorize_request_token(token, user)
382
383     def get_callback(self, oauth_request):
384         """Get the callback URL."""
385         return oauth_request.get_parameter('oauth_callback')
386
387     def build_authenticate_header(self, realm=''):
388         """Optional support for the authenticate header."""
389         return {'WWW-Authenticate': 'OAuth realm="%s"' % realm}
390
391     def _get_version(self, oauth_request):
392         """Verify the correct version request for this server."""
393         try:
394             version = oauth_request.get_parameter('oauth_version')
395         except:
396             version = VERSION
397         if version and version != self.version:
398             raise OAuthError('OAuth version %s not supported.' % str(version))
399         return version
400
401     def _get_signature_method(self, oauth_request):
402         """Figure out the signature with some defaults."""
403         try:
404             signature_method = oauth_request.get_parameter(
405                 'oauth_signature_method')
406         except:
407             signature_method = SIGNATURE_METHOD
408         try:
409             # Get the signature method object.
410             signature_method = self.signature_methods[signature_method]
411         except:
412             signature_method_names = ', '.join(self.signature_methods.keys())
413             raise OAuthError('Signature method %s not supported try one of the '
414                 'following: %s' % (signature_method, signature_method_names))
415
416         return signature_method
417
418     def _get_consumer(self, oauth_request):
419         consumer_key = oauth_request.get_parameter('oauth_consumer_key')
420         consumer = self.data_store.lookup_consumer(consumer_key)
421         if not consumer:
422             raise OAuthError('Invalid consumer.')
423         return consumer
424
425     def _get_token(self, oauth_request, token_type='access'):
426         """Try to find the token for the provided request token key."""
427         token_field = oauth_request.get_parameter('oauth_token')
428         token = self.data_store.lookup_token(token_type, token_field)
429         if not token:
430             raise OAuthError('Invalid %s token: %s' % (token_type, token_field))
431         return token
432
433     def _check_signature(self, oauth_request, consumer, token):
434         timestamp, nonce = oauth_request._get_timestamp_nonce()
435         self._check_timestamp(timestamp)
436         self._check_nonce(consumer, token, nonce)
437         signature_method = self._get_signature_method(oauth_request)
438         try:
439             signature = oauth_request.get_parameter('oauth_signature')
440         except:
441             raise OAuthError('Missing signature.')
442         # Validate the signature.
443         valid_sig = signature_method.check_signature(oauth_request, consumer,
444             token, signature)
445         if not valid_sig:
446             key, base = signature_method.build_signature_base_string(
447                 oauth_request, consumer, token)
448             raise OAuthError('Invalid signature. Expected signature base '
449                 'string: %s' % base)
450         built = signature_method.build_signature(oauth_request, consumer, token)
451
452     def _check_timestamp(self, timestamp):
453         """Verify that timestamp is recentish."""
454         timestamp = int(timestamp)
455         now = int(time.time())
456         lapsed = now - timestamp
457         if lapsed > self.timestamp_threshold:
458             raise OAuthError('Expired timestamp: given %d and now %s has a '
459                 'greater difference than threshold %d' %
460                 (timestamp, now, self.timestamp_threshold))
461
462     def _check_nonce(self, consumer, token, nonce):
463         """Verify that the nonce is uniqueish."""
464         nonce = self.data_store.lookup_nonce(consumer, token, nonce)
465         if nonce:
466             raise OAuthError('Nonce already used: %s' % str(nonce))
467
468
469 class OAuthClient(object):
470     """OAuthClient is a worker to attempt to execute a request."""
471     consumer = None
472     token = None
473
474     def __init__(self, oauth_consumer, oauth_token):
475         self.consumer = oauth_consumer
476         self.token = oauth_token
477
478     def get_consumer(self):
479         return self.consumer
480
481     def get_token(self):
482         return self.token
483
484     def fetch_request_token(self, oauth_request):
485         """-> OAuthToken."""
486         raise NotImplementedError
487
488     def fetch_access_token(self, oauth_request):
489         """-> OAuthToken."""
490         raise NotImplementedError
491
492     def access_resource(self, oauth_request):
493         """-> Some protected resource."""
494         raise NotImplementedError
495
496
497 class OAuthDataStore(object):
498     """A database abstraction used to lookup consumers and tokens."""
499
500     def lookup_consumer(self, key):
501         """-> OAuthConsumer."""
502         raise NotImplementedError
503
504     def lookup_token(self, oauth_consumer, token_type, token_token):
505         """-> OAuthToken."""
506         raise NotImplementedError
507
508     def lookup_nonce(self, oauth_consumer, oauth_token, nonce):
509         """-> OAuthToken."""
510         raise NotImplementedError
511
512     def fetch_request_token(self, oauth_consumer):
513         """-> OAuthToken."""
514         raise NotImplementedError
515
516     def fetch_access_token(self, oauth_consumer, oauth_token):
517         """-> OAuthToken."""
518         raise NotImplementedError
519
520     def authorize_request_token(self, oauth_token, user):
521         """-> OAuthToken."""
522         raise NotImplementedError
523
524
525 class OAuthSignatureMethod(object):
526     """A strategy class that implements a signature method."""
527     def get_name(self):
528         """-> str."""
529         raise NotImplementedError
530
531     def build_signature_base_string(self, oauth_request, oauth_consumer, oauth_token):
532         """-> str key, str raw."""
533         raise NotImplementedError
534
535     def build_signature(self, oauth_request, oauth_consumer, oauth_token):
536         """-> str."""
537         raise NotImplementedError
538
539     def check_signature(self, oauth_request, consumer, token, signature):
540         built = self.build_signature(oauth_request, consumer, token)
541         return built == signature
542
543
544 class OAuthSignatureMethod_HMAC_SHA1(OAuthSignatureMethod):
545
546     def get_name(self):
547         return 'HMAC-SHA1'
548
549     def build_signature_base_string(self, oauth_request, consumer, token):
550         sig = (
551             escape(oauth_request.get_normalized_http_method()),
552             escape(oauth_request.get_normalized_http_url()),
553             escape(oauth_request.get_normalized_parameters()),
554         )
555
556         key = '%s&' % escape(consumer.secret)
557         if token:
558             key += escape(token.secret)
559         raw = '&'.join(sig)
560         return key, raw
561
562     def build_signature(self, oauth_request, consumer, token):
563         """Builds the base signature string."""
564         key, raw = self.build_signature_base_string(oauth_request, consumer,
565             token)
566
567         # HMAC object.
568         try:
569             import hashlib # 2.5
570             hashed = hmac.new(key, raw, hashlib.sha1)
571         except:
572             import sha # Deprecated
573             hashed = hmac.new(key, raw, sha)
574
575         # Calculate the digest base 64.
576         return binascii.b2a_base64(hashed.digest())[:-1]
577
578
579 class OAuthSignatureMethod_PLAINTEXT(OAuthSignatureMethod):
580
581     def get_name(self):
582         return 'PLAINTEXT'
583
584     def build_signature_base_string(self, oauth_request, consumer, token):
585         """Concatenates the consumer key and secret."""
586         sig = '%s&' % escape(consumer.secret)
587         if token:
588             sig = sig + escape(token.secret)
589         return sig, sig
590
591     def build_signature(self, oauth_request, consumer, token):
592         key, raw = self.build_signature_base_string(oauth_request, consumer,
593             token)
594         return key