Skip to content

Commit cb210b8

Browse files
committed
SIANXSVC-826: Added direct-reply result backend
1 parent c617b06 commit cb210b8

File tree

7 files changed

+320
-8
lines changed

7 files changed

+320
-8
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ jobs:
1313
fail-fast: false
1414
matrix:
1515
python-version:
16-
- "3.6"
1716
- "3.7"
1817
- "3.8"
1918
- "3.9"
2019
- "3.10"
2120
celery-version:
2221
- "5.0"
2322
- "5.1"
23+
- "5.2"
2424

2525
steps:
2626
- uses: actions/checkout@v2
@@ -51,7 +51,8 @@ jobs:
5151
ln -s ./tests/test_project/manage.py manage.py
5252
5353
# run tests with coverage
54-
coverage run --source='./celery_amqp_backend' manage.py test
54+
coverage run --append --source='./celery_amqp_backend' manage.py test --settings=test_project.settings.backend
55+
coverage run --append --source='./celery_amqp_backend' manage.py test --settings=test_project.settings.direct_reply_backend
5556
coverage xml
5657
5758
- name: Upload coverage to Codecov

celery_amqp_backend/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .exceptions import *
21
from .backend import *
2+
from .exceptions import *

celery_amqp_backend/backend.py

Lines changed: 286 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import collections
2-
import kombu
32
import socket
43

4+
import kombu
5+
6+
from celery import exceptions
57
from celery import states
68
from celery.backends import base
79

810
from .exceptions import *
911

10-
1112
__all__ = [
1213
'AMQPBackend',
14+
'DirectReplyAMQPBackend',
1315
]
1416

1517

@@ -265,7 +267,7 @@ def get_task_meta(self, task_id, backlog_limit=1000):
265267
else:
266268
raise self.BacklogLimitExceededException(task=task_id)
267269

