]> git.openstreetmap.org Git - osqa.git/blob - forum/templatetags/smart_if.py
initial import
[osqa.git] / forum / templatetags / smart_if.py
1 """
2 A smarter {% if %} tag for django templates.
3
4 While retaining current Django functionality, it also handles equality,
5 greater than and less than operators. Some common case examples::
6
7     {% if articles|length >= 5 %}...{% endif %}
8     {% if "ifnotequal tag" != "beautiful" %}...{% endif %}
9 """
10 import unittest
11 from django import template
12
13
14 register = template.Library()
15
16
17 #==============================================================================
18 # Calculation objects
19 #==============================================================================
20
21 class BaseCalc(object):
22     def __init__(self, var1, var2=None, negate=False):
23         self.var1 = var1
24         self.var2 = var2
25         self.negate = negate
26
27     def resolve(self, context):
28         try:
29             var1, var2 = self.resolve_vars(context)
30             outcome = self.calculate(var1, var2)
31         except:
32             outcome = False
33         if self.negate:
34             return not outcome
35         return outcome
36
37     def resolve_vars(self, context):
38         var2 = self.var2 and self.var2.resolve(context)
39         return self.var1.resolve(context), var2
40
41     def calculate(self, var1, var2):
42         raise NotImplementedError()
43
44
45 class Or(BaseCalc):
46     def calculate(self, var1, var2):
47         return var1 or var2
48
49
50 class And(BaseCalc):
51     def calculate(self, var1, var2):
52         return var1 and var2
53
54
55 class Equals(BaseCalc):
56     def calculate(self, var1, var2):
57         return var1 == var2
58
59
60 class Greater(BaseCalc):
61     def calculate(self, var1, var2):
62         return var1 > var2
63
64
65 class GreaterOrEqual(BaseCalc):
66     def calculate(self, var1, var2):
67         return var1 >= var2
68
69
70 class In(BaseCalc):
71     def calculate(self, var1, var2):
72         return var1 in var2
73
74
75 #==============================================================================
76 # Tests
77 #==============================================================================
78
79 class TestVar(object):
80     """
81     A basic self-resolvable object similar to a Django template variable. Used
82     to assist with tests.
83     """
84     def __init__(self, value):
85         self.value = value
86
87     def resolve(self, context):
88         return self.value
89
90
91 class SmartIfTests(unittest.TestCase):
92     def setUp(self):
93         self.true = TestVar(True)
94         self.false = TestVar(False)
95         self.high = TestVar(9000)
96         self.low = TestVar(1)
97
98     def assertCalc(self, calc, context=None):
99         """
100         Test a calculation is True, also checking the inverse "negate" case.
101         """
102         context = context or {}
103         self.assert_(calc.resolve(context))
104         calc.negate = not calc.negate
105         self.assertFalse(calc.resolve(context))
106
107     def assertCalcFalse(self, calc, context=None):
108         """
109         Test a calculation is False, also checking the inverse "negate" case.
110         """
111         context = context or {}
112         self.assertFalse(calc.resolve(context))
113         calc.negate = not calc.negate
114         self.assert_(calc.resolve(context))
115
116     def test_or(self):
117         self.assertCalc(Or(self.true))
118         self.assertCalcFalse(Or(self.false))
119         self.assertCalc(Or(self.true, self.true))
120         self.assertCalc(Or(self.true, self.false))
121         self.assertCalc(Or(self.false, self.true))
122         self.assertCalcFalse(Or(self.false, self.false))
123
124     def test_and(self):
125         self.assertCalc(And(self.true, self.true))
126         self.assertCalcFalse(And(self.true, self.false))
127         self.assertCalcFalse(And(self.false, self.true))
128         self.assertCalcFalse(And(self.false, self.false))
129
130     def test_equals(self):
131         self.assertCalc(Equals(self.low, self.low))
132         self.assertCalcFalse(Equals(self.low, self.high))
133
134     def test_greater(self):
135         self.assertCalc(Greater(self.high, self.low))
136         self.assertCalcFalse(Greater(self.low, self.low))
137         self.assertCalcFalse(Greater(self.low, self.high))
138
139     def test_greater_or_equal(self):
140         self.assertCalc(GreaterOrEqual(self.high, self.low))
141         self.assertCalc(GreaterOrEqual(self.low, self.low))
142         self.assertCalcFalse(GreaterOrEqual(self.low, self.high))
143
144     def test_in(self):
145         list_ = TestVar([1,2,3])
146         invalid_list = TestVar(None)
147         self.assertCalc(In(self.low, list_))
148         self.assertCalcFalse(In(self.low, invalid_list))
149
150     def test_parse_bits(self):
151         var = IfParser([True]).parse()
152         self.assert_(var.resolve({}))
153         var = IfParser([False]).parse()
154         self.assertFalse(var.resolve({}))
155
156         var = IfParser([False, 'or', True]).parse()
157         self.assert_(var.resolve({}))
158
159         var = IfParser([False, 'and', True]).parse()
160         self.assertFalse(var.resolve({}))
161
162         var = IfParser(['not', False, 'and', 'not', False]).parse()
163         self.assert_(var.resolve({}))
164
165         var = IfParser(['not', 'not', True]).parse()
166         self.assert_(var.resolve({}))
167
168         var = IfParser([1, '=', 1]).parse()
169         self.assert_(var.resolve({}))
170
171         var = IfParser([1, 'not', '=', 1]).parse()
172         self.assertFalse(var.resolve({}))
173
174         var = IfParser([1, 'not', 'not', '=', 1]).parse()
175         self.assert_(var.resolve({}))
176
177         var = IfParser([1, '!=', 1]).parse()
178         self.assertFalse(var.resolve({}))
179
180         var = IfParser([3, '>', 2]).parse()
181         self.assert_(var.resolve({}))
182
183         var = IfParser([1, '<', 2]).parse()
184         self.assert_(var.resolve({}))
185
186         var = IfParser([2, 'not', 'in', [2, 3]]).parse()
187         self.assertFalse(var.resolve({}))
188
189         var = IfParser([1, 'or', 1, '=', 2]).parse()
190         self.assert_(var.resolve({}))
191
192     def test_boolean(self):
193         var = IfParser([True, 'and', True, 'and', True]).parse()
194         self.assert_(var.resolve({}))
195         var = IfParser([False, 'or', False, 'or', True]).parse()
196         self.assert_(var.resolve({}))
197         var = IfParser([True, 'and', False, 'or', True]).parse()
198         self.assert_(var.resolve({}))
199         var = IfParser([False, 'or', True, 'and', True]).parse()
200         self.assert_(var.resolve({}))
201
202         var = IfParser([True, 'and', True, 'and', False]).parse()
203         self.assertFalse(var.resolve({}))
204         var = IfParser([False, 'or', False, 'or', False]).parse()
205         self.assertFalse(var.resolve({}))
206         var = IfParser([False, 'or', True, 'and', False]).parse()
207         self.assertFalse(var.resolve({}))
208         var = IfParser([False, 'and', True, 'or', False]).parse()
209         self.assertFalse(var.resolve({}))
210
211     def test_invalid(self):
212         self.assertRaises(ValueError, IfParser(['not']).parse)
213         self.assertRaises(ValueError, IfParser(['==']).parse)
214         self.assertRaises(ValueError, IfParser([1, 'in']).parse)
215         self.assertRaises(ValueError, IfParser([1, '>', 'in']).parse)
216         self.assertRaises(ValueError, IfParser([1, '==', 'not', 'not']).parse)
217         self.assertRaises(ValueError, IfParser([1, 2]).parse)
218
219
220 OPERATORS = {
221     '=': (Equals, True),
222     '==': (Equals, True),
223     '!=': (Equals, False),
224     '>': (Greater, True),
225     '>=': (GreaterOrEqual, True),
226     '<=': (Greater, False),
227     '<': (GreaterOrEqual, False),
228     'or': (Or, True),
229     'and': (And, True),
230     'in': (In, True),
231 }
232 BOOL_OPERATORS = ('or', 'and')
233
234
235 class IfParser(object):
236     error_class = ValueError
237
238     def __init__(self, tokens):
239         self.tokens = tokens
240
241     def _get_tokens(self):
242         return self._tokens
243
244     def _set_tokens(self, tokens):
245         self._tokens = tokens
246         self.len = len(tokens)
247         self.pos = 0
248
249     tokens = property(_get_tokens, _set_tokens)
250
251     def parse(self):
252         if self.at_end():
253             raise self.error_class('No variables provided.')
254         var1 = self.get_bool_var()
255         while not self.at_end():
256             op, negate = self.get_operator()
257             var2 = self.get_bool_var()
258             var1 = op(var1, var2, negate=negate)
259         return var1
260
261     def get_token(self, eof_message=None, lookahead=False):
262         negate = True
263         token = None
264         pos = self.pos
265         while token is None or token == 'not':
266             if pos >= self.len:
267                 if eof_message is None:
268                     raise self.error_class()
269                 raise self.error_class(eof_message)
270             token = self.tokens[pos]
271             negate = not negate
272             pos += 1
273         if not lookahead:
274             self.pos = pos
275         return token, negate
276
277     def at_end(self):
278         return self.pos >= self.len
279
280     def create_var(self, value):
281         return TestVar(value)
282
283     def get_bool_var(self):
284         """
285         Returns either a variable by itself or a non-boolean operation (such as
286         ``x == 0`` or ``x < 0``).
287
288         This is needed to keep correct precedence for boolean operations (i.e.
289         ``x or x == 0`` should be ``x or (x == 0)``, not ``(x or x) == 0``).
290         """
291         var = self.get_var()
292         if not self.at_end():
293             op_token = self.get_token(lookahead=True)[0]
294             if isinstance(op_token, basestring) and (op_token not in
295                                                      BOOL_OPERATORS):
296                 op, negate = self.get_operator()
297                 return op(var, self.get_var(), negate=negate)
298         return var
299
300     def get_var(self):
301         token, negate = self.get_token('Reached end of statement, still '
302                                        'expecting a variable.')
303         if isinstance(token, basestring) and token in OPERATORS:
304             raise self.error_class('Expected variable, got operator (%s).' %
305                                    token)
306         var = self.create_var(token)
307         if negate:
308             return Or(var, negate=True)
309         return var
310
311     def get_operator(self):
312         token, negate = self.get_token('Reached end of statement, still '
313                                        'expecting an operator.')
314         if not isinstance(token, basestring) or token not in OPERATORS:
315             raise self.error_class('%s is not a valid operator.' % token)
316         if self.at_end():
317             raise self.error_class('No variable provided after "%s".' % token)
318         op, true = OPERATORS[token]
319         if not true:
320             negate = not negate
321         return op, negate
322
323
324 #==============================================================================
325 # Actual templatetag code.
326 #==============================================================================
327
328 class TemplateIfParser(IfParser):
329     error_class = template.TemplateSyntaxError
330
331     def __init__(self, parser, *args, **kwargs):
332         self.template_parser = parser
333         return super(TemplateIfParser, self).__init__(*args, **kwargs)
334
335     def create_var(self, value):
336         return self.template_parser.compile_filter(value)
337
338
339 class SmartIfNode(template.Node):
340     def __init__(self, var, nodelist_true, nodelist_false=None):
341         self.nodelist_true, self.nodelist_false = nodelist_true, nodelist_false
342         self.var = var
343
344     def render(self, context):
345         if self.var.resolve(context):
346             return self.nodelist_true.render(context)
347         if self.nodelist_false:
348             return self.nodelist_false.render(context)
349         return ''
350
351     def __repr__(self):
352         return "<Smart If node>"
353
354     def __iter__(self):
355         for node in self.nodelist_true:
356             yield node
357         if self.nodelist_false:
358             for node in self.nodelist_false:
359                 yield node
360
361     def get_nodes_by_type(self, nodetype):
362         nodes = []
363         if isinstance(self, nodetype):
364             nodes.append(self)
365         nodes.extend(self.nodelist_true.get_nodes_by_type(nodetype))
366         if self.nodelist_false:
367             nodes.extend(self.nodelist_false.get_nodes_by_type(nodetype))
368         return nodes
369
370
371 @register.tag('if')
372 def smart_if(parser, token):
373     """
374     A smarter {% if %} tag for django templates.
375
376     While retaining current Django functionality, it also handles equality,
377     greater than and less than operators. Some common case examples::
378
379         {% if articles|length >= 5 %}...{% endif %}
380         {% if "ifnotequal tag" != "beautiful" %}...{% endif %}
381
382     Arguments and operators _must_ have a space between them, so
383     ``{% if 1>2 %}`` is not a valid smart if tag.
384
385     All supported operators are: ``or``, ``and``, ``in``, ``=`` (or ``==``),
386     ``!=``, ``>``, ``>=``, ``<`` and ``<=``.
387     """
388     bits = token.split_contents()[1:]
389     var = TemplateIfParser(parser, bits).parse()
390     nodelist_true = parser.parse(('else', 'endif'))
391     token = parser.next_token()
392     if token.contents == 'else':
393         nodelist_false = parser.parse(('endif',))
394         parser.delete_first_token()
395     else:
396         nodelist_false = None
397     return SmartIfNode(var, nodelist_true, nodelist_false)
398
399
400 if __name__ == '__main__':
401     unittest.main()