From d7298dc1b49ed3c31c6f46b105d5ee8ed5bdfb71 Mon Sep 17 00:00:00 2001 From: froosty Date: Mon, 7 May 2018 17:58:20 +0300 Subject: [PATCH] Init library --- .gitignore | 28 +++ .isort.cfg | 2 + .travis.yml | 20 +++ LICENCE | 21 +++ MANIFEST.in | 4 + README.rst | 262 +++++++++++++++++++++++++++- aiohttp_csrf/__init__.py | 175 +++++++++++++++++++ aiohttp_csrf/policy.py | 55 ++++++ aiohttp_csrf/storage.py | 116 ++++++++++++ aiohttp_csrf/token_generator.py | 30 ++++ demo/manual_protection.py | 68 ++++++++ demo/middleware.py | 106 +++++++++++ demo/session_storage.py | 113 ++++++++++++ pytest.ini | 2 + requirements_dev.txt | 9 + setup.py | 70 ++++++++ test.py | 35 ++++ tests/__init__.py | 0 tests/conftest.py | 69 ++++++++ tests/test_custom_error_renderer.py | 99 +++++++++++ tests/test_errors.py | 77 ++++++++ tests/test_exempt_decorator.py | 55 ++++++ tests/test_form_policy.py | 158 +++++++++++++++++ tests/test_header_policy.py | 111 ++++++++++++ tests/test_protect_decorator.py | 136 +++++++++++++++ tests/test_storage_api.py | 65 +++++++ tests/test_token_generator.py | 37 ++++ tox.ini | 23 +++ 28 files changed, 1945 insertions(+), 1 deletion(-) create mode 100644 .gitignore create mode 100644 .isort.cfg create mode 100644 .travis.yml create mode 100644 LICENCE create mode 100644 MANIFEST.in create mode 100644 aiohttp_csrf/__init__.py create mode 100644 aiohttp_csrf/policy.py create mode 100644 aiohttp_csrf/storage.py create mode 100644 aiohttp_csrf/token_generator.py create mode 100644 demo/manual_protection.py create mode 100644 demo/middleware.py create mode 100644 demo/session_storage.py create mode 100644 pytest.ini create mode 100644 requirements_dev.txt create mode 100644 setup.py create mode 100644 test.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_custom_error_renderer.py create mode 100644 tests/test_errors.py create mode 100644 tests/test_exempt_decorator.py create mode 100644 tests/test_form_policy.py create mode 100644 tests/test_header_policy.py create mode 100644 tests/test_protect_decorator.py create mode 100644 tests/test_storage_api.py create mode 100644 tests/test_token_generator.py create mode 100644 tox.ini diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3ae2f48 --- /dev/null +++ b/.gitignore @@ -0,0 +1,28 @@ +# python specific +env* +.cache/ +.pytest_cache/ +.idea/ +*.pyc +*.so +*.pyd +aiohttp_csrf.egg-info +build/* +dist/* +MANIFEST +__pycache__/ +*.egg-info/ +.coverage +.python-version +htmlcov + +# generic files to ignore +*~ +*.lock +*.DS_Store +*.swp +*.out + +.tox/ +deps/ +docs/_build/ diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000..47f77e9 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,2 @@ +[settings] +known_third_party=aiohttp_csrf diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..53c1bfd --- /dev/null +++ b/.travis.yml @@ -0,0 +1,20 @@ +dist: trusty +language: python +python: + - "3.5" + - "3.6" +install: + - pip install -U setuptools + - pip install -U pip + - pip install -U wheel + - pip install -U tox +script: + - export TOXENV=py`python -c 'import sys; print("".join(map(str, sys.version_info[:2])))'` + - echo "$TOXENV" + + - tox +cache: + directories: + - $HOME/.cache/pip +notifications: + email: false diff --git a/LICENCE b/LICENCE new file mode 100644 index 0000000..22f6914 --- /dev/null +++ b/LICENCE @@ -0,0 +1,21 @@ +The MIT License + +Copyright (c) Ocean S.A. https://ocean.io/ + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..e24206f --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,4 @@ +include README.rst +include LICENSE +recursive-exclude * __pycache__ +recursive-exclude * *.py[co] diff --git a/README.rst b/README.rst index b779bd2..4ebe07c 100644 --- a/README.rst +++ b/README.rst @@ -1 +1,261 @@ -# aiohttp-csrf \ No newline at end of file +aiohttp_csrf +============ + +The library provides csrf (xsrf) protection for `aiohttp.web`__. + +.. _aiohttp_web: https://docs.aiohttp.org/en/latest/web.html + +__ aiohttp_web_ + +.. image:: https://img.shields.io/travis/wikibusiness/aiohttp-csrf.svg + :target: https://travis-ci.org/wikibusiness/aiohttp-csrf + +Basic usage +----------- + +The library allows you to implement csrf (xsrf) protection for requests + + +Basic usage example: + +.. code-block:: python + + import aiohttp_csrf + from aiohttp import web + + FORM_FIELD_NAME = '_csrf_token' + COOKIE_NAME = 'csrf_token' + + + def make_app(): + csrf_policy = aiohttp_csrf.policy.FormPolicy(FORM_FIELD_NAME) + + csrf_storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME) + + app = web.Application() + + aiohttp_csrf.setup(app, policy=csrf_policy, storage=csrf_storage) + + app.middlewares.append(aiohttp_csrf.csrf_middleware) + + async def handler_get_form_with_token(request): + token = await aiohttp_csrf.generate_token(request) + + + body = ''' + + Form with csrf protection + +
+ + + +
+ + + ''' # noqa + + body = body.format(field_name=FORM_FIELD_NAME, token=token) + + return web.Response( + body=body.encode('utf-8'), + content_type='text/html', + ) + + async def handler_post_check(request): + post = await request.post() + + body = 'Hello, {name}'.format(name=post['name']) + + return web.Response( + body=body.encode('utf-8'), + content_type='text/html', + ) + + app.router.add_route( + 'GET', + '/', + handler_get_form_with_token, + ) + + app.router.add_route( + 'POST', + '/', + handler_post_check, + ) + + return app + + + web.run_app(make_app()) + + +Initialize +~~~~~~~~~~ + + +First of all, you need to initialize ``aiohttp_csrf`` in your application: + +.. code-block:: python + + app = web.Application() + + csrf_policy = aiohttp_csrf.policy.FormPolicy(FORM_FIELD_NAME) + + csrf_storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME) + + aiohttp_csrf.setup(app, policy=csrf_policy, storage=csrf_storage) + + +Middleware and decorators +~~~~~~~~~~~~~~~~~~~~~~~~~ + + +After initialize you can use ``@aiohttp_csrf.csrf_protect`` for handlers, that you want to protect. +Or you can initialize ``aiohttp_csrf.csrf_middleware`` and do not disturb about using decorator (`full middleware example here`_): + +.. _full middleware example here: demo/middleware.py + +.. code-block:: python + + ... + app.middlewares.append(aiohttp_csrf.csrf_middleware) + ... + + +In this case all your handlers will be protected. + + +**Note:** we strongly recommend to use ``aiohttp_csrf.csrf_middleware`` and ``@aiohttp_csrf.csrf_exempt`` instead of manually managing with ``@aiohttp_csrf.csrf_protect``. +But if you prefer to use ``@aiohttp_csrf.csrf_protect``, don't forget to use ``@aiohttp_csrf.csrf_protect`` for both methods: GET and POST +(`manual protection example`_) + +.. _manual protection example: demo/manual_protection.py + + +If you want to use middleware, but need handlers without protection, you can use ``@aiohttp_csrf.csrf_exempt``. +Mark you handler with this decorator and this handler will not check the token: + +.. code-block:: python + + @aiohttp_csrf.csrf_exempt + async def handler_post_not_check(request): + ... + + + +Generate token +~~~~~~~~~~~~~~ + +For generate token you need to call ``aiohttp_csrf.generate_token`` in your handler: + +.. code-block:: python + + @aiohttp_csrf.csrf_protect + async def handler_get(request): + token = await aiohttp_csrf.generate_token(request) + ... + + +Advanced usage +-------------- + + +Policies +~~~~~~~~ + +You can use different policies for check tokens. Library provides 3 types of policy: + +- **FormPolicy**. This policy will search token in the body of your POST request (Usually use for forms). You need to specify name of field that will be checked. +- **HeaderPolicy**. This policy will search token in headers of your POST request (Usually use for AJAX requests). You need to specify name of header that will be checked. +- **FormAndHeaderPolicy**. This policy combines behavior of **FormPolicy** and **HeaderPolicy**. + +You can implement your custom policies if needed. But make sure that your custom policy implements ``aiohttp_csrf.policy.AbstractPolicy`` interface. + +Storages +~~~~~~~~ + +You can use different types of storages for storing token. Library provides 2 types of storage: + +- **CookieStorage**. Your token will be stored in cookie variable. You need to specify cookie name. +- **SessionStorage**. Your token will be stored in session. You need to specify session variable name. + +**Important:** If you want to use session storage, you need setup aiohttp_session in your application +(`session storage example`_) + +.. _session storage example: demo/session_storage.py#L22 + +You can implement your custom storages if needed. But make sure that your custom storage implements ``aiohttp_csrf.storage.AbstractStorage`` interface. + + +Token generators +~~~~~~~~~~~~~~~~ + +You can use different token generator in your application. +By default storages using ``aiohttp_csrf.token_generator.SimpleTokenGenerator`` + +But if you need more secure token generator - you can use ``aiohttp_csrf.token_generator.HashedTokenGenerator`` + +And you can implement your custom token generators if needed. But make sure that your custom token generator implements ``aiohttp_csrf.token_generator.AbstractTokenGenerator`` interface. + + +Invalid token behavior +~~~~~~~~~~~~~~~~~~~~~~ + +By default, if token is invalid, ``aiohttp_csrf`` will raise ``aiohttp.web.HTTPForbidden`` exception. + +You have abbility to specify your custom error handler. It can be: + +- **callable instance**. Input parameter - aiohttp request. +.. code-block:: python + + def custom_error_handler(request): + # do something + return aiohttp.web.Response(status=403) + + # or + + async def custom_async_error_handler(request): + # await do something + return aiohttp.web.Response(status=403) + +It will be called instead of protected handler. + +- **sub class of Exception**. In this case this Exception will be raised. + +.. code-block:: python + + class CustomException(Exception): + pass + + +You can specify custom error handler globally, when initialize ``aiohttp_csrf`` in your application: + +.. code-block:: python + + ... + class CustomException(Exception): + pass + + ... + aiohttp_csrf.setup(app, policy=csrf_policy, storage=csrf_storage, error_renderer=CustomException) + ... + +In this case custom error handler will be applied to all protected handlers. + +Or you can specify custom error handler locally, for specific handler: + +.. code-block:: python + + ... + class CustomException(Exception): + pass + + ... + @aiohttp_csrf.csrf_protect(error_renderer=CustomException) + def handler_with_custom_csrf_error(request): + ... + + +In this case custom error handler will be applied to this handler only. +For all other handlers will be applied global error handler. diff --git a/aiohttp_csrf/__init__.py b/aiohttp_csrf/__init__.py new file mode 100644 index 0000000..7a4c361 --- /dev/null +++ b/aiohttp_csrf/__init__.py @@ -0,0 +1,175 @@ +import asyncio +import inspect + +from functools import wraps + +from aiohttp import web + +from .policy import AbstractPolicy +from .storage import AbstractStorage + + +__version__ = '0.0.1' + +APP_POLICY_KEY = 'aiohttp_csrf_policy' +APP_STORAGE_KEY = 'aiohttp_csrf_storage' +APP_ERROR_RENDERER_KEY = 'aiohttp_csrf_error_renderer' + +MIDDLEWARE_SKIP_PROPERTY = 'csrf_middleware_skip' + +UNPROTECTED_HTTP_METHODS = ('GET', 'HEAD', 'OPTIONS', 'TRACE') + + +def setup(app, *, policy, storage, error_renderer=web.HTTPForbidden): + if not isinstance(policy, AbstractPolicy): + raise TypeError('Policy must be instance of AbstractPolicy') + + if not isinstance(storage, AbstractStorage): + raise TypeError('Storage must be instance of AbstractStorage') + + if not isinstance(error_renderer, Exception) and not callable(error_renderer): # noqa + raise TypeError( + 'Default error renderer must be instance of Exception or callable.' + ) + + app[APP_POLICY_KEY] = policy + app[APP_STORAGE_KEY] = storage + app[APP_ERROR_RENDERER_KEY] = error_renderer + + +def _get_policy(request): + try: + return request.app[APP_POLICY_KEY] + except KeyError: + raise RuntimeError( + 'Policy not found. Install aiohttp_csrf in your ' + 'aiohttp.web.Application using aiohttp_csrf.setup()' + ) + + +def _get_storage(request): + try: + return request.app[APP_STORAGE_KEY] + except KeyError: + raise RuntimeError( + 'Storage not found. Install aiohttp_csrf in your ' + 'aiohttp.web.Application using aiohttp_csrf.setup()' + ) + + +async def _render_error(request, error_renderer=None): + if error_renderer is None: + try: + error_renderer = request.app[APP_ERROR_RENDERER_KEY] + except KeyError: + raise RuntimeError( + 'Default error renderer not found. Install aiohttp_csrf in ' + 'your aiohttp.web.Application using aiohttp_csrf.setup()' + ) + + if inspect.isclass(error_renderer) and issubclass(error_renderer, Exception): # noqa + raise error_renderer + elif callable(error_renderer): + if asyncio.iscoroutinefunction(error_renderer): + return await error_renderer(request) + else: + return error_renderer(request) + else: + raise NotImplementedError + + +async def get_token(request): + storage = _get_storage(request) + + return await storage.get(request) + + +async def generate_token(request): + storage = _get_storage(request) + + return await storage.generate_new_token(request) + + +async def save_token(request, response): + storage = _get_storage(request) + + await storage.save_token(request, response) + + +def csrf_exempt(handler): + @wraps(handler) + def wrapped_handler(*args, **kwargs): + return handler(*args, **kwargs) + + setattr(wrapped_handler, MIDDLEWARE_SKIP_PROPERTY, True) + + return wrapped_handler + + +async def _check(request): + if not isinstance(request, web.Request): + raise RuntimeError('Can\'t get request from handler params') + + original_token = await get_token(request) + + policy = _get_policy(request) + + return await policy.check(request, original_token) + + +def csrf_protect(handler=None, error_renderer=None): + if ( + error_renderer is not None + and not isinstance(error_renderer, Exception) + and not callable(error_renderer) + ): + raise TypeError( + 'Renderer must be instance of Exception or callable.' + ) + + def wrapper(handler): + @wraps(handler) + async def wrapped(*args, **kwargs): + request = args[-1] + + if isinstance(request, web.View): + request = request.request + + if ( + request.method not in UNPROTECTED_HTTP_METHODS + and not await _check(request) + ): + return await _render_error(request, error_renderer) + + raise_response = False + + try: + response = await handler(*args, **kwargs) + except web.HTTPException as exc: + response = exc + raise_response = True + + if isinstance(response, web.Response): + await save_token(request, response) + + if raise_response: + raise response + + return response + + setattr(wrapped, MIDDLEWARE_SKIP_PROPERTY, True) + + return wrapped + + if handler is None: + return wrapper + + return wrapper(handler) + + +@web.middleware +async def csrf_middleware(request, handler): + if not getattr(handler, MIDDLEWARE_SKIP_PROPERTY, False): + handler = csrf_protect(handler=handler) + + return await handler(request) diff --git a/aiohttp_csrf/policy.py b/aiohttp_csrf/policy.py new file mode 100644 index 0000000..1700e1d --- /dev/null +++ b/aiohttp_csrf/policy.py @@ -0,0 +1,55 @@ +import abc + + +class AbstractPolicy(metaclass=abc.ABCMeta): + @abc.abstractmethod + async def check(self, request, original_value): + pass # pragma: no cover + + +class FormPolicy(AbstractPolicy): + + def __init__(self, field_name): + self.field_name = field_name + + async def check(self, request, original_value): + post = await request.post() + + token = post.get(self.field_name) + + return token == original_value + + +class HeaderPolicy(AbstractPolicy): + + def __init__(self, header_name): + self.header_name = header_name + + async def check(self, request, original_value): + token = request.headers.get(self.header_name) + + return token == original_value + + +class FormAndHeaderPolicy(HeaderPolicy, FormPolicy): + + def __init__(self, header_name, field_name): + self.header_name = header_name + self.field_name = field_name + + async def check(self, request, original_value): + header_check = await HeaderPolicy.check( + self, + request, + original_value, + ) + + if header_check: + return True + + form_check = await FormPolicy.check(self, request, original_value) + + if form_check: + return True + + return False diff --git a/aiohttp_csrf/storage.py b/aiohttp_csrf/storage.py new file mode 100644 index 0000000..c77a227 --- /dev/null +++ b/aiohttp_csrf/storage.py @@ -0,0 +1,116 @@ +import abc + +from .token_generator import AbstractTokenGenerator, SimpleTokenGenerator + +try: + from aiohttp_session import get_session +except ImportError: # pragma: no cover + pass + + +REQUEST_NEW_TOKEN_KEY = 'aiohttp_csrf_new_token' + + +class AbstractStorage(metaclass=abc.ABCMeta): + + @abc.abstractmethod + async def generate_new_token(self, request): + pass # pragma: no cover + + @abc.abstractmethod + async def get(self, request): + pass # pragma: no cover + + @abc.abstractmethod + async def save_token(self, request, response): + pass # pragma: no cover + + +class BaseStorage(AbstractStorage, metaclass=abc.ABCMeta): + + def __init__(self, token_generator=None): + if token_generator is None: + token_generator = SimpleTokenGenerator() + elif not isinstance(token_generator, AbstractTokenGenerator): + raise TypeError( + 'Token generator must be instance of AbstractTokenGenerator', + ) + + self.token_generator = token_generator + + def _generate_token(self): + return self.token_generator.generate() + + async def generate_new_token(self, request): + if REQUEST_NEW_TOKEN_KEY in request: + return request[REQUEST_NEW_TOKEN_KEY] + + token = self._generate_token() + + request[REQUEST_NEW_TOKEN_KEY] = token + + return token + + @abc.abstractmethod + async def _get(self, request): + pass # pragma: no cover + + async def get(self, request): + token = await self._get(request) + + await self.generate_new_token(request) + + return token + + @abc.abstractmethod + async def _save_token(self, request, response, token): + pass # pragma: no cover + + async def save_token(self, request, response): + old_token = await self._get(request) + + if REQUEST_NEW_TOKEN_KEY in request: + token = request[REQUEST_NEW_TOKEN_KEY] + elif old_token is None: + token = await self.generate_new_token(request) + else: + token = None + + if token is not None: + await self._save_token(request, response, token) + + +class CookieStorage(BaseStorage): + + def __init__(self, cookie_name, cookie_kwargs=None, *args, **kwargs): + self.cookie_name = cookie_name + self.cookie_kwargs = cookie_kwargs or {} + + super().__init__(*args, **kwargs) + + async def _get(self, request): + return request.cookies.get(self.cookie_name, None) + + async def _save_token(self, request, response, token): + response.set_cookie( + self.cookie_name, + token, + **self.cookie_kwargs, + ) + + +class SessionStorage(BaseStorage): + def __init__(self, session_name, *args, **kwargs): + self.session_name = session_name + + super().__init__(*args, **kwargs) + + async def _get(self, request): + session = await get_session(request) + + return session.get(self.session_name, None) + + async def _save_token(self, request, response, token): + session = await get_session(request) + + session[self.session_name] = token diff --git a/aiohttp_csrf/token_generator.py b/aiohttp_csrf/token_generator.py new file mode 100644 index 0000000..dcecbb7 --- /dev/null +++ b/aiohttp_csrf/token_generator.py @@ -0,0 +1,30 @@ +import abc +import hashlib +import uuid + + +class AbstractTokenGenerator(metaclass=abc.ABCMeta): + @abc.abstractmethod + def generate(self): + pass # pragma: no cover + + +class SimpleTokenGenerator(AbstractTokenGenerator): + def generate(self): + return uuid.uuid4().hex + + +class HashedTokenGenerator(AbstractTokenGenerator): + encoding = 'utf-8' + + def __init__(self, secret_phrase): + self.secret_phrase = secret_phrase + + def generate(self): + token = uuid.uuid4().hex + + token += self.secret_phrase + + hasher = hashlib.sha256(token.encode(self.encoding)) + + return hasher.hexdigest() diff --git a/demo/manual_protection.py b/demo/manual_protection.py new file mode 100644 index 0000000..07caaa1 --- /dev/null +++ b/demo/manual_protection.py @@ -0,0 +1,68 @@ +import aiohttp_csrf +from aiohttp import web + +FORM_FIELD_NAME = '_csrf_token' +COOKIE_NAME = 'csrf_token' + + +def make_app(): + csrf_policy = aiohttp_csrf.policy.FormPolicy(FORM_FIELD_NAME) + + csrf_storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME) + + app = web.Application() + + aiohttp_csrf.setup(app, policy=csrf_policy, storage=csrf_storage) + + # IMPORTANT! You need use @csrf_protect for both methods: GET and POST + @aiohttp_csrf.csrf_protect + async def handler_get(request): + token = await aiohttp_csrf.generate_token(request) + + body = ''' + + Form with csrf protection + +
+ + + +
+ + + ''' # noqa + + body = body.format(field_name=FORM_FIELD_NAME, token=token) + + return web.Response( + body=body.encode('utf-8'), + content_type='text/html', + ) + + @aiohttp_csrf.csrf_protect + async def handler_post(request): + post = await request.post() + + body = 'Hello, {name}'.format(name=post['name']) + + return web.Response( + body=body.encode('utf-8'), + content_type='text/html', + ) + + app.router.add_route( + 'GET', + '/', + handler_get, + ) + + app.router.add_route( + 'POST', + '/', + handler_post, + ) + + return app + + +web.run_app(make_app()) diff --git a/demo/middleware.py b/demo/middleware.py new file mode 100644 index 0000000..42220b0 --- /dev/null +++ b/demo/middleware.py @@ -0,0 +1,106 @@ +import aiohttp_csrf +from aiohttp import web + +FORM_FIELD_NAME = '_csrf_token' +COOKIE_NAME = 'csrf_token' + + +def make_app(): + csrf_policy = aiohttp_csrf.policy.FormPolicy(FORM_FIELD_NAME) + + csrf_storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME) + + app = web.Application() + + aiohttp_csrf.setup(app, policy=csrf_policy, storage=csrf_storage) + + app.middlewares.append(aiohttp_csrf.csrf_middleware) + + async def handler_get_form_with_token(request): + token = await aiohttp_csrf.generate_token(request) + + body = ''' + + Form with csrf protection + +
+ + + +
+ + + ''' # noqa + + body = body.format(field_name=FORM_FIELD_NAME, token=token) + + return web.Response( + body=body.encode('utf-8'), + content_type='text/html', + ) + + async def handler_post_check(request): + post = await request.post() + + body = 'Hello, {name}'.format(name=post['name']) + + return web.Response( + body=body.encode('utf-8'), + content_type='text/html', + ) + + async def handler_get_form_without_token(request): + body = ''' + + Form without csrf protection + +
+ + +
+ + + ''' + + return web.Response( + body=body.encode('utf-8'), + content_type='text/html', + ) + + @aiohttp_csrf.csrf_exempt + async def handler_post_not_check(request): + post = await request.post() + + body = 'Hello, {name}'.format(name=post['name']) + + return web.Response( + body=body.encode('utf-8'), + content_type='text/html', + ) + + app.router.add_route( + 'GET', + '/form_with_check', + handler_get_form_with_token, + ) + app.router.add_route( + 'POST', + '/post_with_check', + handler_post_check, + ) + + app.router.add_route( + 'GET', + '/form_without_check', + handler_get_form_without_token, + ) + app.router.add_route( + 'POST', + '/post_without_check', + handler_post_not_check, + ) + + return app + + +web.run_app(make_app()) diff --git a/demo/session_storage.py b/demo/session_storage.py new file mode 100644 index 0000000..e27a489 --- /dev/null +++ b/demo/session_storage.py @@ -0,0 +1,113 @@ +import aiohttp_csrf +from aiohttp import web +from aiohttp_session import setup as setup_session +from aiohttp_session import SimpleCookieStorage + +FORM_FIELD_NAME = '_csrf_token' +SESSION_NAME = 'csrf_token' + + +def make_app(): + csrf_policy = aiohttp_csrf.policy.FormPolicy(FORM_FIELD_NAME) + + csrf_storage = aiohttp_csrf.storage.SessionStorage(SESSION_NAME) + + app = web.Application() + + aiohttp_csrf.setup(app, policy=csrf_policy, storage=csrf_storage) + + session_storage = SimpleCookieStorage() + + # Important!!! + setup_session(app, session_storage) + + app.middlewares.append(aiohttp_csrf.csrf_middleware) + + async def handler_get_form_with_token(request): + token = await aiohttp_csrf.generate_token(request) + + body = ''' + + Form with csrf protection + +
+ + + +
+ + + ''' # noqa + + body = body.format(field_name=FORM_FIELD_NAME, token=token) + + return web.Response( + body=body.encode('utf-8'), + content_type='text/html', + ) + + async def handler_post_check(request): + post = await request.post() + + body = 'Hello, {name}'.format(name=post['name']) + + return web.Response( + body=body.encode('utf-8'), + content_type='text/html', + ) + + async def handler_get_form_without_token(request): + body = ''' + + Form without csrf protection + +
+ + +
+ + + ''' + + return web.Response( + body=body.encode('utf-8'), + content_type='text/html', + ) + + @aiohttp_csrf.csrf_exempt + async def handler_post_not_check(request): + post = await request.post() + + body = 'Hello, {name}'.format(name=post['name']) + + return web.Response( + body=body.encode('utf-8'), + content_type='text/html', + ) + + app.router.add_route( + 'GET', + '/form_with_check', + handler_get_form_with_token, + ) + app.router.add_route( + 'POST', + '/post_with_check', + handler_post_check, + ) + + app.router.add_route( + 'GET', + '/form_without_check', + handler_get_form_without_token, + ) + app.router.add_route( + 'POST', + '/post_without_check', + handler_post_not_check, + ) + + return app + + +web.run_app(make_app()) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..870e595 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts= --no-cov-on-fail --cov=aiohttp_csrf --cov-report=term --cov-report=html diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 0000000..f5d515e --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,9 @@ +aiohttp==2.3.10 +-e . +aiohttp-session==1.0.1 +flake8==3.4.1 +isort==4.2.15 +pytest==3.4.0 +pytest-aiohttp==0.1.3 +pytest-cov==2.5.1 +tox==2.9.1 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..4396b14 --- /dev/null +++ b/setup.py @@ -0,0 +1,70 @@ +import codecs +import io +import os +import re + +from setuptools import setup + + +def get_version(): + with codecs.open( + os.path.join( + os.path.abspath( + os.path.dirname(__file__), + ), + 'aiohttp_csrf', + '__init__.py', + ), + 'r', + 'utf-8', + ) as fp: + try: + return re.findall(r"^__version__ = '([^']+)'$", fp.read(), re.M)[0] + except IndexError: + raise RuntimeError('Unable to determine version.') + + +def read(*parts): + filename = os.path.join(os.path.abspath(os.path.dirname(__file__)), *parts) + + with io.open(filename, encoding='utf-8', mode='rt') as fp: + return fp.read() + + +install_requires = ['aiohttp>=3.2.0'] +extras_require = { + 'session': ['aiohttp-session>=2.4.0'], +} + + +setup( + name='aiohttp-csrf', + version=get_version(), + description=('CSRF protection for aiohttp.web',), + long_description=read('README.rst'), + author='Ocean S.A.', + author_email='osf@ocean.io', + url='https://github.com/wikibusiness/aiohttp-csrf', + packages=['aiohttp_csrf'], + include_package_data=True, + install_requires=install_requires, + extras_require=extras_require, + zip_safe=False, + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: MIT License', + 'Operating System :: POSIX', + 'Operating System :: MacOS :: MacOS X', + 'Operating System :: Microsoft :: Windows', + 'Programming Language :: Python', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + ], + keywords=[ + 'csrf', + 'xsrf', + 'aiohttp', + ], +) diff --git a/test.py b/test.py new file mode 100644 index 0000000..beea3a1 --- /dev/null +++ b/test.py @@ -0,0 +1,35 @@ +import asyncio + +from aiohttp import web + +async def hello(request): + return web.Response(text="Hello, world") + + +def dec(handler): + def wrapped(*args, **kwargs): + request = args[-1] + import ipdb;ipdb.set_trace() + return handler(*args, **kwargs) + + return wrapped + + +class MyView(web.View): + @dec + async def get(self): + return web.Response(text="Get Hello, world") + + async def post(self): + return web.Response(text="Post Hello, world") + + +@web.middleware +async def middleware(request, handler): + return await handler(request) + + +app = web.Application(middlewares=[middleware]) +app.router.add_route('*', '/', MyView) + +web.run_app(app) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1a2f168 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,69 @@ +import aiohttp_csrf +import pytest +from aiohttp import web + +SESSION_NAME = COOKIE_NAME = 'csrf_token' +FORM_FIELD_NAME = HEADER_NAME = 'X-CSRF-TOKEN' + + +@pytest.yield_fixture +def init_app(): + def go( + loop, + policy, + storage, + handlers, + error_renderer=None, + ): + app = web.Application() + + kwargs = { + 'policy': policy, + 'storage': storage, + } + + if error_renderer is not None: + kwargs['error_renderer'] = error_renderer + + aiohttp_csrf.setup(app, **kwargs) + + for method, url, handler in handlers: + app.router.add_route( + method, + url, + handler, + ) + + return app + + yield go + + +@pytest.fixture(params=[ + (aiohttp_csrf.policy.FormPolicy, (FORM_FIELD_NAME,)), + (aiohttp_csrf.policy.FormAndHeaderPolicy, (HEADER_NAME, FORM_FIELD_NAME)), +]) +def csrf_form_policy(request): + _class, args = request.param + + return _class(*args) + + +@pytest.fixture(params=[ + (aiohttp_csrf.policy.HeaderPolicy, (HEADER_NAME,)), + (aiohttp_csrf.policy.FormAndHeaderPolicy, (HEADER_NAME, FORM_FIELD_NAME)), +]) +def csrf_header_policy(request): + _class, args = request.param + + return _class(*args) + + +@pytest.fixture(params=[ + (aiohttp_csrf.storage.SessionStorage, (SESSION_NAME,)), + (aiohttp_csrf.storage.CookieStorage, (COOKIE_NAME,)), +]) +def csrf_storage(request): + _class, args = request.param + + return _class(*args) diff --git a/tests/test_custom_error_renderer.py b/tests/test_custom_error_renderer.py new file mode 100644 index 0000000..a4160a0 --- /dev/null +++ b/tests/test_custom_error_renderer.py @@ -0,0 +1,99 @@ +import asyncio + +import aiohttp_csrf +import pytest +from aiohttp import web + +COOKIE_NAME = 'csrf_token' +HEADER_NAME = 'X-CSRF-TOKEN' + + +@pytest.yield_fixture +def create_app(init_app): + def go(loop, error_renderer): + @aiohttp_csrf.csrf_protect + async def handler_get(request): + await aiohttp_csrf.generate_token(request) + + return web.Response(body=b'OK') + + @aiohttp_csrf.csrf_protect(error_renderer=error_renderer) + async def handler_post(request): + return web.Response(body=b'OK') + + handlers = [ + ('GET', '/', handler_get), + ('POST', '/', handler_post) + ] + + storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME) + policy = aiohttp_csrf.policy.HeaderPolicy(HEADER_NAME) + + app = init_app( + policy=policy, + storage=storage, + handlers=handlers, + loop=loop, + ) + + return app + + yield go + + +async def test_custom_exception_error_renderer(test_client, create_app): + client = await test_client( + create_app, + error_renderer=web.HTTPBadRequest, + ) + + await client.get('/') + + resp = await client.post('/') + + assert resp.status == web.HTTPBadRequest.status_code + + +@pytest.fixture(params=[False, True]) +def make_error_renderer(request): + is_coroutine = request.param + + def make_renderer(error_body): + def error_renderer(request): + return web.Response(body=error_body) + + if not is_coroutine: + return error_renderer + + return asyncio.coroutine(error_renderer) + + return make_renderer + + +async def test_custom_coroutine_callable_error_renderer(test_client, create_app, make_error_renderer): # noqa + error_body = b'CSRF error' + + error_renderer = make_error_renderer(error_body) + + client = await test_client( + create_app, + error_renderer=error_renderer, + ) + + await client.get('/') + + resp = await client.post('/') + + assert resp.status == 200 + + assert await resp.read() == error_body + + +async def test_bad_error_renderer(test_client, create_app): + error_renderer = 'trololo' + + with pytest.raises(TypeError): + await test_client( + create_app, + error_renderer=error_renderer, + ) diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 0000000..0d7cf36 --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,77 @@ +import aiohttp_csrf +import pytest +from aiohttp import web + +COOKIE_NAME = 'csrf_token' +HEADER_NAME = 'X-CSRF-TOKEN' + + +class FakeClass: + pass + + +async def test_bad_policy(test_client, init_app): + policy = FakeClass() + storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME) + + with pytest.raises(TypeError): + await test_client( + init_app, + policy=policy, + storage=storage, + handlers=[], + ) + + +async def test_bad_storage(test_client, init_app): + policy = aiohttp_csrf.policy.HeaderPolicy(HEADER_NAME) + storage = FakeClass() + + with pytest.raises(TypeError): + await test_client( + init_app, + policy=policy, + storage=storage, + handlers=[], + ) + + +async def test_bad_error_renderer(test_client, init_app): + policy = aiohttp_csrf.policy.HeaderPolicy(HEADER_NAME) + storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME) + + with pytest.raises(TypeError): + await test_client( + init_app, + policy=policy, + storage=storage, + error_renderer=1, + handlers=[], + ) + + +async def test_app_without_setup(test_client): + def create_app(loop): + app = web.Application() + + @aiohttp_csrf.csrf_protect + async def handler(request): + await aiohttp_csrf.generate_token(request) + + return web.Response() + + app.router.add_route( + 'GET', + '/', + handler, + ) + + return app + + client = await test_client( + create_app, + ) + + resp = await client.get('/') + + assert resp.status == 500 diff --git a/tests/test_exempt_decorator.py b/tests/test_exempt_decorator.py new file mode 100644 index 0000000..6802827 --- /dev/null +++ b/tests/test_exempt_decorator.py @@ -0,0 +1,55 @@ +import aiohttp_csrf +import pytest +from aiohttp import web + +COOKIE_NAME = 'csrf_token' +HEADER_NAME = 'X-CSRF-TOKEN' + + +@pytest.yield_fixture +def create_app(init_app): + def go(loop): + async def handler_get(request): + await aiohttp_csrf.generate_token(request) + + return web.Response(body=b'OK') + + @aiohttp_csrf.csrf_exempt + async def handler_post(request): + return web.Response(body=b'OK') + + handlers = [ + ('GET', '/', handler_get), + ('POST', '/', handler_post), + ] + + policy = aiohttp_csrf.policy.HeaderPolicy(HEADER_NAME) + storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME) + + app = init_app( + policy=policy, + storage=storage, + handlers=handlers, + loop=loop, + ) + + app.middlewares.append(aiohttp_csrf.csrf_middleware) + + return app + + yield go + + +async def test_decorator_method_view(test_client, create_app): + + client = await test_client( + create_app, + ) + + resp = await client.get('/') + + assert resp.status == 200 + + resp = await client.post('/') + + assert resp.status == 200 diff --git a/tests/test_form_policy.py b/tests/test_form_policy.py new file mode 100644 index 0000000..11a9689 --- /dev/null +++ b/tests/test_form_policy.py @@ -0,0 +1,158 @@ +import re +import uuid +from unittest import mock + +import aiohttp_csrf +import pytest +from aiohttp import web +from aiohttp_session import setup as setup_session +from aiohttp_session import SimpleCookieStorage + +from .conftest import FORM_FIELD_NAME + +FORM_FIELD_REGEX = re.compile( + r'', +) + + +@pytest.yield_fixture +def create_app(init_app): + def go(loop, policy, storage): + async def handler_get(request): + token = await aiohttp_csrf.generate_token(request) + + body = ''' + + + +
+ +
+ + + ''' # noqa + + body = body.format(field_name=FORM_FIELD_NAME, token=token) + + return web.Response(body=body.encode('utf-8')) + + async def handler_post(request): + return web.Response(body=b'OK') + + handlers = [ + ('GET', '/', handler_get), + ('POST', '/', handler_post) + ] + + app = init_app( + policy=policy, + storage=storage, + handlers=handlers, + loop=loop, + ) + + if isinstance(storage, aiohttp_csrf.storage.SessionStorage): + session_storage = SimpleCookieStorage() + setup_session(app, session_storage) + + app.middlewares.append(aiohttp_csrf.csrf_middleware) + + return app + + yield go + + +async def test_form_policy_success( + test_client, + create_app, + csrf_form_policy, + csrf_storage, +): + client = await test_client( + create_app, + policy=csrf_form_policy, + storage=csrf_storage, + ) + + resp = await client.get('/') + + assert resp.status == 200 + + body = await resp.text() + + search_result = FORM_FIELD_REGEX.search(body) + + token = search_result.group('token') + + data = {FORM_FIELD_NAME: token} + + resp = await client.post('/', data=data) + + assert resp.status == 200 + + +async def test_form_policy_bad_token( + test_client, + create_app, + csrf_form_policy, + csrf_storage, +): + real_token = uuid.uuid4().hex + + bad_token = real_token + + while bad_token == real_token: + bad_token = uuid.uuid4().hex + + with mock.patch( + 'aiohttp_csrf.token_generator.SimpleTokenGenerator.generate', + return_value=real_token, + ): + client = await test_client( + create_app, + policy=csrf_form_policy, + storage=csrf_storage, + ) + + resp = await client.get('/') + + assert resp.status == 200 + + data = {FORM_FIELD_NAME: bad_token} + + resp = await client.post('/', data=data) + + assert resp.status == 403 + + +async def test_form_policy_reuse_token( + test_client, + create_app, + csrf_form_policy, + csrf_storage, +): + client = await test_client( + create_app, + policy=csrf_form_policy, + storage=csrf_storage, + ) + + resp = await client.get('/') + + assert resp.status == 200 + + body = await resp.text() + + search_result = FORM_FIELD_REGEX.search(body) + + token = search_result.group('token') + + data = {FORM_FIELD_NAME: token} + + resp = await client.post('/', data=data) + + assert resp.status == 200 + + resp = await client.post('/', data=data) + + assert resp.status == 403 diff --git a/tests/test_header_policy.py b/tests/test_header_policy.py new file mode 100644 index 0000000..66b9bbf --- /dev/null +++ b/tests/test_header_policy.py @@ -0,0 +1,111 @@ +import uuid +from unittest import mock + +import aiohttp_csrf +import pytest +from aiohttp import web + +from .conftest import COOKIE_NAME, HEADER_NAME + + +@pytest.yield_fixture +def create_app(init_app): + def go(loop, policy): + async def handler_get(request): + await aiohttp_csrf.generate_token(request) + + return web.Response(body=b'OK') + + async def handler_post(request): + return web.Response(body=b'OK') + + handlers = [ + ('GET', '/', handler_get), + ('POST', '/', handler_post) + ] + + storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME) + + app = init_app( + policy=policy, + storage=storage, + handlers=handlers, + loop=loop, + ) + + app.middlewares.append(aiohttp_csrf.csrf_middleware) + + return app + + yield go + + +async def test_header_policy_success(test_client, create_app, csrf_header_policy): # noqa + client = await test_client( + create_app, + policy=csrf_header_policy, + ) + + resp = await client.get('/') + + assert resp.status == 200 + + token = resp.cookies[COOKIE_NAME].value + + headers = {HEADER_NAME: token} + + resp = await client.post('/', headers=headers) + + assert resp.status == 200 + + +async def test_header_policy_bad_token(test_client, create_app, csrf_header_policy): # noqa + real_token = uuid.uuid4().hex + + bad_token = real_token + + while bad_token == real_token: + bad_token = uuid.uuid4().hex + + with mock.patch( + 'aiohttp_csrf.token_generator.SimpleTokenGenerator.generate', + return_value=real_token, + ): + + client = await test_client( + create_app, + policy=csrf_header_policy, + ) + + resp = await client.get('/') + + assert resp.status == 200 + + headers = {HEADER_NAME: bad_token} + + resp = await client.post('/', headers=headers) + + assert resp.status == 403 + + +async def test_header_policy_reuse_token(test_client, create_app, csrf_header_policy): # noqa + client = await test_client( + create_app, + policy=csrf_header_policy, + ) + + resp = await client.get('/') + + assert resp.status == 200 + + token = resp.cookies[COOKIE_NAME].value + + headers = {HEADER_NAME: token} + + resp = await client.post('/', headers=headers) + + assert resp.status == 200 + + resp = await client.post('/', headers=headers) + + assert resp.status == 403 diff --git a/tests/test_protect_decorator.py b/tests/test_protect_decorator.py new file mode 100644 index 0000000..94976f6 --- /dev/null +++ b/tests/test_protect_decorator.py @@ -0,0 +1,136 @@ +import aiohttp_csrf +from aiohttp import web + +COOKIE_NAME = 'csrf_token' +HEADER_NAME = 'X-CSRF-TOKEN' + + +async def test_decorator_method_view(test_client, init_app): + @aiohttp_csrf.csrf_protect + async def handler_get(request): + await aiohttp_csrf.generate_token(request) + + return web.Response(body=b'OK') + + @aiohttp_csrf.csrf_protect + async def handler_post(request): + return web.Response(body=b'OK') + + handlers = [ + ('GET', '/', handler_get), + ('POST', '/', handler_post) + ] + + policy = aiohttp_csrf.policy.HeaderPolicy(HEADER_NAME) + storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME) + + client = await test_client( + init_app, + policy=policy, + storage=storage, + handlers=handlers, + ) + + resp = await client.get('/') + + assert resp.status == 200 + + token = resp.cookies[COOKIE_NAME].value + + headers = {HEADER_NAME: token} + + resp = await client.post('/', headers=headers) + + assert resp.status == 200 + + resp = await client.post('/', headers=headers) + + assert resp.status == 403 + + +async def test_decorator_class_view(test_client): + class TestView(web.View): + @aiohttp_csrf.csrf_protect + async def get(self): + await aiohttp_csrf.generate_token(self.request) + + return web.Response(body=b'OK') + + @aiohttp_csrf.csrf_protect + async def post(self): + return web.Response(body=b'OK') + + def create_app(loop): + policy = aiohttp_csrf.policy.HeaderPolicy(HEADER_NAME) + storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME) + + app = web.Application() + + aiohttp_csrf.setup(app, policy=policy, storage=storage) + + if hasattr(app.router, 'add_view'): + # For aiohttp >= 3.0.0 + app.router.add_view('/', TestView) + else: + app.router.add_route('*', '/', TestView) + + return app + + client = await test_client( + create_app, + ) + + resp = await client.get('/') + + assert resp.status == 200 + + token = resp.cookies[COOKIE_NAME].value + + headers = {HEADER_NAME: token} + + resp = await client.post('/', headers=headers) + + assert resp.status == 200 + + resp = await client.post('/', headers=headers) + + assert resp.status == 403 + + +async def test_handle_http_exceptions(test_client, init_app): + @aiohttp_csrf.csrf_protect + async def handler_get(request): + await aiohttp_csrf.generate_token(request) + + return web.Response(body=b'OK') + + @aiohttp_csrf.csrf_protect + async def handler_post(request): + raise web.HTTPBadRequest + + handlers = [ + ('GET', '/', handler_get), + ('POST', '/', handler_post) + ] + + policy = aiohttp_csrf.policy.HeaderPolicy(HEADER_NAME) + storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME) + + client = await test_client( + init_app, + policy=policy, + storage=storage, + handlers=handlers, + ) + + resp = await client.get('/') + + assert resp.status == 200 + + token = resp.cookies[COOKIE_NAME].value + + headers = {HEADER_NAME: token} + + resp = await client.post('/', headers=headers) + + assert resp.status == 400 diff --git a/tests/test_storage_api.py b/tests/test_storage_api.py new file mode 100644 index 0000000..519dab8 --- /dev/null +++ b/tests/test_storage_api.py @@ -0,0 +1,65 @@ +from unittest.mock import MagicMock + +import aiohttp_csrf +import pytest +from aiohttp.test_utils import make_mocked_request + + +class FakeStorage(aiohttp_csrf.storage.BaseStorage): + + async def _get(self, request): + return request.get('my_field') + + async def _save_token(self, request, response, token): + request['my_field'] = token + + +async def test_1(): + storage = FakeStorage() + + storage._generate_token = MagicMock(return_value='1') + storage._get = MagicMock(return_value='1') + storage._save = MagicMock() + + assert storage._generate_token.call_count == 0 + + request = make_mocked_request('/', 'GET') + + await storage.generate_new_token(request) + + assert storage._generate_token.call_count == 1 + + await storage.generate_new_token(request) + await storage.generate_new_token(request) + + assert storage._generate_token.call_count == 1 + + +async def test_2(): + storage = FakeStorage() + + storage._generate_token = MagicMock(return_value='1') + + request = make_mocked_request('/', 'GET') + + assert storage._generate_token.call_count == 0 + + await storage.save_token(request, None) + + assert storage._generate_token.call_count == 1 + + request2 = make_mocked_request('/', 'GET') + + request2['my_field'] = 1 + + await storage.save_token(request2, None) + + +async def test_3(): + class Some: + pass + + token_generator = Some() + + with pytest.raises(TypeError): + FakeStorage(token_generator=token_generator) diff --git a/tests/test_token_generator.py b/tests/test_token_generator.py new file mode 100644 index 0000000..e4ac6b0 --- /dev/null +++ b/tests/test_token_generator.py @@ -0,0 +1,37 @@ +import hashlib +import uuid +from unittest import mock + +import aiohttp_csrf + +COOKIE_NAME = 'csrf_token' +HEADER_NAME = 'X-CSRF-TOKEN' + + +def test_simple_token_generator(): + token_generator = aiohttp_csrf.token_generator.SimpleTokenGenerator() + + u = uuid.uuid4() + + with mock.patch('uuid.uuid4', return_value=u): + token = token_generator.generate() + + assert u.hex == token + + +def test_hashed_token_generator(): + encoding = aiohttp_csrf.token_generator.HashedTokenGenerator.encoding + + token_generator = aiohttp_csrf.token_generator.HashedTokenGenerator( + 'secret', + ) + + u = uuid.uuid4() + token_string = u.hex + 'secret' + + hasher = hashlib.sha256(token_string.encode(encoding=encoding)) + + with mock.patch('hashlib.sha256', return_value=hasher): + token = token_generator.generate() + + assert token == hasher.hexdigest() diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..6e8ebe8 --- /dev/null +++ b/tox.ini @@ -0,0 +1,23 @@ +[tox] +envlist = + py3{5,6} +skip_missing_interpreters = True +skipsdist = True + +[testenv] +deps = + -r{toxinidir}/requirements_dev.txt +commands = + flake8 --show-source aiohttp_csrf + isort --check-only -rc aiohttp_csrf --diff + + flake8 --show-source demo + isort --check-only -rc demo --diff + + flake8 --show-source tests + isort --check-only -rc tests --diff + + flake8 --show-source setup.py + isort --check-only setup.py --diff + + pytest tests