Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions config/start_celery.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/bin/bash

cd /ars/tr_sys
celery -A tr_sys beat -l info --detach
celery -A tr_sys worker -l info
# Start beat with log file
celery -A tr_sys beat -l info -f /var/log/celerybeat.log &
# Start worker in foreground
celery -A tr_sys worker -l info
63 changes: 61 additions & 2 deletions tr_sys/tr_ars/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.shortcuts import redirect, get_object_or_404
from django.urls import path, re_path, include, reverse
from django.utils import timezone
from pika.spec import NOT_FOUND
from tr_ars import utils
from tr_ars import tasks
from utils2 import urlRemoteFromInforesid
Expand Down Expand Up @@ -1195,6 +1196,7 @@ def query_event_unsubscribe(req=None, key=None):
return HttpResponse(json.dumps(response), status=status)

except Exception as e:
response={}
logger.error("Unexpected error at unsubscribe endpoint: {}".format(traceback.format_exception(type(e), e, e.__traceback__)))
logger.error(str(e.with_traceback()))
response['message']=str(e.with_traceback())
Expand All @@ -1220,6 +1222,63 @@ def query_event_unsubscribe(req=None, key=None):
mesg.clients.remove(subscriber_client)
except Exception as e:
logger.error("Error during auto-unsubscribing pk %s" % key)

@csrf_exempt
def get_status(req=None, key=None):
status_map={
'D': 'Done',
'S': 'Stopped',
'R': 'Running',
'E': 'Error',
'W': 'Waiting',
'U': 'Unknown'
}
response={}
if req is not None:
if req.method=='POST':
try:
body = json.loads(req.body)
pks = body["pks"]
# Make dict keys strings so lookup matches JSON pk strings
QuerySet_tuple = Message.objects.filter(pk__in=pks).values_list("pk", "status","merged_versions_list", 'params')
resultMap = { str(pk): (status, merged, params) for pk, status, merged, params in QuerySet_tuple }
response = []
for pk in pks:
key =str(pk)
if key in resultMap:
status, merged, params = resultMap[key]
response.append({'pk':key,
'status':status_map[status],
'merged_list':merged,
'stats': params['stats'] if 'stats' in params else None})
else:
response.append({'pk':key,
'status':None,
'merged_list':None,
'stats':None}
)
return JsonResponse(response,safe=False, status=200)

except Exception as e:
logger.error("Unexpected error at get notification status endpoint: {}".format(traceback.format_exception(type(e), e, e.__traceback__)))
response['message']=str(e.with_traceback())
response['timestamp']= timezone.now().isoformat()
return JsonResponse("messages", status =405)

else:
return HttpResponse('Only POST is permitted!', status=405)
else:
try:
logging.info("getting the status for the message %s"% key)
mesg = get_object_or_404(Message, pk=key)
response['status']=mesg.status
return JsonResponse(response, status=200)
except Exception as e:
logger.error(f"Message with ID {key} does not exist")
response['message']=str(e.with_traceback())
response['timestamp']= timezone.now().isoformat()
return JsonResponse(response, status =405)

@csrf_exempt
def health(req):
if req.method == 'GET':
Expand Down Expand Up @@ -1268,8 +1327,8 @@ def health(req):
re_path(r'^query_event_subscribe/?$', query_event_subscribe, name='ars-subscribe'),
re_path(r'^query_event_unsubscribe/?$', query_event_unsubscribe, name='ars-unsubscribe'),
path('post_process/<uuid:key>', post_process, name='ars-post_process_debug'),
re_path(r'^health/?$', health, name='ars-health')

re_path(r'^health/?$', health, name='ars-health'),
re_path(r'^get_status/?$', get_status, name='ars-status')
]

