From 1eabb0020a5975e6ca960203883d46d137edbe0c Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 17 Apr 2026 15:06:48 +0530 Subject: [PATCH 01/19] default file size and addding documentation --- backend/app/api/docs/documents/upload.md | 1 + .../services/collections/create_collection.py | 7 +++--- backend/app/services/collections/helpers.py | 22 ++++++++++++++++++- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/backend/app/api/docs/documents/upload.md b/backend/app/api/docs/documents/upload.md index e667015f5..c4c06caa6 100644 --- a/backend/app/api/docs/documents/upload.md +++ b/backend/app/api/docs/documents/upload.md @@ -1,6 +1,7 @@ Upload a document to Kaapi. - If only a file is provided, the document will be uploaded and stored, and its ID will be returned. +- The maximum file size allowed for upload is 25 MB. - If a target format is specified, a transformation job will also be created to transform document into target format in the background. The response will include both the uploaded document details and information about the transformation job. - If a callback URL is provided, you will receive a notification at that URL once the document transformation job is completed. diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index eb37fd039..d12b7be3f 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -22,6 +22,7 @@ CreationRequest, ) from app.services.collections.helpers import ( + calculate_total_size_kb, extract_error_message, to_collection_public, ) @@ -156,6 +157,7 @@ def execute_job( result = None creation_request = None provider = None + storage = None try: creation_request = CreationRequest(**request) @@ -169,9 +171,10 @@ def execute_job( with Session(engine) as session: document_crud = DocumentCrud(session, project_id) flat_docs = document_crud.read_each(creation_request.documents) + storage = get_cloud_storage(session=session, project_id=project_id) file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} - total_size_kb = sum(doc.file_size_kb or 0 for doc in flat_docs) + total_size_kb = calculate_total_size_kb(flat_docs, storage) total_size_mb = round(total_size_kb / 1024, 2) with Session(engine) as session: @@ -186,8 +189,6 @@ def execute_job( ), ) - storage = get_cloud_storage(session=session, project_id=project_id) - provider = get_llm_provider( session=session, provider=creation_request.provider, diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 6985ac78e..66f9dc1c0 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -2,6 +2,7 @@ import json import ast import re +from typing import TYPE_CHECKING from uuid import UUID from fastapi import HTTPException @@ -11,6 +12,9 @@ from app.api.deps import SessionDep from app.models import DocumentCollection, Collection, CollectionPublic, Document +if TYPE_CHECKING: + from app.core.cloud.storage import CloudStorage + logger = logging.getLogger(__name__) @@ -63,6 +67,22 @@ def extract_error_message(err: Exception) -> str: return message.strip()[:1000] +def calculate_total_size_kb(documents: list[Document], storage: CloudStorage) -> float: + """ + Sum document sizes in KB. Uses the stored file_size_kb if available. + """ + total: float = 0 + for doc in documents: + if doc.file_size_kb is not None: + total += doc.file_size_kb + else: + logger.info( + f"[calculate_total_size_kb] file_size_kb missing, fetching from storage | {{'doc_id': '{doc.id}', 'fname': '{doc.fname}'}}" + ) + total += storage.get_file_size_kb(doc.object_store_url) + return total + + def batch_documents(documents: list[Document]) -> list[list[Document]]: """ Batch documents dynamically based on size and count limits. @@ -83,7 +103,7 @@ def batch_documents(documents: list[Document]) -> list[list[Document]]: current_batch_size_kb = 0 for doc in documents: - doc_size_kb = doc.file_size_kb or 0 + doc_size_kb = doc.file_size_kb or 15 * 1024 would_exceed_size = (current_batch_size_kb + doc_size_kb) > MAX_BATCH_SIZE_KB would_exceed_count = len(current_batch) >= MAX_BATCH_COUNT From 7f5d86f45af99e6b8133d9954158377c41106df1 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 17 Apr 2026 15:09:09 +0530 Subject: [PATCH 02/19] default file size and addding documentation --- backend/app/api/docs/documents/upload.md | 3 +-- backend/app/services/collections/create_collection.py | 1 - backend/app/services/collections/helpers.py | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/backend/app/api/docs/documents/upload.md b/backend/app/api/docs/documents/upload.md index c4c06caa6..438dc3e9b 100644 --- a/backend/app/api/docs/documents/upload.md +++ b/backend/app/api/docs/documents/upload.md @@ -1,7 +1,6 @@ Upload a document to Kaapi. -- If only a file is provided, the document will be uploaded and stored, and its ID will be returned. -- The maximum file size allowed for upload is 25 MB. +- If only a file is provided, the document will be uploaded and stored, and its ID will be returned. The maximum file size allowed for upload is 25 MB. - If a target format is specified, a transformation job will also be created to transform document into target format in the background. The response will include both the uploaded document details and information about the transformation job. - If a callback URL is provided, you will receive a notification at that URL once the document transformation job is completed. diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index d12b7be3f..009d55fd1 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -152,7 +152,6 @@ def execute_job( """ start_time = time.time() - # Keeping the references for potential backout/cleanup on failure collection_job = None result = None creation_request = None diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 66f9dc1c0..1b0ae0ace 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -23,7 +23,6 @@ MAX_DOC_SIZE_MB = 25 # 25 MB maximum per document # Maximum batch size for uploading documents to vector store -# Derived from MAX_DOC_SIZE + buffer to ensure single docs always fit MAX_BATCH_SIZE_KB = (MAX_DOC_SIZE_MB + 5) * 1024 # 30 MB in KB (25 + 5 MB buffer) MAX_BATCH_COUNT = 200 # Maximum documents per batch From 8e3d29d2093f488d7c54a0ec2f1708c05d7023b7 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 17 Apr 2026 15:30:57 +0530 Subject: [PATCH 03/19] coderabbit reviews --- .../services/collections/create_collection.py | 19 +++++++++++++++- backend/app/services/collections/helpers.py | 22 +------------------ 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 009d55fd1..887208e18 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -22,7 +22,6 @@ CreationRequest, ) from app.services.collections.helpers import ( - calculate_total_size_kb, extract_error_message, to_collection_public, ) @@ -136,6 +135,24 @@ def _mark_job_failed( return None +def calculate_total_size_kb( + documents: list[Document], storage: "CloudStorage" +) -> float: + """ + Sum document sizes in KB. Uses the stored file_size_kb if available. + """ + total: float = 0 + for doc in documents: + if doc.file_size_kb is not None: + total += doc.file_size_kb + else: + logger.info( + f"[calculate_total_size_kb] file_size_kb missing, fetching from storage | {{'doc_id': '{doc.id}', 'fname': '{doc.fname}'}}" + ) + total += storage.get_file_size_kb(doc.object_store_url) + return total + + def execute_job( request: dict, with_assistant: bool, diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 1b0ae0ace..db972c92d 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -2,7 +2,6 @@ import json import ast import re -from typing import TYPE_CHECKING from uuid import UUID from fastapi import HTTPException @@ -12,9 +11,6 @@ from app.api.deps import SessionDep from app.models import DocumentCollection, Collection, CollectionPublic, Document -if TYPE_CHECKING: - from app.core.cloud.storage import CloudStorage - logger = logging.getLogger(__name__) @@ -66,22 +62,6 @@ def extract_error_message(err: Exception) -> str: return message.strip()[:1000] -def calculate_total_size_kb(documents: list[Document], storage: CloudStorage) -> float: - """ - Sum document sizes in KB. Uses the stored file_size_kb if available. - """ - total: float = 0 - for doc in documents: - if doc.file_size_kb is not None: - total += doc.file_size_kb - else: - logger.info( - f"[calculate_total_size_kb] file_size_kb missing, fetching from storage | {{'doc_id': '{doc.id}', 'fname': '{doc.fname}'}}" - ) - total += storage.get_file_size_kb(doc.object_store_url) - return total - - def batch_documents(documents: list[Document]) -> list[list[Document]]: """ Batch documents dynamically based on size and count limits. @@ -102,7 +82,7 @@ def batch_documents(documents: list[Document]) -> list[list[Document]]: current_batch_size_kb = 0 for doc in documents: - doc_size_kb = doc.file_size_kb or 15 * 1024 + doc_size_kb = doc.file_size_kb if doc.file_size_kb is not None else 15 * 1024 would_exceed_size = (current_batch_size_kb + doc_size_kb) > MAX_BATCH_SIZE_KB would_exceed_count = len(current_batch) >= MAX_BATCH_COUNT From bed1d1a4f13f2cd918e639c0dd288fe9a661dfdc Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 17 Apr 2026 15:35:28 +0530 Subject: [PATCH 04/19] test cases failing --- backend/app/services/collections/create_collection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 887208e18..bd55c2871 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -14,6 +14,7 @@ CollectionJobCrud, ) from app.models import ( + Document, CollectionJobStatus, CollectionJob, Collection, @@ -21,6 +22,7 @@ CollectionJobPublic, CreationRequest, ) +from app.core.cloud.storage import CloudStorage from app.services.collections.helpers import ( extract_error_message, to_collection_public, @@ -135,9 +137,7 @@ def _mark_job_failed( return None -def calculate_total_size_kb( - documents: list[Document], storage: "CloudStorage" -) -> float: +def calculate_total_size_kb(documents: list[Document], storage: CloudStorage) -> float: """ Sum document sizes in KB. Uses the stored file_size_kb if available. """ From d02bac8a57d037327e501eb3e74ecfaa46c70c34 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 17 Apr 2026 17:01:00 +0530 Subject: [PATCH 05/19] changing the logic --- .../services/collections/create_collection.py | 39 +++++++++---------- backend/app/services/collections/helpers.py | 2 +- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index bd55c2871..14696b191 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -14,7 +14,6 @@ CollectionJobCrud, ) from app.models import ( - Document, CollectionJobStatus, CollectionJob, Collection, @@ -22,7 +21,6 @@ CollectionJobPublic, CreationRequest, ) -from app.core.cloud.storage import CloudStorage from app.services.collections.helpers import ( extract_error_message, to_collection_public, @@ -137,22 +135,6 @@ def _mark_job_failed( return None -def calculate_total_size_kb(documents: list[Document], storage: CloudStorage) -> float: - """ - Sum document sizes in KB. Uses the stored file_size_kb if available. - """ - total: float = 0 - for doc in documents: - if doc.file_size_kb is not None: - total += doc.file_size_kb - else: - logger.info( - f"[calculate_total_size_kb] file_size_kb missing, fetching from storage | {{'doc_id': '{doc.id}', 'fname': '{doc.fname}'}}" - ) - total += storage.get_file_size_kb(doc.object_store_url) - return total - - def execute_job( request: dict, with_assistant: bool, @@ -190,10 +172,27 @@ def execute_job( storage = get_cloud_storage(session=session, project_id=project_id) file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} - total_size_kb = calculate_total_size_kb(flat_docs, storage) - total_size_mb = round(total_size_kb / 1024, 2) + + backfill: list[tuple[UUID, float]] = [] + for doc in flat_docs: + if doc.file_size_kb is None: + size_kb = round(storage.get_file_size_kb(doc.object_store_url)) + doc.file_size_kb = size_kb + backfill.append((doc.id, size_kb)) + + total_size_kb = sum( + doc.file_size_kb for doc in flat_docs if doc.file_size_kb is not None + ) + total_size_mb = total_size_kb / 1024 with Session(engine) as session: + if backfill: + document_crud = DocumentCrud(session, project_id) + for doc_id, size_kb in backfill: + doc = document_crud.read_one(doc_id) + doc.file_size_kb = size_kb + document_crud.update(doc) + collection_job_crud = CollectionJobCrud(session, project_id) collection_job = collection_job_crud.read_one(job_uuid) collection_job = collection_job_crud.update( diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index db972c92d..3f0a0cefd 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -82,7 +82,7 @@ def batch_documents(documents: list[Document]) -> list[list[Document]]: current_batch_size_kb = 0 for doc in documents: - doc_size_kb = doc.file_size_kb if doc.file_size_kb is not None else 15 * 1024 + doc_size_kb = doc.file_size_kb would_exceed_size = (current_batch_size_kb + doc_size_kb) > MAX_BATCH_SIZE_KB would_exceed_count = len(current_batch) >= MAX_BATCH_COUNT From 8b7556c226449f0f9a52bc66906642bd55fa8068 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 17 Apr 2026 20:37:30 +0530 Subject: [PATCH 06/19] fixing test cases --- backend/app/services/collections/create_collection.py | 2 +- backend/app/tests/services/collections/test_helpers.py | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 14696b191..25aba0919 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -200,7 +200,7 @@ def execute_job( CollectionJobUpdate( task_id=task_id, status=CollectionJobStatus.PROCESSING, - total_size_mb=total_size_mb, + total_size_mb=round(total_size_mb, 2), ), ) diff --git a/backend/app/tests/services/collections/test_helpers.py b/backend/app/tests/services/collections/test_helpers.py index 7cddaf305..8b43946a1 100644 --- a/backend/app/tests/services/collections/test_helpers.py +++ b/backend/app/tests/services/collections/test_helpers.py @@ -122,14 +122,12 @@ def test_batch_documents_mixed_size_batching() -> None: assert len(batches[2]) == 1 # 15 MB total -def test_batch_documents_with_none_file_size() -> None: - """Test that documents with None file_size are treated as 0 bytes.""" +def test_batch_documents_with_none_file_size_raises() -> None: + """Test that documents with None file_size raise TypeError — sizes must be backfilled before batching.""" docs = create_fake_documents(10, file_size_kb=None) - batches = helpers.batch_documents(docs) - # All files with None/0 size should fit in one batch (under both limits) - assert len(batches) == 1 - assert len(batches[0]) == 10 + with pytest.raises(TypeError): + helpers.batch_documents(docs) def test_batch_documents_empty_input() -> None: From baaeac2eec68e1c89fdf764dcb563bbfdf408c98 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 5 May 2026 09:26:17 +0530 Subject: [PATCH 07/19] adding alembic file --- .../055_add_columns_to_collections.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 backend/app/alembic/versions/055_add_columns_to_collections.py diff --git a/backend/app/alembic/versions/055_add_columns_to_collections.py b/backend/app/alembic/versions/055_add_columns_to_collections.py new file mode 100644 index 000000000..804e5cc7d --- /dev/null +++ b/backend/app/alembic/versions/055_add_columns_to_collections.py @@ -0,0 +1,62 @@ +"""add batch tracking to collection_job and provider_file_id to document + +Revision ID: 055 +Revises: 054 +Create Date: 2026-04-13 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "055" +down_revision = "054" +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + "collection_jobs", + sa.Column( + "total_batches", + sa.Integer(), + nullable=True, + comment="Total number of batches the documents are split into", + ), + ) + op.add_column( + "collection_jobs", + sa.Column( + "current_batch_number", + sa.Integer(), + nullable=True, + comment="Which batch is currently being processed (1-indexed)", + ), + ) + op.add_column( + "collection_jobs", + sa.Column( + "documents_uploaded", + sa.JSON(), + nullable=True, + comment="List of document IDs successfully uploaded so far", + ), + ) + op.add_column( + "document", + sa.Column( + "openai_file_id", + sa.String(), + nullable=True, + comment="File ID assigned by the LLM provider (e.g. OpenAI file ID) to avoid re-uploading", + ), + ) + + +def downgrade(): + op.drop_column("collection_jobs", "total_batches") + op.drop_column("collection_jobs", "current_batch_number") + op.drop_column("collection_jobs", "documents_uploaded") + op.drop_column("document", "openai_file_id") From 47525e49639a2080d96f6a8b9bd980a79c74ff1f Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 5 May 2026 10:16:25 +0530 Subject: [PATCH 08/19] adding logic to the pr --- ..._add_batch_tracking_to_collections_jobs.py | 62 ++ backend/app/celery/utils.py | 18 + backend/app/crud/rag/open_ai.py | 53 ++ backend/app/models/collection_job.py | 27 +- backend/app/models/document.py | 5 + .../services/collections/create_collection.py | 560 +++++++++++------- .../services/collections/providers/base.py | 56 +- .../services/collections/providers/openai.py | 86 ++- 8 files changed, 603 insertions(+), 264 deletions(-) create mode 100644 backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py diff --git a/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py b/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py new file mode 100644 index 000000000..804e5cc7d --- /dev/null +++ b/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py @@ -0,0 +1,62 @@ +"""add batch tracking to collection_job and provider_file_id to document + +Revision ID: 055 +Revises: 054 +Create Date: 2026-04-13 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "055" +down_revision = "054" +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + "collection_jobs", + sa.Column( + "total_batches", + sa.Integer(), + nullable=True, + comment="Total number of batches the documents are split into", + ), + ) + op.add_column( + "collection_jobs", + sa.Column( + "current_batch_number", + sa.Integer(), + nullable=True, + comment="Which batch is currently being processed (1-indexed)", + ), + ) + op.add_column( + "collection_jobs", + sa.Column( + "documents_uploaded", + sa.JSON(), + nullable=True, + comment="List of document IDs successfully uploaded so far", + ), + ) + op.add_column( + "document", + sa.Column( + "openai_file_id", + sa.String(), + nullable=True, + comment="File ID assigned by the LLM provider (e.g. OpenAI file ID) to avoid re-uploading", + ), + ) + + +def downgrade(): + op.drop_column("collection_jobs", "total_batches") + op.drop_column("collection_jobs", "current_batch_number") + op.drop_column("collection_jobs", "documents_uploaded") + op.drop_column("document", "openai_file_id") diff --git a/backend/app/celery/utils.py b/backend/app/celery/utils.py index 5ebbf624a..475a4772a 100644 --- a/backend/app/celery/utils.py +++ b/backend/app/celery/utils.py @@ -85,6 +85,24 @@ def start_doctransform_job( return task_id +def start_create_collection_setup_job( + project_id: int, job_id: str, trace_id: str = "N/A", **kwargs +) -> str: + from app.celery.tasks.job_execution import run_create_collection_setup_job + + task_id = _enqueue_with_trace_context( + run_create_collection_job, + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + **kwargs, + ) + logger.info( + f"[start_create_collection_job] Started job {job_id} with Celery task {task_id}" + ) + return task_id + + def start_create_collection_job( project_id: int, job_id: str, trace_id: str = "N/A", **kwargs ) -> str: diff --git a/backend/app/crud/rag/open_ai.py b/backend/app/crud/rag/open_ai.py index cdae82440..be6970235 100644 --- a/backend/app/crud/rag/open_ai.py +++ b/backend/app/crud/rag/open_ai.py @@ -1,6 +1,7 @@ import json import logging import functools as ft +import time from io import BytesIO from typing import Iterable @@ -149,6 +150,58 @@ def update( yield from docs + def update_batch( + self, + vector_store_id: str, + docs: list[Document], + ) -> tuple[list[Document], list[Document]]: + """ + Attach a batch of documents to the vector store via a single upload_and_poll call. + + All docs must have provider_file_id set before calling this method. + Returns (succeeded, failed) — failed docs should be retried in the next batch. + """ + succeeded: list[Document] = [] + failed: list[Document] = [] + + if not docs: + return succeeded, failed + + try: + _t0 = time.monotonic() + batch = self.client.vector_stores.file_batches.upload_and_poll( + vector_store_id=vector_store_id, + files=[], + file_ids=[doc.openai_file_id for doc in docs], + ) + logger.info( + f"[OpenAIVectorStoreCrud.update_batch] Batch upload_and_poll duration | " + f"{{'vector_store_id': '{vector_store_id}', 'duration_s': {time.monotonic() - _t0:.3f}, " + f"'completed': {batch.file_counts.completed}, 'failed': {batch.file_counts.failed}}}" + ) + if batch.file_counts.failed == 0: + succeeded.extend(docs) + else: + # Can't identify which specific files failed — retry all of them + logger.warning( + f"[OpenAIVectorStoreCrud.update_batch] Batch had failures, marking all for retry | " + f"{{'vector_store_id': '{vector_store_id}', 'failed_count': {batch.file_counts.failed}}}" + ) + failed.extend(docs) + except OpenAIError as err: + logger.error( + f"[OpenAIVectorStoreCrud.update_batch] Batch attach failed | " + f"{{'vector_store_id': '{vector_store_id}', 'error': '{str(err)}'}}", + exc_info=True, + ) + failed.extend(docs) + + logger.info( + f"[OpenAIVectorStoreCrud.update_batch] Batch complete | " + f"{{'vector_store_id': '{vector_store_id}', 'succeeded': {len(succeeded)}, 'failed': {len(failed)}}}" + ) + return succeeded, failed + def delete(self, vector_store_id: str, retries: int = 3): if retries < 1: try: diff --git a/backend/app/models/collection_job.py b/backend/app/models/collection_job.py index 333ebfd14..6b628ad7e 100644 --- a/backend/app/models/collection_job.py +++ b/backend/app/models/collection_job.py @@ -77,7 +77,29 @@ class CollectionJob(SQLModel, table=True): documents: list[str] | None = Field( default=None, sa_column=Column( - JSON, nullable=True, comment="List of documents given to make collection" + JSON, nullable=True, comment="List of document IDs given to make collection" + ), + ) + total_batches: int | None = Field( + default=None, + nullable=True, + sa_column_kwargs={ + "comment": "Total number of batches the documents are split into" + }, + ) + current_batch_number: int | None = Field( + default=None, + nullable=True, + sa_column_kwargs={ + "comment": "Which batch is currently being processed (1-indexed)" + }, + ) + documents_uploaded: list[str] | None = Field( + default=None, + sa_column=Column( + JSON, + nullable=True, + comment="List of document IDs successfully uploaded so far", ), ) @@ -139,6 +161,9 @@ class CollectionJobUpdate(SQLModel): collection_id: UUID | None = None total_size_mb: float | None = None trace_id: str | None = None + total_batches: int | None = None + current_batch_number: int | None = None + documents_uploaded: list[str] | None = None ##Response models diff --git a/backend/app/models/document.py b/backend/app/models/document.py index 12843e72a..5bbcddc77 100644 --- a/backend/app/models/document.py +++ b/backend/app/models/document.py @@ -46,6 +46,11 @@ class Document(DocumentBase, table=True): description="The size of the document in kilobytes", sa_column_kwargs={"comment": "Size of the document in kilobytes (KB)"}, ) + openai_file_id: str | None = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "File ID assigned by OpenAI (avoid re-uploading)"}, + ) # Foreign keys source_document_id: UUID | None = Field( diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index bc42aa0d0..884f8e3bd 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -2,13 +2,11 @@ import time from uuid import UUID, uuid4 -from opentelemetry import trace from sqlmodel import Session from asgi_correlation_id import correlation_id from app.core.cloud import get_cloud_storage from app.core.db import engine -from app.core.telemetry import log_context from app.crud import ( CollectionCrud, DocumentCrud, @@ -23,17 +21,19 @@ CollectionJobPublic, CreationRequest, ) +from app.crud.rag import OpenAIVectorStoreCrud from app.services.collections.helpers import ( + batch_documents, extract_error_message, to_collection_public, ) from app.services.collections.providers.registry import get_llm_provider -from app.celery.utils import start_create_collection_job -from app.utils import send_callback, get_webhook_secret, APIResponse +from gevent import Timeout +from app.celery.utils import start_create_collection_job, start_collection_batch_job +from app.utils import send_callback, APIResponse logger = logging.getLogger(__name__) -tracer = trace.get_tracer(__name__) def start_job( @@ -44,49 +44,31 @@ def start_job( with_assistant: bool, organization_id: int, ) -> str: - with log_context( - tag="collection", - lifecycle="collection.create.start_job", - action="create", - collection_job_id=collection_job_id, - project_id=project_id, - organization_id=organization_id, - ): - trace_id = correlation_id.get() or "N/A" + trace_id = correlation_id.get() or "N/A" - job_crud = CollectionJobCrud(db, project_id) - collection_job = job_crud.update( - collection_job_id, CollectionJobUpdate(trace_id=trace_id) - ) + job_crud = CollectionJobCrud(db, project_id) + job_crud.update(collection_job_id, CollectionJobUpdate(trace_id=trace_id)) - task_id = start_create_collection_job( - project_id=project_id, - job_id=str(collection_job_id), - trace_id=trace_id, - request=request.model_dump(mode="json"), - with_assistant=with_assistant, - organization_id=organization_id, - ) + task_id = start_create_collection_job( + project_id=project_id, + job_id=str(collection_job_id), + trace_id=trace_id, + request=request.model_dump(mode="json"), + with_assistant=with_assistant, + organization_id=organization_id, + ) - logger.info( - "[create_collection.start_job] Job scheduled to create collection | " - f"collection_job_id={collection_job_id}, project_id={project_id}, task_id={task_id}" - ) + logger.info( + "[create_collection.start_job] Job scheduled to create collection | " + f"collection_job_id={collection_job_id}, project_id={project_id}, task_id={task_id}" + ) - return collection_job_id + return collection_job_id def build_success_payload( collection_job: CollectionJob, collection: Collection ) -> dict: - """ - { - "success": true, - "data": { job fields + full collection }, - "error": null, - "metadata": null - } - """ collection_public = to_collection_public(collection) collection_dict = collection_public.model_dump(mode="json", exclude_none=True) @@ -100,15 +82,6 @@ def build_success_payload( def build_failure_payload(collection_job: CollectionJob, error_message: str) -> dict: - """ - { - "success": false, - "data": { job fields, collection: null }, - "error": "something went wrong", - "metadata": null - } - """ - # ensure `collection` is explicitly null in the payload job_public = CollectionJobPublic.model_validate( collection_job, update={"collection": None}, @@ -142,11 +115,63 @@ def _mark_job_failed( ) return collection_job except Exception: - logger.warning("[create_collection.execute_job] Failed to mark job as FAILED") + logger.warning("[create_collection] Failed to mark job as FAILED") return None -def execute_job( +def _persist_succeeded_docs(succeeded: list, project_id: int) -> list[str]: + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + for doc in succeeded: + if doc.openai_file_id: + db_doc = document_crud.read_one(doc.id) + if db_doc.openai_file_id != doc.openai_file_id: + db_doc.openai_file_id = doc.openai_file_id + document_crud.update(db_doc) + return [str(doc.id) for doc in succeeded] + + +def _retry_failed_uploads( + vector_store_crud, + vector_store_id: str, + failed_docs: list, + project_id: int, + max_retries: int = 3, +) -> list[str]: + """ + Retry attaching docs that failed the initial batch upload_and_poll. + All docs must already have provider_file_id set. + Returns the list of successfully retried doc IDs. + Raises RuntimeError if any docs still fail after all retries. + """ + pending = failed_docs + all_succeeded_ids: list[str] = [] + + for attempt in range(1, max_retries + 1): + logger.warning( + "[_retry_failed_uploads] Retry attempt %d/%d: %d doc(s) | vector_store_id=%s", + attempt, + max_retries, + len(pending), + vector_store_id, + ) + succeeded, failed = vector_store_crud.update_batch(vector_store_id, pending) + + if succeeded: + all_succeeded_ids += _persist_succeeded_docs(succeeded, project_id) + + if not failed: + return all_succeeded_ids + + pending = failed + + ids = [str(d.id) for d in pending] + raise RuntimeError( + f"Failed to upload {len(pending)} document(s) after {max_retries} retries: {ids}" + ) + + +def execute_setup_job( request: dict, with_assistant: bool, project_id: int, @@ -156,206 +181,311 @@ def execute_job( task_instance, ) -> None: """ - Worker entrypoint scheduled by start_job. - Orchestrates: job state, provider init, collection creation, - optional assistant creation, collection persistence, linking, callbacks, and cleanup. + Phase 1: Fetch documents, create the vector store, split into batches, + update job state to PROCESSING, then queue the first batch task. """ - start_time = time.time() - collection_job = None - result = None creation_request = None - provider = None - storage = None - - with log_context( - tag="collection", - lifecycle="collection.create.execute_job", - action="create", - collection_job_id=job_id, - task_id=task_id, - project_id=project_id, - organization_id=organization_id, - ), tracer.start_as_current_span("collections.create.execute_job") as span: - span.set_attribute("collection.job_id", str(job_id)) - span.set_attribute("kaapi.project_id", project_id) - span.set_attribute("kaapi.organization_id", organization_id) - - try: - creation_request = CreationRequest(**request) - if with_assistant: - creation_request.provider = "openai" - - span.set_attribute("collection.provider", str(creation_request.provider)) - - job_uuid = UUID(job_id) - - with Session(engine) as session: - document_crud = DocumentCrud(session, project_id) - flat_docs = document_crud.read_each(creation_request.documents) - - file_exts = { - doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname - } - total_size_kb = sum(doc.file_size_kb or 0 for doc in flat_docs) - total_size_mb = round(total_size_kb / 1024, 2) - span.set_attribute("collection.documents.count", len(flat_docs)) - span.set_attribute("collection.documents.total_size_mb", total_size_mb) - - with Session(engine) as session: - collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.read_one(job_uuid) - collection_job = collection_job_crud.update( - job_uuid, - CollectionJobUpdate( - task_id=task_id, - status=CollectionJobStatus.PROCESSING, - total_size_mb=total_size_mb, - ), - ) - - storage = get_cloud_storage(session=session, project_id=project_id) - provider = get_llm_provider( - session=session, - provider=creation_request.provider, - project_id=project_id, - organization_id=organization_id, - ) + + try: + creation_request = CreationRequest(**request) + if with_assistant: + creation_request.provider = "openai" + + job_uuid = UUID(job_id) + trace_id = correlation_id.get() or "N/A" with Session(engine) as session: document_crud = DocumentCrud(session, project_id) flat_docs = document_crud.read_each(creation_request.documents) storage = get_cloud_storage(session=session, project_id=project_id) - file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} + provider = get_llm_provider( + session=session, + provider=creation_request.provider, + project_id=project_id, + organization_id=organization_id, + ) + + for doc in flat_docs: + session.expunge(doc) - backfill: list[tuple[UUID, float]] = [] - for doc in flat_docs: - if doc.file_size_kb is None: - size_kb = round(storage.get_file_size_kb(doc.object_store_url)) - doc.file_size_kb = size_kb - backfill.append((doc.id, size_kb)) + provider.upload_files(storage, flat_docs, project_id) - total_size_kb = sum( - doc.file_size_kb for doc in flat_docs if doc.file_size_kb is not None + logger.info( + "[create_collection.execute_setup_job] All file uploads complete | " + "job_id=%s, total=%d, failed=%d, duration_s=%.2f", + job_id, + len(flat_docs), ) + + total_size_kb = sum(doc.file_size_kb for doc in flat_docs) total_size_mb = total_size_kb / 1024 - with Session(engine) as session: - if backfill: - document_crud = DocumentCrud(session, project_id) - for doc_id, size_kb in backfill: - doc = document_crud.read_one(doc_id) - doc.file_size_kb = size_kb - document_crud.update(doc) + docs_batches = batch_documents(flat_docs) + total_batches = len(docs_batches) + batch_doc_ids = [[str(doc.id) for doc in batch] for batch in docs_batches] + with Session(engine) as session: collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.read_one(job_uuid) collection_job = collection_job_crud.update( job_uuid, CollectionJobUpdate( task_id=task_id, status=CollectionJobStatus.PROCESSING, - total_size_mb=round(total_size_mb, 2), + total_size_mb=total_size_mb, + current_batch_number=0, + total_batches=total_batches, + documents_uploaded=[], ), ) + start_collection_batch_job( + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + batch_number=1, + batch_doc_ids=batch_doc_ids[0], + remaining_batches=batch_doc_ids[1:], + request=request, + with_assistant=with_assistant, + organization_id=organization_id, + ) + + logger.info( + "[create_collection.execute_setup_job] Setup complete, first batch queued | " + f"job_id={job_id}, total_batches={total_batches}" + ) + + except Timeout as err: + timeout_err = TimeoutError( + f"[execute_setup_job] Task exceeded soft time limit of {err.seconds}s" + ) + _mark_job_failed( + project_id=project_id, + job_id=job_id, + err=timeout_err, + collection_job=collection_job, + ) + raise + + except Exception as err: + logger.error( + "[create_collection.execute_setup_job] Setup failed | job_id=%s, error=%s", + job_id, + str(err), + exc_info=True, + ) + + collection_job = _mark_job_failed( + project_id=project_id, + job_id=job_id, + err=err, + collection_job=collection_job, + ) + if creation_request and creation_request.callback_url and collection_job: + failure_payload = build_failure_payload(collection_job, str(err)) + send_callback(creation_request.callback_url, failure_payload) + + +def execute_batch_job( + request: dict, + with_assistant: bool, + project_id: int, + organization_id: int, + task_id: str, + job_id: str, + task_instance, + vector_store_id: str | None, + batch_number: int, + batch_doc_ids: list[str], + remaining_batches: list[list[str]], +) -> None: + """ + Phase 2: Upload one batch of documents to the vector store. + - Uploads the batch; any failures within the batch are retried inline by _upload_batch_with_retry + - Raises immediately if all retries for the batch are exhausted + - Checkpoints progress to the DB + - If more batches remain, queues the next batch task + - If this is the last batch, finalizes: creates Collection, links docs, marks job SUCCESSFUL + """ + collection_job = None + creation_request = None + + try: + batch_start_time = time.time() + creation_request = CreationRequest(**request) + if with_assistant: + creation_request.provider = "openai" + + job_uuid = UUID(job_id) + trace_id = correlation_id.get() or "N/A" + + logger.info( + "[create_collection.execute_batch_job] Starting batch | " + "job_id=%s, batch_number=%d, doc_count=%d, remaining_batches=%d", + job_id, + batch_number, + len(batch_doc_ids), + len(remaining_batches), + ) + + all_doc_ids_this_batch = [UUID(d) for d in batch_doc_ids] + is_final = not remaining_batches + + with Session(engine) as session: provider = get_llm_provider( session=session, provider=creation_request.provider, project_id=project_id, organization_id=organization_id, - with tracer.start_as_current_span("collections.create.provider"): - result = provider.create( - collection_request=creation_request, - storage=storage, - documents=flat_docs, - ) - - llm_service_id = result.llm_service_id - llm_service_name = result.llm_service_name - - with Session(engine) as session: - collection_crud = CollectionCrud(session, project_id) - collection_id = uuid4() - - collection = Collection( - id=collection_id, - project_id=project_id, - llm_service_id=llm_service_id, - llm_service_name=llm_service_name, - provider=creation_request.provider, - name=creation_request.name, - description=creation_request.description, - ) - collection_crud.create(collection) - collection = collection_crud.read_one(collection.id) - - if flat_docs: - DocumentCollectionCrud(session).create(collection, flat_docs) - - collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.update( - collection_job.id, - CollectionJobUpdate( - status=CollectionJobStatus.SUCCESSFUL, - collection_id=collection.id, - ), - ) - - success_payload = build_success_payload(collection_job, collection) - - span.set_attribute("collection.id", str(collection_id)) - - elapsed = time.time() - start_time - logger.info( - "[create_collection.execute_job] Collection created: %s | Time: %.2fs | Files: %d | Total Size: %s MB | Types: %s", - collection_id, - elapsed, - len(flat_docs), - collection_job.total_size_mb, - list(file_exts), ) - if creation_request.callback_url: - webhook_secret = get_webhook_secret(project_id, organization_id) - send_callback( - str(creation_request.callback_url), - success_payload, - webhook_secret=webhook_secret, - ) - - except Exception as err: - span.record_exception(err) - span.set_status(trace.Status(trace.StatusCode.ERROR, str(err))) - logger.error( - "[create_collection.execute_job] Collection Creation Failed | {'collection_job_id': '%s', 'error': '%s'}", - job_id, - str(err), - exc_info=True, + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + batch_docs = ( + document_crud.read_each(all_doc_ids_this_batch) + if all_doc_ids_this_batch + else [] ) + for doc in batch_docs: + session.expunge(doc) + + collection_result = provider.create( + creation_request, + batch_docs, + vector_store_id=vector_store_id, + is_final=is_final, + ) + resolved_vector_store_id = ( + collection_result.llm_service_id + if not is_final + else vector_store_id or collection_result.llm_service_id + ) + + with Session(engine) as session: + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.read_one(job_uuid) + already_uploaded = collection_job.documents_uploaded or [] + now_uploaded = already_uploaded + [str(d) for d in all_doc_ids_this_batch] - if provider is not None and result is not None: - try: - provider.delete(result) - except Exception: - logger.warning( - "[create_collection.execute_job] Provider cleanup failed" - ) + collection_job = collection_job_crud.update( + job_uuid, + CollectionJobUpdate( + current_batch_number=batch_number, + documents_uploaded=now_uploaded, + ), + ) - collection_job = _mark_job_failed( + logger.info( + "[create_collection.execute_batch_job] Batch %d complete | " + "doc_count=%d, job_id=%s", + batch_number, + len(all_doc_ids_this_batch), + job_id, + ) + + if remaining_batches: + start_collection_batch_job( project_id=project_id, job_id=job_id, - err=err, - collection_job=collection_job, + trace_id=trace_id, + vector_store_id=resolved_vector_store_id, + batch_number=batch_number + 1, + batch_doc_ids=remaining_batches[0], + remaining_batches=remaining_batches[1:], + request=request, + with_assistant=with_assistant, + organization_id=organization_id, + ) + logger.info( + "[create_collection.execute_batch_job] Batch %d/%d done, next batch queued | " + "job_id=%s, elapsed=%.2fs", + batch_number, + batch_number + len(remaining_batches), + job_id, + time.time() - batch_start_time, + ) + return + + # Final batch: collection_result already has assistant/vector_store finalized + finalize_start_time = time.time() + + with Session(engine) as session: + all_uploaded_ids = [UUID(d) for d in now_uploaded] + document_crud = DocumentCrud(session, project_id) + all_docs = ( + document_crud.read_each(all_uploaded_ids) if all_uploaded_ids else [] ) + for doc in all_docs: + session.expunge(doc) - if creation_request and creation_request.callback_url and collection_job: - failure_payload = build_failure_payload(collection_job, str(err)) - webhook_secret = get_webhook_secret(project_id, organization_id) - send_callback( - str(creation_request.callback_url), - failure_payload, - webhook_secret=webhook_secret, - ) - raise + with Session(engine) as session: + collection_id = uuid4() + collection = Collection( + id=collection_id, + project_id=project_id, + llm_service_id=collection_result.llm_service_id, + llm_service_name=collection_result.llm_service_name, + provider=creation_request.provider, + name=creation_request.name, + description=creation_request.description, + ) + collection_crud = CollectionCrud(session, project_id) + collection_crud.create(collection) + collection = collection_crud.read_one(collection.id) + + if all_docs: + DocumentCollectionCrud(session).create(collection, all_docs) + + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.update( + job_uuid, + CollectionJobUpdate( + status=CollectionJobStatus.SUCCESSFUL, + collection_id=collection.id, + ), + ) + + success_payload = build_success_payload(collection_job, collection) + + logger.info( + "[create_collection.execute_batch_job] All batches done, collection created: %s | " + "finalize_time=%.2fs, total_time=%.2fs, total_docs=%d", + collection_id, + time.time() - finalize_start_time, + time.time() - batch_start_time, + len(all_docs), + ) + + if creation_request.callback_url: + send_callback(creation_request.callback_url, success_payload) + + except Timeout as err: + timeout_err = TimeoutError( + f"[execute_batch_job] Task exceeded soft time limit of {err.seconds}s" + ) + _mark_job_failed( + project_id=project_id, + job_id=job_id, + err=timeout_err, + collection_job=collection_job, + ) + raise + except BaseException as err: + logger.error( + "[create_collection.execute_batch_job] Batch %d failed | job_id=%s, error=%s", + batch_number, + job_id, + str(err), + exc_info=True, + ) + collection_job = _mark_job_failed( + project_id=project_id, + job_id=job_id, + err=err, + collection_job=collection_job, + ) + if creation_request and creation_request.callback_url and collection_job: + failure_payload = build_failure_payload(collection_job, str(err)) + send_callback(creation_request.callback_url, failure_payload) diff --git a/backend/app/services/collections/providers/base.py b/backend/app/services/collections/providers/base.py index 36283d1fa..6649a0725 100644 --- a/backend/app/services/collections/providers/base.py +++ b/backend/app/services/collections/providers/base.py @@ -19,48 +19,46 @@ class BaseProvider(ABC): """ def __init__(self, client: Any) -> None: - """Initialize provider with client. + self.client = client + + @abstractmethod + def upload_files( + self, + storage: CloudStorage, + docs: list[Document], + project_id: int, + ) -> None: + """Upload all documents to the provider's file storage and persist their file IDs. Args: - client: Provider-specific client instance + storage: Cloud storage instance to fetch raw file bytes from + docs: Documents to upload + project_id: Project ID used to persist the provider file IDs to the DB """ - self.client = client + raise NotImplementedError("Providers must implement upload_files method") @abstractmethod def create( self, collection_request: CreationRequest, - storage: CloudStorage, - documents: list[Document], + docs: list[Document], + vector_store_id: str | None = None, + is_final: bool = False, ) -> Collection: - """Create collection with documents and optionally an assistant. - - Args: - collection_request: Collection parameters (name, description, document list, etc.) - storage: Cloud storage instance for file access - documents: Pre-fetched list of Document objects to add to the collection - - Returns: - Collection object with llm_service_id and llm_service_name populated - """ - raise NotImplementedError("Providers must implement execute method") + """Upload docs batch to vector store (creating it if vector_store_id is None). + Creates assistant only when is_final=True and model/instructions are set. + Returns Collection with llm_service_id set to vector_store_id on intermediate batches, + or to assistant/vector_store id on the final batch.""" + raise NotImplementedError("Providers must implement create method") @abstractmethod def delete(self, collection: Collection) -> None: - """Delete remote resources associated with a collection. - - Called when a collection is being deleted and remote resources need to be cleaned up. - - Args: - llm_service_id: ID of the resource to delete - llm_service_name: Name of the service (determines resource type) - """ + """Delete remote resources associated with a collection.""" raise NotImplementedError("Providers must implement delete method") - def get_provider_name(self) -> str: - """Get the name of the provider. + def get_existing_file_id(self, _doc: Document) -> str | None: + """Return the already-uploaded file ID for this provider, or None to trigger upload.""" + return None - Returns: - Provider name (e.g., "openai", "bedrock", "pinecone") - """ + def get_provider_name(self) -> str: return self.__class__.__name__.replace("Provider", "").lower() diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py index f52e83394..3afaaba81 100644 --- a/backend/app/services/collections/providers/openai.py +++ b/backend/app/services/collections/providers/openai.py @@ -1,12 +1,16 @@ import logging +from io import BytesIO from typing import List from openai import OpenAI +from sqlmodel import Session from app.services.collections.providers import BaseProvider from app.core.cloud.storage import CloudStorage +from app.core.db import engine +from app.crud import DocumentCrud from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud -from app.services.collections.helpers import get_service_name, batch_documents +from app.services.collections.helpers import get_service_name from app.models import CreationRequest, Collection, Document @@ -20,29 +24,72 @@ def __init__(self, client: OpenAI): super().__init__(client) self.client = client + def get_existing_file_id(self, doc: Document) -> str | None: + return doc.openai_file_id + + def upload_files( + self, + storage: CloudStorage, + docs: list[Document], + project_id: int, + ) -> None: + for doc in docs: + if self.get_existing_file_id(doc): + continue + try: + content = storage.get(doc.object_store_url) + if doc.file_size_kb is None: + doc.file_size_kb = round(len(content) / 1024, 2) + f_obj = BytesIO(content) + f_obj.name = doc.fname + uploaded = self.client.files.create(file=f_obj, purpose="assistants") + doc.openai_file_id = uploaded.id + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + db_doc = document_crud.read_one(doc.id) + db_doc.openai_file_id = uploaded.id + db_doc.file_size_kb = doc.file_size_kb + document_crud.update(db_doc) + except Exception as err: + logger.error( + "[OpenAIProvider.upload_files] Failed to upload file | doc_id=%s, error=%s", + doc.id, + str(err), + exc_info=True, + ) + def create( self, collection_request: CreationRequest, - storage: CloudStorage, - documents: List[Document], + docs: List[Document], + vector_store_id: str | None = None, + is_final: bool = False, ) -> Collection: - """ - Create OpenAI vector store with documents and optionally an assistant. - docs_batches must be pre-fetched inside a DB session before this call. - """ try: - docs_batches = batch_documents(documents) vector_store_crud = OpenAIVectorStoreCrud(self.client) - vector_store = vector_store_crud.create() - list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + if vector_store_id is None: + vector_store = vector_store_crud.create() + vector_store_id = vector_store.id + logger.info( + "[OpenAIProvider.create] Vector store created | vector_store_id=%s", + vector_store_id, + ) - logger.info( - "[OpenAIProvider.create] Vector store created | " - f"vector_store_id={vector_store.id}, batches={len(docs_batches)}" - ) + if docs: + vector_store_crud.update_batch(vector_store_id, docs) + logger.info( + "[OpenAIProvider.create] Batch uploaded | vector_store_id=%s, doc_count=%d", + vector_store_id, + len(docs), + ) + + if not is_final: + return Collection( + llm_service_id=vector_store_id, + llm_service_name=get_service_name("openai"), + ) - # Check if we need to create an assistant (based on assistant options in request) with_assistant = ( collection_request.model is not None and collection_request.instructions is not None @@ -59,11 +106,12 @@ def create( k: v for k, v in assistant_options.items() if v is not None } - assistant = assistant_crud.create(vector_store.id, **filtered_options) + assistant = assistant_crud.create(vector_store_id, **filtered_options) logger.info( - "[OpenAIProvider.create] Assistant created | " - f"assistant_id={assistant.id}, vector_store_id={vector_store.id}" + "[OpenAIProvider.create] Assistant created | assistant_id=%s, vector_store_id=%s", + assistant.id, + vector_store_id, ) return Collection( @@ -76,7 +124,7 @@ def create( ) return Collection( - llm_service_id=vector_store.id, + llm_service_id=vector_store_id, llm_service_name=get_service_name("openai"), ) From 2a2e2680ef3f4b7a58821f9f98d7ff119aedaaf7 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Wed, 6 May 2026 08:55:41 +0530 Subject: [PATCH 09/19] pushing few changes --- ..._add_batch_tracking_to_collections_jobs.py | 2 +- .../055_add_columns_to_collections.py | 62 ----- backend/app/celery/tasks/job_execution.py | 215 +++++++----------- backend/app/celery/utils.py | 175 +++++++------- 4 files changed, 158 insertions(+), 296 deletions(-) delete mode 100644 backend/app/alembic/versions/055_add_columns_to_collections.py diff --git a/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py b/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py index 804e5cc7d..26fb1a8d3 100644 --- a/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py +++ b/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py @@ -1,4 +1,4 @@ -"""add batch tracking to collection_job and provider_file_id to document +"""add batch tracking to collection_jobs Revision ID: 055 Revises: 054 diff --git a/backend/app/alembic/versions/055_add_columns_to_collections.py b/backend/app/alembic/versions/055_add_columns_to_collections.py deleted file mode 100644 index 804e5cc7d..000000000 --- a/backend/app/alembic/versions/055_add_columns_to_collections.py +++ /dev/null @@ -1,62 +0,0 @@ -"""add batch tracking to collection_job and provider_file_id to document - -Revision ID: 055 -Revises: 054 -Create Date: 2026-04-13 - -""" -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = "055" -down_revision = "054" -branch_labels = None -depends_on = None - - -def upgrade(): - op.add_column( - "collection_jobs", - sa.Column( - "total_batches", - sa.Integer(), - nullable=True, - comment="Total number of batches the documents are split into", - ), - ) - op.add_column( - "collection_jobs", - sa.Column( - "current_batch_number", - sa.Integer(), - nullable=True, - comment="Which batch is currently being processed (1-indexed)", - ), - ) - op.add_column( - "collection_jobs", - sa.Column( - "documents_uploaded", - sa.JSON(), - nullable=True, - comment="List of document IDs successfully uploaded so far", - ), - ) - op.add_column( - "document", - sa.Column( - "openai_file_id", - sa.String(), - nullable=True, - comment="File ID assigned by the LLM provider (e.g. OpenAI file ID) to avoid re-uploading", - ), - ) - - -def downgrade(): - op.drop_column("collection_jobs", "total_batches") - op.drop_column("collection_jobs", "current_batch_number") - op.drop_column("collection_jobs", "documents_uploaded") - op.drop_column("document", "openai_file_id") diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index 8dd20091a..0156459aa 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -2,11 +2,10 @@ from asgi_correlation_id import correlation_id from celery import current_task -from opentelemetry import context as otel_context -from opentelemetry import trace -from opentelemetry.propagate import extract from app.celery.celery_app import celery_app +from app.celery.utils import gevent_timeout +from app.core.config import settings logger = logging.getLogger(__name__) @@ -16,60 +15,17 @@ def _set_trace(trace_id: str) -> None: logger.info(f"[_set_trace] Set correlation ID: {trace_id}") -def _extract_parent_context(task_instance) -> otel_context.Context: - """Extract OTel parent context from Celery headers if available.""" - headers = getattr(task_instance.request, "headers", None) or {} - carrier: dict[str, str] = {} - - if isinstance(headers, dict): - for key, value in headers.items(): - if isinstance(value, str): - carrier[str(key)] = value - - nested = headers.get("otel", {}) - if isinstance(nested, dict): - for key, value in nested.items(): - if isinstance(value, str): - carrier[str(key)] = value - - return extract(carrier) - - -def _run_with_otel_parent(task_instance, fn): - """Attach extracted parent context and execute function. - - When Celery auto-instrumentation is active, there is already a current - `run/...` span. Re-attaching extracted parent context here would make - service spans become siblings of `run/...` instead of children. - - We only attach extracted context as a fallback when no active span exists. - """ - current_ctx = trace.get_current_span().get_span_context() - if current_ctx and current_ctx.is_valid: - return fn() - - parent_ctx = _extract_parent_context(task_instance) - token = otel_context.attach(parent_ctx) - try: - return fn() - finally: - otel_context.detach(token) - - @celery_app.task(bind=True, queue="high_priority", priority=9) def run_llm_job(self, project_id: int, job_id: str, trace_id: str, **kwargs): from app.services.llm.jobs import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -78,15 +34,12 @@ def run_llm_chain_job(self, project_id: int, job_id: str, trace_id: str, **kwarg from app.services.llm.jobs import execute_chain_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_chain_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_chain_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -95,15 +48,12 @@ def run_response_job(self, project_id: int, job_id: str, trace_id: str, **kwargs from app.services.response.jobs import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -112,34 +62,46 @@ def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kw from app.services.doctransform.job import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_create_collection_job") def run_create_collection_job( self, project_id: int, job_id: str, trace_id: str, **kwargs ): - from app.services.collections.create_collection import execute_job + from app.services.collections.create_collection import execute_setup_job + + _set_trace(trace_id) + return execute_setup_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ) + + +@celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_collection_batch_job") +def run_collection_batch_job( + self, project_id: int, job_id: str, trace_id: str, **kwargs +): + from app.services.collections.create_collection import execute_batch_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_batch_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -150,15 +112,12 @@ def run_delete_collection_job( from app.services.collections.delete_collection import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -169,15 +128,12 @@ def run_stt_batch_submission( from app.services.stt_evaluations.batch_job import execute_batch_submission _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_batch_submission( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_batch_submission( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -188,15 +144,12 @@ def run_stt_metric_computation( from app.services.stt_evaluations.metric_job import execute_metric_computation _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_metric_computation( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_metric_computation( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -207,15 +160,12 @@ def run_tts_batch_submission( from app.services.tts_evaluations.batch_job import execute_batch_submission _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_batch_submission( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_batch_submission( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -228,13 +178,10 @@ def run_tts_result_processing( ) _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_tts_result_processing( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_tts_result_processing( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) diff --git a/backend/app/celery/utils.py b/backend/app/celery/utils.py index 475a4772a..9ffe47113 100644 --- a/backend/app/celery/utils.py +++ b/backend/app/celery/utils.py @@ -3,34 +3,25 @@ Business logic modules can use these functions without knowing Celery internals. """ import logging +import functools from typing import Any, Dict from celery.result import AsyncResult -from opentelemetry.propagate import inject +from gevent import Timeout from app.celery.celery_app import celery_app logger = logging.getLogger(__name__) -def _enqueue_with_trace_context(task, **kwargs) -> str: - """Publish Celery task with explicit trace context headers.""" - otel_headers: dict[str, str] = {} - inject(otel_headers) - celery_headers = dict(otel_headers) - celery_headers["otel"] = otel_headers - async_result = task.apply_async(kwargs=kwargs, headers=celery_headers) - return async_result.id - - def start_llm_job(project_id: int, job_id: str, trace_id: str = "N/A", **kwargs) -> str: from app.celery.tasks.job_execution import run_llm_job - task_id = _enqueue_with_trace_context( - run_llm_job, project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + task = run_llm_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) - logger.info(f"[start_llm_job] Started job {job_id} with Celery task {task_id}") - return task_id + logger.info(f"[start_llm_job] Started job {job_id} with Celery task {task.id}") + return task.id def start_llm_chain_job( @@ -38,17 +29,13 @@ def start_llm_chain_job( ) -> str: from app.celery.tasks.job_execution import run_llm_chain_job - task_id = _enqueue_with_trace_context( - run_llm_chain_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_llm_chain_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_llm_chain_job] Started job {job_id} with Celery task {task_id}" + f"[start_llm_chain_job] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_response_job( @@ -56,15 +43,11 @@ def start_response_job( ) -> str: from app.celery.tasks.job_execution import run_response_job - task_id = _enqueue_with_trace_context( - run_response_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_response_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) - logger.info(f"[start_response_job] Started job {job_id} with Celery task {task_id}") - return task_id + logger.info(f"[start_response_job] Started job {job_id} with Celery task {task.id}") + return task.id def start_doctransform_job( @@ -72,53 +55,41 @@ def start_doctransform_job( ) -> str: from app.celery.tasks.job_execution import run_doctransform_job - task_id = _enqueue_with_trace_context( - run_doctransform_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_doctransform_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_doctransform_job] Started job {job_id} with Celery task {task_id}" + f"[start_doctransform_job] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id -def start_create_collection_setup_job( +def start_create_collection_job( project_id: int, job_id: str, trace_id: str = "N/A", **kwargs ) -> str: - from app.celery.tasks.job_execution import run_create_collection_setup_job + from app.celery.tasks.job_execution import run_create_collection_job - task_id = _enqueue_with_trace_context( - run_create_collection_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_create_collection_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_create_collection_job] Started job {job_id} with Celery task {task_id}" + f"[start_create_collection_job] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id -def start_create_collection_job( +def start_collection_batch_job( project_id: int, job_id: str, trace_id: str = "N/A", **kwargs ) -> str: - from app.celery.tasks.job_execution import run_create_collection_job + from app.celery.tasks.job_execution import run_collection_batch_job - task_id = _enqueue_with_trace_context( - run_create_collection_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_collection_batch_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_create_collection_job] Started job {job_id} with Celery task {task_id}" + f"[start_collection_batch_job] Started batch job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_delete_collection_job( @@ -126,17 +97,13 @@ def start_delete_collection_job( ) -> str: from app.celery.tasks.job_execution import run_delete_collection_job - task_id = _enqueue_with_trace_context( - run_delete_collection_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_delete_collection_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_delete_collection_job] Started job {job_id} with Celery task {task_id}" + f"[start_delete_collection_job] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_stt_batch_submission( @@ -144,17 +111,13 @@ def start_stt_batch_submission( ) -> str: from app.celery.tasks.job_execution import run_stt_batch_submission - task_id = _enqueue_with_trace_context( - run_stt_batch_submission, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_stt_batch_submission.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_stt_batch_submission] Started job {job_id} with Celery task {task_id}" + f"[start_stt_batch_submission] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_stt_metric_computation( @@ -162,17 +125,13 @@ def start_stt_metric_computation( ) -> str: from app.celery.tasks.job_execution import run_stt_metric_computation - task_id = _enqueue_with_trace_context( - run_stt_metric_computation, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_stt_metric_computation.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_stt_metric_computation] Started job {job_id} with Celery task {task_id}" + f"[start_stt_metric_computation] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_tts_batch_submission( @@ -180,17 +139,13 @@ def start_tts_batch_submission( ) -> str: from app.celery.tasks.job_execution import run_tts_batch_submission - task_id = _enqueue_with_trace_context( - run_tts_batch_submission, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_tts_batch_submission.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_tts_batch_submission] Started job {job_id} with Celery task {task_id}" + f"[start_tts_batch_submission] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_tts_result_processing( @@ -198,17 +153,13 @@ def start_tts_result_processing( ) -> str: from app.celery.tasks.job_execution import run_tts_result_processing - task_id = _enqueue_with_trace_context( - run_tts_result_processing, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_tts_result_processing.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_tts_result_processing] Started job {job_id} with Celery task {task_id}" + f"[start_tts_result_processing] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def get_task_status(task_id: str) -> Dict[str, Any]: @@ -229,3 +180,29 @@ def revoke_task(task_id: str, terminate: bool = False) -> bool: except Exception as e: logger.error(f"[revoke_task] Failed to revoke task {task_id}: {e}") return False + + +def gevent_timeout(seconds, task_name=None): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + name = task_name or func.__name__ + timeout = Timeout(seconds) + timeout.start() + try: + return func(*args, **kwargs) + except Timeout: + logger.error( + f"[{name}] Timed out after {seconds}s — args={args}, kwargs={kwargs}" + ) + raise + # raise TimeoutError(f"[{name}] Task exceeded soft time limit of {seconds}s") + finally: + raise TimeoutError( + f"[{name}] Task exceeded soft time limit of {seconds}s" + ) + timeout.cancel() + + return wrapper + + return decorator From 0d7f4a1ad9fd62fa7969acb3398073019ac0eae9 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Sat, 9 May 2026 17:53:36 +0530 Subject: [PATCH 10/19] adding test cases --- ...add_batch_tracking_to_collections_jobs.py} | 8 +- backend/app/crud/rag/open_ai.py | 77 +- .../services/collections/create_collection.py | 546 +++++++------ .../services/collections/providers/openai.py | 5 +- .../providers/test_openai_provider.py | 203 ++++- .../collections/test_create_collection.py | 765 ++++++++++-------- 6 files changed, 920 insertions(+), 684 deletions(-) rename backend/app/alembic/versions/{055_add_batch_tracking_to_collections_jobs.py => 058_add_batch_tracking_to_collections_jobs.py} (95%) diff --git a/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py b/backend/app/alembic/versions/058_add_batch_tracking_to_collections_jobs.py similarity index 95% rename from backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py rename to backend/app/alembic/versions/058_add_batch_tracking_to_collections_jobs.py index 26fb1a8d3..6bf4b97bf 100644 --- a/backend/app/alembic/versions/055_add_batch_tracking_to_collections_jobs.py +++ b/backend/app/alembic/versions/058_add_batch_tracking_to_collections_jobs.py @@ -1,7 +1,7 @@ """add batch tracking to collection_jobs -Revision ID: 055 -Revises: 054 +Revision ID: 058 +Revises: 057 Create Date: 2026-04-13 """ @@ -10,8 +10,8 @@ # revision identifiers, used by Alembic. -revision = "055" -down_revision = "054" +revision = "058" +down_revision = "057" branch_labels = None depends_on = None diff --git a/backend/app/crud/rag/open_ai.py b/backend/app/crud/rag/open_ai.py index be6970235..07a9e671c 100644 --- a/backend/app/crud/rag/open_ai.py +++ b/backend/app/crud/rag/open_ai.py @@ -2,13 +2,10 @@ import logging import functools as ft import time -from io import BytesIO -from typing import Iterable from openai import OpenAI, OpenAIError from pydantic import BaseModel -from app.core.cloud import CloudStorage from app.models import Document logger = logging.getLogger(__name__) @@ -79,11 +76,6 @@ def clean(self, resource): class VectorStoreCleaner(ResourceCleaner): def clean(self, resource): - logger.info( - f"[VectorStoreCleaner.clean] Starting vector store cleanup | {{'vector_store_id': '{resource}'}}" - ) - for i in vs_ls(self.client, resource): - self.client.files.delete(i.id) logger.info( f"[VectorStoreCleaner.clean] Deleting vector store | {{'vector_store_id': '{resource}'}}" ) @@ -117,90 +109,35 @@ def read(self, vector_store_id: str): yield from vs_ls(self.client, vector_store_id) def update( - self, - vector_store_id: str, - storage: CloudStorage, - documents: Iterable[Document], - ): - for docs in documents: - files = [] - for d in docs: - # Get file bytes and wrap in BytesIO for OpenAI API - content = storage.get(d.object_store_url) - f_obj = BytesIO(content) - f_obj.name = d.fname - files.append(f_obj) - - logger.info( - f"[OpenAIVectorStoreCrud.update] Uploading files to vector store | {{'vector_store_id': '{vector_store_id}', 'file_count': {len(files)}}}" - ) - req = self.client.vector_stores.file_batches.upload_and_poll( - vector_store_id=vector_store_id, - files=files, - ) - logger.info( - f"[OpenAIVectorStoreCrud.update] File upload completed | {{'vector_store_id': '{vector_store_id}', 'completed_files': {req.file_counts.completed}, 'total_files': {req.file_counts.total}}}" - ) - if req.file_counts.completed != req.file_counts.total: - error_msg = f"OpenAI document processing error: {req.file_counts.completed}/{req.file_counts.total} files completed" - logger.error( - f"[OpenAIVectorStoreCrud.update] Document processing error | {{'vector_store_id': '{vector_store_id}', 'completed_files': {req.file_counts.completed}, 'total_files': {req.file_counts.total}}}" - ) - raise InterruptedError(error_msg) - - yield from docs - - def update_batch( self, vector_store_id: str, docs: list[Document], - ) -> tuple[list[Document], list[Document]]: - """ - Attach a batch of documents to the vector store via a single upload_and_poll call. - - All docs must have provider_file_id set before calling this method. - Returns (succeeded, failed) — failed docs should be retried in the next batch. - """ - succeeded: list[Document] = [] - failed: list[Document] = [] - + ) -> None: if not docs: - return succeeded, failed + return try: - _t0 = time.monotonic() batch = self.client.vector_stores.file_batches.upload_and_poll( vector_store_id=vector_store_id, files=[], file_ids=[doc.openai_file_id for doc in docs], ) logger.info( - f"[OpenAIVectorStoreCrud.update_batch] Batch upload_and_poll duration | " - f"{{'vector_store_id': '{vector_store_id}', 'duration_s': {time.monotonic() - _t0:.3f}, " + f"[OpenAIVectorStoreCrud.update] Batch complete | " + f"{{'vector_store_id': '{vector_store_id}', " f"'completed': {batch.file_counts.completed}, 'failed': {batch.file_counts.failed}}}" ) - if batch.file_counts.failed == 0: - succeeded.extend(docs) - else: - # Can't identify which specific files failed — retry all of them + if batch.file_counts.failed > 0: logger.warning( - f"[OpenAIVectorStoreCrud.update_batch] Batch had failures, marking all for retry | " + f"[OpenAIVectorStoreCrud.update] Batch had failures | " f"{{'vector_store_id': '{vector_store_id}', 'failed_count': {batch.file_counts.failed}}}" ) - failed.extend(docs) except OpenAIError as err: logger.error( - f"[OpenAIVectorStoreCrud.update_batch] Batch attach failed | " + f"[OpenAIVectorStoreCrud.update] Batch attach failed | " f"{{'vector_store_id': '{vector_store_id}', 'error': '{str(err)}'}}", exc_info=True, ) - failed.extend(docs) - - logger.info( - f"[OpenAIVectorStoreCrud.update_batch] Batch complete | " - f"{{'vector_store_id': '{vector_store_id}', 'succeeded': {len(succeeded)}, 'failed': {len(failed)}}}" - ) - return succeeded, failed def delete(self, vector_store_id: str, retries: int = 3): if retries < 1: diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 22f5e5602..b87e86667 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -3,12 +3,14 @@ from uuid import UUID, uuid4 from sqlmodel import Session -from celery.exceptions import SoftTimeLimitExceeded from gevent import Timeout +from celery.exceptions import SoftTimeLimitExceeded +from opentelemetry import trace from asgi_correlation_id import correlation_id from app.core.cloud import get_cloud_storage from app.core.db import engine +from app.core.telemetry import log_context from app.crud import ( CollectionCrud, DocumentCrud, @@ -23,19 +25,18 @@ CollectionJobPublic, CreationRequest, ) -from app.crud.rag import OpenAIVectorStoreCrud from app.services.collections.helpers import ( batch_documents, extract_error_message, to_collection_public, ) from app.services.collections.providers.registry import get_llm_provider -from gevent import Timeout from app.celery.utils import start_create_collection_job, start_collection_batch_job -from app.utils import send_callback, APIResponse +from app.utils import send_callback, APIResponse, get_webhook_secret logger = logging.getLogger(__name__) +tracer = trace.get_tracer(__name__) def start_job( @@ -121,59 +122,6 @@ def _mark_job_failed( return None -def _persist_succeeded_docs(succeeded: list, project_id: int) -> list[str]: - with Session(engine) as session: - document_crud = DocumentCrud(session, project_id) - for doc in succeeded: - if doc.openai_file_id: - db_doc = document_crud.read_one(doc.id) - if db_doc.openai_file_id != doc.openai_file_id: - db_doc.openai_file_id = doc.openai_file_id - document_crud.update(db_doc) - return [str(doc.id) for doc in succeeded] - - -def _retry_failed_uploads( - vector_store_crud, - vector_store_id: str, - failed_docs: list, - project_id: int, - max_retries: int = 3, -) -> list[str]: - """ - Retry attaching docs that failed the initial batch upload_and_poll. - All docs must already have provider_file_id set. - Returns the list of successfully retried doc IDs. - Raises RuntimeError if any docs still fail after all retries. - """ - pending = failed_docs - all_succeeded_ids: list[str] = [] - - for attempt in range(1, max_retries + 1): - logger.warning( - "[_retry_failed_uploads] Retry attempt %d/%d: %d doc(s) | vector_store_id=%s", - attempt, - max_retries, - len(pending), - vector_store_id, - ) - succeeded, failed = vector_store_crud.update_batch(vector_store_id, pending) - - if succeeded: - all_succeeded_ids += _persist_succeeded_docs(succeeded, project_id) - - if not failed: - return all_succeeded_ids - - pending = failed - - ids = [str(d.id) for d in pending] - raise RuntimeError( - f"Failed to upload {len(pending)} document(s) after {max_retries} retries: {ids}" - ) - - -def execute_setup_job( def _handle_job_failure( span, project_id: int, @@ -212,7 +160,7 @@ def _handle_job_failure( ) -def execute_job( +def execute_setup_job( request: dict, with_assistant: bool, project_id: int, @@ -228,105 +176,136 @@ def execute_job( collection_job = None creation_request = None - try: - creation_request = CreationRequest(**request) - if with_assistant: - creation_request.provider = "openai" + with log_context( + tag="collection", + lifecycle="collection.create.execute_setup_job", + action="create", + collection_job_id=job_id, + task_id=task_id, + project_id=project_id, + organization_id=organization_id, + ), tracer.start_as_current_span("collections.create.execute_setup_job") as span: + span.set_attribute("collection.job_id", str(job_id)) + span.set_attribute("kaapi.project_id", project_id) + span.set_attribute("kaapi.organization_id", organization_id) - job_uuid = UUID(job_id) - trace_id = correlation_id.get() or "N/A" + try: + creation_request = CreationRequest(**request) + if with_assistant: + creation_request.provider = "openai" - with Session(engine) as session: - document_crud = DocumentCrud(session, project_id) - flat_docs = document_crud.read_each(creation_request.documents) - storage = get_cloud_storage(session=session, project_id=project_id) + span.set_attribute("collection.provider", str(creation_request.provider)) - provider = get_llm_provider( - session=session, - provider=creation_request.provider, - project_id=project_id, - organization_id=organization_id, - ) + job_uuid = UUID(job_id) + trace_id = correlation_id.get() or "N/A" - for doc in flat_docs: - session.expunge(doc) + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + flat_docs = document_crud.read_each(creation_request.documents) + storage = get_cloud_storage(session=session, project_id=project_id) - provider.upload_files(storage, flat_docs, project_id) + provider = get_llm_provider( + session=session, + provider=creation_request.provider, + project_id=project_id, + organization_id=organization_id, + ) - logger.info( - "[create_collection.execute_setup_job] All file uploads complete | " - "job_id=%s, total=%d, failed=%d, duration_s=%.2f", - job_id, - len(flat_docs), - ) + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.update( + job_uuid, + CollectionJobUpdate( + task_id=task_id, + status=CollectionJobStatus.PROCESSING, + ), + ) - total_size_kb = sum(doc.file_size_kb for doc in flat_docs) - total_size_mb = total_size_kb / 1024 + for doc in flat_docs: + session.expunge(doc) - docs_batches = batch_documents(flat_docs) - total_batches = len(docs_batches) - batch_doc_ids = [[str(doc.id) for doc in batch] for batch in docs_batches] + provider.upload_files(storage, flat_docs, project_id) - with Session(engine) as session: - collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.update( - job_uuid, - CollectionJobUpdate( - task_id=task_id, - status=CollectionJobStatus.PROCESSING, - total_size_mb=total_size_mb, - current_batch_number=0, - total_batches=total_batches, - documents_uploaded=[], - ), + logger.info( + "[create_collection.execute_setup_job] All file uploads complete | " + "job_id=%s, total=%d", + job_id, + len(flat_docs), ) - start_collection_batch_job( - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - batch_number=1, - batch_doc_ids=batch_doc_ids[0], - remaining_batches=batch_doc_ids[1:], - request=request, - with_assistant=with_assistant, - organization_id=organization_id, - ) + total_size_kb = sum(doc.file_size_kb for doc in flat_docs) + total_size_mb = total_size_kb / 1024 + + docs_batches = batch_documents(flat_docs) + total_batches = len(docs_batches) + batch_doc_ids = [[str(doc.id) for doc in batch] for batch in docs_batches] + + with Session(engine) as session: + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.update( + job_uuid, + CollectionJobUpdate( + task_id=task_id, + status=CollectionJobStatus.PROCESSING, + total_size_mb=total_size_mb, + current_batch_number=0, + total_batches=total_batches, + documents_uploaded=[], + ), + ) - logger.info( - "[create_collection.execute_setup_job] Setup complete, first batch queued | " - f"job_id={job_id}, total_batches={total_batches}" - ) + start_collection_batch_job( + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + batch_number=1, + batch_doc_ids=batch_doc_ids[0], + remaining_batches=batch_doc_ids[1:], + request=request, + vector_store_id=None, + with_assistant=with_assistant, + organization_id=organization_id, + ) - except Timeout as err: - timeout_err = TimeoutError( - f"[execute_setup_job] Task exceeded soft time limit of {err.seconds}s" - ) - _mark_job_failed( - project_id=project_id, - job_id=job_id, - err=timeout_err, - collection_job=collection_job, - ) - raise - - except Exception as err: - logger.error( - "[create_collection.execute_setup_job] Setup failed | job_id=%s, error=%s", - job_id, - str(err), - exc_info=True, - ) + logger.info( + "[create_collection.execute_setup_job] Setup complete, first batch queued | " + f"job_id={job_id}, total_batches={total_batches}" + ) - collection_job = _mark_job_failed( - project_id=project_id, - job_id=job_id, - err=err, - collection_job=collection_job, - ) - if creation_request and creation_request.callback_url and collection_job: - failure_payload = build_failure_payload(collection_job, str(err)) - send_callback(creation_request.callback_url, failure_payload) + except (Timeout, SoftTimeLimitExceeded) as err: + timeout_err = TimeoutError("Task exceeded soft time limit") + logger.warning( + "[create_collection.execute_setup_job] Collection Creation Timed Out | {'collection_job_id': '%s', 'error': '%s'}", + job_id, + str(timeout_err), + ) + _handle_job_failure( + span, + project_id, + organization_id, + job_id, + timeout_err, + collection_job, + creation_request, + ) + raise + + except Exception as err: + logger.error( + "[create_collection.execute_setup_job] Setup failed | job_id=%s, error=%s", + job_id, + str(err), + exc_info=True, + ) + _handle_job_failure( + span, + project_id, + organization_id, + job_id, + err, + collection_job, + creation_request, + ) + raise def execute_batch_job( @@ -344,161 +323,180 @@ def execute_batch_job( ) -> None: """ Phase 2: Upload one batch of documents to the vector store. - - Uploads the batch; any failures within the batch are retried inline by _upload_batch_with_retry - - Raises immediately if all retries for the batch are exhausted + - Uploads the batch via provider.create(); raises immediately on failure - Checkpoints progress to the DB - If more batches remain, queues the next batch task - If this is the last batch, finalizes: creates Collection, links docs, marks job SUCCESSFUL """ collection_job = None + result = None creation_request = None + provider = None + + with log_context( + tag="collection", + lifecycle="collection.create.execute_batch_job", + action="create", + collection_job_id=job_id, + task_id=task_id, + project_id=project_id, + organization_id=organization_id, + ), tracer.start_as_current_span("collections.create.execute_batch_job") as span: + span.set_attribute("collection.job_id", str(job_id)) + span.set_attribute("kaapi.project_id", project_id) + span.set_attribute("kaapi.organization_id", organization_id) - try: - batch_start_time = time.time() - creation_request = CreationRequest(**request) - if with_assistant: - creation_request.provider = "openai" - - job_uuid = UUID(job_id) - trace_id = correlation_id.get() or "N/A" - - logger.info( - "[create_collection.execute_batch_job] Starting batch | " - "job_id=%s, batch_number=%d, doc_count=%d, remaining_batches=%d", - job_id, - batch_number, - len(batch_doc_ids), - len(remaining_batches), - ) + try: + batch_start_time = time.time() + creation_request = CreationRequest(**request) + if with_assistant: + creation_request.provider = "openai" - all_doc_ids_this_batch = [UUID(d) for d in batch_doc_ids] - is_final = not remaining_batches + span.set_attribute("collection.provider", str(creation_request.provider)) - with Session(engine) as session: - provider = get_llm_provider( - session=session, - provider=creation_request.provider, - project_id=project_id, - organization_id=organization_id, - ) + job_uuid = UUID(job_id) + trace_id = correlation_id.get() or "N/A" - with Session(engine) as session: - document_crud = DocumentCrud(session, project_id) - batch_docs = ( - document_crud.read_each(all_doc_ids_this_batch) - if all_doc_ids_this_batch - else [] + logger.info( + "[create_collection.execute_batch_job] Starting batch | " + "job_id=%s, batch_number=%d, doc_count=%d, remaining_batches=%d", + job_id, + batch_number, + len(batch_doc_ids), + len(remaining_batches), ) - for doc in batch_docs: - session.expunge(doc) - - collection_result = provider.create( - creation_request, - batch_docs, - vector_store_id=vector_store_id, - is_final=is_final, - ) - resolved_vector_store_id = ( - collection_result.llm_service_id - if not is_final - else vector_store_id or collection_result.llm_service_id - ) - with Session(engine) as session: - collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.read_one(job_uuid) - already_uploaded = collection_job.documents_uploaded or [] - now_uploaded = already_uploaded + [str(d) for d in all_doc_ids_this_batch] + all_doc_ids_this_batch = [UUID(d) for d in batch_doc_ids] + is_final = not remaining_batches - collection_job = collection_job_crud.update( - job_uuid, - CollectionJobUpdate( - current_batch_number=batch_number, - documents_uploaded=now_uploaded, - ), - ) + with Session(engine) as session: + provider = get_llm_provider( + session=session, + provider=creation_request.provider, + project_id=project_id, + organization_id=organization_id, + ) - logger.info( - "[create_collection.execute_batch_job] Batch %d complete | " - "doc_count=%d, job_id=%s", - batch_number, - len(all_doc_ids_this_batch), - job_id, - ) + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + batch_docs = ( + document_crud.read_each(all_doc_ids_this_batch) + if all_doc_ids_this_batch + else [] + ) + for doc in batch_docs: + session.expunge(doc) - if remaining_batches: - start_collection_batch_job( - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - vector_store_id=resolved_vector_store_id, - batch_number=batch_number + 1, - batch_doc_ids=remaining_batches[0], - remaining_batches=remaining_batches[1:], - request=request, - with_assistant=with_assistant, - organization_id=organization_id, + collection_result = provider.create( + creation_request, + batch_docs, + vector_store_id=vector_store_id, + is_final=is_final, ) + result = collection_result + resolved_vector_store_id = collection_result.llm_service_id + + with Session(engine) as session: + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.read_one(job_uuid) + already_uploaded = collection_job.documents_uploaded or [] + now_uploaded = already_uploaded + [ + str(d) for d in all_doc_ids_this_batch + ] + + collection_job = collection_job_crud.update( + job_uuid, + CollectionJobUpdate( + current_batch_number=batch_number, + documents_uploaded=now_uploaded, + ), + ) + logger.info( - "[create_collection.execute_batch_job] Batch %d/%d done, next batch queued | " - "job_id=%s, elapsed=%.2fs", + "[create_collection.execute_batch_job] Batch %d complete | " + "doc_count=%d, job_id=%s", batch_number, - batch_number + len(remaining_batches), + len(all_doc_ids_this_batch), job_id, - time.time() - batch_start_time, ) - return - - # Final batch: collection_result already has assistant/vector_store finalized - finalize_start_time = time.time() - with Session(engine) as session: - all_uploaded_ids = [UUID(d) for d in now_uploaded] - document_crud = DocumentCrud(session, project_id) - all_docs = ( - document_crud.read_each(all_uploaded_ids) if all_uploaded_ids else [] - ) - for doc in all_docs: - session.expunge(doc) + if remaining_batches: + start_collection_batch_job( + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + vector_store_id=resolved_vector_store_id, + batch_number=batch_number + 1, + batch_doc_ids=remaining_batches[0], + remaining_batches=remaining_batches[1:], + request=request, + with_assistant=with_assistant, + organization_id=organization_id, + ) + logger.info( + "[create_collection.execute_batch_job] Batch %d/%d done, next batch queued | " + "job_id=%s, elapsed=%.2fs", + batch_number, + batch_number + len(remaining_batches), + job_id, + time.time() - batch_start_time, + ) + return + + # Final batch: collection_result already has assistant/vector_store finalized + finalize_start_time = time.time() + + with Session(engine) as session: + all_uploaded_ids = [UUID(d) for d in now_uploaded] + document_crud = DocumentCrud(session, project_id) + all_docs = ( + document_crud.read_each(all_uploaded_ids) + if all_uploaded_ids + else [] + ) + for doc in all_docs: + session.expunge(doc) + + with Session(engine) as session: + collection_id = uuid4() + collection = Collection( + id=collection_id, + project_id=project_id, + llm_service_id=collection_result.llm_service_id, + llm_service_name=collection_result.llm_service_name, + provider=creation_request.provider, + name=creation_request.name, + description=creation_request.description, + ) + collection_crud = CollectionCrud(session, project_id) + collection_crud.create(collection) + collection = collection_crud.read_one(collection.id) + + if all_docs: + DocumentCollectionCrud(session).create(collection, all_docs) + + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.update( + job_uuid, + CollectionJobUpdate( + status=CollectionJobStatus.SUCCESSFUL, + collection_id=collection.id, + ), + ) - with Session(engine) as session: - collection_id = uuid4() - collection = Collection( - id=collection_id, - project_id=project_id, - llm_service_id=collection_result.llm_service_id, - llm_service_name=collection_result.llm_service_name, - provider=creation_request.provider, - name=creation_request.name, - description=creation_request.description, - ) - collection_crud = CollectionCrud(session, project_id) - collection_crud.create(collection) - collection = collection_crud.read_one(collection.id) + success_payload = build_success_payload(collection_job, collection) - if all_docs: - DocumentCollectionCrud(session).create(collection, all_docs) + span.set_attribute("collection.id", str(collection_id)) - collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.update( - job_uuid, - CollectionJobUpdate( - status=CollectionJobStatus.SUCCESSFUL, - collection_id=collection.id, - ), + logger.info( + "[create_collection.execute_batch_job] All batches done, collection created: %s | " + "finalize_time=%.2fs, total_time=%.2fs, total_docs=%d", + collection_id, + time.time() - finalize_start_time, + time.time() - batch_start_time, + len(all_docs), ) - success_payload = build_success_payload(collection_job, collection) - - logger.info( - "[create_collection.execute_batch_job] All batches done, collection created: %s | " - "finalize_time=%.2fs, total_time=%.2fs, total_docs=%d", - collection_id, - time.time() - finalize_start_time, - time.time() - batch_start_time, - len(all_docs), - ) - if creation_request.callback_url: webhook_secret = get_webhook_secret(project_id, organization_id) send_callback( @@ -510,7 +508,7 @@ def execute_batch_job( except (Timeout, SoftTimeLimitExceeded) as err: timeout_err = TimeoutError("Task exceeded soft time limit") logger.warning( - "[create_collection.execute_job] Collection Creation Timed Out | {'collection_job_id': '%s', 'error': '%s'}", + "[create_collection.execute_batch_job] Collection Creation Timed Out | {'collection_job_id': '%s', 'error': '%s'}", job_id, str(timeout_err), ) @@ -529,7 +527,7 @@ def execute_batch_job( except Exception as err: logger.error( - "[create_collection.execute_job] Collection Creation Failed | {'collection_job_id': '%s', 'error': '%s'}", + "[create_collection.execute_batch_job] Collection Creation Failed | {'collection_job_id': '%s', 'error': '%s'}", job_id, str(err), exc_info=True, diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py index 3afaaba81..61e7c6374 100644 --- a/backend/app/services/collections/providers/openai.py +++ b/backend/app/services/collections/providers/openai.py @@ -57,6 +57,7 @@ def upload_files( str(err), exc_info=True, ) + raise def create( self, @@ -77,7 +78,7 @@ def create( ) if docs: - vector_store_crud.update_batch(vector_store_id, docs) + vector_store_crud.update(vector_store_id, docs) logger.info( "[OpenAIProvider.create] Batch uploaded | vector_store_id=%s, doc_count=%d", vector_store_id, @@ -89,7 +90,7 @@ def create( llm_service_id=vector_store_id, llm_service_name=get_service_name("openai"), ) - + # if "is_final" is true then only will assistant creation happen - with_assistant = ( collection_request.model is not None and collection_request.instructions is not None diff --git a/backend/app/tests/services/collections/providers/test_openai_provider.py b/backend/app/tests/services/collections/providers/test_openai_provider.py index b21577d49..8431ee512 100644 --- a/backend/app/tests/services/collections/providers/test_openai_provider.py +++ b/backend/app/tests/services/collections/providers/test_openai_provider.py @@ -1,5 +1,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest @@ -35,7 +36,7 @@ def test_create_openai_vector_store_only() -> None: ) as vector_store_crud_cls: vector_store_crud = vector_store_crud_cls.return_value vector_store_crud.create.return_value = MagicMock(id=vector_store_id) - vector_store_crud.update.return_value = iter([None]) + vector_store_crud.update.return_value = None collection = provider.create( collection_request, @@ -71,7 +72,7 @@ def test_create_openai_with_assistant() -> None: ) as assistant_crud_cls: vector_store_crud = vector_store_crud_cls.return_value vector_store_crud.create.return_value = MagicMock(id=vector_store_id) - vector_store_crud.update.return_value = iter([None]) + vector_store_crud.update.return_value = None assistant_crud = assistant_crud_cls.return_value assistant_crud.create.return_value = MagicMock(id=assistant_id) @@ -124,6 +125,204 @@ def test_delete_openai_vector_store() -> None: vector_store_crud.delete.assert_called_once_with(collection.llm_service_id) +# --------------------------------------------------------------------------- +# upload_files +# --------------------------------------------------------------------------- + + +def _make_doc(*, openai_file_id=None, file_size_kb=None): + return SimpleNamespace( + id=uuid4(), + fname="test.md", + object_store_url="s3://bucket/test.md", + openai_file_id=openai_file_id, + file_size_kb=file_size_kb, + ) + + +def _patch_session_and_crud(): + """Patches Session and DocumentCrud used inside upload_files.""" + session_patcher = patch("app.services.collections.providers.openai.Session") + crud_patcher = patch("app.services.collections.providers.openai.DocumentCrud") + return session_patcher, crud_patcher + + +def test_upload_files_skips_doc_with_existing_openai_file_id() -> None: + client = MagicMock() + provider = OpenAIProvider(client=client) + storage = MagicMock() + doc = _make_doc(openai_file_id="file-already-exists", file_size_kb=10.0) + + session_p, crud_p = _patch_session_and_crud() + with session_p, crud_p: + provider.upload_files(storage, [doc], project_id=1) + + storage.get.assert_not_called() + client.files.create.assert_not_called() + + +def test_upload_files_uploads_doc_and_sets_openai_file_id() -> None: + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-new-abc") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"file content" + + doc = _make_doc(file_size_kb=10.0) + + mock_crud = MagicMock() + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p as MockDocCrud: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + MockDocCrud.return_value = mock_crud + + provider.upload_files(storage, [doc], project_id=1) + + assert doc.openai_file_id == "file-new-abc" + client.files.create.assert_called_once() + _, kwargs = client.files.create.call_args + assert kwargs.get("purpose") == "assistants" + mock_crud.update.assert_called_once() + + +def test_upload_files_sets_file_size_kb_when_none() -> None: + """file_size_kb should be computed from content length if not already set.""" + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-xyz") + provider = OpenAIProvider(client=client) + + content = b"x" * 2048 # 2 KB + storage = MagicMock() + storage.get.return_value = content + + doc = _make_doc(file_size_kb=None) + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + + provider.upload_files(storage, [doc], project_id=1) + + assert doc.file_size_kb == round(len(content) / 1024, 2) + + +def test_upload_files_preserves_existing_file_size_kb() -> None: + """file_size_kb should not be overwritten if already set.""" + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-xyz") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"x" * 4096 + + doc = _make_doc(file_size_kb=99.0) + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + + provider.upload_files(storage, [doc], project_id=1) + + assert doc.file_size_kb == 99.0 + + +def test_upload_files_updates_db_with_file_id_and_size() -> None: + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-db-check") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"content" + + doc = _make_doc(file_size_kb=5.0) + mock_db_doc = MagicMock() + mock_crud = MagicMock() + mock_crud.read_one.return_value = mock_db_doc + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p as MockDocCrud: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + MockDocCrud.return_value = mock_crud + + provider.upload_files(storage, [doc], project_id=42) + + MockDocCrud.assert_called_once_with( + MockSession.return_value.__enter__.return_value, 42 + ) + mock_crud.read_one.assert_called_once_with(doc.id) + assert mock_db_doc.openai_file_id == "file-db-check" + assert mock_db_doc.file_size_kb == 5.0 + mock_crud.update.assert_called_once_with(mock_db_doc) + + +def test_upload_files_raises_on_storage_failure() -> None: + client = MagicMock() + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.side_effect = RuntimeError("S3 error") + + doc = _make_doc() + + session_p, crud_p = _patch_session_and_crud() + with session_p, crud_p: + with pytest.raises(RuntimeError, match="S3 error"): + provider.upload_files(storage, [doc], project_id=1) + + client.files.create.assert_not_called() + + +def test_upload_files_raises_on_openai_failure() -> None: + client = MagicMock() + client.files.create.side_effect = RuntimeError("OpenAI error") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"content" + + doc = _make_doc() + + session_p, crud_p = _patch_session_and_crud() + with session_p, crud_p: + with pytest.raises(RuntimeError, match="OpenAI error"): + provider.upload_files(storage, [doc], project_id=1) + + +def test_upload_files_mixed_skips_uploaded_uploads_new() -> None: + """Docs with openai_file_id are skipped; others are uploaded.""" + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-new") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"content" + + already_uploaded = _make_doc(openai_file_id="file-exists", file_size_kb=5.0) + new_doc = _make_doc(file_size_kb=5.0) + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + + provider.upload_files(storage, [already_uploaded, new_doc], project_id=1) + + assert already_uploaded.openai_file_id == "file-exists" + assert new_doc.openai_file_id == "file-new" + client.files.create.assert_called_once() + storage.get.assert_called_once_with(new_doc.object_store_url) + + +# --------------------------------------------------------------------------- +# create (existing tests below) +# --------------------------------------------------------------------------- + + def test_create_propagates_exception() -> None: provider = OpenAIProvider(client=MagicMock()) diff --git a/backend/app/tests/services/collections/test_create_collection.py b/backend/app/tests/services/collections/test_create_collection.py index d8ca2829b..05213b879 100644 --- a/backend/app/tests/services/collections/test_create_collection.py +++ b/backend/app/tests/services/collections/test_create_collection.py @@ -1,26 +1,27 @@ from typing import Any import os -from pathlib import Path from unittest.mock import patch, MagicMock -from urllib.parse import urlparse import uuid from uuid import UUID, uuid4 +from celery.exceptions import SoftTimeLimitExceeded from gevent import Timeout import pytest -from moto import mock_aws from sqlmodel import Session -from app.core.cloud import AmazonCloudStorageClient from app.core.config import settings from app.crud import CollectionCrud, CollectionJobCrud, DocumentCollectionCrud from app.models import CollectionJobStatus, CollectionJob, CollectionActionType, Project from app.models.collection import CreationRequest -from app.services.collections.create_collection import start_job, execute_job +from app.services.collections.create_collection import ( + start_job, + execute_setup_job, + execute_batch_job, +) from app.tests.utils.llm_provider import get_mock_provider from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection_job, get_assistant_collection +from app.tests.utils.collection import get_collection_job from app.tests.utils.document import DocumentStore @@ -33,30 +34,33 @@ def aws_credentials() -> Any: os.environ["AWS_DEFAULT_REGION"] = settings.AWS_DEFAULT_REGION -def create_collection_job_for_create( - db: Session, - project: Project, - job_id: UUID, -) -> CollectionJob: - """Pre-create a CREATE job with the given id so start_job can update it.""" - return CollectionJobCrud(db, project.id).create( - CollectionJob( - id=job_id, - action_type=CollectionActionType.CREATE, - project_id=project.id, - collection_id=None, - status=CollectionJobStatus.PENDING, - ) - ) +def _mock_provider_with_size(llm_service_id: str, llm_service_name: str): + """Returns a mock provider whose upload_files sets file_size_kb=10.0 on each doc.""" + mock_provider = get_mock_provider(llm_service_id, llm_service_name) + + def _set_file_size(storage, docs, project_id): + for doc in docs: + doc.file_size_kb = 10.0 + + mock_provider.upload_files.side_effect = _set_file_size + return mock_provider + + +def _patch_session(db: Session): + """Context manager that routes all Session(engine) calls to the test db.""" + patcher = patch("app.services.collections.create_collection.Session") + mock_ctor = patcher.start() + mock_ctor.return_value.__enter__.return_value = db + mock_ctor.return_value.__exit__.return_value = False + return patcher + + +# --------------------------------------------------------------------------- +# start_job +# --------------------------------------------------------------------------- def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> None: - """ - start_job should: - - update an existing CollectionJob (status=PENDING, action=CREATE) - - call start_create_collection_job with the correct kwargs - - return the job UUID (same one that was passed in) - """ project = get_project(db) request = CreationRequest( documents=[UUID("f3e86a17-1e6f-41ec-b020-5b08eebef928")], @@ -65,7 +69,7 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non ) job_id = uuid4() - _ = get_collection_job( + get_collection_job( db, project, job_id=job_id, @@ -88,472 +92,569 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non organization_id=project.organization_id, ) - assert returned_job_id == job_id + assert returned_job_id == job_id + mock_schedule.assert_called_once() + kwargs = mock_schedule.call_args.kwargs + assert kwargs["project_id"] == project.id + assert kwargs["organization_id"] == project.organization_id + assert kwargs["job_id"] == str(job_id) + assert kwargs["request"] == request.model_dump(mode="json") - job = CollectionJobCrud(db, project.id).read_one(job_id) - assert job.id == job_id - assert job.project_id == project.id - assert job.status == CollectionJobStatus.PENDING - assert job.action_type in ( - CollectionActionType.CREATE, - CollectionActionType.CREATE.value, - ) - assert job.collection_id is None - mock_schedule.assert_called_once() - kwargs = mock_schedule.call_args.kwargs - assert kwargs["project_id"] == project.id - assert kwargs["organization_id"] == project.organization_id - assert kwargs["job_id"] == str(job_id) - assert kwargs["request"] == request.model_dump(mode="json") +# --------------------------------------------------------------------------- +# execute_setup_job +# --------------------------------------------------------------------------- -@pytest.mark.usefixtures("aws_credentials") -@mock_aws +@patch("app.services.collections.create_collection.get_cloud_storage") @patch("app.services.collections.create_collection.get_llm_provider") -def test_execute_job_success_flow_updates_job_and_creates_collection( - mock_get_llm_provider: MagicMock, db: Session +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_setup_job_marks_processing_and_queues_first_batch( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + mock_get_storage: MagicMock, + db: Session, ) -> None: - """ - execute_job should: - - set task_id on the CollectionJob - - ingest documents into a vector store - - create an OpenAI assistant - - create a Collection with llm fields filled - - link the CollectionJob -> collection_id, set status=successful - - create DocumentCollection links - """ project = get_project(db) - - aws = AmazonCloudStorageClient() - aws.create() - store = DocumentStore(db=db, project_id=project.id) - document = store.put() - s3_key = Path(urlparse(document.object_store_url).path).relative_to("/") - aws.client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=str(s3_key), Body=b"test") + doc = store.put() - sample_request = CreationRequest( - documents=[document.id], callback_url=None, provider="openai" + mock_get_provider.return_value = _mock_provider_with_size( + "vs_123", "openai vector store" ) - mock_get_llm_provider.return_value = get_mock_provider( - llm_service_id="mock_vector_store_id", llm_service_name="openai vector store" - ) - - job_id = uuid4() - _ = get_collection_job( + job = get_collection_job( db, project, - job_id=job_id, action_type=CollectionActionType.CREATE, status=CollectionJobStatus.PENDING, - collection_id=None, ) - - task_id = uuid4() - - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False - - execute_job( - request=sample_request.model_dump(), + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) + task_id = str(uuid4()) + + patcher = _patch_session(db) + try: + execute_setup_job( + request=request.model_dump(mode="json"), + with_assistant=False, project_id=project.id, organization_id=project.organization_id, - task_id=str(task_id), - with_assistant=True, - job_id=str(job_id), + task_id=task_id, + job_id=str(job.id), task_instance=None, ) + finally: + patcher.stop() - updated_job = CollectionJobCrud(db, project.id).read_one(job_id) - assert updated_job.task_id == str(task_id) - assert updated_job.status == CollectionJobStatus.SUCCESSFUL - assert updated_job.collection_id is not None - - created_collection = CollectionCrud(db, project.id).read_one( - updated_job.collection_id - ) - assert created_collection.llm_service_id == "mock_vector_store_id" - assert created_collection.llm_service_name == "openai vector store" - assert created_collection.updated_at is not None + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.PROCESSING + assert updated_job.task_id == task_id - docs = DocumentCollectionCrud(db).read(created_collection, skip=0, limit=10) - assert len(docs) == 1 - assert docs[0].fname == document.fname + mock_queue_batch.assert_called_once() + kw = mock_queue_batch.call_args.kwargs + assert kw["batch_number"] == 1 + assert kw["vector_store_id"] is None + assert str(doc.id) in kw["batch_doc_ids"] + assert kw["remaining_batches"] == [] -@pytest.mark.usefixtures("aws_credentials") -@mock_aws +@patch("app.services.collections.create_collection.get_cloud_storage") @patch("app.services.collections.create_collection.get_llm_provider") -def test_execute_job_assistant_create_failure_marks_failed_and_deletes_collection( - mock_get_llm_provider: MagicMock, db +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_setup_job_failure_marks_job_failed_and_raises( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + mock_get_storage: MagicMock, + db: Session, ) -> None: project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_provider = _mock_provider_with_size("vs_123", "openai vector store") + mock_provider.upload_files.side_effect = RuntimeError("S3 upload failed") + mock_get_provider.return_value = mock_provider job = get_collection_job( db, project, - job_id=uuid4(), action_type=CollectionActionType.CREATE, status=CollectionJobStatus.PENDING, - collection_id=None, ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) - req = CreationRequest(documents=[], callback_url=None, provider="openai") - - mock_provider = get_mock_provider( - llm_service_id="vs_123", llm_service_name="openai vector store" - ) - mock_get_llm_provider.return_value = mock_provider - - with patch( - "app.services.collections.create_collection.Session" - ) as SessionCtor, patch( - "app.services.collections.create_collection.CollectionCrud" - ) as MockCrud: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False - - MockCrud.return_value.create.side_effect = Exception("DB constraint violation") - - task_id = str(uuid4()) - with pytest.raises(Exception, match="DB constraint violation"): - execute_job( - request=req.model_dump(), + patcher = _patch_session(db) + try: + with pytest.raises(RuntimeError, match="S3 upload failed"): + execute_setup_job( + request=request.model_dump(mode="json"), + with_assistant=False, project_id=project.id, organization_id=project.organization_id, - task_id=task_id, - with_assistant=True, + task_id=str(uuid4()), job_id=str(job.id), task_instance=None, ) + finally: + patcher.stop() - mock_provider.delete.assert_called_once() + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.FAILED + assert "S3 upload failed" in (updated_job.error_message or "") + mock_queue_batch.assert_not_called() -@pytest.mark.usefixtures("aws_credentials") -@mock_aws -@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") -def test_execute_job_success_flow_callback_job_and_creates_collection( +@patch("app.services.collections.create_collection.get_cloud_storage") +@patch("app.services.collections.create_collection.get_llm_provider") +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_setup_job_failure_sends_callback( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + mock_get_storage: MagicMock, mock_send_callback: MagicMock, - mock_get_llm_provider: MagicMock, - db, + db: Session, ) -> None: - """ - execute_job should: - - set task_id on the CollectionJob - - ingest documents into a vector store - - create an OpenAI assistant - - create a Collection with llm fields filled - - link the CollectionJob -> collection_id, set status=successful - - create DocumentCollection links - """ project = get_project(db) - - aws = AmazonCloudStorageClient() - aws.create() - store = DocumentStore(db=db, project_id=project.id) - document = store.put() - s3_key = Path(urlparse(document.object_store_url).path).relative_to("/") - aws.client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=str(s3_key), Body=b"test") + doc = store.put() - callback_url = "https://example.com/collections/create-success" + mock_provider = _mock_provider_with_size("vs_123", "openai vector store") + mock_provider.upload_files.side_effect = RuntimeError("upload error") + mock_get_provider.return_value = mock_provider - sample_request = CreationRequest( - documents=[document.id], - callback_url=callback_url, - provider="openai", + callback_url = "https://example.com/callback" + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, ) - - mock_get_llm_provider.return_value = get_mock_provider( - llm_service_id="mock_vector_store_id", llm_service_name="openai vector store" + request = CreationRequest( + documents=[doc.id], provider="openai", callback_url=callback_url ) - job_id = uuid.uuid4() - _ = get_collection_job( + patcher = _patch_session(db) + try: + with pytest.raises(RuntimeError): + execute_setup_job( + request=request.model_dump(mode="json"), + with_assistant=False, + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + task_instance=None, + ) + finally: + patcher.stop() + + mock_send_callback.assert_called_once() + cb_url, payload = mock_send_callback.call_args.args + assert str(cb_url) == callback_url + assert payload["success"] is False + assert payload["data"]["status"] == CollectionJobStatus.FAILED + + +@patch("app.services.collections.create_collection.get_cloud_storage") +@patch("app.services.collections.create_collection.get_llm_provider") +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_setup_job_timeout_marks_failed_and_reraises( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + mock_get_storage: MagicMock, + db: Session, +) -> None: + project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_provider = _mock_provider_with_size("vs_123", "openai vector store") + mock_provider.upload_files.side_effect = Timeout(300) + mock_get_provider.return_value = mock_provider + + job = get_collection_job( db, project, - job_id=job_id, action_type=CollectionActionType.CREATE, status=CollectionJobStatus.PENDING, - collection_id=None, ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) - task_id = uuid.uuid4() + patcher = _patch_session(db) + try: + with pytest.raises(Timeout): + execute_setup_job( + request=request.model_dump(mode="json"), + with_assistant=False, + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + task_instance=None, + ) + finally: + patcher.stop() + + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.FAILED + assert "soft time limit" in (updated_job.error_message or "") + + +# --------------------------------------------------------------------------- +# execute_batch_job +# --------------------------------------------------------------------------- + + +@patch("app.services.collections.create_collection.get_llm_provider") +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_batch_job_non_final_queues_next_batch( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + db: Session, +) -> None: + project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc1 = store.put() + doc2 = store.put() - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False + mock_get_provider.return_value = get_mock_provider("vs_123", "openai vector store") - mock_send_callback.return_value = MagicMock(status_code=403) + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PROCESSING, + ) + request = CreationRequest( + documents=[doc1.id, doc2.id], provider="openai", callback_url=None + ) + task_id = str(uuid4()) - execute_job( - request=sample_request.model_dump(), + patcher = _patch_session(db) + try: + execute_batch_job( + request=request.model_dump(mode="json"), + with_assistant=False, project_id=project.id, organization_id=project.organization_id, - task_id=str(task_id), - with_assistant=True, - job_id=str(job_id), + task_id=task_id, + job_id=str(job.id), task_instance=None, + vector_store_id="vs_123", + batch_number=1, + batch_doc_ids=[str(doc1.id)], + remaining_batches=[[str(doc2.id)]], ) + finally: + patcher.stop() - updated_job = CollectionJobCrud(db, project.id).read_one(job_id) - collection = CollectionCrud(db, project.id).read_one(updated_job.collection_id) + mock_queue_batch.assert_called_once() + kw = mock_queue_batch.call_args.kwargs + assert kw["batch_number"] == 2 + assert kw["batch_doc_ids"] == [str(doc2.id)] + assert kw["remaining_batches"] == [] + assert kw["vector_store_id"] == "vs_123" - mock_send_callback.assert_called_once() - cb_url_arg, payload_arg = mock_send_callback.call_args.args - assert str(cb_url_arg) == callback_url - assert payload_arg["success"] is True - assert payload_arg["data"]["status"] == CollectionJobStatus.SUCCESSFUL - assert payload_arg["data"]["collection"]["id"] == str(collection.id) - assert uuid.UUID(payload_arg["data"]["job_id"]) == job_id + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.current_batch_number == 1 + assert str(doc1.id) in (updated_job.documents_uploaded or []) -@pytest.mark.usefixtures("aws_credentials") -@mock_aws @patch("app.services.collections.create_collection.get_llm_provider") -@patch("app.services.collections.create_collection.send_callback") -def test_execute_job_success_creates_collection_with_callback( - mock_send_callback: MagicMock, - mock_get_llm_provider: MagicMock, - db, +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_batch_job_final_batch_creates_collection_and_marks_successful( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + db: Session, ) -> None: - """ - execute_job should: - - set task_id on the CollectionJob - - ingest documents into a vector store - - create an OpenAI assistant - - create a Collection with llm fields filled - - link the CollectionJob -> collection_id, set status=successful - - create DocumentCollection links - """ project = get_project(db) - - aws = AmazonCloudStorageClient() - aws.create() - store = DocumentStore(db=db, project_id=project.id) - document = store.put() - s3_key = Path(urlparse(document.object_store_url).path).relative_to("/") - aws.client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=str(s3_key), Body=b"test") + doc = store.put() - callback_url = "https://example.com/collections/create-success" - - sample_request = CreationRequest( - documents=[document.id], - callback_url=callback_url, - provider="openai", - ) - - mock_get_llm_provider.return_value = get_mock_provider( - llm_service_id="mock_vector_store_id", llm_service_name="gpt-4o" + mock_get_provider.return_value = get_mock_provider( + "vs_final", "openai vector store" ) - job_id = uuid.uuid4() - _ = get_collection_job( + job = get_collection_job( db, project, - job_id=job_id, action_type=CollectionActionType.CREATE, - status=CollectionJobStatus.PENDING, - collection_id=None, + status=CollectionJobStatus.PROCESSING, ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) - task_id = uuid.uuid4() - - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False - - mock_send_callback.return_value = MagicMock(status_code=403) - - execute_job( - request=sample_request.model_dump(), + patcher = _patch_session(db) + try: + execute_batch_job( + request=request.model_dump(mode="json"), + with_assistant=False, project_id=project.id, organization_id=project.organization_id, - task_id=str(task_id), - with_assistant=True, - job_id=str(job_id), + task_id=str(uuid4()), + job_id=str(job.id), task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], ) + finally: + patcher.stop() + + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.SUCCESSFUL + assert updated_job.collection_id is not None - updated_job = CollectionJobCrud(db, project.id).read_one(job_id) collection = CollectionCrud(db, project.id).read_one(updated_job.collection_id) + assert collection.llm_service_id == "vs_final" - mock_send_callback.assert_called_once() - cb_url_arg, payload_arg = mock_send_callback.call_args.args - assert str(cb_url_arg) == callback_url - assert payload_arg["success"] is True - assert payload_arg["data"]["status"] == CollectionJobStatus.SUCCESSFUL - assert payload_arg["data"]["collection"]["id"] == str(collection.id) - assert uuid.UUID(payload_arg["data"]["job_id"]) == job_id + linked_docs = DocumentCollectionCrud(db).read(collection, skip=0, limit=10) + assert len(linked_docs) == 1 + assert linked_docs[0].id == doc.id + + mock_queue_batch.assert_not_called() -@pytest.mark.usefixtures("aws_credentials") -@mock_aws -@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") -@patch("app.services.collections.create_collection.CollectionCrud") -def test_execute_job_failure_flow_callback_job_and_marks_failed( - MockCollectionCrud, +@patch("app.services.collections.create_collection.get_llm_provider") +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_batch_job_final_batch_sends_success_callback( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, mock_send_callback: MagicMock, - mock_get_llm_provider: MagicMock, db: Session, ) -> None: - """ - When creation fails, the job should be marked as FAILED, an error should be logged, - and a failure callback with the error message should be triggered. - """ project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_get_provider.return_value = get_mock_provider( + "vs_final", "openai vector store" + ) - collection = get_assistant_collection(db, project, assistant_id="asst_123") + callback_url = "https://example.com/success" job = get_collection_job( db, project, action_type=CollectionActionType.CREATE, - status=CollectionJobStatus.PENDING, - collection_id=None, + status=CollectionJobStatus.PROCESSING, + ) + request = CreationRequest( + documents=[doc.id], provider="openai", callback_url=callback_url ) - mock_get_llm_provider.return_value = MagicMock() + patcher = _patch_session(db) + try: + execute_batch_job( + request=request.model_dump(mode="json"), + with_assistant=False, + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], + ) + finally: + patcher.stop() - callback_url = "https://example.com/collections/create-failure" + mock_send_callback.assert_called_once() + cb_url, payload = mock_send_callback.call_args.args + assert str(cb_url) == callback_url + assert payload["success"] is True + assert payload["data"]["status"] == CollectionJobStatus.SUCCESSFUL + assert payload["data"]["collection"] is not None - collection_crud_instance = MockCollectionCrud.return_value - collection_crud_instance.read_one.return_value = collection - sample_request = CreationRequest( - documents=[uuid.uuid4()], - callback_url=callback_url, - provider="openai", - ) +@patch("app.services.collections.create_collection.get_llm_provider") +def test_execute_batch_job_provider_failure_marks_failed_and_raises( + mock_get_provider: MagicMock, + db: Session, +) -> None: + project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() - task_id = uuid.uuid4() + mock_provider = get_mock_provider("vs_123", "openai vector store") + mock_provider.create.side_effect = RuntimeError("vector store error") + mock_get_provider.return_value = mock_provider - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PROCESSING, + ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) - with pytest.raises( - ValueError, match="Requested atleast 1 document retrieved 0" - ): - execute_job( - request=sample_request.model_dump(), + patcher = _patch_session(db) + try: + with pytest.raises(RuntimeError, match="vector store error"): + execute_batch_job( + request=request.model_dump(mode="json"), + with_assistant=False, project_id=project.id, organization_id=project.organization_id, - task_id=str(task_id), - with_assistant=True, + task_id=str(uuid4()), job_id=str(job.id), task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], ) + finally: + patcher.stop() updated_job = CollectionJobCrud(db, project.id).read_one(job.id) - assert updated_job.status == CollectionJobStatus.FAILED - assert "Requested atleast 1 document retrieved 0" in ( - updated_job.error_message or "" - ) - - mock_send_callback.assert_called_once() - cb_url_arg, payload_arg = mock_send_callback.call_args.args - assert str(cb_url_arg) == callback_url - assert payload_arg["success"] is False - assert "Requested atleast 1 document retrieved 0" in (payload_arg["error"] or "") - assert payload_arg["data"]["status"] == CollectionJobStatus.FAILED - assert payload_arg["data"]["collection"] is None - assert uuid.UUID(payload_arg["data"]["job_id"]) == job.id + assert "vector store error" in (updated_job.error_message or "") @patch("app.services.collections.create_collection.get_llm_provider") -def test_execute_job_timeout_marks_job_failed( - mock_get_llm_provider: MagicMock, db: Session +@patch("app.services.collections.create_collection.CollectionCrud") +def test_execute_batch_job_cleanup_called_when_provider_create_succeeds_but_db_fails( + MockCollectionCrud: MagicMock, + mock_get_provider: MagicMock, + db: Session, ) -> None: + """provider.delete should be called if create() succeeded but finalization fails.""" project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_provider = get_mock_provider("vs_123", "openai vector store") + mock_get_provider.return_value = mock_provider + + MockCollectionCrud.return_value.create.side_effect = Exception("DB write failed") job = get_collection_job( db, project, - job_id=uuid4(), action_type=CollectionActionType.CREATE, - status=CollectionJobStatus.PENDING, - collection_id=None, + status=CollectionJobStatus.PROCESSING, ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) - mock_provider = MagicMock() - mock_provider.create.side_effect = Timeout(300) - mock_get_llm_provider.return_value = mock_provider + patcher = _patch_session(db) + try: + with pytest.raises(Exception, match="DB write failed"): + execute_batch_job( + request=request.model_dump(mode="json"), + with_assistant=False, + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], + ) + finally: + patcher.stop() + + mock_provider.delete.assert_called_once() + + +@patch("app.services.collections.create_collection.get_llm_provider") +def test_execute_batch_job_timeout_marks_failed_and_reraises( + mock_get_provider: MagicMock, + db: Session, +) -> None: + project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() - req = CreationRequest(documents=[], callback_url=None, provider="openai") + mock_provider = get_mock_provider("vs_123", "openai vector store") + mock_provider.create.side_effect = Timeout(300) + mock_get_provider.return_value = mock_provider - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PROCESSING, + ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) + patcher = _patch_session(db) + try: with pytest.raises(Timeout): - execute_job( - request=req.model_dump(), + execute_batch_job( + request=request.model_dump(mode="json"), + with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), - with_assistant=False, job_id=str(job.id), task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], ) + finally: + patcher.stop() updated_job = CollectionJobCrud(db, project.id).read_one(job.id) assert updated_job.status == CollectionJobStatus.FAILED assert "soft time limit" in (updated_job.error_message or "") -@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") -def test_execute_job_timeout_sends_failure_callback( +@patch("app.services.collections.create_collection.get_llm_provider") +def test_execute_batch_job_failure_sends_callback( + mock_get_provider: MagicMock, mock_send_callback: MagicMock, - mock_get_llm_provider: MagicMock, db: Session, ) -> None: project = get_project(db) - callback_url = "https://example.com/collections/timeout" + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_provider = get_mock_provider("vs_123", "openai vector store") + mock_provider.create.side_effect = RuntimeError("batch failed") + mock_get_provider.return_value = mock_provider + callback_url = "https://example.com/failure" job = get_collection_job( db, project, - job_id=uuid4(), action_type=CollectionActionType.CREATE, - status=CollectionJobStatus.PENDING, - collection_id=None, + status=CollectionJobStatus.PROCESSING, + ) + request = CreationRequest( + documents=[doc.id], provider="openai", callback_url=callback_url ) - mock_provider = MagicMock() - mock_provider.create.side_effect = Timeout(300) - mock_get_llm_provider.return_value = mock_provider - - req = CreationRequest(documents=[], callback_url=callback_url, provider="openai") - - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False - - with pytest.raises(Timeout): - execute_job( - request=req.model_dump(), + patcher = _patch_session(db) + try: + with pytest.raises(RuntimeError): + execute_batch_job( + request=request.model_dump(mode="json"), + with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), - with_assistant=False, job_id=str(job.id), task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], ) + finally: + patcher.stop() mock_send_callback.assert_called_once() - cb_url_arg, payload_arg = mock_send_callback.call_args.args - assert str(cb_url_arg) == callback_url - assert payload_arg["success"] is False - assert "soft time limit" in (payload_arg["error"] or "") - assert payload_arg["data"]["status"] == CollectionJobStatus.FAILED - assert payload_arg["data"]["collection"] is None - assert uuid.UUID(payload_arg["data"]["job_id"]) == job.id + cb_url, payload = mock_send_callback.call_args.args + assert str(cb_url) == callback_url + assert payload["success"] is False + assert "batch failed" in (payload["error"] or "") + assert payload["data"]["status"] == CollectionJobStatus.FAILED From 290d5f08c07b17d9e874f310cdd06a47e3690b91 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Sun, 10 May 2026 19:50:13 +0530 Subject: [PATCH 11/19] fixing test cases --- backend/app/services/collections/create_collection.py | 6 +++--- .../services/collections/providers/test_openai_provider.py | 6 +----- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index b87e86667..ca3197200 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -211,6 +211,9 @@ def execute_setup_job( organization_id=organization_id, ) + for doc in flat_docs: + session.expunge(doc) + collection_job_crud = CollectionJobCrud(session, project_id) collection_job = collection_job_crud.update( job_uuid, @@ -220,9 +223,6 @@ def execute_setup_job( ), ) - for doc in flat_docs: - session.expunge(doc) - provider.upload_files(storage, flat_docs, project_id) logger.info( diff --git a/backend/app/tests/services/collections/providers/test_openai_provider.py b/backend/app/tests/services/collections/providers/test_openai_provider.py index 8431ee512..f48d04d63 100644 --- a/backend/app/tests/services/collections/providers/test_openai_provider.py +++ b/backend/app/tests/services/collections/providers/test_openai_provider.py @@ -24,7 +24,6 @@ def test_create_openai_vector_store_only() -> None: temperature=None, ) - storage = MagicMock() documents = [ SimpleNamespace(file_size_kb=10), SimpleNamespace(file_size_kb=20), @@ -40,7 +39,6 @@ def test_create_openai_vector_store_only() -> None: collection = provider.create( collection_request, - storage, documents, ) @@ -60,7 +58,6 @@ def test_create_openai_with_assistant() -> None: temperature=0.7, ) - storage = MagicMock() documents = [SimpleNamespace(file_size_kb=10)] vector_store_id = generate_openai_id("vs_") assistant_id = generate_openai_id("asst_") @@ -79,8 +76,8 @@ def test_create_openai_with_assistant() -> None: collection = provider.create( collection_request, - storage, documents, + is_final=True, ) assert collection.llm_service_id == assistant_id @@ -341,6 +338,5 @@ def test_create_propagates_exception() -> None: with pytest.raises(RuntimeError): provider.create( collection_request, - MagicMock(), [SimpleNamespace(file_size_kb=10)], ) From 9f95a7634520c9d01436afd8987fde29d754dbd7 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Sun, 10 May 2026 20:16:37 +0530 Subject: [PATCH 12/19] increasing test cases --- .../services/collections/test_helpers.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/backend/app/tests/services/collections/test_helpers.py b/backend/app/tests/services/collections/test_helpers.py index 8b43946a1..15a8a8019 100644 --- a/backend/app/tests/services/collections/test_helpers.py +++ b/backend/app/tests/services/collections/test_helpers.py @@ -122,6 +122,28 @@ def test_batch_documents_mixed_size_batching() -> None: assert len(batches[2]) == 1 # 15 MB total +def test_batch_documents_zero_size_files_batches_by_count() -> None: + """Zero-size docs contribute nothing to size, so only the 200-doc count limit applies.""" + docs = create_fake_documents(250, file_size_kb=0) + batches = helpers.batch_documents(docs) + + assert len(batches) == 2 + assert len(batches[0]) == 200 + assert len(batches[1]) == 50 + + +def test_batch_documents_doc_exactly_at_size_limit_stays_in_same_batch() -> None: + """A doc whose size exactly equals MAX_BATCH_SIZE_KB should not trigger a new batch + on its own — the split only happens when adding it would *exceed* the limit.""" + from app.services.collections.helpers import MAX_BATCH_SIZE_KB + + docs = create_fake_documents(1, file_size_kb=MAX_BATCH_SIZE_KB) + batches = helpers.batch_documents(docs) + + assert len(batches) == 1 + assert len(batches[0]) == 1 + + def test_batch_documents_with_none_file_size_raises() -> None: """Test that documents with None file_size raise TypeError — sizes must be backfilled before batching.""" docs = create_fake_documents(10, file_size_kb=None) From cb976548b87b800a45819edeab8564f8b253d5fd Mon Sep 17 00:00:00 2001 From: nishika26 Date: Mon, 11 May 2026 00:01:18 +0530 Subject: [PATCH 13/19] incrreading test coverage --- backend/app/celery/tasks/job_execution.py | 214 ++++++++++++------ backend/app/celery/utils.py | 144 ++++++++---- .../services/collections/create_collection.py | 4 +- .../providers/test_openai_provider.py | 83 +++++++ .../collections/test_create_collection.py | 2 +- 5 files changed, 327 insertions(+), 120 deletions(-) diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index 9a13fddcf..cd3cf6397 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -1,9 +1,10 @@ import logging -from typing import Any -import celery from asgi_correlation_id import correlation_id from celery import current_task +from opentelemetry import context as otel_context +from opentelemetry import trace +from opentelemetry.propagate import extract from app.celery.celery_app import celery_app from app.celery.utils import gevent_timeout @@ -17,18 +18,61 @@ def _set_trace(trace_id: str) -> None: logger.info(f"[_set_trace] Set correlation ID: {trace_id}") +def _extract_parent_context(task_instance) -> otel_context.Context: + """Extract OTel parent context from Celery headers if available.""" + headers = getattr(task_instance.request, "headers", None) or {} + carrier: dict[str, str] = {} + + if isinstance(headers, dict): + for key, value in headers.items(): + if isinstance(value, str): + carrier[str(key)] = value + + nested = headers.get("otel", {}) + if isinstance(nested, dict): + for key, value in nested.items(): + if isinstance(value, str): + carrier[str(key)] = value + + return extract(carrier) + + +def _run_with_otel_parent(task_instance, fn): + """Attach extracted parent context and execute function. + + When Celery auto-instrumentation is active, there is already a current + `run/...` span. Re-attaching extracted parent context here would make + service spans become siblings of `run/...` instead of children. + + We only attach extracted context as a fallback when no active span exists. + """ + current_ctx = trace.get_current_span().get_span_context() + if current_ctx and current_ctx.is_valid: + return fn() + + parent_ctx = _extract_parent_context(task_instance) + token = otel_context.attach(parent_ctx) + try: + return fn() + finally: + otel_context.detach(token) + + @celery_app.task(bind=True, queue="high_priority", priority=9) @gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_llm_job") def run_llm_job(self, project_id: int, job_id: str, trace_id: str, **kwargs): from app.services.llm.jobs import execute_job _set_trace(trace_id) - return execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -38,12 +82,15 @@ def run_llm_chain_job(self, project_id: int, job_id: str, trace_id: str, **kwarg from app.services.llm.jobs import execute_chain_job _set_trace(trace_id) - return execute_chain_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_chain_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -53,12 +100,15 @@ def run_response_job(self, project_id: int, job_id: str, trace_id: str, **kwargs from app.services.response.jobs import execute_job _set_trace(trace_id) - return execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -68,29 +118,35 @@ def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kw from app.services.doctransform.job import execute_job _set_trace(trace_id) - return execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @celery_app.task(bind=True, queue="low_priority", priority=1) -@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_create_collection_job") -def run_create_collection_job( +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_collection_setup_job") +def run_collection_setup_job( self, project_id: int, job_id: str, trace_id: str, **kwargs ): from app.services.collections.create_collection import execute_setup_job _set_trace(trace_id) - return execute_setup_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_setup_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -102,12 +158,15 @@ def run_collection_batch_job( from app.services.collections.create_collection import execute_batch_job _set_trace(trace_id) - return execute_batch_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_batch_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -119,12 +178,15 @@ def run_delete_collection_job( from app.services.collections.delete_collection import execute_job _set_trace(trace_id) - return execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -136,12 +198,15 @@ def run_stt_batch_submission( from app.services.stt_evaluations.batch_job import execute_batch_submission _set_trace(trace_id) - return execute_batch_submission( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_batch_submission( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -153,12 +218,15 @@ def run_stt_metric_computation( from app.services.stt_evaluations.metric_job import execute_metric_computation _set_trace(trace_id) - return execute_metric_computation( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_metric_computation( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -170,12 +238,15 @@ def run_tts_batch_submission( from app.services.tts_evaluations.batch_job import execute_batch_submission _set_trace(trace_id) - return execute_batch_submission( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_batch_submission( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -189,10 +260,13 @@ def run_tts_result_processing( ) _set_trace(trace_id) - return execute_tts_result_processing( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_tts_result_processing( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) diff --git a/backend/app/celery/utils.py b/backend/app/celery/utils.py index f39fb9d5d..907914661 100644 --- a/backend/app/celery/utils.py +++ b/backend/app/celery/utils.py @@ -18,14 +18,24 @@ F = TypeVar("F", bound=Callable[..., Any]) +def _enqueue_with_trace_context(task, **kwargs) -> str: + """Publish Celery task with explicit trace context headers.""" + otel_headers: dict[str, str] = {} + inject(otel_headers) + celery_headers = dict(otel_headers) + celery_headers["otel"] = otel_headers + async_result = task.apply_async(kwargs=kwargs, headers=celery_headers) + return async_result.id + + def start_llm_job(project_id: int, job_id: str, trace_id: str = "N/A", **kwargs) -> str: from app.celery.tasks.job_execution import run_llm_job - task = run_llm_job.delay( - project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + task_id = _enqueue_with_trace_context( + run_llm_job, project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) - logger.info(f"[start_llm_job] Started job {job_id} with Celery task {task.id}") - return task.id + logger.info(f"[start_llm_job] Started job {job_id} with Celery task {task_id}") + return task_id def start_llm_chain_job( @@ -33,13 +43,17 @@ def start_llm_chain_job( ) -> str: from app.celery.tasks.job_execution import run_llm_chain_job - task = run_llm_chain_job.delay( - project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + task_id = _enqueue_with_trace_context( + run_llm_chain_job, + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + **kwargs, ) logger.info( - f"[start_llm_chain_job] Started job {job_id} with Celery task {task.id}" + f"[start_llm_chain_job] Started job {job_id} with Celery task {task_id}" ) - return task.id + return task_id def start_response_job( @@ -47,11 +61,15 @@ def start_response_job( ) -> str: from app.celery.tasks.job_execution import run_response_job - task = run_response_job.delay( - project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + task_id = _enqueue_with_trace_context( + run_response_job, + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + **kwargs, ) - logger.info(f"[start_response_job] Started job {job_id} with Celery task {task.id}") - return task.id + logger.info(f"[start_response_job] Started job {job_id} with Celery task {task_id}") + return task_id def start_doctransform_job( @@ -59,27 +77,35 @@ def start_doctransform_job( ) -> str: from app.celery.tasks.job_execution import run_doctransform_job - task = run_doctransform_job.delay( - project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + task_id = _enqueue_with_trace_context( + run_doctransform_job, + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + **kwargs, ) logger.info( - f"[start_doctransform_job] Started job {job_id} with Celery task {task.id}" + f"[start_doctransform_job] Started job {job_id} with Celery task {task_id}" ) - return task.id + return task_id -def start_create_collection_job( +def start_collection_setup_job( project_id: int, job_id: str, trace_id: str = "N/A", **kwargs ) -> str: - from app.celery.tasks.job_execution import run_create_collection_job - - task = run_create_collection_job.delay( - project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + from app.celery.tasks.job_execution import run_collection_setup_job + + task_id = _enqueue_with_trace_context( + run_collection_setup_job, + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + **kwargs, ) logger.info( - f"[start_create_collection_job] Started job {job_id} with Celery task {task.id}" + f"[start_collection_setup_job] Started job {job_id} with Celery task {task_id}" ) - return task.id + return task_id def start_collection_batch_job( @@ -87,13 +113,17 @@ def start_collection_batch_job( ) -> str: from app.celery.tasks.job_execution import run_collection_batch_job - task = run_collection_batch_job.delay( - project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + task_id = _enqueue_with_trace_context( + run_collection_batch_job, + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + **kwargs, ) logger.info( - f"[start_collection_batch_job] Started batch job {job_id} with Celery task {task.id}" + f"[start_collection_setup_job] Started job {job_id} with Celery task {task_id}" ) - return task.id + return task_id def start_delete_collection_job( @@ -101,13 +131,17 @@ def start_delete_collection_job( ) -> str: from app.celery.tasks.job_execution import run_delete_collection_job - task = run_delete_collection_job.delay( - project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + task_id = _enqueue_with_trace_context( + run_delete_collection_job, + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + **kwargs, ) logger.info( - f"[start_delete_collection_job] Started job {job_id} with Celery task {task.id}" + f"[start_delete_collection_job] Started job {job_id} with Celery task {task_id}" ) - return task.id + return task_id def start_stt_batch_submission( @@ -115,13 +149,17 @@ def start_stt_batch_submission( ) -> str: from app.celery.tasks.job_execution import run_stt_batch_submission - task = run_stt_batch_submission.delay( - project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + task_id = _enqueue_with_trace_context( + run_stt_batch_submission, + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + **kwargs, ) logger.info( - f"[start_stt_batch_submission] Started job {job_id} with Celery task {task.id}" + f"[start_stt_batch_submission] Started job {job_id} with Celery task {task_id}" ) - return task.id + return task_id def start_stt_metric_computation( @@ -129,13 +167,17 @@ def start_stt_metric_computation( ) -> str: from app.celery.tasks.job_execution import run_stt_metric_computation - task = run_stt_metric_computation.delay( - project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + task_id = _enqueue_with_trace_context( + run_stt_metric_computation, + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + **kwargs, ) logger.info( - f"[start_stt_metric_computation] Started job {job_id} with Celery task {task.id}" + f"[start_stt_metric_computation] Started job {job_id} with Celery task {task_id}" ) - return task.id + return task_id def start_tts_batch_submission( @@ -143,13 +185,17 @@ def start_tts_batch_submission( ) -> str: from app.celery.tasks.job_execution import run_tts_batch_submission - task = run_tts_batch_submission.delay( - project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + task_id = _enqueue_with_trace_context( + run_tts_batch_submission, + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + **kwargs, ) logger.info( - f"[start_tts_batch_submission] Started job {job_id} with Celery task {task.id}" + f"[start_tts_batch_submission] Started job {job_id} with Celery task {task_id}" ) - return task.id + return task_id def start_tts_result_processing( @@ -157,13 +203,17 @@ def start_tts_result_processing( ) -> str: from app.celery.tasks.job_execution import run_tts_result_processing - task = run_tts_result_processing.delay( - project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + task_id = _enqueue_with_trace_context( + run_tts_result_processing, + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + **kwargs, ) logger.info( - f"[start_tts_result_processing] Started job {job_id} with Celery task {task.id}" + f"[start_tts_result_processing] Started job {job_id} with Celery task {task_id}" ) - return task.id + return task_id def get_task_status(task_id: str) -> Dict[str, Any]: diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index ca3197200..52852e63c 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -31,7 +31,7 @@ to_collection_public, ) from app.services.collections.providers.registry import get_llm_provider -from app.celery.utils import start_create_collection_job, start_collection_batch_job +from app.celery.utils import start_collection_setup_job, start_collection_batch_job from app.utils import send_callback, APIResponse, get_webhook_secret @@ -52,7 +52,7 @@ def start_job( job_crud = CollectionJobCrud(db, project_id) job_crud.update(collection_job_id, CollectionJobUpdate(trace_id=trace_id)) - task_id = start_create_collection_job( + task_id = start_collection_setup_job( project_id=project_id, job_id=str(collection_job_id), trace_id=trace_id, diff --git a/backend/app/tests/services/collections/providers/test_openai_provider.py b/backend/app/tests/services/collections/providers/test_openai_provider.py index f48d04d63..838dea26a 100644 --- a/backend/app/tests/services/collections/providers/test_openai_provider.py +++ b/backend/app/tests/services/collections/providers/test_openai_provider.py @@ -315,6 +315,89 @@ def test_upload_files_mixed_skips_uploaded_uploads_new() -> None: storage.get.assert_called_once_with(new_doc.object_store_url) +def test_upload_files_empty_docs_is_noop() -> None: + client = MagicMock() + provider = OpenAIProvider(client=client) + storage = MagicMock() + + session_p, crud_p = _patch_session_and_crud() + with session_p, crud_p: + provider.upload_files(storage, [], project_id=1) + + storage.get.assert_not_called() + client.files.create.assert_not_called() + + +def test_upload_files_file_object_name_matches_doc_fname() -> None: + """The BytesIO passed to OpenAI must carry the original filename.""" + from io import BytesIO + + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-abc") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"data" + + doc = _make_doc(file_size_kb=1.0) + doc.fname = "report.pdf" + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + provider.upload_files(storage, [doc], project_id=1) + + _, kwargs = client.files.create.call_args + f_obj = kwargs["file"] + assert isinstance(f_obj, BytesIO) + assert f_obj.name == "report.pdf" + + +def test_upload_files_raises_on_db_update_failure() -> None: + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-ok") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"content" + + doc = _make_doc(file_size_kb=1.0) + mock_crud = MagicMock() + mock_crud.read_one.return_value = MagicMock() + mock_crud.update.side_effect = RuntimeError("DB write failed") + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p as MockDocCrud: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + MockDocCrud.return_value = mock_crud + + with pytest.raises(RuntimeError, match="DB write failed"): + provider.upload_files(storage, [doc], project_id=1) + + +def test_upload_files_first_failure_stops_remaining_docs() -> None: + """If the first doc raises, subsequent docs are never attempted.""" + client = MagicMock() + client.files.create.side_effect = RuntimeError("quota exceeded") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"content" + + doc1 = _make_doc(file_size_kb=1.0) + doc2 = _make_doc(file_size_kb=1.0) + + session_p, crud_p = _patch_session_and_crud() + with session_p, crud_p: + with pytest.raises(RuntimeError, match="quota exceeded"): + provider.upload_files(storage, [doc1, doc2], project_id=1) + + client.files.create.assert_called_once() + assert storage.get.call_count == 1 + + # --------------------------------------------------------------------------- # create (existing tests below) # --------------------------------------------------------------------------- diff --git a/backend/app/tests/services/collections/test_create_collection.py b/backend/app/tests/services/collections/test_create_collection.py index 05213b879..a5f9356a7 100644 --- a/backend/app/tests/services/collections/test_create_collection.py +++ b/backend/app/tests/services/collections/test_create_collection.py @@ -79,7 +79,7 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non ) with patch( - "app.services.collections.create_collection.start_create_collection_job" + "app.services.collections.create_collection.start_collection_setup_job" ) as mock_schedule: mock_schedule.return_value = "fake-task-id" From fd37d1434115555f12de3274a18bb8caefaee32f Mon Sep 17 00:00:00 2001 From: nishika26 Date: Mon, 11 May 2026 11:04:04 +0530 Subject: [PATCH 14/19] coderabbit reviews --- backend/app/celery/utils.py | 2 +- backend/app/crud/rag/open_ai.py | 23 ++-- .../services/collections/providers/openai.py | 38 +++++- .../providers/test_openai_provider.py | 116 ++++++++++++++++++ .../collections/test_create_collection.py | 90 ++++++++++++++ 5 files changed, 252 insertions(+), 17 deletions(-) diff --git a/backend/app/celery/utils.py b/backend/app/celery/utils.py index 907914661..ab1dbc080 100644 --- a/backend/app/celery/utils.py +++ b/backend/app/celery/utils.py @@ -121,7 +121,7 @@ def start_collection_batch_job( **kwargs, ) logger.info( - f"[start_collection_setup_job] Started job {job_id} with Celery task {task_id}" + f"[start_collection_batch_job] Started job {job_id} with Celery task {task_id}" ) return task_id diff --git a/backend/app/crud/rag/open_ai.py b/backend/app/crud/rag/open_ai.py index 07a9e671c..5757a16ee 100644 --- a/backend/app/crud/rag/open_ai.py +++ b/backend/app/crud/rag/open_ai.py @@ -122,22 +122,25 @@ def update( files=[], file_ids=[doc.openai_file_id for doc in docs], ) - logger.info( - f"[OpenAIVectorStoreCrud.update] Batch complete | " - f"{{'vector_store_id': '{vector_store_id}', " - f"'completed': {batch.file_counts.completed}, 'failed': {batch.file_counts.failed}}}" - ) - if batch.file_counts.failed > 0: - logger.warning( - f"[OpenAIVectorStoreCrud.update] Batch had failures | " - f"{{'vector_store_id': '{vector_store_id}', 'failed_count': {batch.file_counts.failed}}}" - ) except OpenAIError as err: logger.error( f"[OpenAIVectorStoreCrud.update] Batch attach failed | " f"{{'vector_store_id': '{vector_store_id}', 'error': '{str(err)}'}}", exc_info=True, ) + raise + + logger.info( + f"[OpenAIVectorStoreCrud.update] Batch complete | " + f"{{'vector_store_id': '{vector_store_id}', " + f"'completed': {batch.file_counts.completed}, 'failed': {batch.file_counts.failed}}}" + ) + if batch.file_counts.failed > 0: + raise RuntimeError( + f"Batch attach to vector store {vector_store_id!r} completed with " + f"{batch.file_counts.failed} failed file(s) " + f"({batch.file_counts.completed} succeeded)" + ) def delete(self, vector_store_id: str, retries: int = 3): if retries < 1: diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py index 61e7c6374..3612f8747 100644 --- a/backend/app/services/collections/providers/openai.py +++ b/backend/app/services/collections/providers/openai.py @@ -36,6 +36,7 @@ def upload_files( for doc in docs: if self.get_existing_file_id(doc): continue + try: content = storage.get(doc.object_store_url) if doc.file_size_kb is None: @@ -43,7 +44,17 @@ def upload_files( f_obj = BytesIO(content) f_obj.name = doc.fname uploaded = self.client.files.create(file=f_obj, purpose="assistants") - doc.openai_file_id = uploaded.id + except Exception as err: + logger.error( + "[OpenAIProvider.upload_files] Failed to upload file | doc_id=%s, error=%s", + doc.id, + str(err), + exc_info=True, + ) + raise + + doc.openai_file_id = uploaded.id + try: with Session(engine) as session: document_crud = DocumentCrud(session, project_id) db_doc = document_crud.read_one(doc.id) @@ -52,11 +63,30 @@ def upload_files( document_crud.update(db_doc) except Exception as err: logger.error( - "[OpenAIProvider.upload_files] Failed to upload file | doc_id=%s, error=%s", + "[OpenAIProvider.upload_files] DB persistence failed, rolling back OpenAI file | " + "doc_id=%s, openai_file_id=%s, error=%s", doc.id, + uploaded.id, str(err), exc_info=True, ) + try: + self.client.files.delete(uploaded.id) + logger.info( + "[OpenAIProvider.upload_files] Rolled back OpenAI file | " + "doc_id=%s, openai_file_id=%s", + doc.id, + uploaded.id, + ) + except Exception as delete_err: + logger.error( + "[OpenAIProvider.upload_files] Rollback failed, file is orphaned | " + "doc_id=%s, openai_file_id=%s, error=%s", + doc.id, + uploaded.id, + str(delete_err), + ) + doc.openai_file_id = None raise def create( @@ -72,10 +102,6 @@ def create( if vector_store_id is None: vector_store = vector_store_crud.create() vector_store_id = vector_store.id - logger.info( - "[OpenAIProvider.create] Vector store created | vector_store_id=%s", - vector_store_id, - ) if docs: vector_store_crud.update(vector_store_id, docs) diff --git a/backend/app/tests/services/collections/providers/test_openai_provider.py b/backend/app/tests/services/collections/providers/test_openai_provider.py index 838dea26a..24c726454 100644 --- a/backend/app/tests/services/collections/providers/test_openai_provider.py +++ b/backend/app/tests/services/collections/providers/test_openai_provider.py @@ -3,7 +3,9 @@ from uuid import uuid4 import pytest +from openai import OpenAIError +from app.crud.rag.open_ai import OpenAIVectorStoreCrud from app.services.collections.providers.openai import OpenAIProvider from app.models.collection import Collection from app.services.collections.helpers import get_service_name @@ -376,6 +378,37 @@ def test_upload_files_raises_on_db_update_failure() -> None: with pytest.raises(RuntimeError, match="DB write failed"): provider.upload_files(storage, [doc], project_id=1) + client.files.delete.assert_called_once_with("file-ok") + assert doc.openai_file_id is None + + +def test_upload_files_db_failure_rollback_delete_error_still_raises_original() -> None: + """If both DB persistence and the rollback delete fail, the original DB error propagates.""" + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-ok") + client.files.delete.side_effect = RuntimeError("delete failed") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"content" + + doc = _make_doc(file_size_kb=1.0) + mock_crud = MagicMock() + mock_crud.read_one.return_value = MagicMock() + mock_crud.update.side_effect = RuntimeError("DB write failed") + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p as MockDocCrud: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + MockDocCrud.return_value = mock_crud + + with pytest.raises(RuntimeError, match="DB write failed"): + provider.upload_files(storage, [doc], project_id=1) + + client.files.delete.assert_called_once_with("file-ok") + assert doc.openai_file_id is None + def test_upload_files_first_failure_stops_remaining_docs() -> None: """If the first doc raises, subsequent docs are never attempted.""" @@ -398,6 +431,89 @@ def test_upload_files_first_failure_stops_remaining_docs() -> None: assert storage.get.call_count == 1 +# --------------------------------------------------------------------------- +# OpenAIVectorStoreCrud.update +# --------------------------------------------------------------------------- + + +def _make_batch(completed: int, failed: int) -> MagicMock: + batch = MagicMock() + batch.file_counts.completed = completed + batch.file_counts.failed = failed + return batch + + +def _make_openai_doc(file_id: str = "file-abc") -> MagicMock: + doc = MagicMock() + doc.openai_file_id = file_id + return doc + + +def test_vector_store_update_skips_when_no_docs() -> None: + client = MagicMock() + crud = OpenAIVectorStoreCrud(client) + crud.update("vs_123", []) + client.vector_stores.file_batches.upload_and_poll.assert_not_called() + + +def test_vector_store_update_succeeds_with_no_failures() -> None: + client = MagicMock() + client.vector_stores.file_batches.upload_and_poll.return_value = _make_batch( + completed=3, failed=0 + ) + crud = OpenAIVectorStoreCrud(client) + crud.update("vs_123", [_make_openai_doc() for _ in range(3)]) + client.vector_stores.file_batches.upload_and_poll.assert_called_once() + + +def test_vector_store_update_raises_on_openai_error() -> None: + client = MagicMock() + client.vector_stores.file_batches.upload_and_poll.side_effect = OpenAIError( + "rate limit" + ) + crud = OpenAIVectorStoreCrud(client) + + with pytest.raises(OpenAIError, match="rate limit"): + crud.update("vs_123", [_make_openai_doc()]) + + +def test_vector_store_update_raises_on_partial_failures() -> None: + client = MagicMock() + client.vector_stores.file_batches.upload_and_poll.return_value = _make_batch( + completed=2, failed=1 + ) + crud = OpenAIVectorStoreCrud(client) + + with pytest.raises(RuntimeError, match="1 failed file"): + crud.update("vs_123", [_make_openai_doc() for _ in range(3)]) + + +def test_vector_store_update_raises_on_all_failures() -> None: + client = MagicMock() + client.vector_stores.file_batches.upload_and_poll.return_value = _make_batch( + completed=0, failed=2 + ) + crud = OpenAIVectorStoreCrud(client) + + with pytest.raises(RuntimeError, match="2 failed file"): + crud.update("vs_123", [_make_openai_doc() for _ in range(2)]) + + +def test_vector_store_update_passes_file_ids_to_openai() -> None: + client = MagicMock() + client.vector_stores.file_batches.upload_and_poll.return_value = _make_batch( + completed=2, failed=0 + ) + crud = OpenAIVectorStoreCrud(client) + docs = [_make_openai_doc("file-1"), _make_openai_doc("file-2")] + + crud.update("vs_abc", docs) + + _, kwargs = client.vector_stores.file_batches.upload_and_poll.call_args + assert kwargs["vector_store_id"] == "vs_abc" + assert kwargs["file_ids"] == ["file-1", "file-2"] + + # --------------------------------------------------------------------------- # create (existing tests below) # --------------------------------------------------------------------------- diff --git a/backend/app/tests/services/collections/test_create_collection.py b/backend/app/tests/services/collections/test_create_collection.py index a5f9356a7..840ac3416 100644 --- a/backend/app/tests/services/collections/test_create_collection.py +++ b/backend/app/tests/services/collections/test_create_collection.py @@ -301,6 +301,51 @@ def test_execute_setup_job_timeout_marks_failed_and_reraises( assert "soft time limit" in (updated_job.error_message or "") +@patch("app.services.collections.create_collection.get_cloud_storage") +@patch("app.services.collections.create_collection.get_llm_provider") +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_setup_job_soft_time_limit_marks_failed_and_reraises( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + mock_get_storage: MagicMock, + db: Session, +) -> None: + project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_provider = _mock_provider_with_size("vs_123", "openai vector store") + mock_provider.upload_files.side_effect = SoftTimeLimitExceeded() + mock_get_provider.return_value = mock_provider + + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, + ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) + + patcher = _patch_session(db) + try: + with pytest.raises(SoftTimeLimitExceeded): + execute_setup_job( + request=request.model_dump(mode="json"), + with_assistant=False, + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + task_instance=None, + ) + finally: + patcher.stop() + + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.FAILED + assert "soft time limit" in (updated_job.error_message or "") + + # --------------------------------------------------------------------------- # execute_batch_job # --------------------------------------------------------------------------- @@ -607,6 +652,51 @@ def test_execute_batch_job_timeout_marks_failed_and_reraises( assert "soft time limit" in (updated_job.error_message or "") +@patch("app.services.collections.create_collection.get_llm_provider") +def test_execute_batch_job_soft_time_limit_marks_failed_and_reraises( + mock_get_provider: MagicMock, + db: Session, +) -> None: + project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_provider = get_mock_provider("vs_123", "openai vector store") + mock_provider.create.side_effect = SoftTimeLimitExceeded() + mock_get_provider.return_value = mock_provider + + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PROCESSING, + ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) + + patcher = _patch_session(db) + try: + with pytest.raises(SoftTimeLimitExceeded): + execute_batch_job( + request=request.model_dump(mode="json"), + with_assistant=False, + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], + ) + finally: + patcher.stop() + + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.FAILED + assert "soft time limit" in (updated_job.error_message or "") + + @patch("app.services.collections.create_collection.send_callback") @patch("app.services.collections.create_collection.get_llm_provider") def test_execute_batch_job_failure_sends_callback( From 510912ccd1415420919ced9098a59597472205ae Mon Sep 17 00:00:00 2001 From: nishika26 Date: Wed, 13 May 2026 14:39:41 +0530 Subject: [PATCH 15/19] removing assistant --- backend/app/api/routes/collections.py | 20 +-- backend/app/api/routes/documents.py | 16 +-- backend/app/crud/collection/collection.py | 30 +++++ backend/app/crud/rag/__init__.py | 2 +- backend/app/crud/rag/open_ai.py | 65 --------- backend/app/models/collection.py | 127 +----------------- .../services/collections/create_collection.py | 15 --- backend/app/services/collections/helpers.py | 67 ++------- .../services/collections/providers/base.py | 11 +- .../services/collections/providers/openai.py | 73 ++-------- .../collections/test_create_collections.py | 94 +------------ .../collection/test_crud_collection_delete.py | 44 +++--- .../providers/test_openai_provider.py | 83 +----------- .../collections/test_create_collection.py | 14 -- 14 files changed, 85 insertions(+), 576 deletions(-) diff --git a/backend/app/api/routes/collections.py b/backend/app/api/routes/collections.py index 96272d38c..f6c7016fb 100644 --- a/backend/app/api/routes/collections.py +++ b/backend/app/api/routes/collections.py @@ -117,32 +117,16 @@ def create_collection( ) ) - # True if both model and instructions were provided in the request body - with_assistant = bool( - getattr(request, "model", None) and getattr(request, "instructions", None) - ) - create_service.start_job( db=session, request=request, collection_job_id=collection_job.id, project_id=current_user.project_.id, organization_id=current_user.organization_.id, - with_assistant=with_assistant, ) - metadata = None - if not with_assistant: - metadata = { - "note": ( - "This job will create a vector store only (no Assistant). " - "Assistant creation happens when both 'model' and 'instructions' are included." - ) - } - return APIResponse.success_response( CollectionJobImmediatePublic.model_validate(collection_job), - metadata=metadata, ) @@ -171,7 +155,9 @@ def delete_collection( if request and request.callback_url: validate_callback_url(str(request.callback_url)) - _ = CollectionCrud(session, current_user.project_.id).read_one(collection_id) + _ = CollectionCrud(session, current_user.project_.id).read_one_if_delete( + collection_id + ) deletion_request = DeletionRequest( collection_id=collection_id, diff --git a/backend/app/api/routes/documents.py b/backend/app/api/routes/documents.py index eb7d41ca3..6d17bc104 100644 --- a/backend/app/api/routes/documents.py +++ b/backend/app/api/routes/documents.py @@ -17,7 +17,7 @@ from app.api.deps import AuthContextDep, SessionDep from app.api.permissions import Permission, require_permission from app.crud import CollectionCrud, DocumentCrud -from app.crud.rag import OpenAIAssistantCrud, OpenAIVectorStoreCrud +from app.crud.rag import OpenAIVectorStoreCrud from app.models import ( Document, DocumentPublic, @@ -28,7 +28,7 @@ DocTransformationJobPublic, ) from app.core.cloud import get_cloud_storage -from app.services.collections.helpers import pick_service_for_documennt, MAX_DOC_SIZE_MB +from app.services.collections.helpers import MAX_DOC_SIZE_MB from app.services.documents.helpers import ( calculate_file_size, schedule_transformation, @@ -197,16 +197,12 @@ def remove_doc( session, current_user.organization_.id, current_user.project_.id ) - a_crud = OpenAIAssistantCrud(client) v_crud = OpenAIVectorStoreCrud(client) d_crud = DocumentCrud(session, current_user.project_.id) c_crud = CollectionCrud(session, current_user.project_.id) document = d_crud.read_one(doc_id) - remote = pick_service_for_documennt( - session, doc_id, a_crud, v_crud - ) # assistant crud or vector store crud - c_crud.delete(document, remote) + c_crud.delete(document, v_crud) d_crud.delete(doc_id) return APIResponse.success_response( @@ -228,7 +224,6 @@ def permanent_delete_doc( client = get_openai_client( session, current_user.organization_.id, current_user.project_.id ) - a_crud = OpenAIAssistantCrud(client) v_crud = OpenAIVectorStoreCrud(client) d_crud = DocumentCrud(session, current_user.project_.id) c_crud = CollectionCrud(session, current_user.project_.id) @@ -236,10 +231,7 @@ def permanent_delete_doc( document = d_crud.read_one(doc_id) - remote = pick_service_for_documennt( - session, doc_id, a_crud, v_crud - ) # assistant crud or vector store crud - c_crud.delete(document, remote) + c_crud.delete(document, v_crud) storage.delete(document.object_store_url) d_crud.delete(doc_id) diff --git a/backend/app/crud/collection/collection.py b/backend/app/crud/collection/collection.py index cb8b6e27d..ac0dd9497 100644 --- a/backend/app/crud/collection/collection.py +++ b/backend/app/crud/collection/collection.py @@ -82,6 +82,31 @@ def read_one(self, collection_id: UUID) -> Collection: ) return collection + def read_one_if_delete(self, collection_id: UUID) -> Collection: + statement = select(Collection).where( + and_( + Collection.project_id == self.project_id, + Collection.id == collection_id, + ) + ) + + collection = self.session.exec(statement).one_or_none() + if collection is None: + logger.warning( + "[CollectionCrud.read_one_if_delete] Collection not found | " + f"{{'project_id': '{self.project_id}', 'collection_id': '{collection_id}'}}" + ) + raise HTTPException(status_code=404, detail="Collection not found") + + if collection.deleted_at is not None: + logger.warning( + "[CollectionCrud.read_one_if_delete] Collection already deleted | " + f"{{'project_id': '{self.project_id}', 'collection_id': '{collection_id}'}}" + ) + raise HTTPException(status_code=400, detail="Collection already deleted") + + return collection + def read_all(self): statement = select(Collection).where( and_( @@ -122,6 +147,11 @@ def delete(self, model, remote): # remote should be an OpenAICrud @delete.register def _(self, model: Collection, remote): + if model.deleted_at is not None: + logger.info( + f"[CollectionCrud.delete] Collection already deleted | {{'collection_id': '{model.id}'}}" + ) + return model remote.delete(model.llm_service_id) model.deleted_at = now() collection = self._update(model) diff --git a/backend/app/crud/rag/__init__.py b/backend/app/crud/rag/__init__.py index 6d6586427..9c1a42723 100644 --- a/backend/app/crud/rag/__init__.py +++ b/backend/app/crud/rag/__init__.py @@ -1 +1 @@ -from .open_ai import OpenAICrud, OpenAIVectorStoreCrud, OpenAIAssistantCrud +from .open_ai import OpenAICrud, OpenAIVectorStoreCrud diff --git a/backend/app/crud/rag/open_ai.py b/backend/app/crud/rag/open_ai.py index 5757a16ee..bb5f79aff 100644 --- a/backend/app/crud/rag/open_ai.py +++ b/backend/app/crud/rag/open_ai.py @@ -66,14 +66,6 @@ def clean(self, resource): raise NotImplementedError() -class AssistantCleaner(ResourceCleaner): - def clean(self, resource): - logger.info( - f"[AssistantCleaner.clean] Deleting assistant | {{'assistant_id': '{resource}'}}" - ) - self.client.beta.assistants.delete(resource) - - class VectorStoreCleaner(ResourceCleaner): def clean(self, resource): logger.info( @@ -158,60 +150,3 @@ def delete(self, vector_store_id: str, retries: int = 3): logger.info( f"[OpenAIVectorStoreCrud.delete] Vector store deleted | {{'vector_store_id': '{vector_store_id}'}}" ) - - -class OpenAIAssistantCrud(OpenAICrud): - def create(self, vector_store_id: str, **kwargs): - logger.info( - f"[OpenAIAssistantCrud.create] Creating assistant | {{'vector_store_id': '{vector_store_id}'}}" - ) - assistant = self.client.beta.assistants.create( - tools=[ - { - "type": "file_search", - } - ], - tool_resources={ - "file_search": { - "vector_store_ids": [ - vector_store_id, - ], - }, - }, - **kwargs, - ) - logger.info( - f"[OpenAIAssistantCrud.create] Assistant created | {{'assistant_id': '{assistant.id}', 'vector_store_id': '{vector_store_id}'}}" - ) - return assistant - - def delete(self, assistant_id: str): - logger.info( - f"[OpenAIAssistantCrud.delete] Starting assistant deletion | {{'assistant_id': '{assistant_id}'}}" - ) - assistant = self.client.beta.assistants.retrieve(assistant_id) - vector_stores = assistant.tool_resources.file_search.vector_store_ids - - try: - (vector_store_id,) = vector_stores - except ValueError: - if vector_stores: - names = ", ".join(vector_stores) - err = ValueError(f"Too many attached vector stores: {names}") - else: - err = ValueError("No vector stores found") - - logger.error( - f"[OpenAIAssistantCrud.delete] Invalid vector store state | {{'assistant_id': '{assistant_id}', 'vector_stores': '{vector_stores}'}}", - exc_info=True, - ) - raise err - - v_crud = OpenAIVectorStoreCrud(self.client) - v_crud.delete(vector_store_id) - - cleaner = AssistantCleaner(self.client) - cleaner(assistant_id) - logger.info( - f"[OpenAIAssistantCrud.delete] Assistant deleted | {{'assistant_id': '{assistant_id}'}}" - ) diff --git a/backend/app/models/collection.py b/backend/app/models/collection.py index ccd606deb..9957cf7b0 100644 --- a/backend/app/models/collection.py +++ b/backend/app/models/collection.py @@ -3,7 +3,7 @@ from typing import Any, Literal from uuid import UUID, uuid4 -from pydantic import HttpUrl, model_validator, model_serializer +from pydantic import HttpUrl from sqlalchemy import Index, text from sqlmodel import Field, Relationship, SQLModel @@ -105,62 +105,6 @@ def model_post_init(self, __context: Any): self.documents = list(set(self.documents)) -class AssistantOptions(SQLModel): - # Fields to be passed along to OpenAI. They must be a subset of - # parameters accepted by the OpenAI.clien.beta.assistants.create - # API. - model: str | None = Field( - default=None, - description=( - "**[Deprecated]** " - "OpenAI model to attach to this assistant. The model " - "must be compatable with the assistants API; see the " - "OpenAI [model documentation](https://platform.openai.com/docs/models/compare) for more." - ), - ) - - instructions: str | None = Field( - default=None, - description=( - "**[Deprecated]** " - "Assistant instruction. Sometimes referred to as the " - '"system" prompt.' - ), - ) - temperature: float = Field( - default=1e-6, - description=( - "**[Deprecated]** " - "Model temperature. The default is slightly " - "greater-than zero because it is [unknown how OpenAI " - "handles zero](https://community.openai.com/t/clarifications-on-setting-temperature-0/886447/5)." - ), - ) - - @model_validator(mode="before") - def _assistant_fields_all_or_none(cls, values: dict[str, Any]) -> dict[str, Any]: - def norm(x: Any) -> Any: - if x is None: - return None - if isinstance(x, str): - s = x.strip() - return s if s else None - return x # let Pydantic handle non-strings - - model = norm(values.get("model")) - instructions = norm(values.get("instructions")) - - if (model is None) ^ (instructions is None): - raise ValueError( - "To create an Assistant, provide BOTH 'model' and 'instructions'. " - "If you only want a vector store, remove both fields." - ) - - values["model"] = model - values["instructions"] = instructions - return values - - class CallbackRequest(SQLModel): callback_url: HttpUrl | None = Field( default=None, @@ -177,7 +121,6 @@ class ProviderOptions(SQLModel): class CreationRequest( - AssistantOptions, CollectionOptions, ProviderOptions, CallbackRequest, @@ -201,21 +144,11 @@ class CollectionIDPublic(SQLModel): class CollectionPublic(SQLModel): id: UUID - llm_service_id: str | None = Field( - default=None, - description="LLM service ID (e.g., Assistant ID) when model and instructions were provided", - ) - llm_service_name: str | None = Field( - default=None, - description="LLM service name (e.g., model name) when model and instructions were provided", - ) - knowledge_base_id: str | None = Field( - default=None, - description="Knowledge base ID (e.g., Vector Store ID) when only vector store was created", + knowledge_base_id: str = Field( + description="Knowledge base ID (e.g., Vector Store ID)", ) - knowledge_base_provider: str | None = Field( - default=None, - description="Knowledge base provider name when only vector store was created", + knowledge_base_provider: str = Field( + description="Knowledge base provider name", ) project_id: int @@ -223,56 +156,6 @@ class CollectionPublic(SQLModel): updated_at: datetime deleted_at: datetime | None = None - @model_validator(mode="after") - def validate_service_fields(self) -> "CollectionPublic": - """Ensure either LLM service fields or knowledge base fields are set, not both.""" - has_llm = self.llm_service_id is not None or self.llm_service_name is not None - has_kb = ( - self.knowledge_base_id is not None - or self.knowledge_base_provider is not None - ) - - if has_llm and has_kb: - raise ValueError( - "Cannot have both LLM service fields and knowledge base fields set" - ) - - if not has_llm and not has_kb: - raise ValueError( - "Either LLM service fields or knowledge base fields must be set" - ) - - # Ensure both fields in the pair are set or both are None - if has_llm and ( - (self.llm_service_id is None) != (self.llm_service_name is None) - ): - raise ValueError("Both llm_service_id and llm_service_name must be set") - - if has_kb and ( - (self.knowledge_base_id is None) != (self.knowledge_base_provider is None) - ): - raise ValueError( - "Both knowledge_base_id and knowledge_base_provider must be set" - ) - - return self - - @model_serializer(mode="wrap", when_used="json") - def _serialize_model(self, serializer: Any, info: Any) -> dict[str, Any]: - """Exclude unused service fields from JSON serialization.""" - data = serializer(self) - - # If this is a knowledge base, remove llm_service fields - if data.get("knowledge_base_id") is not None: - data.pop("llm_service_id", None) - data.pop("llm_service_name", None) - # If this is an assistant, remove knowledge_base fields - elif data.get("llm_service_id") is not None: - data.pop("knowledge_base_id", None) - data.pop("knowledge_base_provider", None) - - return data - class CollectionWithDocsPublic(CollectionPublic): documents: list[DocumentPublic] | None = None diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 52852e63c..43056fe87 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -44,7 +44,6 @@ def start_job( request: CreationRequest, project_id: int, collection_job_id: UUID, - with_assistant: bool, organization_id: int, ) -> str: trace_id = correlation_id.get() or "N/A" @@ -57,7 +56,6 @@ def start_job( job_id=str(collection_job_id), trace_id=trace_id, request=request.model_dump(mode="json"), - with_assistant=with_assistant, organization_id=organization_id, ) @@ -162,7 +160,6 @@ def _handle_job_failure( def execute_setup_job( request: dict, - with_assistant: bool, project_id: int, organization_id: int, task_id: str, @@ -191,9 +188,6 @@ def execute_setup_job( try: creation_request = CreationRequest(**request) - if with_assistant: - creation_request.provider = "openai" - span.set_attribute("collection.provider", str(creation_request.provider)) job_uuid = UUID(job_id) @@ -262,7 +256,6 @@ def execute_setup_job( remaining_batches=batch_doc_ids[1:], request=request, vector_store_id=None, - with_assistant=with_assistant, organization_id=organization_id, ) @@ -310,7 +303,6 @@ def execute_setup_job( def execute_batch_job( request: dict, - with_assistant: bool, project_id: int, organization_id: int, task_id: str, @@ -349,9 +341,6 @@ def execute_batch_job( try: batch_start_time = time.time() creation_request = CreationRequest(**request) - if with_assistant: - creation_request.provider = "openai" - span.set_attribute("collection.provider", str(creation_request.provider)) job_uuid = UUID(job_id) @@ -367,7 +356,6 @@ def execute_batch_job( ) all_doc_ids_this_batch = [UUID(d) for d in batch_doc_ids] - is_final = not remaining_batches with Session(engine) as session: provider = get_llm_provider( @@ -388,10 +376,8 @@ def execute_batch_job( session.expunge(doc) collection_result = provider.create( - creation_request, batch_docs, vector_store_id=vector_store_id, - is_final=is_final, ) result = collection_result resolved_vector_store_id = collection_result.llm_service_id @@ -430,7 +416,6 @@ def execute_batch_job( batch_doc_ids=remaining_batches[0], remaining_batches=remaining_batches[1:], request=request, - with_assistant=with_assistant, organization_id=organization_id, ) logger.info( diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 3f0a0cefd..8e01a8bac 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -2,14 +2,12 @@ import json import ast import re -from uuid import UUID from fastapi import HTTPException -from sqlmodel import select from app.crud import CollectionCrud from app.api.deps import SessionDep -from app.models import DocumentCollection, Collection, CollectionPublic, Document +from app.models import Collection, CollectionPublic, Document logger = logging.getLogger(__name__) @@ -111,30 +109,6 @@ def batch_documents(documents: list[Document]) -> list[list[Document]]: return docs_batches -# Even though this function is used in the documents router, it's kept here for now since the assistant creation logic will -# eventually be removed from Kaapi. Once that happens, this function can be safely deleted - -def pick_service_for_documennt(session, doc_id: UUID, a_crud, v_crud): - """ - Return the correct remote (v_crud or a_crud) for this document - by inspecting an active linked Collection's llm_service_name. - Defaults to a_crud if not vector store. - """ - coll = session.exec( - select(Collection) - .join(DocumentCollection, DocumentCollection.collection_id == Collection.id) - .where( - DocumentCollection.document_id == doc_id, - Collection.deleted_at.is_(None), - ) - .limit(1) - ).first() - - service = ( - (getattr(coll, "llm_service_name", "") or "").strip().lower() if coll else "" - ) - return v_crud if service == get_service_name("openai") else a_crud - - def ensure_unique_name( session: SessionDep, project_id: int, @@ -155,35 +129,12 @@ def ensure_unique_name( def to_collection_public(collection: Collection) -> CollectionPublic: - """ - Convert a Collection DB model to CollectionPublic response model. - - Maps fields based on service type: - - If llm_service_name is a vector store (matches get_service_name pattern), - use knowledge_base_id/knowledge_base_provider - - Otherwise (assistant), use llm_service_id/llm_service_name - """ - is_vector_store = collection.llm_service_name == get_service_name( - collection.provider + return CollectionPublic( + id=collection.id, + knowledge_base_id=collection.llm_service_id, + knowledge_base_provider=collection.llm_service_name, + project_id=collection.project_id, + inserted_at=collection.inserted_at, + updated_at=collection.updated_at, + deleted_at=collection.deleted_at, ) - - if is_vector_store: - return CollectionPublic( - id=collection.id, - knowledge_base_id=collection.llm_service_id, - knowledge_base_provider=collection.llm_service_name, - project_id=collection.project_id, - inserted_at=collection.inserted_at, - updated_at=collection.updated_at, - deleted_at=collection.deleted_at, - ) - else: - return CollectionPublic( - id=collection.id, - llm_service_id=collection.llm_service_id, - llm_service_name=collection.llm_service_name, - project_id=collection.project_id, - inserted_at=collection.inserted_at, - updated_at=collection.updated_at, - deleted_at=collection.deleted_at, - ) diff --git a/backend/app/services/collections/providers/base.py b/backend/app/services/collections/providers/base.py index 6649a0725..38355f7d5 100644 --- a/backend/app/services/collections/providers/base.py +++ b/backend/app/services/collections/providers/base.py @@ -2,7 +2,7 @@ from typing import Any from app.core.cloud.storage import CloudStorage -from app.models import CreationRequest, Collection, Document +from app.models import Collection, Document class BaseProvider(ABC): @@ -11,8 +11,7 @@ class BaseProvider(ABC): All provider implementations (OpenAI, Bedrock, etc.) must inherit from this class and implement the required methods. - Providers handle creation of collection and - optional assistant/agent creation backed by those collections. + Providers handle creation of vector store collections. Attributes: client: The provider-specific client instance @@ -40,15 +39,11 @@ def upload_files( @abstractmethod def create( self, - collection_request: CreationRequest, docs: list[Document], vector_store_id: str | None = None, - is_final: bool = False, ) -> Collection: """Upload docs batch to vector store (creating it if vector_store_id is None). - Creates assistant only when is_final=True and model/instructions are set. - Returns Collection with llm_service_id set to vector_store_id on intermediate batches, - or to assistant/vector_store id on the final batch.""" + Returns Collection with llm_service_id set to the vector store ID.""" raise NotImplementedError("Providers must implement create method") @abstractmethod diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py index 3612f8747..dbcbcc7fd 100644 --- a/backend/app/services/collections/providers/openai.py +++ b/backend/app/services/collections/providers/openai.py @@ -9,9 +9,9 @@ from app.core.cloud.storage import CloudStorage from app.core.db import engine from app.crud import DocumentCrud -from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud +from app.crud.rag import OpenAIVectorStoreCrud from app.services.collections.helpers import get_service_name -from app.models import CreationRequest, Collection, Document +from app.models import Collection, Document logger = logging.getLogger(__name__) @@ -91,10 +91,8 @@ def upload_files( def create( self, - collection_request: CreationRequest, docs: List[Document], vector_store_id: str | None = None, - is_final: bool = False, ) -> Collection: try: vector_store_crud = OpenAIVectorStoreCrud(self.client) @@ -111,49 +109,10 @@ def create( len(docs), ) - if not is_final: - return Collection( - llm_service_id=vector_store_id, - llm_service_name=get_service_name("openai"), - ) - # if "is_final" is true then only will assistant creation happen - - with_assistant = ( - collection_request.model is not None - and collection_request.instructions is not None + return Collection( + llm_service_id=vector_store_id, + llm_service_name=get_service_name("openai"), ) - if with_assistant: - assistant_crud = OpenAIAssistantCrud(self.client) - - assistant_options = { - "model": collection_request.model, - "instructions": collection_request.instructions, - "temperature": collection_request.temperature, - } - filtered_options = { - k: v for k, v in assistant_options.items() if v is not None - } - - assistant = assistant_crud.create(vector_store_id, **filtered_options) - - logger.info( - "[OpenAIProvider.create] Assistant created | assistant_id=%s, vector_store_id=%s", - assistant.id, - vector_store_id, - ) - - return Collection( - llm_service_id=assistant.id, - llm_service_name=filtered_options.get("model", "assistant"), - ) - else: - logger.info( - "[OpenAIProvider.create] Skipping assistant creation | with_assistant=False" - ) - - return Collection( - llm_service_id=vector_store_id, - llm_service_name=get_service_name("openai"), - ) except Exception as e: logger.error( @@ -163,26 +122,14 @@ def create( raise def delete(self, collection: Collection) -> None: - """Delete OpenAI resources (assistant or vector store). - - Determines what to delete based on llm_service_name: - - If assistant was created, delete the assistant (which also removes the vector store) - - If only vector store was created, delete the vector store - """ try: - if collection.llm_service_name != get_service_name("openai"): - OpenAIAssistantCrud(self.client).delete(collection.llm_service_id) - logger.info( - f"[OpenAIProvider.delete] Deleted assistant | assistant_id={collection.llm_service_id}" - ) - else: - OpenAIVectorStoreCrud(self.client).delete(collection.llm_service_id) - logger.info( - f"[OpenAIProvider.delete] Deleted vector store | vector_store_id={collection.llm_service_id}" - ) + OpenAIVectorStoreCrud(self.client).delete(collection.llm_service_id) + logger.info( + f"[OpenAIProvider.delete] Deleted vector store | vector_store_id={collection.llm_service_id}" + ) except Exception as e: logger.error( - f"[OpenAIProvider.delete] Failed to delete resource | " + f"[OpenAIProvider.delete] Failed to delete vector store | " f"llm_service_id={collection.llm_service_id}, error={str(e)}", exc_info=True, ) diff --git a/backend/app/tests/api/routes/collections/test_create_collections.py b/backend/app/tests/api/routes/collections/test_create_collections.py index b51631939..5934394b3 100644 --- a/backend/app/tests/api/routes/collections/test_create_collections.py +++ b/backend/app/tests/api/routes/collections/test_create_collections.py @@ -11,14 +11,9 @@ from app.models.collection import CreationRequest -def _extract_metadata(body: dict) -> dict | None: - return body.get("metadata") or body.get("meta") - - def _create_test_document( db: Session, project_id: int, file_size: float = 1 ) -> Document: - """Helper to create a test document.""" doc = Document( id=uuid4(), fname="test_document.txt", @@ -33,20 +28,16 @@ def _create_test_document( @patch("app.api.routes.collections.create_service.start_job") -def test_collection_creation_with_assistant_calls_start_job_and_returns_job( +def test_collection_creation_calls_start_job_and_returns_job( mock_start_job: Any, client: TestClient, user_api_key_header: dict[str, str], user_api_key: TestAuthContext, db: Session, ) -> None: - # Create a test document in the database doc = _create_test_document(db, user_api_key.project_id, file_size=2) creation_data = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, documents=[doc.id], callback_url=None, ) @@ -65,14 +56,12 @@ def test_collection_creation_with_assistant_calls_start_job_and_returns_job( assert data["job_inserted_at"] assert data["job_updated_at"] - assert _extract_metadata(body) in (None, {}) - mock_start_job.assert_called_once() kwargs = mock_start_job.call_args.kwargs assert "db" in kwargs assert kwargs["project_id"] == user_api_key.project_id assert kwargs["organization_id"] == user_api_key.organization_id - assert kwargs["with_assistant"] is True + assert "with_assistant" not in kwargs returned_job_id = UUID(data["job_id"]) assert kwargs["collection_job_id"] == returned_job_id @@ -80,82 +69,3 @@ def test_collection_creation_with_assistant_calls_start_job_and_returns_job( assert kwargs["request"].model_dump(mode="json") == creation_data.model_dump( mode="json" ) - - -@patch("app.api.routes.collections.create_service.start_job") -def test_collection_creation_vector_only_adds_metadata_and_sets_with_assistant_false( - mock_start_job: Any, - client: TestClient, - user_api_key_header: dict[str, str], - user_api_key: TestAuthContext, - db: Session, -) -> None: - # Create a test document in the database - doc = _create_test_document(db, user_api_key.project_id, file_size=5) - - creation_data = CreationRequest( - temperature=0.000001, - documents=[doc.id], - callback_url=None, - ) - - resp = client.post( - f"{settings.API_V1_STR}/collections", - json=creation_data.model_dump(mode="json"), - headers=user_api_key_header, - ) - - assert resp.status_code == 200 - body = resp.json() - - data = body["data"] - assert data["status"] == CollectionJobStatus.PENDING - - meta = _extract_metadata(body) - assert isinstance(meta, dict) - assert "vector store only" in meta.get("note", "").lower() - - mock_start_job.assert_called_once() - kwargs = mock_start_job.call_args.kwargs - assert kwargs["project_id"] == user_api_key.project_id - assert kwargs["organization_id"] == user_api_key.organization_id - assert kwargs["with_assistant"] is False - assert kwargs["collection_job_id"] == UUID(data["job_id"]) - assert kwargs["request"].model_dump(mode="json") == creation_data.model_dump( - mode="json" - ) - - -def test_collection_creation_vector_only_request_validation_error( - client: TestClient, - user_api_key_header: dict[str, str], - user_api_key: TestAuthContext, - db: Session, -) -> None: - # Create a test document in the database - doc = _create_test_document(db, user_api_key.project_id) - - payload = { - "model": "gpt-4o", - "temperature": 0.000001, - "documents": [str(doc.id)], - "callback_url": None, - } - - resp = client.post( - f"{settings.API_V1_STR}/collections", - json=payload, - headers=user_api_key_header, - ) - - assert resp.status_code == 422 - body = resp.json() - assert body["success"] is False - assert body["data"] is None - assert body["metadata"] is None - assert body["errors"] - assert any( - "To create an Assistant, provide BOTH 'model' and 'instructions'" - in e["message"] - for e in body["errors"] - ) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py index 5cf4643d6..05eaa5bda 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py @@ -4,29 +4,19 @@ from app.crud import CollectionCrud from app.models import APIKey, Collection, ProviderType -from app.crud.rag import OpenAIAssistantCrud +from app.crud.rag import OpenAIVectorStoreCrud from app.tests.utils.utils import get_project from app.tests.utils.document import DocumentStore -def get_assistant_collection_for_delete( - db: Session, client=None, project_id: int = None +def get_vector_store_collection( + db: Session, client: OpenAI, project_id: int ) -> Collection: - project = get_project(db) - if client is None: - client = OpenAI(api_key="test_api_key") - vector_store = client.vector_stores.create() - assistant = client.beta.assistants.create( - model="gpt-4o", - tools=[{"type": "file_search"}], - tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}}, - ) - return Collection( project_id=project_id, - llm_service_id=assistant.id, - llm_service_name="gpt-4o", + llm_service_id=vector_store.id, + llm_service_name="openai vector store", provider=ProviderType.openai, ) @@ -39,26 +29,24 @@ def test_delete_marks_deleted(self, db: Session) -> None: project = get_project(db) client = OpenAI(api_key="sk-test-key") - assistant = OpenAIAssistantCrud(client) - collection = get_assistant_collection_for_delete( - db, client, project_id=project.id - ) + v_crud = OpenAIVectorStoreCrud(client) + collection = get_vector_store_collection(db, client, project_id=project.id) crud = CollectionCrud(db, collection.project_id) - collection_ = crud.delete(collection, assistant) + collection_ = crud.delete(collection, v_crud) assert collection_.deleted_at is not None @openai_responses.mock() def test_delete_follows_insert(self, db: Session) -> None: + project = get_project(db) client = OpenAI(api_key="sk-test-key") - assistant = OpenAIAssistantCrud(client) - project = get_project(db) - collection = get_assistant_collection_for_delete(db, project_id=project.id) + v_crud = OpenAIVectorStoreCrud(client) + collection = get_vector_store_collection(db, client, project_id=project.id) crud = CollectionCrud(db, collection.project_id) - collection_ = crud.delete(collection, assistant) + collection_ = crud.delete(collection, v_crud) assert collection_.inserted_at <= collection_.deleted_at @@ -76,15 +64,13 @@ def test_delete_document_deletes_collections(self, db: Session) -> None: client = OpenAI(api_key="sk-test-key") resources = [] for _ in range(self._n_collections): - coll = get_assistant_collection_for_delete( - db, client, project_id=project.id - ) + coll = get_vector_store_collection(db, client, project_id=project.id) crud = CollectionCrud(db, project_id=project.id) collection = crud.create(coll, documents) resources.append((crud, collection)) ((crud, _), *_) = resources - assistant = OpenAIAssistantCrud(client) - crud.delete(documents[0], assistant) + v_crud = OpenAIVectorStoreCrud(client) + crud.delete(documents[0], v_crud) assert all(y.deleted_at for (_, y) in resources) diff --git a/backend/app/tests/services/collections/providers/test_openai_provider.py b/backend/app/tests/services/collections/providers/test_openai_provider.py index 24c726454..cedf5ed2f 100644 --- a/backend/app/tests/services/collections/providers/test_openai_provider.py +++ b/backend/app/tests/services/collections/providers/test_openai_provider.py @@ -15,17 +15,10 @@ ) -def test_create_openai_vector_store_only() -> None: +def test_create_openai_vector_store() -> None: client = get_mock_openai_client_with_vector_store() provider = OpenAIProvider(client=client) - collection_request = SimpleNamespace( - documents=["doc1", "doc2"], - model=None, - instructions=None, - temperature=None, - ) - documents = [ SimpleNamespace(file_size_kb=10), SimpleNamespace(file_size_kb=20), @@ -39,73 +32,13 @@ def test_create_openai_vector_store_only() -> None: vector_store_crud.create.return_value = MagicMock(id=vector_store_id) vector_store_crud.update.return_value = None - collection = provider.create( - collection_request, - documents, - ) + collection = provider.create(documents) assert isinstance(collection, Collection) assert collection.llm_service_id == vector_store_id assert collection.llm_service_name == get_service_name("openai") -def test_create_openai_with_assistant() -> None: - client = get_mock_openai_client_with_vector_store() - provider = OpenAIProvider(client=client) - - collection_request = SimpleNamespace( - documents=["doc1"], - model="gpt-4o", - instructions="You are helpful", - temperature=0.7, - ) - - documents = [SimpleNamespace(file_size_kb=10)] - vector_store_id = generate_openai_id("vs_") - assistant_id = generate_openai_id("asst_") - - with patch( - "app.services.collections.providers.openai.OpenAIVectorStoreCrud" - ) as vector_store_crud_cls, patch( - "app.services.collections.providers.openai.OpenAIAssistantCrud" - ) as assistant_crud_cls: - vector_store_crud = vector_store_crud_cls.return_value - vector_store_crud.create.return_value = MagicMock(id=vector_store_id) - vector_store_crud.update.return_value = None - - assistant_crud = assistant_crud_cls.return_value - assistant_crud.create.return_value = MagicMock(id=assistant_id) - - collection = provider.create( - collection_request, - documents, - is_final=True, - ) - - assert collection.llm_service_id == assistant_id - assert collection.llm_service_name == "gpt-4o" - - -def test_delete_openai_assistant() -> None: - client = MagicMock() - provider = OpenAIProvider(client=client) - - collection = Collection( - llm_service_id=generate_openai_id("asst_"), - llm_service_name="gpt-4o", - provider="openai", - project_id=1, - ) - - with patch( - "app.services.collections.providers.openai.OpenAIAssistantCrud" - ) as assistant_crud_cls: - assistant_crud = assistant_crud_cls.return_value - provider.delete(collection) - - assistant_crud.delete.assert_called_once_with(collection.llm_service_id) - - def test_delete_openai_vector_store() -> None: client = MagicMock() provider = OpenAIProvider(client=client) @@ -522,20 +455,10 @@ def test_vector_store_update_passes_file_ids_to_openai() -> None: def test_create_propagates_exception() -> None: provider = OpenAIProvider(client=MagicMock()) - collection_request = SimpleNamespace( - documents=["doc1"], - model=None, - instructions=None, - temperature=None, - ) - with patch( "app.services.collections.providers.openai.OpenAIVectorStoreCrud" ) as vector_store_crud_cls: vector_store_crud_cls.return_value.create.side_effect = RuntimeError("boom") with pytest.raises(RuntimeError): - provider.create( - collection_request, - [SimpleNamespace(file_size_kb=10)], - ) + provider.create([SimpleNamespace(file_size_kb=10)]) diff --git a/backend/app/tests/services/collections/test_create_collection.py b/backend/app/tests/services/collections/test_create_collection.py index 840ac3416..c17e794cb 100644 --- a/backend/app/tests/services/collections/test_create_collection.py +++ b/backend/app/tests/services/collections/test_create_collection.py @@ -88,7 +88,6 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non request=request, project_id=project.id, collection_job_id=job_id, - with_assistant=True, organization_id=project.organization_id, ) @@ -136,7 +135,6 @@ def test_execute_setup_job_marks_processing_and_queues_first_batch( try: execute_setup_job( request=request.model_dump(mode="json"), - with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=task_id, @@ -188,7 +186,6 @@ def test_execute_setup_job_failure_marks_job_failed_and_raises( with pytest.raises(RuntimeError, match="S3 upload failed"): execute_setup_job( request=request.model_dump(mode="json"), - with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), @@ -239,7 +236,6 @@ def test_execute_setup_job_failure_sends_callback( with pytest.raises(RuntimeError): execute_setup_job( request=request.model_dump(mode="json"), - with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), @@ -286,7 +282,6 @@ def test_execute_setup_job_timeout_marks_failed_and_reraises( with pytest.raises(Timeout): execute_setup_job( request=request.model_dump(mode="json"), - with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), @@ -331,7 +326,6 @@ def test_execute_setup_job_soft_time_limit_marks_failed_and_reraises( with pytest.raises(SoftTimeLimitExceeded): execute_setup_job( request=request.model_dump(mode="json"), - with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), @@ -380,7 +374,6 @@ def test_execute_batch_job_non_final_queues_next_batch( try: execute_batch_job( request=request.model_dump(mode="json"), - with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=task_id, @@ -433,7 +426,6 @@ def test_execute_batch_job_final_batch_creates_collection_and_marks_successful( try: execute_batch_job( request=request.model_dump(mode="json"), - with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), @@ -493,7 +485,6 @@ def test_execute_batch_job_final_batch_sends_success_callback( try: execute_batch_job( request=request.model_dump(mode="json"), - with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), @@ -541,7 +532,6 @@ def test_execute_batch_job_provider_failure_marks_failed_and_raises( with pytest.raises(RuntimeError, match="vector store error"): execute_batch_job( request=request.model_dump(mode="json"), - with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), @@ -590,7 +580,6 @@ def test_execute_batch_job_cleanup_called_when_provider_create_succeeds_but_db_f with pytest.raises(Exception, match="DB write failed"): execute_batch_job( request=request.model_dump(mode="json"), - with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), @@ -633,7 +622,6 @@ def test_execute_batch_job_timeout_marks_failed_and_reraises( with pytest.raises(Timeout): execute_batch_job( request=request.model_dump(mode="json"), - with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), @@ -678,7 +666,6 @@ def test_execute_batch_job_soft_time_limit_marks_failed_and_reraises( with pytest.raises(SoftTimeLimitExceeded): execute_batch_job( request=request.model_dump(mode="json"), - with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), @@ -728,7 +715,6 @@ def test_execute_batch_job_failure_sends_callback( with pytest.raises(RuntimeError): execute_batch_job( request=request.model_dump(mode="json"), - with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), From 4a492da772d8995e58fc3d60800701e6fc349513 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Thu, 14 May 2026 12:53:54 +0530 Subject: [PATCH 16/19] fixing test cases --- .../collections/test_collection_delete.py | 6 ++-- .../collections/test_collection_info.py | 21 ++++++------ .../collections/test_collection_job_info.py | 12 +++---- .../collections/test_collection_list.py | 12 +++---- .../test_crud_collection_read_all.py | 25 ++++++--------- .../test_crud_collection_read_one.py | 16 ++++------ .../services/collections/test_helpers.py | 32 ------------------- backend/app/tests/utils/collection.py | 26 --------------- 8 files changed, 38 insertions(+), 112 deletions(-) diff --git a/backend/app/tests/api/routes/collections/test_collection_delete.py b/backend/app/tests/api/routes/collections/test_collection_delete.py index f7ed400d9..441bca388 100644 --- a/backend/app/tests/api/routes/collections/test_collection_delete.py +++ b/backend/app/tests/api/routes/collections/test_collection_delete.py @@ -9,7 +9,7 @@ from app.tests.utils.auth import TestAuthContext from app.models import CollectionJobStatus from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_assistant_collection +from app.tests.utils.collection import get_vector_store_collection @patch("app.api.routes.collections.delete_service.start_job") @@ -28,7 +28,7 @@ def test_delete_collection_calls_start_job_and_returns_job( - Calls delete_service.start_job with correct arguments """ project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) resp = client.request( "DELETE", @@ -72,7 +72,7 @@ def test_delete_collection_with_callback_url_passes_it_to_start_job( into the DeletionRequest and then into delete_service.start_job. """ project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) payload = { "callback_url": "https://example.com/collections/delete-callback", diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index f41623a54..73467251b 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -6,10 +6,7 @@ from app.core.config import settings from app.tests.utils.utils import get_project, get_document -from app.tests.utils.collection import ( - get_assistant_collection, - get_vector_store_collection, -) +from app.tests.utils.collection import get_vector_store_collection from app.crud import DocumentCollectionCrud from app.models import Collection, Document from app.services.collections.helpers import get_service_name @@ -37,20 +34,20 @@ def link_document_to_collection( return document -def test_collection_info_returns_assistant_collection_with_docs( +def test_collection_info_returns_collection_with_docs( client: TestClient, db: Session, user_api_key_header: dict[str, str], ) -> None: """ Happy path: - - Assistant-style collection (get_assistant_collection) + - Vector store collection - include_docs = True (default) - At least one document linked """ project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) document = link_document_to_collection(db, collection) @@ -86,7 +83,7 @@ def test_collection_info_include_docs_false_returns_no_docs( When include_docs=false, the endpoint should not populate the documents list. """ project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) link_document_to_collection(db, collection) @@ -102,8 +99,8 @@ def test_collection_info_include_docs_false_returns_no_docs( payload = data["data"] assert payload["id"] == str(collection.id) - assert payload["llm_service_name"] == "gpt-4o" - assert payload["llm_service_id"] == collection.llm_service_id + assert payload["knowledge_base_provider"] == collection.llm_service_name + assert payload["knowledge_base_id"] == collection.llm_service_id assert payload["documents"] is None @@ -117,7 +114,7 @@ def test_collection_info_pagination_skip_and_limit( We create multiple document links and then request a paginated slice. """ project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) documents = db.exec( select(Document).where(Document.deleted_at.is_(None)).limit(2) @@ -205,7 +202,7 @@ def test_collection_info_include_docs_and_url( the endpoint returns documents with their URLs. """ project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) document = link_document_to_collection(db, collection) diff --git a/backend/app/tests/api/routes/collections/test_collection_job_info.py b/backend/app/tests/api/routes/collections/test_collection_job_info.py index ad95b6769..1c7f3855c 100644 --- a/backend/app/tests/api/routes/collections/test_collection_job_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_job_info.py @@ -5,7 +5,7 @@ from app.core.config import settings from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_assistant_collection, get_collection_job +from app.tests.utils.collection import get_vector_store_collection, get_collection_job from app.models import ( CollectionActionType, CollectionJobStatus, @@ -41,7 +41,7 @@ def test_collection_info_create_successful( headers = user_api_key_header project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) collection_job = get_collection_job( db, project, collection_id=collection.id, status=CollectionJobStatus.SUCCESSFUL @@ -61,8 +61,8 @@ def test_collection_info_create_successful( assert data["collection"] is not None col = data["collection"] assert col["id"] == str(collection.id) - assert col["llm_service_id"] == collection.llm_service_id - assert col["llm_service_name"] == "gpt-4o" + assert col["knowledge_base_id"] == collection.llm_service_id + assert col["knowledge_base_provider"] == collection.llm_service_name def test_collection_info_create_failed( @@ -101,7 +101,7 @@ def test_collection_info_delete_successful( headers = user_api_key_header project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) collection_job = get_collection_job( db, @@ -133,7 +133,7 @@ def test_collection_info_delete_failed( headers = user_api_key_header project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) collection_job = get_collection_job( db, diff --git a/backend/app/tests/api/routes/collections/test_collection_list.py b/backend/app/tests/api/routes/collections/test_collection_list.py index e1cf3d589..2f746117b 100644 --- a/backend/app/tests/api/routes/collections/test_collection_list.py +++ b/backend/app/tests/api/routes/collections/test_collection_list.py @@ -3,10 +3,7 @@ from app.core.config import settings from app.tests.utils.utils import get_project -from app.tests.utils.collection import ( - get_assistant_collection, - get_vector_store_collection, -) +from app.tests.utils.collection import get_vector_store_collection from app.services.collections.helpers import get_service_name @@ -34,14 +31,13 @@ def test_list_collections_returns_api_response( assert isinstance(data["data"], list) -def test_list_collections_includes_assistant_collection( +def test_list_collections_includes_new_collection( db: Session, client: TestClient, user_api_key_header: dict[str, str], ) -> None: """ - Ensure that a newly created assistant-style collection (get_assistant_collection) - appears in the list for the current project. + Ensure that a newly created collection appears in the list for the current project. """ project = get_project(db, "Dalgo") @@ -52,7 +48,7 @@ def test_list_collections_includes_assistant_collection( ) assert response_before.status_code == 200 - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) response_after = client.get( f"{settings.API_V1_STR}/collections", diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py index 0382b8830..e6adfc58c 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py @@ -1,29 +1,24 @@ -from openai_responses import OpenAIMock -from openai import OpenAI from sqlmodel import Session, delete from app.crud import CollectionCrud from app.models import Collection from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_assistant_collection +from app.tests.utils.collection import get_vector_store_collection def create_collections(db: Session, n: int) -> Collection: crud = None project = get_project(db) - openai_mock = OpenAIMock() - with openai_mock.router: - client = OpenAI(api_key="sk-test-key") - for _ in range(n): - collection = get_assistant_collection(db, project=project) - store = DocumentStore(db, project_id=collection.project_id) - documents = store.fill(1) - if crud is None: - crud = CollectionCrud(db, collection.project_id) - crud.create(collection, documents) - - return crud.project_id + for _ in range(n): + collection = get_vector_store_collection(db, project=project) + store = DocumentStore(db, project_id=collection.project_id) + documents = store.fill(1) + if crud is None: + crud = CollectionCrud(db, collection.project_id) + crud.create(collection, documents) + + return crud.project_id class TestCollectionReadAll: diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py index 2fc5f7676..a85b06ccb 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py @@ -1,6 +1,4 @@ import pytest -from openai import OpenAI -from openai_responses import OpenAIMock from fastapi import HTTPException from sqlmodel import Session @@ -8,18 +6,16 @@ from app.models import Collection from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_assistant_collection +from app.tests.utils.collection import get_vector_store_collection def mk_collection(db: Session) -> Collection: - openai_mock = OpenAIMock() project = get_project(db) - with openai_mock.router: - collection = get_assistant_collection(db, project=project) - store = DocumentStore(db, project_id=collection.project_id) - documents = store.fill(1) - crud = CollectionCrud(db, collection.project_id) - return crud.create(collection, documents) + collection = get_vector_store_collection(db, project=project) + store = DocumentStore(db, project_id=collection.project_id) + documents = store.fill(1) + crud = CollectionCrud(db, collection.project_id) + return crud.create(collection, documents) class TestDatabaseReadOne: diff --git a/backend/app/tests/services/collections/test_helpers.py b/backend/app/tests/services/collections/test_helpers.py index 15a8a8019..8caa61901 100644 --- a/backend/app/tests/services/collections/test_helpers.py +++ b/backend/app/tests/services/collections/test_helpers.py @@ -239,47 +239,15 @@ def test_to_collection_public_vector_store() -> None: result = to_collection_public(collection) - # For vector store, should map to knowledge_base fields assert result.id == collection.id assert result.knowledge_base_id == "vs_123" assert result.knowledge_base_provider == "openai vector store" - assert result.llm_service_id is None - assert result.llm_service_name is None assert result.project_id == 1 assert result.inserted_at == collection.inserted_at assert result.updated_at == collection.updated_at assert result.deleted_at is None -def test_to_collection_public_assistant() -> None: - """Test conversion of assistant collection to public model.""" - collection = Collection( - id=uuid4(), - project_id=2, - provider=ProviderType.openai, - llm_service_id="asst_456", - llm_service_name="gpt-4", # Does NOT match vector store name - name="Assistant Collection", - description="Assistant description", - inserted_at=now(), - updated_at=now(), - deleted_at=None, - ) - - result = to_collection_public(collection) - - # For assistant, should map to llm_service fields - assert result.id == collection.id - assert result.llm_service_id == "asst_456" - assert result.llm_service_name == "gpt-4" - assert result.knowledge_base_id is None - assert result.knowledge_base_provider is None - assert result.project_id == 2 - assert result.inserted_at == collection.inserted_at - assert result.updated_at == collection.updated_at - assert result.deleted_at is None - - def test_to_collection_public_with_deleted_at() -> None: """Test that deleted_at field is properly included when set.""" deleted_time = now() diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index f1844cb96..94231c999 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -25,32 +25,6 @@ def uuid_increment(value: UUID) -> UUID: return UUID(int=inc) -def get_assistant_collection( - db: Session, - project: Project, - *, - assistant_id: Optional[str] = None, - model: str = "gpt-4o", - collection_id: Optional[UUID] = None, -) -> Collection: - """ - Create a Collection configured for the Assistant path. - execute_job will treat this as `is_vector = False` and use assistant id. - """ - if assistant_id is None: - assistant_id = f"asst_{uuid4().hex}" - - collection = Collection( - id=collection_id or uuid4(), - project_id=project.id, - organization_id=project.organization_id, - llm_service_name=model, - llm_service_id=assistant_id, - provider=ProviderType.openai, - ) - return CollectionCrud(db, project.id).create(collection) - - def get_vector_store_collection( db: Session, project: Project, From 979f7561bfb6bcc943f6ae4d45957049882d566b Mon Sep 17 00:00:00 2001 From: nishika26 Date: Thu, 14 May 2026 13:00:16 +0530 Subject: [PATCH 17/19] fixing alembic migration --- ...s.py => 061_add_batch_tracking_to_collections_jobs.py} | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) rename backend/app/alembic/versions/{058_add_batch_tracking_to_collections_jobs.py => 061_add_batch_tracking_to_collections_jobs.py} (95%) diff --git a/backend/app/alembic/versions/058_add_batch_tracking_to_collections_jobs.py b/backend/app/alembic/versions/061_add_batch_tracking_to_collections_jobs.py similarity index 95% rename from backend/app/alembic/versions/058_add_batch_tracking_to_collections_jobs.py rename to backend/app/alembic/versions/061_add_batch_tracking_to_collections_jobs.py index 6bf4b97bf..aa91853e4 100644 --- a/backend/app/alembic/versions/058_add_batch_tracking_to_collections_jobs.py +++ b/backend/app/alembic/versions/061_add_batch_tracking_to_collections_jobs.py @@ -1,7 +1,7 @@ """add batch tracking to collection_jobs -Revision ID: 058 -Revises: 057 +Revision ID: 061 +Revises: 060 Create Date: 2026-04-13 """ @@ -10,8 +10,8 @@ # revision identifiers, used by Alembic. -revision = "058" -down_revision = "057" +revision = "061" +down_revision = "060" branch_labels = None depends_on = None From 0e027c2100157312cb7bd283c7175bea1e34c4ef Mon Sep 17 00:00:00 2001 From: nishika26 Date: Thu, 14 May 2026 16:07:20 +0530 Subject: [PATCH 18/19] test cases fix --- .../collection/test_crud_collection_delete.py | 10 ++++------ .../collection/test_crud_collection_read_all.py | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py index 86da995ea..64abc4385 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py @@ -9,9 +9,7 @@ from app.tests.utils.document import DocumentStore -def get_vector_store_collection( - db: Session, client: OpenAI, project_id: int -) -> Collection: +def get_vector_store_collection(client: OpenAI, project_id: int) -> Collection: vector_store = client.vector_stores.create() return Collection( project_id=project_id, @@ -30,7 +28,7 @@ def test_delete_marks_deleted(self, db: Session) -> None: client = OpenAI(api_key="sk-test-key") v_crud = OpenAIVectorStoreCrud(client) - collection = get_vector_store_collection(db, client, project_id=project.id) + collection = get_vector_store_collection(client, project_id=project.id) crud = CollectionCrud(db, collection.project_id) collection_ = crud.delete(collection, v_crud) @@ -43,7 +41,7 @@ def test_delete_follows_insert(self, db: Session) -> None: client = OpenAI(api_key="sk-test-key") v_crud = OpenAIVectorStoreCrud(client) - collection = get_vector_store_collection(db, client, project_id=project.id) + collection = get_vector_store_collection(client, project_id=project.id) crud = CollectionCrud(db, collection.project_id) collection_ = crud.delete(collection, v_crud) @@ -64,7 +62,7 @@ def test_delete_document_deletes_collections(self, db: Session) -> None: client = OpenAI(api_key="sk-test-key") resources = [] for _ in range(self._n_collections): - coll = get_vector_store_collection(db, client, project_id=project.id) + coll = get_vector_store_collection(client, project_id=project.id) crud = CollectionCrud(db, project_id=project.id) collection = crud.create(coll, documents) resources.append((crud, collection)) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py index e6adfc58c..56cde1ca2 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py @@ -7,7 +7,7 @@ from app.tests.utils.collection import get_vector_store_collection -def create_collections(db: Session, n: int) -> Collection: +def create_collections(db: Session, n: int) -> int: crud = None project = get_project(db) for _ in range(n): From a8878cb01e462a729413feaaeb29bcd41c4f14a0 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Fri, 15 May 2026 16:26:06 +0530 Subject: [PATCH 19/19] left docs and adding openai file deletion --- backend/app/api/docs/collections/create.md | 9 +++------ backend/app/api/docs/collections/delete.md | 6 ------ backend/app/api/routes/documents.py | 9 +++++++-- backend/app/crud/rag/__init__.py | 2 +- backend/app/crud/rag/open_ai.py | 16 ++++++++++++++++ 5 files changed, 27 insertions(+), 15 deletions(-) diff --git a/backend/app/api/docs/collections/create.md b/backend/app/api/docs/collections/create.md index 5df4200c6..6a50b252c 100644 --- a/backend/app/api/docs/collections/create.md +++ b/backend/app/api/docs/collections/create.md @@ -3,14 +3,11 @@ pipeline: * Create a vector store from the document IDs you received after uploading your documents through the Documents module. -* Documents are automatically batched when creating the vector store to optimize - the upload process for large document sets. A new batch is created when either - the cumulative size reaches 30 MB of documents given to upload to a vector store - or the document count reaches 200 files in a batch, whichever limit is hit first. +* Given Documents are automatically batched during vector store creation to handle large uploads efficiently. A new batch starts when the total size reaches 30 MB or the file count reaches 200, whichever comes first. -If any step in the LLM service interaction fails, all previously created resources are cleaned up automatically. For example, if the vector store creation fails, any files already uploaded to OpenAI are removed. Failures can be caused by service downtime, invalid parameter values, or unsupported document types — the latter is especially common with PDFs that cannot be parsed. +If any step in the LLM service interaction fails, all previously created resources are cleaned up automatically. Failures can be caused by service downtime, invalid parameter values, or unsupported document types — the latter is especially common with PDFs that cannot be parsed. -The Vector store/assistant will be created asynchronously. +The Vector store will be created asynchronously. The immediate response from this endpoint is going to contain the collection "job ID" and status. Once the collection has been created, information about the collection will be returned to the user via diff --git a/backend/app/api/docs/collections/delete.md b/backend/app/api/docs/collections/delete.md index c6ffeabb2..193166bf6 100644 --- a/backend/app/api/docs/collections/delete.md +++ b/backend/app/api/docs/collections/delete.md @@ -1,11 +1,5 @@ Remove a collection from the platform. -This is a two-step process: - -1. Delete all resources that were allocated: file(s), the Vector - Store, and the Assistant. -2. Delete the collection entry from the kaapi database. - No action is taken on the documents themselves: the contents of the documents that were a part of the collection remain unchanged, those documents can still be accessed via the documents endpoints. The endpoint returns the job ID and status of the collection delete operation. When you take the id returned and use the `collection job info` endpoint, diff --git a/backend/app/api/routes/documents.py b/backend/app/api/routes/documents.py index 6d17bc104..720897b3a 100644 --- a/backend/app/api/routes/documents.py +++ b/backend/app/api/routes/documents.py @@ -17,7 +17,7 @@ from app.api.deps import AuthContextDep, SessionDep from app.api.permissions import Permission, require_permission from app.crud import CollectionCrud, DocumentCrud -from app.crud.rag import OpenAIVectorStoreCrud +from app.crud.rag import OpenAIFileCrud, OpenAIVectorStoreCrud from app.models import ( Document, DocumentPublic, @@ -198,11 +198,14 @@ def remove_doc( ) v_crud = OpenAIVectorStoreCrud(client) + f_crud = OpenAIFileCrud(client) d_crud = DocumentCrud(session, current_user.project_.id) c_crud = CollectionCrud(session, current_user.project_.id) document = d_crud.read_one(doc_id) c_crud.delete(document, v_crud) + if document.openai_file_id: + f_crud.delete(document.openai_file_id) d_crud.delete(doc_id) return APIResponse.success_response( @@ -225,6 +228,7 @@ def permanent_delete_doc( session, current_user.organization_.id, current_user.project_.id ) v_crud = OpenAIVectorStoreCrud(client) + f_crud = OpenAIFileCrud(client) d_crud = DocumentCrud(session, current_user.project_.id) c_crud = CollectionCrud(session, current_user.project_.id) storage = get_cloud_storage(session=session, project_id=current_user.project_.id) @@ -232,7 +236,8 @@ def permanent_delete_doc( document = d_crud.read_one(doc_id) c_crud.delete(document, v_crud) - + if document.openai_file_id: + f_crud.delete(document.openai_file_id) storage.delete(document.object_store_url) d_crud.delete(doc_id) diff --git a/backend/app/crud/rag/__init__.py b/backend/app/crud/rag/__init__.py index 9c1a42723..5319159a1 100644 --- a/backend/app/crud/rag/__init__.py +++ b/backend/app/crud/rag/__init__.py @@ -1 +1 @@ -from .open_ai import OpenAICrud, OpenAIVectorStoreCrud +from .open_ai import OpenAICrud, OpenAIFileCrud, OpenAIVectorStoreCrud diff --git a/backend/app/crud/rag/open_ai.py b/backend/app/crud/rag/open_ai.py index 9fd16f26a..296818684 100644 --- a/backend/app/crud/rag/open_ai.py +++ b/backend/app/crud/rag/open_ai.py @@ -149,3 +149,19 @@ def delete(self, vector_store_id: str, retries: int = 3): logger.info( f"[OpenAIVectorStoreCrud.delete] Vector store deleted | {{'vector_store_id': '{vector_store_id}'}}" ) + + +class OpenAIFileCrud(OpenAICrud): + def delete(self, file_id: str) -> None: + logger.info( + f"[OpenAIFileCrud.delete] Deleting OpenAI file | {{'file_id': '{file_id}'}}" + ) + try: + self.client.files.delete(file_id) + logger.info( + f"[OpenAIFileCrud.delete] OpenAI file deleted | {{'file_id': '{file_id}'}}" + ) + except OpenAIError as err: + logger.warning( + f"[OpenAIFileCrud.delete] Failed to delete OpenAI file | {{'file_id': '{file_id}', 'error': '{str(err)}'}}" + )