Skip to content

Commit 75088c5

Browse files
committed
SIANXSVC-826: Added direct-reply result backend
1 parent c617b06 commit 75088c5

File tree

7 files changed

+297
-8
lines changed

7 files changed

+297
-8
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ jobs:
1313
fail-fast: false
1414
matrix:
1515
python-version:
16-
- "3.6"
1716
- "3.7"
1817
- "3.8"
1918
- "3.9"
2019
- "3.10"
2120
celery-version:
2221
- "5.0"
2322
- "5.1"
23+
- "5.2"
2424

2525
steps:
2626
- uses: actions/checkout@v2
@@ -51,7 +51,8 @@ jobs:
5151
ln -s ./tests/test_project/manage.py manage.py
5252
5353
# run tests with coverage
54-
coverage run --source='./celery_amqp_backend' manage.py test
54+
coverage run --append --source='./celery_amqp_backend' manage.py test --settings=test_project.settings.backend
55+
coverage run --append --source='./celery_amqp_backend' manage.py test --settings=test_project.settings.direct_reply_backend
5556
coverage xml
5657
5758
- name: Upload coverage to Codecov

celery_amqp_backend/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .exceptions import *
21
from .backend import *
2+
from .exceptions import *

celery_amqp_backend/backend.py

Lines changed: 261 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import collections
2-
import kombu
32
import socket
43

4+
import kombu
5+
6+
from celery import exceptions
57
from celery import states
68
from celery.backends import base
79

810
from .exceptions import *
911

10-
1112
__all__ = [
1213
'AMQPBackend',
14+
'DirectReplyAMQPBackend',
1315
]
1416

1517

@@ -265,7 +267,7 @@ def get_task_meta(self, task_id, backlog_limit=1000):
265267
else:
266268
raise self.BacklogLimitExceededException(task=task_id)
267269