urlpatterns = [
Expand Down
4 changes: 2 additions & 2 deletions tr_sys/tr_ars/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def save(self, *args, **kwargs):
self.original_data = {} # Clear original data to avoid redundancy

super().save(*args, **kwargs)
if self.should_notify():
if self.should_notify() and self.ref is None:
self.notify_subscribers()

def save_compressed_dict(self, data):
Expand Down Expand Up @@ -254,7 +254,7 @@ def notify_subscribers(self, additional_notification_fields=None):
from .tasks import notify_subscribers_task
if self.status == 'D':
additional_notification_fields = {
"event_type":"admin",
"event_type": "admin",
"complete" : True
}
if self.status == 'E':
Expand Down
15 changes: 12 additions & 3 deletions tr_sys/tr_ars/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys, logging
from .models import Actor, Agent, Message, Channel
from .pubsub import send_messages
from .utils import get_safe, createMessage
from .utils import createMessage
from django.utils import timezone
logger = logging.getLogger(__name__)
from .api import query_event_unsubscribe,get_ars_actor
Expand Down Expand Up @@ -61,12 +61,12 @@ def message_post_save(sender, instance, **kwargs):
finished = False
logger.info('+++ Parent message %s not Done because of child: %s in state %s' % (str(pmessage.id),str(child.id),str(child.status)))

if child.status == 'D' and child.actor.agent.name.startswith('ar') and (child.result_count is not None and child.result_count > 0):
elif child.status == 'D' and child.actor.agent.name.startswith('ar') and (child.result_count is not None and child.result_count > 0):
if child.actor.agent.name == 'ars-ars-agent':
merge_count += 1
else:
orig_count += 1
if child.status == 'E' and child.actor.agent.name == 'ars-ars-agent':
elif child.status == 'E' and child.actor.agent.name == 'ars-ars-agent':
if child.code == 444:
merge_count += 1
else:
Expand Down Expand Up @@ -103,6 +103,15 @@ def message_post_save(sender, instance, **kwargs):
pmessage._skip_post_save = True
pmessage.save(update_fields=['status','code','updated_at', 'merged_version', 'merged_versions_list'])
else:
#adding one last notification about last merge being done
logger.info('+++ merged_versions : %s' % (pmessage.merged_version))
logger.info('+++ merged_versions_list : %s' % (pmessage.merged_versions_list))
notification={
"event_type":"last_merged_completed",
"complete":True,
"merged_versions_list":pmessage.merged_versions_list if pmessage.merged_versions_list is not None else []
}
pmessage.notify_subscribers(notification)
pmessage.status = 'D'
pmessage.code = 200
pmessage.updated_at = timezone.now()
Expand Down
68 changes: 40 additions & 28 deletions tr_sys/tr_ars/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from django.utils import timezone
from django.shortcuts import get_object_or_404
from opentelemetry import trace
from opentelemetry.propagate import inject
# Ensure that the tracing context is properly propagated within tasks
from opentelemetry.context import attach, detach, set_value, get_current
from requests.exceptions import RequestException, Timeout
import time as sleeptime
from .api import decrypt_secret
import hmac
Expand Down Expand Up @@ -296,8 +296,38 @@ def catch_timeout_async():
logging.info(f'NOT TIMING OUT for pk: {str(id)}')
logging.info(f'{query_type} : max_time_pathfinder: {max_time_pathfinder} -- timestamp: {timestamp}')

@shared_task(name="notify_subscribers")
def notify_subscribers_task(pk, status_code, additional_notification_fields=None, count=0):
@shared_task(name="notify-one-client",
bind=True,
autoretry_for=(Timeout, RequestException),
retry_backoff=True, # exponential backoff
retry_backoff_max=300, # cap backoff
retry_jitter=True, # avoid thundering herd
retry_kwargs={"max_retries": 8})
def notify_one_client_task(self, client_pk, notification):

from .models import Client
client = Client.objects.get(pk=client_pk)
callback = client.callback_url
encrpyted_secret = client.client_secret
encoded_master_key = os.getenv("AES_MASTER_KEY")
if not encoded_master_key:
raise RuntimeError("AES_MASTER_KEY is not set")
master_key= base64.b64decode(encoded_master_key)
client_secret = decrypt_secret(encrpyted_secret, master_key)
data_json = json.dumps(notification, separators=(',', ':'), sort_keys=True).encode('utf-8') #convert notification to a consistent byte representation
digest = hmac.new(client_secret.encode('utf-8'), data_json, hashlib.sha256).hexdigest()
headers={
"Content-Type": "application/json",
"x-event-signature": digest
}

r = requests.post(url=callback, data=data_json, headers=headers, timeout=10)
if r.status_code != 200:
raise RequestException(f"notify failed: status={r.status_code}, body={r.text[:200]}")


@shared_task(name="notify-subscribers")
def notify_subscribers_task(pk, status_code, additional_notification_fields=None):
from .models import Message
try:
message = get_object_or_404(Message.objects.filter(pk=pk))
Expand All @@ -308,32 +338,14 @@ def notify_subscribers_task(pk, status_code, additional_notification_fields=None
}
if additional_notification_fields:
for k, v in additional_notification_fields.items():
if k =="event_type" and v == "last_merged_completed":
notification['code'] = 200
notification[k] = v

all_subscribed_clients = message.clients.all()
for client in all_subscribed_clients:
callback = client.callback_url
encrpyted_secret = client.client_secret
encoded_master_key = os.getenv("AES_MASTER_KEY")
master_key= base64.b64decode(encoded_master_key)
client_secret = decrypt_secret(encrpyted_secret, master_key)
data_json = json.dumps(notification, separators=(',', ':'), sort_keys=True).encode('utf-8') #convert notification to a consistent byte representation
digest = hmac.new(client_secret.encode('utf-8'), data_json, hashlib.sha256).hexdigest()
headers={
"Content-Type": "application/json",
"x-event-signature": digest
}
try:
r = requests.post(url=callback, data=data_json, headers=headers)
if r.status_code != 200:
if count <= 10:
count = count + 1
delay = 5 * pow(2, count)
sleeptime.sleep(delay)
notify_subscribers_task.apply_async((message.pk, status_code,additional_notification_fields, count))
except Exception as e:
logger.info("Unexpected error notifying %s about %s: %s" % (callback, str(message.pk), str(e)))

logger.info(f"Sending final version of notification: {notification}")
# fan out one task per client so one slow/bad callback doesn't block others
for client in message.clients.all():
logger.info("sending msg to client_id %s" %(client.client_id))
notify_one_client_task.apply_async(args=(client.pk, notification), queue="notify")
except Message.DoesNotExist:
logger.error(f"Message with ID {pk} does not exist")

Expand Down
8 changes: 2 additions & 6 deletions tr_sys/tr_ars/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,11 @@
re_path(r'^query_event_subscribe/?$', api.query_event_subscribe, name='ars-subscribe'),
re_path(r'^query_event_unsubscribe/?$', api.query_event_unsubscribe, name='ars-unsubscribe'),
path('post_process/<uuid:key>', api.post_process, name='ars-post_process_debug'),
re_path(r'^health/?$', api.health, name='ars-health')
re_path(r'^health/?$', api.health, name='ars-health'),
re_path(r'^get_status/?$', api.get_status, name='ars-status')
]



