Skip to content

Commit

Permalink
Add per-view force_https
Browse files Browse the repository at this point in the history
  • Loading branch information
Jon Wayne Parrott committed Mar 2, 2017
1 parent e61e8f3 commit 011fe58
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ Per-view options
~~~~~~~~~~~~~~~~

Sometimes you want to change the policy for a specific view. The
``frame_options``, ``frame_options_allow_from``, and
``force_https``, ``frame_options``, ``frame_options_allow_from``, and
``content_security_policy`` options can be changed on a per-view basis.

.. code:: python
Expand Down
45 changes: 31 additions & 14 deletions flask_talisman/talisman.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,27 @@ def init_app(
app.before_request(self._force_https)
app.after_request(self._set_response_headers)

def _update_local_options(
self,
frame_options=_sentinel,
frame_options_allow_from=_sentinel,
content_security_policy=_sentinel):
def _update_local_options(self):
"""Updates view-local options with defaults or specified values."""
view_function = flask.current_app.view_functions[
flask.request.endpoint]

view_options = getattr(
view_function, 'talisman_view_options', {})

print(type(flask.request.endpoint))
print(flask.request.endpoint)
print(flask.request.endpoint.__hash__())
print(view_options)
force_https = view_options.get('force_https', _sentinel)
frame_options = view_options.get('frame_options', _sentinel)
frame_options_allow_from = view_options.get(
'frame_options_allow_from', _sentinel)
content_security_policy = view_options.get(
'content_security_policy', _sentinel)
setattr(self.local_options, 'force_https',
force_https if force_https is not _sentinel
else self.force_https)
setattr(self.local_options, 'frame_options',
frame_options if frame_options is not _sentinel
else self.frame_options)
Expand All @@ -179,7 +194,7 @@ def _force_https(self):
flask.request.headers.get('X-Forwarded-Proto', 'http') == 'https',
]

if self.force_https and not any(criteria):
if self.local_options.force_https and not any(criteria):
if flask.request.url.startswith('http://'):
url = flask.request.url.replace('http://', 'https://', 1)
code = 302
Expand Down Expand Up @@ -253,6 +268,7 @@ def _set_hsts_headers(self, headers):

def __call__(
self,
force_https=_sentinel,
frame_options=_sentinel,
frame_options_allow_from=_sentinel,
content_security_policy=_sentinel):
Expand All @@ -277,12 +293,13 @@ def embeddable():
return 'Embeddable'
"""
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
self._update_local_options(
frame_options=frame_options,
frame_options_allow_from=frame_options_allow_from,
content_security_policy=content_security_policy)
return f(*args, **kwargs)
return decorated_function
setattr(f, 'talisman_view_options', dict(
force_https=force_https,
frame_options=frame_options,
frame_options_allow_from=frame_options_allow_from,
content_security_policy=content_security_policy))
print(f.talisman_view_options)
print(f)
print(f.__hash__())
return f
return decorator
10 changes: 9 additions & 1 deletion flask_talisman/talisman_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def testContentSecurityPolicyOptionsReport(self):
content_security_policy_report_only=True)

def testDecorator(self):

@self.app.route('/nocsp')
@self.talisman(content_security_policy=None)
def nocsp():
Expand All @@ -179,6 +178,15 @@ def nocsp():
self.assertFalse('Content-Security-Policy' in response.headers)
self.assertEqual(response.headers['X-Frame-Options'], 'SAMEORIGIN')

def testDecoratorForceHttps(self):
@self.app.route('/noforcehttps')
@self.talisman(force_https=False)
def noforcehttps():
return 'Hello, world'

response = self.client.get('/noforcehttps')
self.assertEqual(response.status_code, 200)

def testForceFileSave(self):
self.talisman.force_file_save = True
response = self.client.get('/', environ_overrides=HTTPS_ENVIRON)
Expand Down

0 comments on commit 011fe58

Please sign in to comment.