268-
# If we got a latest task result from the queue, we store this message to the local cache, send the task
270+
# If we got the latest task result from the queue, we store this message to the local cache, send the task
269271
# result message back to the queue, and return it. Else, we try to get the task result from the local
270272
# cache, and assume that the task result is pending if it is not present on the cache.
271273
if latest:
@@ -379,3 +381,284 @@ def __reduce__(self, args=(), kwargs=None):
379381
expires=self.expires,
380382
)
381383
return super().__reduce__(args, kwargs)
384+
385+
386+
class DirectReplyAMQPBackend(base.BaseBackend):
387+
"""
388+
Celery result backend that uses RabbitMQ's direct-reply functionality for results.
389+
"""
390+
READY_STATES = states.READY_STATES
391+
PROPAGATE_STATES = states.PROPAGATE_STATES
392+
393+
Exchange = kombu.Exchange
394+
Consumer = kombu.Consumer
395+
Producer = kombu.Producer
396+
Queue = kombu.Queue
397+
398+
BacklogLimitExceededException = AMQPBacklogLimitExceededException
399+
WaitEmptyException = AMQPWaitEmptyException
400+
WaitTimeoutException = AMQPWaitTimeoutException
401+
402+
persistent = True
403+
supports_autoexpire = True
404+
supports_native_join = True
405+
406+
retry_policy = {
407+
'max_retries': 20,
408+
'interval_start': 0,
409+
'interval_step': 1,
410+
'interval_max': 1,
411+
}
412+
413+
def __init__(self, app, serializer=None, **kwargs):
414+
super().__init__(app, **kwargs)
415+
416+
conf = self.app.conf
417+
418+
self.persistent = False
419+
self.delivery_mode = 1
420+
self.result_exchange = ''
421+
self.result_exchange_type = 'direct'
422+
self.exchange = self._create_exchange(
423+
self.result_exchange,
424+
self.result_exchange_type,
425+
self.delivery_mode,
426+
)
427+
self.serializer = serializer or conf.result_serializer
428+
429+
self._consumers = {}
430+
self._cache = kombu.utils.functional.LRUCache(limit=10000)
431+
432+
def reload_task_result(self, task_id):
433+
raise NotImplementedError('reload_task_result is not supported by this backend.')
434+
435+
def reload_group_result(self, task_id):
436+
raise NotImplementedError('reload_group_result is not supported by this backend.')
437+
438+
def save_group(self, group_id, result):
439+
raise NotImplementedError('save_group is not supported by this backend.')
440+
441+
def restore_group(self, group_id, cache=True):
442+
raise NotImplementedError('restore_group is not supported by this backend.')
443+
444+
def delete_group(self, group_id):
445+
raise NotImplementedError('delete_group is not supported by this backend.')
446+
447+
def add_to_chord(self, chord_id, result):
448+
raise NotImplementedError('add_to_chord is not supported by this backend.')
449+
450+
def get_many(self, task_ids, timeout=None, on_interval=None, **kwargs):
451+
interval = 0.25
452+
iterations = 0
453+
task_ids = task_ids if isinstance(task_ids, set) else set(task_ids)
454+
455+
while task_ids:
456+
yielded_task_ids = set()
457+
458+
for task_id in task_ids:
459+
meta = self.wait_for(
460+
task_id,
461+
timeout=timeout,
462+
interval=interval,
463+
on_interval=on_interval,
464+
no_ack=True,
465+
)
466+
467+
if meta['status'] in states.READY_STATES:
468+
yielded_task_ids.add(task_id)
469+
yield task_id, meta
470+
471+
if timeout and iterations * interval >= timeout:
472+
raise self.WaitTimeoutException()
473+
474+
iterations += 1
475+
476+
task_ids.difference_update(yielded_task_ids)
477+
478+
def wait_for(self, task_id, timeout=None, interval=0.5, on_interval=None, no_ack=True):
479+
"""
480+
Waits for task and returns the result.
481+
482+
:param task_id: The task identifiers we want the result for
483+
:param timeout: Consumer read timeout
484+
:param no_ack: If enabled the messages are automatically acknowledged by the broker
485+
:param interval: Interval to drain messages from the queue
486+
:param on_interval: Callback function for message poll intervals
487+
:param kwargs:
488+
:return: Task result body as dict
489+
"""
490+
try:
491+
return super().wait_for(
492+
task_id,
493+
timeout=timeout,
494+
interval=interval,
495+
no_ack=no_ack,
496+
on_interval=on_interval
497+
)
498+
except exceptions.TimeoutError:
499+
consumer = self._consumers.pop(task_id, None)
500+
if consumer and consumer not in self._consumers.values():
501+
consumer.cancel()
502+
503+
raise self.WaitTimeoutException()
504+
505+
def get_task_meta(self, task_id, backlog_limit=1000):
506+
def _on_message_callback(message):
507+
nonlocal meta, task_id
508+
payload = message.decode()
509+
510+
if not isinstance(payload, (dict,)) or 'task_id' not in payload:
511+
return
512+
513+
if task_id == payload['task_id']:
514+
meta = payload
515+
else:
516+
self._cache[payload['task_id']] = payload
517+
518+
meta = self._cache.pop(task_id, None)
519+
520+
if meta is not None:
521+
return meta
522+
523+
consumer = self._consumers.get(task_id)
524+
525+
if not consumer:
526+
return {
527+
'status': states.FAILURE,
528+
'result': None,
529+
}
530+
531+
consumer.on_message = _on_message_callback
532+
consumer.consume()
533+
534+
try:
535+
consumer.connection.drain_events(timeout=0.5)
536+
except socket.timeout:
537+
pass
538+
539+
if meta:
540+
consumer = self._consumers.pop(task_id, None)
541+
if consumer and consumer not in self._consumers.values():
542+
consumer.cancel()
543+
544+
return self.meta_from_decoded(meta)
545+
else:
546+
return {
547+
'status': states.PENDING,
548+
'result': None,
549+
}
550+
551+
def store_result(self, task_id, result, state, traceback=None, request=None, **kwargs):
552+
"""
553+
Sends the task result for the given task identifier to the task result queue and returns the sent result dict.
554+
555+
:param task_id: Task identifier to send the result for
556+
:param result: The task result as dict
557+
:param state: The task result state
558+
:param traceback: The traceback if the task resulted in an exception
559+
:param request: Request data
560+
:param kwargs:
561+
:return: The task result as dict
562+
"""
563+
# Determine the routing key and a potential correlation identifier.
564+
routing_key = self._create_routing_key(task_id, request)
565+
correlation_id = self._create_correlation_id(task_id, request)
566+
567+
with self.app.amqp.producer_pool.acquire(block=True) as producer:
568+
producer.publish(
569+
{
570+
'task_id': task_id,
571+
'status': state,
572+
'result': self.encode_result(result, state),
573+
'traceback': traceback,
574+
'children': self.current_task_children(request),
575+
},
576+
exchange='',
577+
routing_key=routing_key,
578+
correlation_id=correlation_id,
579+
serializer=self.serializer,
580+
retry=True,
581+
retry_policy=self.retry_policy,
582+
delivery_mode=self.delivery_mode,
583+
)
584+
585+
return result
586+
587+
def on_task_call(self, producer, task_id):
588+
"""
589+
Creates and saves a consumer for the direct-reply pseudo-queue, before the task request is sent
590+
to the queue.
591+
592+
:param producer: The producer for the task request
593+
:param task_id: The task identifier
594+
"""
595+
for _, consumer in self._consumers.items():
596+
if consumer.channel is producer.channel:
597+
self._consumers[task_id] = consumer
598+
break
599+
else:
600+
self._consumers[task_id] = self._create_consumer(
601+
producer.channel,
602+
)
603+
604+
def _create_consumer(self, channel):
605+
"""
606+
Creates a consumer with the given parameters.
607+
608+
:param channel: The channel to use for the consumer
609+
:return: Created consumer
610+
"""
611+
consumer_queue = kombu.Queue("amq.rabbitmq.reply-to", no_ack=True)
612+
consumer = kombu.Consumer(
613+
channel,
614+
queues=[consumer_queue],
615+
auto_declare=True,
616+
)
617+
consumer.consume()
618+
619+
return consumer
620+
621+
def _create_exchange(self, name, exchange_type='direct', delivery_mode=2):
622+
"""
623+
Creates an exchange with the given parameters.
624+
625+
:param name: Name of the exchange as string
626+
:param exchange_type: Type of the exchange as string (e.g. 'direct', 'topic', …)
627+
:param delivery_mode: Exchange delivery mode as integer (1 for transient, 2 for persistent)
628+
:return: Created exchange
629+
"""
630+
return self.Exchange(
631+
name=name,
632+
type=exchange_type,
633+
delivery_mode=delivery_mode,
634+
durable=self.persistent,
635+
auto_delete=False,
636+
)
637+
638+
def _create_routing_key(self, task_id, request=None):
639+
"""
640+
Creates a routing key from the given request or task identifier.
641+
642+
:param task_id: Task identifier as string
643+
:param request: The task request object
644+
:return: Routing key as string
645+
"""
646+
return request and request.reply_to or task_id
647+
648+
def _create_correlation_id(self, task_id, request=None):
649+
"""
650+
Creates a correlation identifier from the given task identifier.
651+
652+
:param task_id: Task identifier as string
653+
:param request: The task request object
654+
:return: Routing key as string
655+
"""
656+
return request and request.correlation_id or task_id
657+
658+
def __reduce__(self, args=(), kwargs=None):
659+
kwargs = kwargs if kwargs else {}
660+
kwargs.update(
661+
url=self.url,
662+
serializer=self.serializer,
663+
)
664+
return super().__reduce__(args, kwargs)

tests/test_project/manage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
def main():
88
"""Run administrative tasks."""
9-
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'test_project.settings')
9+
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'test_project.settings.backend')
1010
try:
1111
from django.core.management import execute_from_command_line
1212
except ImportError as exc:
File renamed without changes.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .backend import *
2+
3+
4+
CELERY_RESULT_BACKEND = 'celery_amqp_backend.DirectReplyAMQPBackend://'

0 commit comments

Comments
 (0)