268-
# If we got a latest task result from the queue, we store this message to the local cache, send the task
270+
# If we got the latest task result from the queue, we store this message to the local cache, send the task
269271
# result message back to the queue, and return it. Else, we try to get the task result from the local
270272
# cache, and assume that the task result is pending if it is not present on the cache.
271273
if latest:
@@ -379,3 +381,259 @@ def __reduce__(self, args=(), kwargs=None):
379381
expires=self.expires,
380382
)
381383
return super().__reduce__(args, kwargs)
384+
385+
386+
class DirectReplyAMQPBackend(base.BaseBackend):
387+
"""
388+
Celery result backend that uses RabbitMQ's direct-reply functionality for results.
389+
"""
390+
READY_STATES = states.READY_STATES
391+
PROPAGATE_STATES = states.PROPAGATE_STATES
392+
393+
Exchange = kombu.Exchange
394+
Consumer = kombu.Consumer
395+
Producer = kombu.Producer
396+
Queue = kombu.Queue
397+
398+
BacklogLimitExceededException = AMQPBacklogLimitExceededException
399+
WaitEmptyException = AMQPWaitEmptyException
400+
WaitTimeoutException = AMQPWaitTimeoutException
401+
402+
persistent = True
403+
supports_autoexpire = True
404+
supports_native_join = True
405+
406+
retry_policy = {
407+
'max_retries': 20,
408+
'interval_start': 0,
409+
'interval_step': 1,
410+
'interval_max': 1,
411+
}
412+
413+
def __init__(self, app, serializer=None, **kwargs):
414+
super().__init__(app, **kwargs)
415+
416+
conf = self.app.conf
417+
418+
self.persistent = False
419+
self.delivery_mode = 1
420+
self.result_exchange = ''
421+
self.result_exchange_type = 'direct'
422+
self.exchange = self._create_exchange(
423+
self.result_exchange,
424+
self.result_exchange_type,
425+
self.delivery_mode,
426+
)
427+
self.serializer = serializer or conf.result_serializer
428+
429+
self._consumers = {}
430+
self._cache = kombu.utils.functional.LRUCache(limit=10000)
431+
432+
def reload_task_result(self, task_id):
433+
raise NotImplementedError('reload_task_result is not supported by this backend.')
434+
435+
def reload_group_result(self, task_id):
436+
raise NotImplementedError('reload_group_result is not supported by this backend.')
437+
438+
def save_group(self, group_id, result):
439+
raise NotImplementedError('save_group is not supported by this backend.')
440+
441+
def restore_group(self, group_id, cache=True):
442+
raise NotImplementedError('restore_group is not supported by this backend.')
443+
444+
def delete_group(self, group_id):
445+
raise NotImplementedError('delete_group is not supported by this backend.')
446+
447+
def add_to_chord(self, chord_id, result):
448+
raise NotImplementedError('add_to_chord is not supported by this backend.')
449+
450+
def get_many(self, task_ids, **kwargs):
451+
raise NotImplementedError('get_many is not supported by this backend.')
452+
453+
def wait_for(self, task_id, timeout=None, interval=0.5, on_interval=None, no_ack=True):
454+
"""
455+
Waits for task and returns the result.
456+
457+
:param task_id: The task identifiers we want the result for
458+
:param timeout: Consumer read timeout
459+
:param no_ack: If enabled the messages are automatically acknowledged by the broker
460+
:param interval: Interval to drain messages from the queue
461+
:param on_interval: Callback function for message poll intervals
462+
:param kwargs:
463+
:return: Task result body as dict
464+
"""
465+
try:
466+
return super().wait_for(
467+
task_id,
468+
timeout=timeout,
469+
interval=interval,
470+
no_ack=no_ack,
471+
on_interval=on_interval
472+
)
473+
except exceptions.TimeoutError:
474+
consumer = self._consumers.pop(task_id, None)
475+
if consumer and consumer not in self._consumers.values():
476+
consumer.cancel()
477+
478+
raise self.WaitTimeoutException()
479+
480+
def get_task_meta(self, task_id, backlog_limit=1000):
481+
def _on_message_callback(message):
482+
nonlocal meta, task_id
483+
payload = message.decode()
484+
485+
if not isinstance(payload, (dict,)) or 'task_id' not in payload:
486+
return
487+
488+
if task_id == payload['task_id']:
489+
meta = payload
490+
else:
491+
self._cache[payload['task_id']] = payload
492+
493+
meta = self._cache.pop(task_id, None)
494+
495+
if meta is not None:
496+
return meta
497+
498+
consumer = self._consumers.get(task_id)
499+
500+
if not consumer:
501+
return {
502+
'status': states.FAILURE,
503+
'result': None,
504+
}
505+
506+
consumer.on_message = _on_message_callback
507+
consumer.consume()
508+
509+
try:
510+
consumer.connection.drain_events(timeout=0.5)
511+
except socket.timeout:
512+
pass
513+
514+
if meta:
515+
consumer = self._consumers.pop(task_id, None)
516+
if consumer and consumer not in self._consumers.values():
517+
consumer.cancel()
518+
519+
return self.meta_from_decoded(meta)
520+
else:
521+
return {
522+
'status': states.PENDING,
523+
'result': None,
524+
}
525+
526+
def store_result(self, task_id, result, state, traceback=None, request=None, **kwargs):
527+
"""
528+
Sends the task result for the given task identifier to the task result queue and returns the sent result dict.
529+
530+
:param task_id: Task identifier to send the result for
531+
:param result: The task result as dict
532+
:param state: The task result state
533+
:param traceback: The traceback if the task resulted in an exception
534+
:param request: Request data
535+
:param kwargs:
536+
:return: The task result as dict
537+
"""
538+
# Determine the routing key and a potential correlation identifier.
539+
routing_key = self._create_routing_key(task_id, request)
540+
correlation_id = self._create_correlation_id(task_id, request)
541+
542+
with self.app.amqp.producer_pool.acquire(block=True) as producer:
543+
producer.publish(
544+
{
545+
'task_id': task_id,
546+
'status': state,
547+
'result': self.encode_result(result, state),
548+
'traceback': traceback,
549+
'children': self.current_task_children(request),
550+
},
551+
exchange='',
552+
routing_key=routing_key,
553+
correlation_id=correlation_id,
554+
serializer=self.serializer,
555+
retry=True,
556+
retry_policy=self.retry_policy,
557+
delivery_mode=self.delivery_mode,
558+
)
559+
560+
return result
561+
562+
def on_task_call(self, producer, task_id):
563+
"""
564+
Creates and saves a consumer for the direct-reply pseudo-queue, before the task request is sent
565+
to the queue.
566+
567+
:param producer: The producer for the task request
568+
:param task_id: The task identifier
569+
"""
570+
for _, consumer in self._consumers.items():
571+
if consumer.channel is producer.channel:
572+
self._consumers[task_id] = consumer
573+
break
574+
else:
575+
self._consumers[task_id] = self._create_consumer(
576+
producer.channel,
577+
)
578+
579+
def _create_consumer(self, channel):
580+
"""
581+
Creates a consumer with the given parameters.
582+
583+
:param channel: The channel to use for the consumer
584+
:return: Created consumer
585+
"""
586+
consumer_queue = kombu.Queue("amq.rabbitmq.reply-to", no_ack=True)
587+
consumer = kombu.Consumer(
588+
channel,
589+
queues=[consumer_queue],
590+
auto_declare=True,
591+
)
592+
consumer.consume()
593+
594+
return consumer
595+
596+
def _create_exchange(self, name, exchange_type='direct', delivery_mode=2):
597+
"""
598+
Creates an exchange with the given parameters.
599+
600+
:param name: Name of the exchange as string
601+
:param exchange_type: Type of the exchange as string (e.g. 'direct', 'topic', …)
602+
:param delivery_mode: Exchange delivery mode as integer (1 for transient, 2 for persistent)
603+
:return: Created exchange
604+
"""
605+
return self.Exchange(
606+
name=name,
607+
type=exchange_type,
608+
delivery_mode=delivery_mode,
609+
durable=self.persistent,
610+
auto_delete=False,
611+
)
612+
613+
def _create_routing_key(self, task_id, request=None):
614+
"""
615+
Creates a routing key from the given request or task identifier.
616+
617+
:param task_id: Task identifier as string
618+
:param request: The task request object
619+
:return: Routing key as string
620+
"""
621+
return request and request.reply_to or task_id
622+
623+
def _create_correlation_id(self, task_id, request=None):
624+
"""
625+
Creates a correlation identifier from the given task identifier.
626+
627+
:param task_id: Task identifier as string
628+
:param request: The task request object
629+
:return: Routing key as string
630+
"""
631+
return request and request.correlation_id or task_id
632+
633+
def __reduce__(self, args=(), kwargs=None):
634+
kwargs = kwargs if kwargs else {}
635+
kwargs.update(
636+
url=self.url,
637+
serializer=self.serializer,
638+
)
639+
return super().__reduce__(args, kwargs)

tests/test_project/manage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
def main():
88
"""Run administrative tasks."""
9-
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'test_project.settings')
9+
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'test_project.settings.backend')
1010
try:
1111
from django.core.management import execute_from_command_line
1212
except ImportError as exc:
File renamed without changes.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .backend import *
2+
3+
4+
CELERY_RESULT_BACKEND = 'celery_amqp_backend.DirectReplyAMQPBackend://'

0 commit comments

Comments
 (0)