diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..3ae2f48
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,28 @@
+# python specific
+env*
+.cache/
+.pytest_cache/
+.idea/
+*.pyc
+*.so
+*.pyd
+aiohttp_csrf.egg-info
+build/*
+dist/*
+MANIFEST
+__pycache__/
+*.egg-info/
+.coverage
+.python-version
+htmlcov
+
+# generic files to ignore
+*~
+*.lock
+*.DS_Store
+*.swp
+*.out
+
+.tox/
+deps/
+docs/_build/
diff --git a/.isort.cfg b/.isort.cfg
new file mode 100644
index 0000000..47f77e9
--- /dev/null
+++ b/.isort.cfg
@@ -0,0 +1,2 @@
+[settings]
+known_third_party=aiohttp_csrf
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 0000000..53c1bfd
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,20 @@
+dist: trusty
+language: python
+python:
+ - "3.5"
+ - "3.6"
+install:
+ - pip install -U setuptools
+ - pip install -U pip
+ - pip install -U wheel
+ - pip install -U tox
+script:
+ - export TOXENV=py`python -c 'import sys; print("".join(map(str, sys.version_info[:2])))'`
+ - echo "$TOXENV"
+
+ - tox
+cache:
+ directories:
+ - $HOME/.cache/pip
+notifications:
+ email: false
diff --git a/LICENCE b/LICENCE
new file mode 100644
index 0000000..22f6914
--- /dev/null
+++ b/LICENCE
@@ -0,0 +1,21 @@
+The MIT License
+
+Copyright (c) Ocean S.A. https://ocean.io/
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000..e24206f
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,4 @@
+include README.rst
+include LICENSE
+recursive-exclude * __pycache__
+recursive-exclude * *.py[co]
diff --git a/README.rst b/README.rst
index b779bd2..4ebe07c 100644
--- a/README.rst
+++ b/README.rst
@@ -1 +1,261 @@
-# aiohttp-csrf
\ No newline at end of file
+aiohttp_csrf
+============
+
+The library provides csrf (xsrf) protection for `aiohttp.web`__.
+
+.. _aiohttp_web: https://docs.aiohttp.org/en/latest/web.html
+
+__ aiohttp_web_
+
+.. image:: https://img.shields.io/travis/wikibusiness/aiohttp-csrf.svg
+ :target: https://travis-ci.org/wikibusiness/aiohttp-csrf
+
+Basic usage
+-----------
+
+The library allows you to implement csrf (xsrf) protection for requests
+
+
+Basic usage example:
+
+.. code-block:: python
+
+ import aiohttp_csrf
+ from aiohttp import web
+
+ FORM_FIELD_NAME = '_csrf_token'
+ COOKIE_NAME = 'csrf_token'
+
+
+ def make_app():
+ csrf_policy = aiohttp_csrf.policy.FormPolicy(FORM_FIELD_NAME)
+
+ csrf_storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME)
+
+ app = web.Application()
+
+ aiohttp_csrf.setup(app, policy=csrf_policy, storage=csrf_storage)
+
+ app.middlewares.append(aiohttp_csrf.csrf_middleware)
+
+ async def handler_get_form_with_token(request):
+ token = await aiohttp_csrf.generate_token(request)
+
+
+ body = '''
+
+
Form with csrf protection
+
+
+
+
+ ''' # noqa
+
+ body = body.format(field_name=FORM_FIELD_NAME, token=token)
+
+ return web.Response(
+ body=body.encode('utf-8'),
+ content_type='text/html',
+ )
+
+ async def handler_post_check(request):
+ post = await request.post()
+
+ body = 'Hello, {name}'.format(name=post['name'])
+
+ return web.Response(
+ body=body.encode('utf-8'),
+ content_type='text/html',
+ )
+
+ app.router.add_route(
+ 'GET',
+ '/',
+ handler_get_form_with_token,
+ )
+
+ app.router.add_route(
+ 'POST',
+ '/',
+ handler_post_check,
+ )
+
+ return app
+
+
+ web.run_app(make_app())
+
+
+Initialize
+~~~~~~~~~~
+
+
+First of all, you need to initialize ``aiohttp_csrf`` in your application:
+
+.. code-block:: python
+
+ app = web.Application()
+
+ csrf_policy = aiohttp_csrf.policy.FormPolicy(FORM_FIELD_NAME)
+
+ csrf_storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME)
+
+ aiohttp_csrf.setup(app, policy=csrf_policy, storage=csrf_storage)
+
+
+Middleware and decorators
+~~~~~~~~~~~~~~~~~~~~~~~~~
+
+
+After initialize you can use ``@aiohttp_csrf.csrf_protect`` for handlers, that you want to protect.
+Or you can initialize ``aiohttp_csrf.csrf_middleware`` and do not disturb about using decorator (`full middleware example here`_):
+
+.. _full middleware example here: demo/middleware.py
+
+.. code-block:: python
+
+ ...
+ app.middlewares.append(aiohttp_csrf.csrf_middleware)
+ ...
+
+
+In this case all your handlers will be protected.
+
+
+**Note:** we strongly recommend to use ``aiohttp_csrf.csrf_middleware`` and ``@aiohttp_csrf.csrf_exempt`` instead of manually managing with ``@aiohttp_csrf.csrf_protect``.
+But if you prefer to use ``@aiohttp_csrf.csrf_protect``, don't forget to use ``@aiohttp_csrf.csrf_protect`` for both methods: GET and POST
+(`manual protection example`_)
+
+.. _manual protection example: demo/manual_protection.py
+
+
+If you want to use middleware, but need handlers without protection, you can use ``@aiohttp_csrf.csrf_exempt``.
+Mark you handler with this decorator and this handler will not check the token:
+
+.. code-block:: python
+
+ @aiohttp_csrf.csrf_exempt
+ async def handler_post_not_check(request):
+ ...
+
+
+
+Generate token
+~~~~~~~~~~~~~~
+
+For generate token you need to call ``aiohttp_csrf.generate_token`` in your handler:
+
+.. code-block:: python
+
+ @aiohttp_csrf.csrf_protect
+ async def handler_get(request):
+ token = await aiohttp_csrf.generate_token(request)
+ ...
+
+
+Advanced usage
+--------------
+
+
+Policies
+~~~~~~~~
+
+You can use different policies for check tokens. Library provides 3 types of policy:
+
+- **FormPolicy**. This policy will search token in the body of your POST request (Usually use for forms). You need to specify name of field that will be checked.
+- **HeaderPolicy**. This policy will search token in headers of your POST request (Usually use for AJAX requests). You need to specify name of header that will be checked.
+- **FormAndHeaderPolicy**. This policy combines behavior of **FormPolicy** and **HeaderPolicy**.
+
+You can implement your custom policies if needed. But make sure that your custom policy implements ``aiohttp_csrf.policy.AbstractPolicy`` interface.
+
+Storages
+~~~~~~~~
+
+You can use different types of storages for storing token. Library provides 2 types of storage:
+
+- **CookieStorage**. Your token will be stored in cookie variable. You need to specify cookie name.
+- **SessionStorage**. Your token will be stored in session. You need to specify session variable name.
+
+**Important:** If you want to use session storage, you need setup aiohttp_session in your application
+(`session storage example`_)
+
+.. _session storage example: demo/session_storage.py#L22
+
+You can implement your custom storages if needed. But make sure that your custom storage implements ``aiohttp_csrf.storage.AbstractStorage`` interface.
+
+
+Token generators
+~~~~~~~~~~~~~~~~
+
+You can use different token generator in your application.
+By default storages using ``aiohttp_csrf.token_generator.SimpleTokenGenerator``
+
+But if you need more secure token generator - you can use ``aiohttp_csrf.token_generator.HashedTokenGenerator``
+
+And you can implement your custom token generators if needed. But make sure that your custom token generator implements ``aiohttp_csrf.token_generator.AbstractTokenGenerator`` interface.
+
+
+Invalid token behavior
+~~~~~~~~~~~~~~~~~~~~~~
+
+By default, if token is invalid, ``aiohttp_csrf`` will raise ``aiohttp.web.HTTPForbidden`` exception.
+
+You have abbility to specify your custom error handler. It can be:
+
+- **callable instance**. Input parameter - aiohttp request.
+.. code-block:: python
+
+ def custom_error_handler(request):
+ # do something
+ return aiohttp.web.Response(status=403)
+
+ # or
+
+ async def custom_async_error_handler(request):
+ # await do something
+ return aiohttp.web.Response(status=403)
+
+It will be called instead of protected handler.
+
+- **sub class of Exception**. In this case this Exception will be raised.
+
+.. code-block:: python
+
+ class CustomException(Exception):
+ pass
+
+
+You can specify custom error handler globally, when initialize ``aiohttp_csrf`` in your application:
+
+.. code-block:: python
+
+ ...
+ class CustomException(Exception):
+ pass
+
+ ...
+ aiohttp_csrf.setup(app, policy=csrf_policy, storage=csrf_storage, error_renderer=CustomException)
+ ...
+
+In this case custom error handler will be applied to all protected handlers.
+
+Or you can specify custom error handler locally, for specific handler:
+
+.. code-block:: python
+
+ ...
+ class CustomException(Exception):
+ pass
+
+ ...
+ @aiohttp_csrf.csrf_protect(error_renderer=CustomException)
+ def handler_with_custom_csrf_error(request):
+ ...
+
+
+In this case custom error handler will be applied to this handler only.
+For all other handlers will be applied global error handler.
diff --git a/aiohttp_csrf/__init__.py b/aiohttp_csrf/__init__.py
new file mode 100644
index 0000000..7a4c361
--- /dev/null
+++ b/aiohttp_csrf/__init__.py
@@ -0,0 +1,175 @@
+import asyncio
+import inspect
+
+from functools import wraps
+
+from aiohttp import web
+
+from .policy import AbstractPolicy
+from .storage import AbstractStorage
+
+
+__version__ = '0.0.1'
+
+APP_POLICY_KEY = 'aiohttp_csrf_policy'
+APP_STORAGE_KEY = 'aiohttp_csrf_storage'
+APP_ERROR_RENDERER_KEY = 'aiohttp_csrf_error_renderer'
+
+MIDDLEWARE_SKIP_PROPERTY = 'csrf_middleware_skip'
+
+UNPROTECTED_HTTP_METHODS = ('GET', 'HEAD', 'OPTIONS', 'TRACE')
+
+
+def setup(app, *, policy, storage, error_renderer=web.HTTPForbidden):
+ if not isinstance(policy, AbstractPolicy):
+ raise TypeError('Policy must be instance of AbstractPolicy')
+
+ if not isinstance(storage, AbstractStorage):
+ raise TypeError('Storage must be instance of AbstractStorage')
+
+ if not isinstance(error_renderer, Exception) and not callable(error_renderer): # noqa
+ raise TypeError(
+ 'Default error renderer must be instance of Exception or callable.'
+ )
+
+ app[APP_POLICY_KEY] = policy
+ app[APP_STORAGE_KEY] = storage
+ app[APP_ERROR_RENDERER_KEY] = error_renderer
+
+
+def _get_policy(request):
+ try:
+ return request.app[APP_POLICY_KEY]
+ except KeyError:
+ raise RuntimeError(
+ 'Policy not found. Install aiohttp_csrf in your '
+ 'aiohttp.web.Application using aiohttp_csrf.setup()'
+ )
+
+
+def _get_storage(request):
+ try:
+ return request.app[APP_STORAGE_KEY]
+ except KeyError:
+ raise RuntimeError(
+ 'Storage not found. Install aiohttp_csrf in your '
+ 'aiohttp.web.Application using aiohttp_csrf.setup()'
+ )
+
+
+async def _render_error(request, error_renderer=None):
+ if error_renderer is None:
+ try:
+ error_renderer = request.app[APP_ERROR_RENDERER_KEY]
+ except KeyError:
+ raise RuntimeError(
+ 'Default error renderer not found. Install aiohttp_csrf in '
+ 'your aiohttp.web.Application using aiohttp_csrf.setup()'
+ )
+
+ if inspect.isclass(error_renderer) and issubclass(error_renderer, Exception): # noqa
+ raise error_renderer
+ elif callable(error_renderer):
+ if asyncio.iscoroutinefunction(error_renderer):
+ return await error_renderer(request)
+ else:
+ return error_renderer(request)
+ else:
+ raise NotImplementedError
+
+
+async def get_token(request):
+ storage = _get_storage(request)
+
+ return await storage.get(request)
+
+
+async def generate_token(request):
+ storage = _get_storage(request)
+
+ return await storage.generate_new_token(request)
+
+
+async def save_token(request, response):
+ storage = _get_storage(request)
+
+ await storage.save_token(request, response)
+
+
+def csrf_exempt(handler):
+ @wraps(handler)
+ def wrapped_handler(*args, **kwargs):
+ return handler(*args, **kwargs)
+
+ setattr(wrapped_handler, MIDDLEWARE_SKIP_PROPERTY, True)
+
+ return wrapped_handler
+
+
+async def _check(request):
+ if not isinstance(request, web.Request):
+ raise RuntimeError('Can\'t get request from handler params')
+
+ original_token = await get_token(request)
+
+ policy = _get_policy(request)
+
+ return await policy.check(request, original_token)
+
+
+def csrf_protect(handler=None, error_renderer=None):
+ if (
+ error_renderer is not None
+ and not isinstance(error_renderer, Exception)
+ and not callable(error_renderer)
+ ):
+ raise TypeError(
+ 'Renderer must be instance of Exception or callable.'
+ )
+
+ def wrapper(handler):
+ @wraps(handler)
+ async def wrapped(*args, **kwargs):
+ request = args[-1]
+
+ if isinstance(request, web.View):
+ request = request.request
+
+ if (
+ request.method not in UNPROTECTED_HTTP_METHODS
+ and not await _check(request)
+ ):
+ return await _render_error(request, error_renderer)
+
+ raise_response = False
+
+ try:
+ response = await handler(*args, **kwargs)
+ except web.HTTPException as exc:
+ response = exc
+ raise_response = True
+
+ if isinstance(response, web.Response):
+ await save_token(request, response)
+
+ if raise_response:
+ raise response
+
+ return response
+
+ setattr(wrapped, MIDDLEWARE_SKIP_PROPERTY, True)
+
+ return wrapped
+
+ if handler is None:
+ return wrapper
+
+ return wrapper(handler)
+
+
+@web.middleware
+async def csrf_middleware(request, handler):
+ if not getattr(handler, MIDDLEWARE_SKIP_PROPERTY, False):
+ handler = csrf_protect(handler=handler)
+
+ return await handler(request)
diff --git a/aiohttp_csrf/policy.py b/aiohttp_csrf/policy.py
new file mode 100644
index 0000000..1700e1d
--- /dev/null
+++ b/aiohttp_csrf/policy.py
@@ -0,0 +1,55 @@
+import abc
+
+
+class AbstractPolicy(metaclass=abc.ABCMeta):
+ @abc.abstractmethod
+ async def check(self, request, original_value):
+ pass # pragma: no cover
+
+
+class FormPolicy(AbstractPolicy):
+
+ def __init__(self, field_name):
+ self.field_name = field_name
+
+ async def check(self, request, original_value):
+ post = await request.post()
+
+ token = post.get(self.field_name)
+
+ return token == original_value
+
+
+class HeaderPolicy(AbstractPolicy):
+
+ def __init__(self, header_name):
+ self.header_name = header_name
+
+ async def check(self, request, original_value):
+ token = request.headers.get(self.header_name)
+
+ return token == original_value
+
+
+class FormAndHeaderPolicy(HeaderPolicy, FormPolicy):
+
+ def __init__(self, header_name, field_name):
+ self.header_name = header_name
+ self.field_name = field_name
+
+ async def check(self, request, original_value):
+ header_check = await HeaderPolicy.check(
+ self,
+ request,
+ original_value,
+ )
+
+ if header_check:
+ return True
+
+ form_check = await FormPolicy.check(self, request, original_value)
+
+ if form_check:
+ return True
+
+ return False
diff --git a/aiohttp_csrf/storage.py b/aiohttp_csrf/storage.py
new file mode 100644
index 0000000..c77a227
--- /dev/null
+++ b/aiohttp_csrf/storage.py
@@ -0,0 +1,116 @@
+import abc
+
+from .token_generator import AbstractTokenGenerator, SimpleTokenGenerator
+
+try:
+ from aiohttp_session import get_session
+except ImportError: # pragma: no cover
+ pass
+
+
+REQUEST_NEW_TOKEN_KEY = 'aiohttp_csrf_new_token'
+
+
+class AbstractStorage(metaclass=abc.ABCMeta):
+
+ @abc.abstractmethod
+ async def generate_new_token(self, request):
+ pass # pragma: no cover
+
+ @abc.abstractmethod
+ async def get(self, request):
+ pass # pragma: no cover
+
+ @abc.abstractmethod
+ async def save_token(self, request, response):
+ pass # pragma: no cover
+
+
+class BaseStorage(AbstractStorage, metaclass=abc.ABCMeta):
+
+ def __init__(self, token_generator=None):
+ if token_generator is None:
+ token_generator = SimpleTokenGenerator()
+ elif not isinstance(token_generator, AbstractTokenGenerator):
+ raise TypeError(
+ 'Token generator must be instance of AbstractTokenGenerator',
+ )
+
+ self.token_generator = token_generator
+
+ def _generate_token(self):
+ return self.token_generator.generate()
+
+ async def generate_new_token(self, request):
+ if REQUEST_NEW_TOKEN_KEY in request:
+ return request[REQUEST_NEW_TOKEN_KEY]
+
+ token = self._generate_token()
+
+ request[REQUEST_NEW_TOKEN_KEY] = token
+
+ return token
+
+ @abc.abstractmethod
+ async def _get(self, request):
+ pass # pragma: no cover
+
+ async def get(self, request):
+ token = await self._get(request)
+
+ await self.generate_new_token(request)
+
+ return token
+
+ @abc.abstractmethod
+ async def _save_token(self, request, response, token):
+ pass # pragma: no cover
+
+ async def save_token(self, request, response):
+ old_token = await self._get(request)
+
+ if REQUEST_NEW_TOKEN_KEY in request:
+ token = request[REQUEST_NEW_TOKEN_KEY]
+ elif old_token is None:
+ token = await self.generate_new_token(request)
+ else:
+ token = None
+
+ if token is not None:
+ await self._save_token(request, response, token)
+
+
+class CookieStorage(BaseStorage):
+
+ def __init__(self, cookie_name, cookie_kwargs=None, *args, **kwargs):
+ self.cookie_name = cookie_name
+ self.cookie_kwargs = cookie_kwargs or {}
+
+ super().__init__(*args, **kwargs)
+
+ async def _get(self, request):
+ return request.cookies.get(self.cookie_name, None)
+
+ async def _save_token(self, request, response, token):
+ response.set_cookie(
+ self.cookie_name,
+ token,
+ **self.cookie_kwargs,
+ )
+
+
+class SessionStorage(BaseStorage):
+ def __init__(self, session_name, *args, **kwargs):
+ self.session_name = session_name
+
+ super().__init__(*args, **kwargs)
+
+ async def _get(self, request):
+ session = await get_session(request)
+
+ return session.get(self.session_name, None)
+
+ async def _save_token(self, request, response, token):
+ session = await get_session(request)
+
+ session[self.session_name] = token
diff --git a/aiohttp_csrf/token_generator.py b/aiohttp_csrf/token_generator.py
new file mode 100644
index 0000000..dcecbb7
--- /dev/null
+++ b/aiohttp_csrf/token_generator.py
@@ -0,0 +1,30 @@
+import abc
+import hashlib
+import uuid
+
+
+class AbstractTokenGenerator(metaclass=abc.ABCMeta):
+ @abc.abstractmethod
+ def generate(self):
+ pass # pragma: no cover
+
+
+class SimpleTokenGenerator(AbstractTokenGenerator):
+ def generate(self):
+ return uuid.uuid4().hex
+
+
+class HashedTokenGenerator(AbstractTokenGenerator):
+ encoding = 'utf-8'
+
+ def __init__(self, secret_phrase):
+ self.secret_phrase = secret_phrase
+
+ def generate(self):
+ token = uuid.uuid4().hex
+
+ token += self.secret_phrase
+
+ hasher = hashlib.sha256(token.encode(self.encoding))
+
+ return hasher.hexdigest()
diff --git a/demo/manual_protection.py b/demo/manual_protection.py
new file mode 100644
index 0000000..07caaa1
--- /dev/null
+++ b/demo/manual_protection.py
@@ -0,0 +1,68 @@
+import aiohttp_csrf
+from aiohttp import web
+
+FORM_FIELD_NAME = '_csrf_token'
+COOKIE_NAME = 'csrf_token'
+
+
+def make_app():
+ csrf_policy = aiohttp_csrf.policy.FormPolicy(FORM_FIELD_NAME)
+
+ csrf_storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME)
+
+ app = web.Application()
+
+ aiohttp_csrf.setup(app, policy=csrf_policy, storage=csrf_storage)
+
+ # IMPORTANT! You need use @csrf_protect for both methods: GET and POST
+ @aiohttp_csrf.csrf_protect
+ async def handler_get(request):
+ token = await aiohttp_csrf.generate_token(request)
+
+ body = '''
+
+ Form with csrf protection
+
+
+
+
+ ''' # noqa
+
+ body = body.format(field_name=FORM_FIELD_NAME, token=token)
+
+ return web.Response(
+ body=body.encode('utf-8'),
+ content_type='text/html',
+ )
+
+ @aiohttp_csrf.csrf_protect
+ async def handler_post(request):
+ post = await request.post()
+
+ body = 'Hello, {name}'.format(name=post['name'])
+
+ return web.Response(
+ body=body.encode('utf-8'),
+ content_type='text/html',
+ )
+
+ app.router.add_route(
+ 'GET',
+ '/',
+ handler_get,
+ )
+
+ app.router.add_route(
+ 'POST',
+ '/',
+ handler_post,
+ )
+
+ return app
+
+
+web.run_app(make_app())
diff --git a/demo/middleware.py b/demo/middleware.py
new file mode 100644
index 0000000..42220b0
--- /dev/null
+++ b/demo/middleware.py
@@ -0,0 +1,106 @@
+import aiohttp_csrf
+from aiohttp import web
+
+FORM_FIELD_NAME = '_csrf_token'
+COOKIE_NAME = 'csrf_token'
+
+
+def make_app():
+ csrf_policy = aiohttp_csrf.policy.FormPolicy(FORM_FIELD_NAME)
+
+ csrf_storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME)
+
+ app = web.Application()
+
+ aiohttp_csrf.setup(app, policy=csrf_policy, storage=csrf_storage)
+
+ app.middlewares.append(aiohttp_csrf.csrf_middleware)
+
+ async def handler_get_form_with_token(request):
+ token = await aiohttp_csrf.generate_token(request)
+
+ body = '''
+
+ Form with csrf protection
+
+
+
+
+ ''' # noqa
+
+ body = body.format(field_name=FORM_FIELD_NAME, token=token)
+
+ return web.Response(
+ body=body.encode('utf-8'),
+ content_type='text/html',
+ )
+
+ async def handler_post_check(request):
+ post = await request.post()
+
+ body = 'Hello, {name}'.format(name=post['name'])
+
+ return web.Response(
+ body=body.encode('utf-8'),
+ content_type='text/html',
+ )
+
+ async def handler_get_form_without_token(request):
+ body = '''
+
+ Form without csrf protection
+
+
+
+
+ '''
+
+ return web.Response(
+ body=body.encode('utf-8'),
+ content_type='text/html',
+ )
+
+ @aiohttp_csrf.csrf_exempt
+ async def handler_post_not_check(request):
+ post = await request.post()
+
+ body = 'Hello, {name}'.format(name=post['name'])
+
+ return web.Response(
+ body=body.encode('utf-8'),
+ content_type='text/html',
+ )
+
+ app.router.add_route(
+ 'GET',
+ '/form_with_check',
+ handler_get_form_with_token,
+ )
+ app.router.add_route(
+ 'POST',
+ '/post_with_check',
+ handler_post_check,
+ )
+
+ app.router.add_route(
+ 'GET',
+ '/form_without_check',
+ handler_get_form_without_token,
+ )
+ app.router.add_route(
+ 'POST',
+ '/post_without_check',
+ handler_post_not_check,
+ )
+
+ return app
+
+
+web.run_app(make_app())
diff --git a/demo/session_storage.py b/demo/session_storage.py
new file mode 100644
index 0000000..e27a489
--- /dev/null
+++ b/demo/session_storage.py
@@ -0,0 +1,113 @@
+import aiohttp_csrf
+from aiohttp import web
+from aiohttp_session import setup as setup_session
+from aiohttp_session import SimpleCookieStorage
+
+FORM_FIELD_NAME = '_csrf_token'
+SESSION_NAME = 'csrf_token'
+
+
+def make_app():
+ csrf_policy = aiohttp_csrf.policy.FormPolicy(FORM_FIELD_NAME)
+
+ csrf_storage = aiohttp_csrf.storage.SessionStorage(SESSION_NAME)
+
+ app = web.Application()
+
+ aiohttp_csrf.setup(app, policy=csrf_policy, storage=csrf_storage)
+
+ session_storage = SimpleCookieStorage()
+
+ # Important!!!
+ setup_session(app, session_storage)
+
+ app.middlewares.append(aiohttp_csrf.csrf_middleware)
+
+ async def handler_get_form_with_token(request):
+ token = await aiohttp_csrf.generate_token(request)
+
+ body = '''
+
+ Form with csrf protection
+
+
+
+
+ ''' # noqa
+
+ body = body.format(field_name=FORM_FIELD_NAME, token=token)
+
+ return web.Response(
+ body=body.encode('utf-8'),
+ content_type='text/html',
+ )
+
+ async def handler_post_check(request):
+ post = await request.post()
+
+ body = 'Hello, {name}'.format(name=post['name'])
+
+ return web.Response(
+ body=body.encode('utf-8'),
+ content_type='text/html',
+ )
+
+ async def handler_get_form_without_token(request):
+ body = '''
+
+ Form without csrf protection
+
+
+
+
+ '''
+
+ return web.Response(
+ body=body.encode('utf-8'),
+ content_type='text/html',
+ )
+
+ @aiohttp_csrf.csrf_exempt
+ async def handler_post_not_check(request):
+ post = await request.post()
+
+ body = 'Hello, {name}'.format(name=post['name'])
+
+ return web.Response(
+ body=body.encode('utf-8'),
+ content_type='text/html',
+ )
+
+ app.router.add_route(
+ 'GET',
+ '/form_with_check',
+ handler_get_form_with_token,
+ )
+ app.router.add_route(
+ 'POST',
+ '/post_with_check',
+ handler_post_check,
+ )
+
+ app.router.add_route(
+ 'GET',
+ '/form_without_check',
+ handler_get_form_without_token,
+ )
+ app.router.add_route(
+ 'POST',
+ '/post_without_check',
+ handler_post_not_check,
+ )
+
+ return app
+
+
+web.run_app(make_app())
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 0000000..870e595
--- /dev/null
+++ b/pytest.ini
@@ -0,0 +1,2 @@
+[pytest]
+addopts= --no-cov-on-fail --cov=aiohttp_csrf --cov-report=term --cov-report=html
diff --git a/requirements_dev.txt b/requirements_dev.txt
new file mode 100644
index 0000000..f5d515e
--- /dev/null
+++ b/requirements_dev.txt
@@ -0,0 +1,9 @@
+aiohttp==2.3.10
+-e .
+aiohttp-session==1.0.1
+flake8==3.4.1
+isort==4.2.15
+pytest==3.4.0
+pytest-aiohttp==0.1.3
+pytest-cov==2.5.1
+tox==2.9.1
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..4396b14
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,70 @@
+import codecs
+import io
+import os
+import re
+
+from setuptools import setup
+
+
+def get_version():
+ with codecs.open(
+ os.path.join(
+ os.path.abspath(
+ os.path.dirname(__file__),
+ ),
+ 'aiohttp_csrf',
+ '__init__.py',
+ ),
+ 'r',
+ 'utf-8',
+ ) as fp:
+ try:
+ return re.findall(r"^__version__ = '([^']+)'$", fp.read(), re.M)[0]
+ except IndexError:
+ raise RuntimeError('Unable to determine version.')
+
+
+def read(*parts):
+ filename = os.path.join(os.path.abspath(os.path.dirname(__file__)), *parts)
+
+ with io.open(filename, encoding='utf-8', mode='rt') as fp:
+ return fp.read()
+
+
+install_requires = ['aiohttp>=3.2.0']
+extras_require = {
+ 'session': ['aiohttp-session>=2.4.0'],
+}
+
+
+setup(
+ name='aiohttp-csrf',
+ version=get_version(),
+ description=('CSRF protection for aiohttp.web',),
+ long_description=read('README.rst'),
+ author='Ocean S.A.',
+ author_email='osf@ocean.io',
+ url='https://github.com/wikibusiness/aiohttp-csrf',
+ packages=['aiohttp_csrf'],
+ include_package_data=True,
+ install_requires=install_requires,
+ extras_require=extras_require,
+ zip_safe=False,
+ classifiers=[
+ 'Development Status :: 5 - Production/Stable',
+ 'Intended Audience :: Developers',
+ 'License :: OSI Approved :: MIT License',
+ 'Operating System :: POSIX',
+ 'Operating System :: MacOS :: MacOS X',
+ 'Operating System :: Microsoft :: Windows',
+ 'Programming Language :: Python',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.5',
+ 'Programming Language :: Python :: 3.6',
+ ],
+ keywords=[
+ 'csrf',
+ 'xsrf',
+ 'aiohttp',
+ ],
+)
diff --git a/test.py b/test.py
new file mode 100644
index 0000000..beea3a1
--- /dev/null
+++ b/test.py
@@ -0,0 +1,35 @@
+import asyncio
+
+from aiohttp import web
+
+async def hello(request):
+ return web.Response(text="Hello, world")
+
+
+def dec(handler):
+ def wrapped(*args, **kwargs):
+ request = args[-1]
+ import ipdb;ipdb.set_trace()
+ return handler(*args, **kwargs)
+
+ return wrapped
+
+
+class MyView(web.View):
+ @dec
+ async def get(self):
+ return web.Response(text="Get Hello, world")
+
+ async def post(self):
+ return web.Response(text="Post Hello, world")
+
+
+@web.middleware
+async def middleware(request, handler):
+ return await handler(request)
+
+
+app = web.Application(middlewares=[middleware])
+app.router.add_route('*', '/', MyView)
+
+web.run_app(app)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..1a2f168
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,69 @@
+import aiohttp_csrf
+import pytest
+from aiohttp import web
+
+SESSION_NAME = COOKIE_NAME = 'csrf_token'
+FORM_FIELD_NAME = HEADER_NAME = 'X-CSRF-TOKEN'
+
+
+@pytest.yield_fixture
+def init_app():
+ def go(
+ loop,
+ policy,
+ storage,
+ handlers,
+ error_renderer=None,
+ ):
+ app = web.Application()
+
+ kwargs = {
+ 'policy': policy,
+ 'storage': storage,
+ }
+
+ if error_renderer is not None:
+ kwargs['error_renderer'] = error_renderer
+
+ aiohttp_csrf.setup(app, **kwargs)
+
+ for method, url, handler in handlers:
+ app.router.add_route(
+ method,
+ url,
+ handler,
+ )
+
+ return app
+
+ yield go
+
+
+@pytest.fixture(params=[
+ (aiohttp_csrf.policy.FormPolicy, (FORM_FIELD_NAME,)),
+ (aiohttp_csrf.policy.FormAndHeaderPolicy, (HEADER_NAME, FORM_FIELD_NAME)),
+])
+def csrf_form_policy(request):
+ _class, args = request.param
+
+ return _class(*args)
+
+
+@pytest.fixture(params=[
+ (aiohttp_csrf.policy.HeaderPolicy, (HEADER_NAME,)),
+ (aiohttp_csrf.policy.FormAndHeaderPolicy, (HEADER_NAME, FORM_FIELD_NAME)),
+])
+def csrf_header_policy(request):
+ _class, args = request.param
+
+ return _class(*args)
+
+
+@pytest.fixture(params=[
+ (aiohttp_csrf.storage.SessionStorage, (SESSION_NAME,)),
+ (aiohttp_csrf.storage.CookieStorage, (COOKIE_NAME,)),
+])
+def csrf_storage(request):
+ _class, args = request.param
+
+ return _class(*args)
diff --git a/tests/test_custom_error_renderer.py b/tests/test_custom_error_renderer.py
new file mode 100644
index 0000000..a4160a0
--- /dev/null
+++ b/tests/test_custom_error_renderer.py
@@ -0,0 +1,99 @@
+import asyncio
+
+import aiohttp_csrf
+import pytest
+from aiohttp import web
+
+COOKIE_NAME = 'csrf_token'
+HEADER_NAME = 'X-CSRF-TOKEN'
+
+
+@pytest.yield_fixture
+def create_app(init_app):
+ def go(loop, error_renderer):
+ @aiohttp_csrf.csrf_protect
+ async def handler_get(request):
+ await aiohttp_csrf.generate_token(request)
+
+ return web.Response(body=b'OK')
+
+ @aiohttp_csrf.csrf_protect(error_renderer=error_renderer)
+ async def handler_post(request):
+ return web.Response(body=b'OK')
+
+ handlers = [
+ ('GET', '/', handler_get),
+ ('POST', '/', handler_post)
+ ]
+
+ storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME)
+ policy = aiohttp_csrf.policy.HeaderPolicy(HEADER_NAME)
+
+ app = init_app(
+ policy=policy,
+ storage=storage,
+ handlers=handlers,
+ loop=loop,
+ )
+
+ return app
+
+ yield go
+
+
+async def test_custom_exception_error_renderer(test_client, create_app):
+ client = await test_client(
+ create_app,
+ error_renderer=web.HTTPBadRequest,
+ )
+
+ await client.get('/')
+
+ resp = await client.post('/')
+
+ assert resp.status == web.HTTPBadRequest.status_code
+
+
+@pytest.fixture(params=[False, True])
+def make_error_renderer(request):
+ is_coroutine = request.param
+
+ def make_renderer(error_body):
+ def error_renderer(request):
+ return web.Response(body=error_body)
+
+ if not is_coroutine:
+ return error_renderer
+
+ return asyncio.coroutine(error_renderer)
+
+ return make_renderer
+
+
+async def test_custom_coroutine_callable_error_renderer(test_client, create_app, make_error_renderer): # noqa
+ error_body = b'CSRF error'
+
+ error_renderer = make_error_renderer(error_body)
+
+ client = await test_client(
+ create_app,
+ error_renderer=error_renderer,
+ )
+
+ await client.get('/')
+
+ resp = await client.post('/')
+
+ assert resp.status == 200
+
+ assert await resp.read() == error_body
+
+
+async def test_bad_error_renderer(test_client, create_app):
+ error_renderer = 'trololo'
+
+ with pytest.raises(TypeError):
+ await test_client(
+ create_app,
+ error_renderer=error_renderer,
+ )
diff --git a/tests/test_errors.py b/tests/test_errors.py
new file mode 100644
index 0000000..0d7cf36
--- /dev/null
+++ b/tests/test_errors.py
@@ -0,0 +1,77 @@
+import aiohttp_csrf
+import pytest
+from aiohttp import web
+
+COOKIE_NAME = 'csrf_token'
+HEADER_NAME = 'X-CSRF-TOKEN'
+
+
+class FakeClass:
+ pass
+
+
+async def test_bad_policy(test_client, init_app):
+ policy = FakeClass()
+ storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME)
+
+ with pytest.raises(TypeError):
+ await test_client(
+ init_app,
+ policy=policy,
+ storage=storage,
+ handlers=[],
+ )
+
+
+async def test_bad_storage(test_client, init_app):
+ policy = aiohttp_csrf.policy.HeaderPolicy(HEADER_NAME)
+ storage = FakeClass()
+
+ with pytest.raises(TypeError):
+ await test_client(
+ init_app,
+ policy=policy,
+ storage=storage,
+ handlers=[],
+ )
+
+
+async def test_bad_error_renderer(test_client, init_app):
+ policy = aiohttp_csrf.policy.HeaderPolicy(HEADER_NAME)
+ storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME)
+
+ with pytest.raises(TypeError):
+ await test_client(
+ init_app,
+ policy=policy,
+ storage=storage,
+ error_renderer=1,
+ handlers=[],
+ )
+
+
+async def test_app_without_setup(test_client):
+ def create_app(loop):
+ app = web.Application()
+
+ @aiohttp_csrf.csrf_protect
+ async def handler(request):
+ await aiohttp_csrf.generate_token(request)
+
+ return web.Response()
+
+ app.router.add_route(
+ 'GET',
+ '/',
+ handler,
+ )
+
+ return app
+
+ client = await test_client(
+ create_app,
+ )
+
+ resp = await client.get('/')
+
+ assert resp.status == 500
diff --git a/tests/test_exempt_decorator.py b/tests/test_exempt_decorator.py
new file mode 100644
index 0000000..6802827
--- /dev/null
+++ b/tests/test_exempt_decorator.py
@@ -0,0 +1,55 @@
+import aiohttp_csrf
+import pytest
+from aiohttp import web
+
+COOKIE_NAME = 'csrf_token'
+HEADER_NAME = 'X-CSRF-TOKEN'
+
+
+@pytest.yield_fixture
+def create_app(init_app):
+ def go(loop):
+ async def handler_get(request):
+ await aiohttp_csrf.generate_token(request)
+
+ return web.Response(body=b'OK')
+
+ @aiohttp_csrf.csrf_exempt
+ async def handler_post(request):
+ return web.Response(body=b'OK')
+
+ handlers = [
+ ('GET', '/', handler_get),
+ ('POST', '/', handler_post),
+ ]
+
+ policy = aiohttp_csrf.policy.HeaderPolicy(HEADER_NAME)
+ storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME)
+
+ app = init_app(
+ policy=policy,
+ storage=storage,
+ handlers=handlers,
+ loop=loop,
+ )
+
+ app.middlewares.append(aiohttp_csrf.csrf_middleware)
+
+ return app
+
+ yield go
+
+
+async def test_decorator_method_view(test_client, create_app):
+
+ client = await test_client(
+ create_app,
+ )
+
+ resp = await client.get('/')
+
+ assert resp.status == 200
+
+ resp = await client.post('/')
+
+ assert resp.status == 200
diff --git a/tests/test_form_policy.py b/tests/test_form_policy.py
new file mode 100644
index 0000000..11a9689
--- /dev/null
+++ b/tests/test_form_policy.py
@@ -0,0 +1,158 @@
+import re
+import uuid
+from unittest import mock
+
+import aiohttp_csrf
+import pytest
+from aiohttp import web
+from aiohttp_session import setup as setup_session
+from aiohttp_session import SimpleCookieStorage
+
+from .conftest import FORM_FIELD_NAME
+
+FORM_FIELD_REGEX = re.compile(
+ r'[^"]+)".*>',
+)
+
+
+@pytest.yield_fixture
+def create_app(init_app):
+ def go(loop, policy, storage):
+ async def handler_get(request):
+ token = await aiohttp_csrf.generate_token(request)
+
+ body = '''
+
+
+
+
+
+
+ ''' # noqa
+
+ body = body.format(field_name=FORM_FIELD_NAME, token=token)
+
+ return web.Response(body=body.encode('utf-8'))
+
+ async def handler_post(request):
+ return web.Response(body=b'OK')
+
+ handlers = [
+ ('GET', '/', handler_get),
+ ('POST', '/', handler_post)
+ ]
+
+ app = init_app(
+ policy=policy,
+ storage=storage,
+ handlers=handlers,
+ loop=loop,
+ )
+
+ if isinstance(storage, aiohttp_csrf.storage.SessionStorage):
+ session_storage = SimpleCookieStorage()
+ setup_session(app, session_storage)
+
+ app.middlewares.append(aiohttp_csrf.csrf_middleware)
+
+ return app
+
+ yield go
+
+
+async def test_form_policy_success(
+ test_client,
+ create_app,
+ csrf_form_policy,
+ csrf_storage,
+):
+ client = await test_client(
+ create_app,
+ policy=csrf_form_policy,
+ storage=csrf_storage,
+ )
+
+ resp = await client.get('/')
+
+ assert resp.status == 200
+
+ body = await resp.text()
+
+ search_result = FORM_FIELD_REGEX.search(body)
+
+ token = search_result.group('token')
+
+ data = {FORM_FIELD_NAME: token}
+
+ resp = await client.post('/', data=data)
+
+ assert resp.status == 200
+
+
+async def test_form_policy_bad_token(
+ test_client,
+ create_app,
+ csrf_form_policy,
+ csrf_storage,
+):
+ real_token = uuid.uuid4().hex
+
+ bad_token = real_token
+
+ while bad_token == real_token:
+ bad_token = uuid.uuid4().hex
+
+ with mock.patch(
+ 'aiohttp_csrf.token_generator.SimpleTokenGenerator.generate',
+ return_value=real_token,
+ ):
+ client = await test_client(
+ create_app,
+ policy=csrf_form_policy,
+ storage=csrf_storage,
+ )
+
+ resp = await client.get('/')
+
+ assert resp.status == 200
+
+ data = {FORM_FIELD_NAME: bad_token}
+
+ resp = await client.post('/', data=data)
+
+ assert resp.status == 403
+
+
+async def test_form_policy_reuse_token(
+ test_client,
+ create_app,
+ csrf_form_policy,
+ csrf_storage,
+):
+ client = await test_client(
+ create_app,
+ policy=csrf_form_policy,
+ storage=csrf_storage,
+ )
+
+ resp = await client.get('/')
+
+ assert resp.status == 200
+
+ body = await resp.text()
+
+ search_result = FORM_FIELD_REGEX.search(body)
+
+ token = search_result.group('token')
+
+ data = {FORM_FIELD_NAME: token}
+
+ resp = await client.post('/', data=data)
+
+ assert resp.status == 200
+
+ resp = await client.post('/', data=data)
+
+ assert resp.status == 403
diff --git a/tests/test_header_policy.py b/tests/test_header_policy.py
new file mode 100644
index 0000000..66b9bbf
--- /dev/null
+++ b/tests/test_header_policy.py
@@ -0,0 +1,111 @@
+import uuid
+from unittest import mock
+
+import aiohttp_csrf
+import pytest
+from aiohttp import web
+
+from .conftest import COOKIE_NAME, HEADER_NAME
+
+
+@pytest.yield_fixture
+def create_app(init_app):
+ def go(loop, policy):
+ async def handler_get(request):
+ await aiohttp_csrf.generate_token(request)
+
+ return web.Response(body=b'OK')
+
+ async def handler_post(request):
+ return web.Response(body=b'OK')
+
+ handlers = [
+ ('GET', '/', handler_get),
+ ('POST', '/', handler_post)
+ ]
+
+ storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME)
+
+ app = init_app(
+ policy=policy,
+ storage=storage,
+ handlers=handlers,
+ loop=loop,
+ )
+
+ app.middlewares.append(aiohttp_csrf.csrf_middleware)
+
+ return app
+
+ yield go
+
+
+async def test_header_policy_success(test_client, create_app, csrf_header_policy): # noqa
+ client = await test_client(
+ create_app,
+ policy=csrf_header_policy,
+ )
+
+ resp = await client.get('/')
+
+ assert resp.status == 200
+
+ token = resp.cookies[COOKIE_NAME].value
+
+ headers = {HEADER_NAME: token}
+
+ resp = await client.post('/', headers=headers)
+
+ assert resp.status == 200
+
+
+async def test_header_policy_bad_token(test_client, create_app, csrf_header_policy): # noqa
+ real_token = uuid.uuid4().hex
+
+ bad_token = real_token
+
+ while bad_token == real_token:
+ bad_token = uuid.uuid4().hex
+
+ with mock.patch(
+ 'aiohttp_csrf.token_generator.SimpleTokenGenerator.generate',
+ return_value=real_token,
+ ):
+
+ client = await test_client(
+ create_app,
+ policy=csrf_header_policy,
+ )
+
+ resp = await client.get('/')
+
+ assert resp.status == 200
+
+ headers = {HEADER_NAME: bad_token}
+
+ resp = await client.post('/', headers=headers)
+
+ assert resp.status == 403
+
+
+async def test_header_policy_reuse_token(test_client, create_app, csrf_header_policy): # noqa
+ client = await test_client(
+ create_app,
+ policy=csrf_header_policy,
+ )
+
+ resp = await client.get('/')
+
+ assert resp.status == 200
+
+ token = resp.cookies[COOKIE_NAME].value
+
+ headers = {HEADER_NAME: token}
+
+ resp = await client.post('/', headers=headers)
+
+ assert resp.status == 200
+
+ resp = await client.post('/', headers=headers)
+
+ assert resp.status == 403
diff --git a/tests/test_protect_decorator.py b/tests/test_protect_decorator.py
new file mode 100644
index 0000000..94976f6
--- /dev/null
+++ b/tests/test_protect_decorator.py
@@ -0,0 +1,136 @@
+import aiohttp_csrf
+from aiohttp import web
+
+COOKIE_NAME = 'csrf_token'
+HEADER_NAME = 'X-CSRF-TOKEN'
+
+
+async def test_decorator_method_view(test_client, init_app):
+ @aiohttp_csrf.csrf_protect
+ async def handler_get(request):
+ await aiohttp_csrf.generate_token(request)
+
+ return web.Response(body=b'OK')
+
+ @aiohttp_csrf.csrf_protect
+ async def handler_post(request):
+ return web.Response(body=b'OK')
+
+ handlers = [
+ ('GET', '/', handler_get),
+ ('POST', '/', handler_post)
+ ]
+
+ policy = aiohttp_csrf.policy.HeaderPolicy(HEADER_NAME)
+ storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME)
+
+ client = await test_client(
+ init_app,
+ policy=policy,
+ storage=storage,
+ handlers=handlers,
+ )
+
+ resp = await client.get('/')
+
+ assert resp.status == 200
+
+ token = resp.cookies[COOKIE_NAME].value
+
+ headers = {HEADER_NAME: token}
+
+ resp = await client.post('/', headers=headers)
+
+ assert resp.status == 200
+
+ resp = await client.post('/', headers=headers)
+
+ assert resp.status == 403
+
+
+async def test_decorator_class_view(test_client):
+ class TestView(web.View):
+ @aiohttp_csrf.csrf_protect
+ async def get(self):
+ await aiohttp_csrf.generate_token(self.request)
+
+ return web.Response(body=b'OK')
+
+ @aiohttp_csrf.csrf_protect
+ async def post(self):
+ return web.Response(body=b'OK')
+
+ def create_app(loop):
+ policy = aiohttp_csrf.policy.HeaderPolicy(HEADER_NAME)
+ storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME)
+
+ app = web.Application()
+
+ aiohttp_csrf.setup(app, policy=policy, storage=storage)
+
+ if hasattr(app.router, 'add_view'):
+ # For aiohttp >= 3.0.0
+ app.router.add_view('/', TestView)
+ else:
+ app.router.add_route('*', '/', TestView)
+
+ return app
+
+ client = await test_client(
+ create_app,
+ )
+
+ resp = await client.get('/')
+
+ assert resp.status == 200
+
+ token = resp.cookies[COOKIE_NAME].value
+
+ headers = {HEADER_NAME: token}
+
+ resp = await client.post('/', headers=headers)
+
+ assert resp.status == 200
+
+ resp = await client.post('/', headers=headers)
+
+ assert resp.status == 403
+
+
+async def test_handle_http_exceptions(test_client, init_app):
+ @aiohttp_csrf.csrf_protect
+ async def handler_get(request):
+ await aiohttp_csrf.generate_token(request)
+
+ return web.Response(body=b'OK')
+
+ @aiohttp_csrf.csrf_protect
+ async def handler_post(request):
+ raise web.HTTPBadRequest
+
+ handlers = [
+ ('GET', '/', handler_get),
+ ('POST', '/', handler_post)
+ ]
+
+ policy = aiohttp_csrf.policy.HeaderPolicy(HEADER_NAME)
+ storage = aiohttp_csrf.storage.CookieStorage(COOKIE_NAME)
+
+ client = await test_client(
+ init_app,
+ policy=policy,
+ storage=storage,
+ handlers=handlers,
+ )
+
+ resp = await client.get('/')
+
+ assert resp.status == 200
+
+ token = resp.cookies[COOKIE_NAME].value
+
+ headers = {HEADER_NAME: token}
+
+ resp = await client.post('/', headers=headers)
+
+ assert resp.status == 400
diff --git a/tests/test_storage_api.py b/tests/test_storage_api.py
new file mode 100644
index 0000000..519dab8
--- /dev/null
+++ b/tests/test_storage_api.py
@@ -0,0 +1,65 @@
+from unittest.mock import MagicMock
+
+import aiohttp_csrf
+import pytest
+from aiohttp.test_utils import make_mocked_request
+
+
+class FakeStorage(aiohttp_csrf.storage.BaseStorage):
+
+ async def _get(self, request):
+ return request.get('my_field')
+
+ async def _save_token(self, request, response, token):
+ request['my_field'] = token
+
+
+async def test_1():
+ storage = FakeStorage()
+
+ storage._generate_token = MagicMock(return_value='1')
+ storage._get = MagicMock(return_value='1')
+ storage._save = MagicMock()
+
+ assert storage._generate_token.call_count == 0
+
+ request = make_mocked_request('/', 'GET')
+
+ await storage.generate_new_token(request)
+
+ assert storage._generate_token.call_count == 1
+
+ await storage.generate_new_token(request)
+ await storage.generate_new_token(request)
+
+ assert storage._generate_token.call_count == 1
+
+
+async def test_2():
+ storage = FakeStorage()
+
+ storage._generate_token = MagicMock(return_value='1')
+
+ request = make_mocked_request('/', 'GET')
+
+ assert storage._generate_token.call_count == 0
+
+ await storage.save_token(request, None)
+
+ assert storage._generate_token.call_count == 1
+
+ request2 = make_mocked_request('/', 'GET')
+
+ request2['my_field'] = 1
+
+ await storage.save_token(request2, None)
+
+
+async def test_3():
+ class Some:
+ pass
+
+ token_generator = Some()
+
+ with pytest.raises(TypeError):
+ FakeStorage(token_generator=token_generator)
diff --git a/tests/test_token_generator.py b/tests/test_token_generator.py
new file mode 100644
index 0000000..e4ac6b0
--- /dev/null
+++ b/tests/test_token_generator.py
@@ -0,0 +1,37 @@
+import hashlib
+import uuid
+from unittest import mock
+
+import aiohttp_csrf
+
+COOKIE_NAME = 'csrf_token'
+HEADER_NAME = 'X-CSRF-TOKEN'
+
+
+def test_simple_token_generator():
+ token_generator = aiohttp_csrf.token_generator.SimpleTokenGenerator()
+
+ u = uuid.uuid4()
+
+ with mock.patch('uuid.uuid4', return_value=u):
+ token = token_generator.generate()
+
+ assert u.hex == token
+
+
+def test_hashed_token_generator():
+ encoding = aiohttp_csrf.token_generator.HashedTokenGenerator.encoding
+
+ token_generator = aiohttp_csrf.token_generator.HashedTokenGenerator(
+ 'secret',
+ )
+
+ u = uuid.uuid4()
+ token_string = u.hex + 'secret'
+
+ hasher = hashlib.sha256(token_string.encode(encoding=encoding))
+
+ with mock.patch('hashlib.sha256', return_value=hasher):
+ token = token_generator.generate()
+
+ assert token == hasher.hexdigest()
diff --git a/tox.ini b/tox.ini
new file mode 100644
index 0000000..6e8ebe8
--- /dev/null
+++ b/tox.ini
@@ -0,0 +1,23 @@
+[tox]
+envlist =
+ py3{5,6}
+skip_missing_interpreters = True
+skipsdist = True
+
+[testenv]
+deps =
+ -r{toxinidir}/requirements_dev.txt
+commands =
+ flake8 --show-source aiohttp_csrf
+ isort --check-only -rc aiohttp_csrf --diff
+
+ flake8 --show-source demo
+ isort --check-only -rc demo --diff
+
+ flake8 --show-source tests
+ isort --check-only -rc tests --diff
+
+ flake8 --show-source setup.py
+ isort --check-only setup.py --diff
+
+ pytest tests