urlpatterns = [
path(r'', api.api_redirect, name='ars-base'),
path(r'app/', views.app_home, name='ars-app-home'),
path(r'app/status', views.status, name='ars-app-status'),
path(r'api/', include(apipatterns)),
path(r'answer/<uuid:key>', views.answer, name='ars-answer'),
]
31 changes: 24 additions & 7 deletions tr_sys/tr_ars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,16 @@ def to_dict(self):
def __json__(self):
return self.to_dict()

def get_msg_stats(mesg_dict):
stats={}
for component in mesg_dict['message'].keys():
#print(component)
if component == "knowledge_graph":
for subComp in ["nodes", "edges"]:
stats[f'{component}_{subComp}']=len(get_safe(mesg_dict, "message", f'{component}',f'{subComp}'))
else:
stats[component]=len(get_safe(mesg_dict, "message", f"{component}"))
return stats
def mergeMessages(messageList,pk):
messageListCopy = copy.deepcopy(messageList)
message = messageListCopy.pop()
Expand Down Expand Up @@ -656,10 +666,10 @@ def lock_merge(message):
message.save()
return False

@shared_task(name="merge_and_post_process")
@shared_task(name="merge-and-post-process")
def merge_and_post_process(parent_pk,message_to_merge, agent_name, counter=0):
merged=None

stats={}
logging.info(f"Starting merge for %s with parent PK: %s"% (agent_name,parent_pk))

logging.info(f"Before atomic transaction for %s with parent PK: %s"% (agent_name,parent_pk))
Expand All @@ -675,7 +685,7 @@ def merge_and_post_process(parent_pk,message_to_merge, agent_name, counter=0):
try:

logging.info(f"Before merging for %s with parent PK: %s"% (agent_name,parent_pk))
merged, parent = merge_received(parent,message_to_merge, agent_name)
merged, parent, stats = merge_received(parent,message_to_merge, agent_name)
logging.info(f"After merging for %s with parent PK: %s"% (agent_name,parent_pk))
parent.save()
notification={
Expand Down Expand Up @@ -712,10 +722,10 @@ def merge_and_post_process(parent_pk,message_to_merge, agent_name, counter=0):
merged.code = code
merged.save()



notification["event_type"]="merged_version_available"
notification["merged_version"]=str(merged.pk)
notification['stats']=stats
logging.info(f"✅✅✅NOTIFICATION: {notification}✅✅✅")
parent.notify_subscribers(notification)

def remove_blocked(mesg, data, blocklist=None):
Expand Down Expand Up @@ -1453,7 +1463,7 @@ def createMessage(actor,parent_pk):


@app.task(name="merge_received")
def merge_received(parent,message_to_merge, agent_name, counter=0):
def merge_received(parent,message_to_merge, agent_name):
current_merged_pk=parent.merged_version_id
logging.info("Beginning merge for agent %s with current_pk: %s" %(agent_name,str(current_merged_pk)))
t_to_merge_message=TranslatorMessage(message_to_merge)
Expand Down Expand Up @@ -1483,6 +1493,8 @@ def merge_received(parent,message_to_merge, agent_name, counter=0):

merged_dict = merged.to_dict()
logging.info('the keys for merged_dict are %s' % merged_dict['message'].keys())
stats = get_msg_stats(merged_dict)
logging.info(f'the return stat is {stats}')
new_merged_message.save_compressed_dict(merged_dict)
# new_merged_message.data = merged_dict
new_merged_message.status='R'
Expand All @@ -1500,9 +1512,14 @@ def merge_received(parent,message_to_merge, agent_name, counter=0):
parent.merged_versions_list=[pk_infores_merge]
else:
parent.merged_versions_list.append(pk_infores_merge)
parameter = parent.params
if parameter is None:
parameter={}
parameter['stats']=stats
parent.params=parameter
parent.save()
logging.info("returning new_merged_message to be post processed with pk: %s" % str(new_merged_message.pk))
return new_merged_message, parent
return new_merged_message, parent, stats
except Exception as e:
logging.exception("problem with merging for %s :" % agent_name)
#If anything goes wrong, we at least need to unlock the semaphore
Expand Down
2 changes: 2 additions & 0 deletions tr_sys/tr_sys/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,7 @@ def debug_task(self):
task_default_delivery_mode='persistent',
task_create_missing_queues=True, # ← This one ensures auto-creation with durability
worker_prefetch_multiplier=1, # useful for crash resilience,when you have task with long duration->
#task_soft_time_limit=60 * 15,
#task_time_limit=60 * 20,
# reserve one task per worker process at a time (https://docs.celeryq.dev/en/stable/userguide/optimizing.html#prefetch-limits)
)
9 changes: 8 additions & 1 deletion tr_sys/tr_sys/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@
)

# Celery settings

CELERY_RESULT_BACKEND = 'django-db'
CELERY_CACHE_BACKEND = 'django-cache'
CELERY_BROKER_URL = 'amqp://localhost'
Expand All @@ -224,3 +223,11 @@

USE_CELERY = True
DEFAULT_HOST = 'http://localhost:8000'

# CELERY_TASK_ROUTES = {
# "send-message-to-actor": {"queue": "agent_outbound"},
# "notify-subscribers": {"queue": "notify"},
# "notify-one-client": {"queue": "notify"},
# "merge-and-post-process": {"queue": "postprocess"},
# }
# CELERY_TASK_DEFAULT_QUEUE = "default"