diff --git a/backend/app/alembic/versions/061_add_batch_tracking_to_collections_jobs.py b/backend/app/alembic/versions/061_add_batch_tracking_to_collections_jobs.py new file mode 100644 index 000000000..aa91853e4 --- /dev/null +++ b/backend/app/alembic/versions/061_add_batch_tracking_to_collections_jobs.py @@ -0,0 +1,62 @@ +"""add batch tracking to collection_jobs + +Revision ID: 061 +Revises: 060 +Create Date: 2026-04-13 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "061" +down_revision = "060" +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + "collection_jobs", + sa.Column( + "total_batches", + sa.Integer(), + nullable=True, + comment="Total number of batches the documents are split into", + ), + ) + op.add_column( + "collection_jobs", + sa.Column( + "current_batch_number", + sa.Integer(), + nullable=True, + comment="Which batch is currently being processed (1-indexed)", + ), + ) + op.add_column( + "collection_jobs", + sa.Column( + "documents_uploaded", + sa.JSON(), + nullable=True, + comment="List of document IDs successfully uploaded so far", + ), + ) + op.add_column( + "document", + sa.Column( + "openai_file_id", + sa.String(), + nullable=True, + comment="File ID assigned by the LLM provider (e.g. OpenAI file ID) to avoid re-uploading", + ), + ) + + +def downgrade(): + op.drop_column("collection_jobs", "total_batches") + op.drop_column("collection_jobs", "current_batch_number") + op.drop_column("collection_jobs", "documents_uploaded") + op.drop_column("document", "openai_file_id") diff --git a/backend/app/api/docs/collections/create.md b/backend/app/api/docs/collections/create.md index 5df4200c6..6a50b252c 100644 --- a/backend/app/api/docs/collections/create.md +++ b/backend/app/api/docs/collections/create.md @@ -3,14 +3,11 @@ pipeline: * Create a vector store from the document IDs you received after uploading your documents through the Documents module. -* Documents are automatically batched when creating the vector store to optimize - the upload process for large document sets. A new batch is created when either - the cumulative size reaches 30 MB of documents given to upload to a vector store - or the document count reaches 200 files in a batch, whichever limit is hit first. +* Given Documents are automatically batched during vector store creation to handle large uploads efficiently. A new batch starts when the total size reaches 30 MB or the file count reaches 200, whichever comes first. -If any step in the LLM service interaction fails, all previously created resources are cleaned up automatically. For example, if the vector store creation fails, any files already uploaded to OpenAI are removed. Failures can be caused by service downtime, invalid parameter values, or unsupported document types — the latter is especially common with PDFs that cannot be parsed. +If any step in the LLM service interaction fails, all previously created resources are cleaned up automatically. Failures can be caused by service downtime, invalid parameter values, or unsupported document types — the latter is especially common with PDFs that cannot be parsed. -The Vector store/assistant will be created asynchronously. +The Vector store will be created asynchronously. The immediate response from this endpoint is going to contain the collection "job ID" and status. Once the collection has been created, information about the collection will be returned to the user via diff --git a/backend/app/api/docs/collections/delete.md b/backend/app/api/docs/collections/delete.md index c6ffeabb2..193166bf6 100644 --- a/backend/app/api/docs/collections/delete.md +++ b/backend/app/api/docs/collections/delete.md @@ -1,11 +1,5 @@ Remove a collection from the platform. -This is a two-step process: - -1. Delete all resources that were allocated: file(s), the Vector - Store, and the Assistant. -2. Delete the collection entry from the kaapi database. - No action is taken on the documents themselves: the contents of the documents that were a part of the collection remain unchanged, those documents can still be accessed via the documents endpoints. The endpoint returns the job ID and status of the collection delete operation. When you take the id returned and use the `collection job info` endpoint, diff --git a/backend/app/api/docs/documents/upload.md b/backend/app/api/docs/documents/upload.md index c71a07804..7d5855654 100644 --- a/backend/app/api/docs/documents/upload.md +++ b/backend/app/api/docs/documents/upload.md @@ -1,6 +1,6 @@ Upload a document to Kaapi and optionally transform it as well. -- If only a file is provided, the document will be uploaded and stored, and its ID will be returned. +- If only a file is provided, the document will be uploaded and stored, and its ID will be returned. The maximum file size allowed for upload is 25 MB. - If a target format is specified, a transformation job will also be created to transform document into target format in the background. The response will include both the uploaded document details and information about the transformation job. - If a callback URL is provided, you will receive a notification at that URL once the document transformation job is completed. diff --git a/backend/app/api/routes/collections.py b/backend/app/api/routes/collections.py index 96272d38c..f6c7016fb 100644 --- a/backend/app/api/routes/collections.py +++ b/backend/app/api/routes/collections.py @@ -117,32 +117,16 @@ def create_collection( ) ) - # True if both model and instructions were provided in the request body - with_assistant = bool( - getattr(request, "model", None) and getattr(request, "instructions", None) - ) - create_service.start_job( db=session, request=request, collection_job_id=collection_job.id, project_id=current_user.project_.id, organization_id=current_user.organization_.id, - with_assistant=with_assistant, ) - metadata = None - if not with_assistant: - metadata = { - "note": ( - "This job will create a vector store only (no Assistant). " - "Assistant creation happens when both 'model' and 'instructions' are included." - ) - } - return APIResponse.success_response( CollectionJobImmediatePublic.model_validate(collection_job), - metadata=metadata, ) @@ -171,7 +155,9 @@ def delete_collection( if request and request.callback_url: validate_callback_url(str(request.callback_url)) - _ = CollectionCrud(session, current_user.project_.id).read_one(collection_id) + _ = CollectionCrud(session, current_user.project_.id).read_one_if_delete( + collection_id + ) deletion_request = DeletionRequest( collection_id=collection_id, diff --git a/backend/app/api/routes/documents.py b/backend/app/api/routes/documents.py index eb7d41ca3..720897b3a 100644 --- a/backend/app/api/routes/documents.py +++ b/backend/app/api/routes/documents.py @@ -17,7 +17,7 @@ from app.api.deps import AuthContextDep, SessionDep from app.api.permissions import Permission, require_permission from app.crud import CollectionCrud, DocumentCrud -from app.crud.rag import OpenAIAssistantCrud, OpenAIVectorStoreCrud +from app.crud.rag import OpenAIFileCrud, OpenAIVectorStoreCrud from app.models import ( Document, DocumentPublic, @@ -28,7 +28,7 @@ DocTransformationJobPublic, ) from app.core.cloud import get_cloud_storage -from app.services.collections.helpers import pick_service_for_documennt, MAX_DOC_SIZE_MB +from app.services.collections.helpers import MAX_DOC_SIZE_MB from app.services.documents.helpers import ( calculate_file_size, schedule_transformation, @@ -197,16 +197,15 @@ def remove_doc( session, current_user.organization_.id, current_user.project_.id ) - a_crud = OpenAIAssistantCrud(client) v_crud = OpenAIVectorStoreCrud(client) + f_crud = OpenAIFileCrud(client) d_crud = DocumentCrud(session, current_user.project_.id) c_crud = CollectionCrud(session, current_user.project_.id) document = d_crud.read_one(doc_id) - remote = pick_service_for_documennt( - session, doc_id, a_crud, v_crud - ) # assistant crud or vector store crud - c_crud.delete(document, remote) + c_crud.delete(document, v_crud) + if document.openai_file_id: + f_crud.delete(document.openai_file_id) d_crud.delete(doc_id) return APIResponse.success_response( @@ -228,19 +227,17 @@ def permanent_delete_doc( client = get_openai_client( session, current_user.organization_.id, current_user.project_.id ) - a_crud = OpenAIAssistantCrud(client) v_crud = OpenAIVectorStoreCrud(client) + f_crud = OpenAIFileCrud(client) d_crud = DocumentCrud(session, current_user.project_.id) c_crud = CollectionCrud(session, current_user.project_.id) storage = get_cloud_storage(session=session, project_id=current_user.project_.id) document = d_crud.read_one(doc_id) - remote = pick_service_for_documennt( - session, doc_id, a_crud, v_crud - ) # assistant crud or vector store crud - c_crud.delete(document, remote) - + c_crud.delete(document, v_crud) + if document.openai_file_id: + f_crud.delete(document.openai_file_id) storage.delete(document.object_store_url) d_crud.delete(doc_id) diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index adadf1c9c..cd3cf6397 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -1,7 +1,5 @@ import logging -from typing import Any -import celery from asgi_correlation_id import correlation_id from celery import current_task from opentelemetry import context as otel_context @@ -133,16 +131,36 @@ def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kw @celery_app.task(bind=True, queue="low_priority", priority=1) -@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_create_collection_job") -def run_create_collection_job( +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_collection_setup_job") +def run_collection_setup_job( self, project_id: int, job_id: str, trace_id: str, **kwargs ): - from app.services.collections.create_collection import execute_job + from app.services.collections.create_collection import execute_setup_job _set_trace(trace_id) return _run_with_otel_parent( self, - lambda: execute_job( + lambda: execute_setup_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), + ) + + +@celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_collection_batch_job") +def run_collection_batch_job( + self, project_id: int, job_id: str, trace_id: str, **kwargs +): + from app.services.collections.create_collection import execute_batch_job + + _set_trace(trace_id) + return _run_with_otel_parent( + self, + lambda: execute_batch_job( project_id=project_id, job_id=job_id, task_id=current_task.request.id, diff --git a/backend/app/celery/utils.py b/backend/app/celery/utils.py index 288cba7c4..ab1dbc080 100644 --- a/backend/app/celery/utils.py +++ b/backend/app/celery/utils.py @@ -90,20 +90,38 @@ def start_doctransform_job( return task_id -def start_create_collection_job( +def start_collection_setup_job( project_id: int, job_id: str, trace_id: str = "N/A", **kwargs ) -> str: - from app.celery.tasks.job_execution import run_create_collection_job + from app.celery.tasks.job_execution import run_collection_setup_job task_id = _enqueue_with_trace_context( - run_create_collection_job, + run_collection_setup_job, project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs, ) logger.info( - f"[start_create_collection_job] Started job {job_id} with Celery task {task_id}" + f"[start_collection_setup_job] Started job {job_id} with Celery task {task_id}" + ) + return task_id + + +def start_collection_batch_job( + project_id: int, job_id: str, trace_id: str = "N/A", **kwargs +) -> str: + from app.celery.tasks.job_execution import run_collection_batch_job + + task_id = _enqueue_with_trace_context( + run_collection_batch_job, + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + **kwargs, + ) + logger.info( + f"[start_collection_batch_job] Started job {job_id} with Celery task {task_id}" ) return task_id diff --git a/backend/app/crud/collection/collection.py b/backend/app/crud/collection/collection.py index cb8b6e27d..ac0dd9497 100644 --- a/backend/app/crud/collection/collection.py +++ b/backend/app/crud/collection/collection.py @@ -82,6 +82,31 @@ def read_one(self, collection_id: UUID) -> Collection: ) return collection + def read_one_if_delete(self, collection_id: UUID) -> Collection: + statement = select(Collection).where( + and_( + Collection.project_id == self.project_id, + Collection.id == collection_id, + ) + ) + + collection = self.session.exec(statement).one_or_none() + if collection is None: + logger.warning( + "[CollectionCrud.read_one_if_delete] Collection not found | " + f"{{'project_id': '{self.project_id}', 'collection_id': '{collection_id}'}}" + ) + raise HTTPException(status_code=404, detail="Collection not found") + + if collection.deleted_at is not None: + logger.warning( + "[CollectionCrud.read_one_if_delete] Collection already deleted | " + f"{{'project_id': '{self.project_id}', 'collection_id': '{collection_id}'}}" + ) + raise HTTPException(status_code=400, detail="Collection already deleted") + + return collection + def read_all(self): statement = select(Collection).where( and_( @@ -122,6 +147,11 @@ def delete(self, model, remote): # remote should be an OpenAICrud @delete.register def _(self, model: Collection, remote): + if model.deleted_at is not None: + logger.info( + f"[CollectionCrud.delete] Collection already deleted | {{'collection_id': '{model.id}'}}" + ) + return model remote.delete(model.llm_service_id) model.deleted_at = now() collection = self._update(model) diff --git a/backend/app/crud/rag/__init__.py b/backend/app/crud/rag/__init__.py index 6d6586427..5319159a1 100644 --- a/backend/app/crud/rag/__init__.py +++ b/backend/app/crud/rag/__init__.py @@ -1 +1 @@ -from .open_ai import OpenAICrud, OpenAIVectorStoreCrud, OpenAIAssistantCrud +from .open_ai import OpenAICrud, OpenAIFileCrud, OpenAIVectorStoreCrud diff --git a/backend/app/crud/rag/open_ai.py b/backend/app/crud/rag/open_ai.py index 2ae36f4f1..296818684 100644 --- a/backend/app/crud/rag/open_ai.py +++ b/backend/app/crud/rag/open_ai.py @@ -1,13 +1,11 @@ import json import logging import functools as ft -from io import BytesIO -from typing import Iterable +import time from openai import OpenAI, OpenAIError from pydantic import BaseModel -from app.core.cloud import CloudStorage from app.models import Document logger = logging.getLogger(__name__) @@ -67,21 +65,8 @@ def clean(self, resource): raise NotImplementedError() -class AssistantCleaner(ResourceCleaner): - def clean(self, resource): - logger.info( - f"[AssistantCleaner.clean] Deleting assistant | {{'assistant_id': '{resource}'}}" - ) - self.client.beta.assistants.delete(resource) - - class VectorStoreCleaner(ResourceCleaner): def clean(self, resource): - logger.info( - f"[VectorStoreCleaner.clean] Starting vector store cleanup | {{'vector_store_id': '{resource}'}}" - ) - for i in vs_ls(self.client, resource): - self.client.files.delete(i.id) logger.info( f"[VectorStoreCleaner.clean] Deleting vector store | {{'vector_store_id': '{resource}'}}" ) @@ -117,36 +102,36 @@ def read(self, vector_store_id: str): def update( self, vector_store_id: str, - storage: CloudStorage, - documents: Iterable[Document], - ): - for docs in documents: - files = [] - for d in docs: - # Get file bytes and wrap in BytesIO for OpenAI API - content = storage.get(d.object_store_url) - f_obj = BytesIO(content) - f_obj.name = d.fname - files.append(f_obj) + docs: list[Document], + ) -> None: + if not docs: + return - logger.info( - f"[OpenAIVectorStoreCrud.update] Uploading files to vector store | {{'vector_store_id': '{vector_store_id}', 'file_count': {len(files)}}}" - ) - req = self.client.vector_stores.file_batches.upload_and_poll( + try: + batch = self.client.vector_stores.file_batches.upload_and_poll( vector_store_id=vector_store_id, - files=files, + files=[], + file_ids=[doc.openai_file_id for doc in docs], ) - logger.info( - f"[OpenAIVectorStoreCrud.update] File upload completed | {{'vector_store_id': '{vector_store_id}', 'completed_files': {req.file_counts.completed}, 'total_files': {req.file_counts.total}}}" + except OpenAIError as err: + logger.error( + f"[OpenAIVectorStoreCrud.update] Batch attach failed | " + f"{{'vector_store_id': '{vector_store_id}', 'error': '{str(err)}'}}", + exc_info=True, ) - if req.file_counts.completed != req.file_counts.total: - error_msg = f"OpenAI document processing error: {req.file_counts.completed}/{req.file_counts.total} files completed" - logger.error( - f"[OpenAIVectorStoreCrud.update] Document processing error | {{'vector_store_id': '{vector_store_id}', 'completed_files': {req.file_counts.completed}, 'total_files': {req.file_counts.total}}}" - ) - raise InterruptedError(error_msg) + raise - yield from docs + logger.info( + f"[OpenAIVectorStoreCrud.update] Batch complete | " + f"{{'vector_store_id': '{vector_store_id}', " + f"'completed': {batch.file_counts.completed}, 'failed': {batch.file_counts.failed}}}" + ) + if batch.file_counts.failed > 0: + raise RuntimeError( + f"Batch attach to vector store {vector_store_id!r} completed with " + f"{batch.file_counts.failed} failed file(s) " + f"({batch.file_counts.completed} succeeded)" + ) def delete(self, vector_store_id: str, retries: int = 3): if retries < 1: @@ -166,58 +151,17 @@ def delete(self, vector_store_id: str, retries: int = 3): ) -class OpenAIAssistantCrud(OpenAICrud): - def create(self, vector_store_id: str, **kwargs): +class OpenAIFileCrud(OpenAICrud): + def delete(self, file_id: str) -> None: logger.info( - f"[OpenAIAssistantCrud.create] Creating assistant | {{'vector_store_id': '{vector_store_id}'}}" + f"[OpenAIFileCrud.delete] Deleting OpenAI file | {{'file_id': '{file_id}'}}" ) - assistant = self.client.beta.assistants.create( - tools=[ - { - "type": "file_search", - } - ], - tool_resources={ - "file_search": { - "vector_store_ids": [ - vector_store_id, - ], - }, - }, - **kwargs, - ) - logger.info( - f"[OpenAIAssistantCrud.create] Assistant created | {{'assistant_id': '{assistant.id}', 'vector_store_id': '{vector_store_id}'}}" - ) - return assistant - - def delete(self, assistant_id: str): - logger.info( - f"[OpenAIAssistantCrud.delete] Starting assistant deletion | {{'assistant_id': '{assistant_id}'}}" - ) - assistant = self.client.beta.assistants.retrieve(assistant_id) - vector_stores = assistant.tool_resources.file_search.vector_store_ids - try: - (vector_store_id,) = vector_stores - except ValueError: - if vector_stores: - names = ", ".join(vector_stores) - err = ValueError(f"Too many attached vector stores: {names}") - else: - err = ValueError("No vector stores found") - - logger.error( - f"[OpenAIAssistantCrud.delete] Invalid vector store state | {{'assistant_id': '{assistant_id}', 'vector_stores': '{vector_stores}'}}", - exc_info=True, + self.client.files.delete(file_id) + logger.info( + f"[OpenAIFileCrud.delete] OpenAI file deleted | {{'file_id': '{file_id}'}}" + ) + except OpenAIError as err: + logger.warning( + f"[OpenAIFileCrud.delete] Failed to delete OpenAI file | {{'file_id': '{file_id}', 'error': '{str(err)}'}}" ) - raise err - - v_crud = OpenAIVectorStoreCrud(self.client) - v_crud.delete(vector_store_id) - - cleaner = AssistantCleaner(self.client) - cleaner(assistant_id) - logger.info( - f"[OpenAIAssistantCrud.delete] Assistant deleted | {{'assistant_id': '{assistant_id}'}}" - ) diff --git a/backend/app/models/collection.py b/backend/app/models/collection.py index ccd606deb..9957cf7b0 100644 --- a/backend/app/models/collection.py +++ b/backend/app/models/collection.py @@ -3,7 +3,7 @@ from typing import Any, Literal from uuid import UUID, uuid4 -from pydantic import HttpUrl, model_validator, model_serializer +from pydantic import HttpUrl from sqlalchemy import Index, text from sqlmodel import Field, Relationship, SQLModel @@ -105,62 +105,6 @@ def model_post_init(self, __context: Any): self.documents = list(set(self.documents)) -class AssistantOptions(SQLModel): - # Fields to be passed along to OpenAI. They must be a subset of - # parameters accepted by the OpenAI.clien.beta.assistants.create - # API. - model: str | None = Field( - default=None, - description=( - "**[Deprecated]** " - "OpenAI model to attach to this assistant. The model " - "must be compatable with the assistants API; see the " - "OpenAI [model documentation](https://platform.openai.com/docs/models/compare) for more." - ), - ) - - instructions: str | None = Field( - default=None, - description=( - "**[Deprecated]** " - "Assistant instruction. Sometimes referred to as the " - '"system" prompt.' - ), - ) - temperature: float = Field( - default=1e-6, - description=( - "**[Deprecated]** " - "Model temperature. The default is slightly " - "greater-than zero because it is [unknown how OpenAI " - "handles zero](https://community.openai.com/t/clarifications-on-setting-temperature-0/886447/5)." - ), - ) - - @model_validator(mode="before") - def _assistant_fields_all_or_none(cls, values: dict[str, Any]) -> dict[str, Any]: - def norm(x: Any) -> Any: - if x is None: - return None - if isinstance(x, str): - s = x.strip() - return s if s else None - return x # let Pydantic handle non-strings - - model = norm(values.get("model")) - instructions = norm(values.get("instructions")) - - if (model is None) ^ (instructions is None): - raise ValueError( - "To create an Assistant, provide BOTH 'model' and 'instructions'. " - "If you only want a vector store, remove both fields." - ) - - values["model"] = model - values["instructions"] = instructions - return values - - class CallbackRequest(SQLModel): callback_url: HttpUrl | None = Field( default=None, @@ -177,7 +121,6 @@ class ProviderOptions(SQLModel): class CreationRequest( - AssistantOptions, CollectionOptions, ProviderOptions, CallbackRequest, @@ -201,21 +144,11 @@ class CollectionIDPublic(SQLModel): class CollectionPublic(SQLModel): id: UUID - llm_service_id: str | None = Field( - default=None, - description="LLM service ID (e.g., Assistant ID) when model and instructions were provided", - ) - llm_service_name: str | None = Field( - default=None, - description="LLM service name (e.g., model name) when model and instructions were provided", - ) - knowledge_base_id: str | None = Field( - default=None, - description="Knowledge base ID (e.g., Vector Store ID) when only vector store was created", + knowledge_base_id: str = Field( + description="Knowledge base ID (e.g., Vector Store ID)", ) - knowledge_base_provider: str | None = Field( - default=None, - description="Knowledge base provider name when only vector store was created", + knowledge_base_provider: str = Field( + description="Knowledge base provider name", ) project_id: int @@ -223,56 +156,6 @@ class CollectionPublic(SQLModel): updated_at: datetime deleted_at: datetime | None = None - @model_validator(mode="after") - def validate_service_fields(self) -> "CollectionPublic": - """Ensure either LLM service fields or knowledge base fields are set, not both.""" - has_llm = self.llm_service_id is not None or self.llm_service_name is not None - has_kb = ( - self.knowledge_base_id is not None - or self.knowledge_base_provider is not None - ) - - if has_llm and has_kb: - raise ValueError( - "Cannot have both LLM service fields and knowledge base fields set" - ) - - if not has_llm and not has_kb: - raise ValueError( - "Either LLM service fields or knowledge base fields must be set" - ) - - # Ensure both fields in the pair are set or both are None - if has_llm and ( - (self.llm_service_id is None) != (self.llm_service_name is None) - ): - raise ValueError("Both llm_service_id and llm_service_name must be set") - - if has_kb and ( - (self.knowledge_base_id is None) != (self.knowledge_base_provider is None) - ): - raise ValueError( - "Both knowledge_base_id and knowledge_base_provider must be set" - ) - - return self - - @model_serializer(mode="wrap", when_used="json") - def _serialize_model(self, serializer: Any, info: Any) -> dict[str, Any]: - """Exclude unused service fields from JSON serialization.""" - data = serializer(self) - - # If this is a knowledge base, remove llm_service fields - if data.get("knowledge_base_id") is not None: - data.pop("llm_service_id", None) - data.pop("llm_service_name", None) - # If this is an assistant, remove knowledge_base fields - elif data.get("llm_service_id") is not None: - data.pop("knowledge_base_id", None) - data.pop("knowledge_base_provider", None) - - return data - class CollectionWithDocsPublic(CollectionPublic): documents: list[DocumentPublic] | None = None diff --git a/backend/app/models/collection_job.py b/backend/app/models/collection_job.py index 333ebfd14..6b628ad7e 100644 --- a/backend/app/models/collection_job.py +++ b/backend/app/models/collection_job.py @@ -77,7 +77,29 @@ class CollectionJob(SQLModel, table=True): documents: list[str] | None = Field( default=None, sa_column=Column( - JSON, nullable=True, comment="List of documents given to make collection" + JSON, nullable=True, comment="List of document IDs given to make collection" + ), + ) + total_batches: int | None = Field( + default=None, + nullable=True, + sa_column_kwargs={ + "comment": "Total number of batches the documents are split into" + }, + ) + current_batch_number: int | None = Field( + default=None, + nullable=True, + sa_column_kwargs={ + "comment": "Which batch is currently being processed (1-indexed)" + }, + ) + documents_uploaded: list[str] | None = Field( + default=None, + sa_column=Column( + JSON, + nullable=True, + comment="List of document IDs successfully uploaded so far", ), ) @@ -139,6 +161,9 @@ class CollectionJobUpdate(SQLModel): collection_id: UUID | None = None total_size_mb: float | None = None trace_id: str | None = None + total_batches: int | None = None + current_batch_number: int | None = None + documents_uploaded: list[str] | None = None ##Response models diff --git a/backend/app/models/document.py b/backend/app/models/document.py index c0aa6c8b6..12ab4f683 100644 --- a/backend/app/models/document.py +++ b/backend/app/models/document.py @@ -42,6 +42,11 @@ class Document(DocumentBase, table=True): description="The size of the document in kilobytes", sa_column_kwargs={"comment": "Size of the document in kilobytes (KB)"}, ) + openai_file_id: str | None = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "File ID assigned by OpenAI (avoid re-uploading)"}, + ) # Foreign keys source_document_id: UUID | None = Field( diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index a9b787f6b..43056fe87 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -2,10 +2,10 @@ import time from uuid import UUID, uuid4 -from opentelemetry import trace from sqlmodel import Session -from celery.exceptions import SoftTimeLimitExceeded from gevent import Timeout +from celery.exceptions import SoftTimeLimitExceeded +from opentelemetry import trace from asgi_correlation_id import correlation_id from app.core.cloud import get_cloud_storage @@ -26,12 +26,13 @@ CreationRequest, ) from app.services.collections.helpers import ( + batch_documents, extract_error_message, to_collection_public, ) from app.services.collections.providers.registry import get_llm_provider -from app.celery.utils import start_create_collection_job -from app.utils import send_callback, get_webhook_secret, APIResponse +from app.celery.utils import start_collection_setup_job, start_collection_batch_job +from app.utils import send_callback, APIResponse, get_webhook_secret logger = logging.getLogger(__name__) @@ -43,52 +44,32 @@ def start_job( request: CreationRequest, project_id: int, collection_job_id: UUID, - with_assistant: bool, organization_id: int, ) -> str: - with log_context( - tag="collection", - lifecycle="collection.create.start_job", - action="create", - collection_job_id=collection_job_id, - project_id=project_id, - organization_id=organization_id, - ): - trace_id = correlation_id.get() or "N/A" + trace_id = correlation_id.get() or "N/A" - job_crud = CollectionJobCrud(db, project_id) - collection_job = job_crud.update( - collection_job_id, CollectionJobUpdate(trace_id=trace_id) - ) + job_crud = CollectionJobCrud(db, project_id) + job_crud.update(collection_job_id, CollectionJobUpdate(trace_id=trace_id)) - task_id = start_create_collection_job( - project_id=project_id, - job_id=str(collection_job_id), - trace_id=trace_id, - request=request.model_dump(mode="json"), - with_assistant=with_assistant, - organization_id=organization_id, - ) + task_id = start_collection_setup_job( + project_id=project_id, + job_id=str(collection_job_id), + trace_id=trace_id, + request=request.model_dump(mode="json"), + organization_id=organization_id, + ) - logger.info( - "[create_collection.start_job] Job scheduled to create collection | " - f"collection_job_id={collection_job_id}, project_id={project_id}, task_id={task_id}" - ) + logger.info( + "[create_collection.start_job] Job scheduled to create collection | " + f"collection_job_id={collection_job_id}, project_id={project_id}, task_id={task_id}" + ) - return collection_job_id + return collection_job_id def build_success_payload( collection_job: CollectionJob, collection: Collection ) -> dict: - """ - { - "success": true, - "data": { job fields + full collection }, - "error": null, - "metadata": null - } - """ collection_public = to_collection_public(collection) collection_dict = collection_public.model_dump(mode="json", exclude_none=True) @@ -102,15 +83,6 @@ def build_success_payload( def build_failure_payload(collection_job: CollectionJob, error_message: str) -> dict: - """ - { - "success": false, - "data": { job fields, collection: null }, - "error": "something went wrong", - "metadata": null - } - """ - # ensure `collection` is explicitly null in the payload job_public = CollectionJobPublic.model_validate( collection_job, update={"collection": None}, @@ -144,7 +116,7 @@ def _mark_job_failed( ) return collection_job except Exception: - logger.warning("[create_collection.execute_job] Failed to mark job as FAILED") + logger.warning("[create_collection] Failed to mark job as FAILED") return None @@ -186,9 +158,8 @@ def _handle_job_failure( ) -def execute_job( +def execute_setup_job( request: dict, - with_assistant: bool, project_id: int, organization_id: int, task_id: str, @@ -196,64 +167,197 @@ def execute_job( task_instance, ) -> None: """ - Worker entrypoint scheduled by start_job. - Orchestrates: job state, provider init, collection creation, - optional assistant creation, collection persistence, linking, callbacks, and cleanup. + Phase 1: Fetch documents, create the vector store, split into batches, + update job state to PROCESSING, then queue the first batch task. """ - start_time = time.time() - collection_job = None - result = None creation_request = None - provider = None with log_context( tag="collection", - lifecycle="collection.create.execute_job", + lifecycle="collection.create.execute_setup_job", action="create", collection_job_id=job_id, task_id=task_id, project_id=project_id, organization_id=organization_id, - ), tracer.start_as_current_span("collections.create.execute_job") as span: + ), tracer.start_as_current_span("collections.create.execute_setup_job") as span: span.set_attribute("collection.job_id", str(job_id)) span.set_attribute("kaapi.project_id", project_id) span.set_attribute("kaapi.organization_id", organization_id) try: creation_request = CreationRequest(**request) - if with_assistant: - creation_request.provider = "openai" - span.set_attribute("collection.provider", str(creation_request.provider)) job_uuid = UUID(job_id) + trace_id = correlation_id.get() or "N/A" with Session(engine) as session: document_crud = DocumentCrud(session, project_id) flat_docs = document_crud.read_each(creation_request.documents) + storage = get_cloud_storage(session=session, project_id=project_id) + + provider = get_llm_provider( + session=session, + provider=creation_request.provider, + project_id=project_id, + organization_id=organization_id, + ) + + for doc in flat_docs: + session.expunge(doc) + + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.update( + job_uuid, + CollectionJobUpdate( + task_id=task_id, + status=CollectionJobStatus.PROCESSING, + ), + ) + + provider.upload_files(storage, flat_docs, project_id) + + logger.info( + "[create_collection.execute_setup_job] All file uploads complete | " + "job_id=%s, total=%d", + job_id, + len(flat_docs), + ) - file_exts = { - doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname - } - total_size_kb = sum(doc.file_size_kb or 0 for doc in flat_docs) - total_size_mb = round(total_size_kb / 1024, 2) - span.set_attribute("collection.documents.count", len(flat_docs)) - span.set_attribute("collection.documents.total_size_mb", total_size_mb) + total_size_kb = sum(doc.file_size_kb for doc in flat_docs) + total_size_mb = total_size_kb / 1024 + + docs_batches = batch_documents(flat_docs) + total_batches = len(docs_batches) + batch_doc_ids = [[str(doc.id) for doc in batch] for batch in docs_batches] with Session(engine) as session: collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.read_one(job_uuid) collection_job = collection_job_crud.update( job_uuid, CollectionJobUpdate( task_id=task_id, status=CollectionJobStatus.PROCESSING, total_size_mb=total_size_mb, + current_batch_number=0, + total_batches=total_batches, + documents_uploaded=[], ), ) - storage = get_cloud_storage(session=session, project_id=project_id) + start_collection_batch_job( + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + batch_number=1, + batch_doc_ids=batch_doc_ids[0], + remaining_batches=batch_doc_ids[1:], + request=request, + vector_store_id=None, + organization_id=organization_id, + ) + + logger.info( + "[create_collection.execute_setup_job] Setup complete, first batch queued | " + f"job_id={job_id}, total_batches={total_batches}" + ) + + except (Timeout, SoftTimeLimitExceeded) as err: + timeout_err = TimeoutError("Task exceeded soft time limit") + logger.warning( + "[create_collection.execute_setup_job] Collection Creation Timed Out | {'collection_job_id': '%s', 'error': '%s'}", + job_id, + str(timeout_err), + ) + _handle_job_failure( + span, + project_id, + organization_id, + job_id, + timeout_err, + collection_job, + creation_request, + ) + raise + + except Exception as err: + logger.error( + "[create_collection.execute_setup_job] Setup failed | job_id=%s, error=%s", + job_id, + str(err), + exc_info=True, + ) + _handle_job_failure( + span, + project_id, + organization_id, + job_id, + err, + collection_job, + creation_request, + ) + raise + + +def execute_batch_job( + request: dict, + project_id: int, + organization_id: int, + task_id: str, + job_id: str, + task_instance, + vector_store_id: str | None, + batch_number: int, + batch_doc_ids: list[str], + remaining_batches: list[list[str]], +) -> None: + """ + Phase 2: Upload one batch of documents to the vector store. + - Uploads the batch via provider.create(); raises immediately on failure + - Checkpoints progress to the DB + - If more batches remain, queues the next batch task + - If this is the last batch, finalizes: creates Collection, links docs, marks job SUCCESSFUL + """ + collection_job = None + result = None + creation_request = None + provider = None + + with log_context( + tag="collection", + lifecycle="collection.create.execute_batch_job", + action="create", + collection_job_id=job_id, + task_id=task_id, + project_id=project_id, + organization_id=organization_id, + ), tracer.start_as_current_span("collections.create.execute_batch_job") as span: + span.set_attribute("collection.job_id", str(job_id)) + span.set_attribute("kaapi.project_id", project_id) + span.set_attribute("kaapi.organization_id", organization_id) + + try: + batch_start_time = time.time() + creation_request = CreationRequest(**request) + span.set_attribute("collection.provider", str(creation_request.provider)) + + job_uuid = UUID(job_id) + trace_id = correlation_id.get() or "N/A" + + logger.info( + "[create_collection.execute_batch_job] Starting batch | " + "job_id=%s, batch_number=%d, doc_count=%d, remaining_batches=%d", + job_id, + batch_number, + len(batch_doc_ids), + len(remaining_batches), + ) + + all_doc_ids_this_batch = [UUID(d) for d in batch_doc_ids] + + with Session(engine) as session: provider = get_llm_provider( session=session, provider=creation_request.provider, @@ -261,38 +365,104 @@ def execute_job( organization_id=organization_id, ) - with tracer.start_as_current_span("collections.create.provider"): - result = provider.create( - collection_request=creation_request, - storage=storage, - documents=flat_docs, + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + batch_docs = ( + document_crud.read_each(all_doc_ids_this_batch) + if all_doc_ids_this_batch + else [] ) + for doc in batch_docs: + session.expunge(doc) - llm_service_id = result.llm_service_id - llm_service_name = result.llm_service_name + collection_result = provider.create( + batch_docs, + vector_store_id=vector_store_id, + ) + result = collection_result + resolved_vector_store_id = collection_result.llm_service_id with Session(engine) as session: - collection_crud = CollectionCrud(session, project_id) - collection_id = uuid4() + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.read_one(job_uuid) + already_uploaded = collection_job.documents_uploaded or [] + now_uploaded = already_uploaded + [ + str(d) for d in all_doc_ids_this_batch + ] + + collection_job = collection_job_crud.update( + job_uuid, + CollectionJobUpdate( + current_batch_number=batch_number, + documents_uploaded=now_uploaded, + ), + ) + + logger.info( + "[create_collection.execute_batch_job] Batch %d complete | " + "doc_count=%d, job_id=%s", + batch_number, + len(all_doc_ids_this_batch), + job_id, + ) + + if remaining_batches: + start_collection_batch_job( + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + vector_store_id=resolved_vector_store_id, + batch_number=batch_number + 1, + batch_doc_ids=remaining_batches[0], + remaining_batches=remaining_batches[1:], + request=request, + organization_id=organization_id, + ) + logger.info( + "[create_collection.execute_batch_job] Batch %d/%d done, next batch queued | " + "job_id=%s, elapsed=%.2fs", + batch_number, + batch_number + len(remaining_batches), + job_id, + time.time() - batch_start_time, + ) + return + # Final batch: collection_result already has assistant/vector_store finalized + finalize_start_time = time.time() + + with Session(engine) as session: + all_uploaded_ids = [UUID(d) for d in now_uploaded] + document_crud = DocumentCrud(session, project_id) + all_docs = ( + document_crud.read_each(all_uploaded_ids) + if all_uploaded_ids + else [] + ) + for doc in all_docs: + session.expunge(doc) + + with Session(engine) as session: + collection_id = uuid4() collection = Collection( id=collection_id, project_id=project_id, - llm_service_id=llm_service_id, - llm_service_name=llm_service_name, + llm_service_id=collection_result.llm_service_id, + llm_service_name=collection_result.llm_service_name, provider=creation_request.provider, name=creation_request.name, description=creation_request.description, ) + collection_crud = CollectionCrud(session, project_id) collection_crud.create(collection) collection = collection_crud.read_one(collection.id) - if flat_docs: - DocumentCollectionCrud(session).create(collection, flat_docs) + if all_docs: + DocumentCollectionCrud(session).create(collection, all_docs) collection_job_crud = CollectionJobCrud(session, project_id) collection_job = collection_job_crud.update( - collection_job.id, + job_uuid, CollectionJobUpdate( status=CollectionJobStatus.SUCCESSFUL, collection_id=collection.id, @@ -303,14 +473,13 @@ def execute_job( span.set_attribute("collection.id", str(collection_id)) - elapsed = time.time() - start_time logger.info( - "[create_collection.execute_job] Collection created: %s | Time: %.2fs | Files: %d | Total Size: %s MB | Types: %s", + "[create_collection.execute_batch_job] All batches done, collection created: %s | " + "finalize_time=%.2fs, total_time=%.2fs, total_docs=%d", collection_id, - elapsed, - len(flat_docs), - collection_job.total_size_mb, - list(file_exts), + time.time() - finalize_start_time, + time.time() - batch_start_time, + len(all_docs), ) if creation_request.callback_url: @@ -324,7 +493,7 @@ def execute_job( except (Timeout, SoftTimeLimitExceeded) as err: timeout_err = TimeoutError("Task exceeded soft time limit") logger.warning( - "[create_collection.execute_job] Collection Creation Timed Out | {'collection_job_id': '%s', 'error': '%s'}", + "[create_collection.execute_batch_job] Collection Creation Timed Out | {'collection_job_id': '%s', 'error': '%s'}", job_id, str(timeout_err), ) @@ -343,7 +512,7 @@ def execute_job( except Exception as err: logger.error( - "[create_collection.execute_job] Collection Creation Failed | {'collection_job_id': '%s', 'error': '%s'}", + "[create_collection.execute_batch_job] Collection Creation Failed | {'collection_job_id': '%s', 'error': '%s'}", job_id, str(err), exc_info=True, diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 6985ac78e..8e01a8bac 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -2,14 +2,12 @@ import json import ast import re -from uuid import UUID from fastapi import HTTPException -from sqlmodel import select from app.crud import CollectionCrud from app.api.deps import SessionDep -from app.models import DocumentCollection, Collection, CollectionPublic, Document +from app.models import Collection, CollectionPublic, Document logger = logging.getLogger(__name__) @@ -19,7 +17,6 @@ MAX_DOC_SIZE_MB = 25 # 25 MB maximum per document # Maximum batch size for uploading documents to vector store -# Derived from MAX_DOC_SIZE + buffer to ensure single docs always fit MAX_BATCH_SIZE_KB = (MAX_DOC_SIZE_MB + 5) * 1024 # 30 MB in KB (25 + 5 MB buffer) MAX_BATCH_COUNT = 200 # Maximum documents per batch @@ -83,7 +80,7 @@ def batch_documents(documents: list[Document]) -> list[list[Document]]: current_batch_size_kb = 0 for doc in documents: - doc_size_kb = doc.file_size_kb or 0 + doc_size_kb = doc.file_size_kb would_exceed_size = (current_batch_size_kb + doc_size_kb) > MAX_BATCH_SIZE_KB would_exceed_count = len(current_batch) >= MAX_BATCH_COUNT @@ -112,30 +109,6 @@ def batch_documents(documents: list[Document]) -> list[list[Document]]: return docs_batches -# Even though this function is used in the documents router, it's kept here for now since the assistant creation logic will -# eventually be removed from Kaapi. Once that happens, this function can be safely deleted - -def pick_service_for_documennt(session, doc_id: UUID, a_crud, v_crud): - """ - Return the correct remote (v_crud or a_crud) for this document - by inspecting an active linked Collection's llm_service_name. - Defaults to a_crud if not vector store. - """ - coll = session.exec( - select(Collection) - .join(DocumentCollection, DocumentCollection.collection_id == Collection.id) - .where( - DocumentCollection.document_id == doc_id, - Collection.deleted_at.is_(None), - ) - .limit(1) - ).first() - - service = ( - (getattr(coll, "llm_service_name", "") or "").strip().lower() if coll else "" - ) - return v_crud if service == get_service_name("openai") else a_crud - - def ensure_unique_name( session: SessionDep, project_id: int, @@ -156,35 +129,12 @@ def ensure_unique_name( def to_collection_public(collection: Collection) -> CollectionPublic: - """ - Convert a Collection DB model to CollectionPublic response model. - - Maps fields based on service type: - - If llm_service_name is a vector store (matches get_service_name pattern), - use knowledge_base_id/knowledge_base_provider - - Otherwise (assistant), use llm_service_id/llm_service_name - """ - is_vector_store = collection.llm_service_name == get_service_name( - collection.provider + return CollectionPublic( + id=collection.id, + knowledge_base_id=collection.llm_service_id, + knowledge_base_provider=collection.llm_service_name, + project_id=collection.project_id, + inserted_at=collection.inserted_at, + updated_at=collection.updated_at, + deleted_at=collection.deleted_at, ) - - if is_vector_store: - return CollectionPublic( - id=collection.id, - knowledge_base_id=collection.llm_service_id, - knowledge_base_provider=collection.llm_service_name, - project_id=collection.project_id, - inserted_at=collection.inserted_at, - updated_at=collection.updated_at, - deleted_at=collection.deleted_at, - ) - else: - return CollectionPublic( - id=collection.id, - llm_service_id=collection.llm_service_id, - llm_service_name=collection.llm_service_name, - project_id=collection.project_id, - inserted_at=collection.inserted_at, - updated_at=collection.updated_at, - deleted_at=collection.deleted_at, - ) diff --git a/backend/app/services/collections/providers/base.py b/backend/app/services/collections/providers/base.py index 36283d1fa..38355f7d5 100644 --- a/backend/app/services/collections/providers/base.py +++ b/backend/app/services/collections/providers/base.py @@ -2,7 +2,7 @@ from typing import Any from app.core.cloud.storage import CloudStorage -from app.models import CreationRequest, Collection, Document +from app.models import Collection, Document class BaseProvider(ABC): @@ -11,56 +11,49 @@ class BaseProvider(ABC): All provider implementations (OpenAI, Bedrock, etc.) must inherit from this class and implement the required methods. - Providers handle creation of collection and - optional assistant/agent creation backed by those collections. + Providers handle creation of vector store collections. Attributes: client: The provider-specific client instance """ def __init__(self, client: Any) -> None: - """Initialize provider with client. - - Args: - client: Provider-specific client instance - """ self.client = client @abstractmethod - def create( + def upload_files( self, - collection_request: CreationRequest, storage: CloudStorage, - documents: list[Document], - ) -> Collection: - """Create collection with documents and optionally an assistant. + docs: list[Document], + project_id: int, + ) -> None: + """Upload all documents to the provider's file storage and persist their file IDs. Args: - collection_request: Collection parameters (name, description, document list, etc.) - storage: Cloud storage instance for file access - documents: Pre-fetched list of Document objects to add to the collection - - Returns: - Collection object with llm_service_id and llm_service_name populated + storage: Cloud storage instance to fetch raw file bytes from + docs: Documents to upload + project_id: Project ID used to persist the provider file IDs to the DB """ - raise NotImplementedError("Providers must implement execute method") + raise NotImplementedError("Providers must implement upload_files method") @abstractmethod - def delete(self, collection: Collection) -> None: - """Delete remote resources associated with a collection. - - Called when a collection is being deleted and remote resources need to be cleaned up. + def create( + self, + docs: list[Document], + vector_store_id: str | None = None, + ) -> Collection: + """Upload docs batch to vector store (creating it if vector_store_id is None). + Returns Collection with llm_service_id set to the vector store ID.""" + raise NotImplementedError("Providers must implement create method") - Args: - llm_service_id: ID of the resource to delete - llm_service_name: Name of the service (determines resource type) - """ + @abstractmethod + def delete(self, collection: Collection) -> None: + """Delete remote resources associated with a collection.""" raise NotImplementedError("Providers must implement delete method") - def get_provider_name(self) -> str: - """Get the name of the provider. + def get_existing_file_id(self, _doc: Document) -> str | None: + """Return the already-uploaded file ID for this provider, or None to trigger upload.""" + return None - Returns: - Provider name (e.g., "openai", "bedrock", "pinecone") - """ + def get_provider_name(self) -> str: return self.__class__.__name__.replace("Provider", "").lower() diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py index f52e83394..dbcbcc7fd 100644 --- a/backend/app/services/collections/providers/openai.py +++ b/backend/app/services/collections/providers/openai.py @@ -1,13 +1,17 @@ import logging +from io import BytesIO from typing import List from openai import OpenAI +from sqlmodel import Session from app.services.collections.providers import BaseProvider from app.core.cloud.storage import CloudStorage -from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud -from app.services.collections.helpers import get_service_name, batch_documents -from app.models import CreationRequest, Collection, Document +from app.core.db import engine +from app.crud import DocumentCrud +from app.crud.rag import OpenAIVectorStoreCrud +from app.services.collections.helpers import get_service_name +from app.models import Collection, Document logger = logging.getLogger(__name__) @@ -20,65 +24,95 @@ def __init__(self, client: OpenAI): super().__init__(client) self.client = client - def create( + def get_existing_file_id(self, doc: Document) -> str | None: + return doc.openai_file_id + + def upload_files( self, - collection_request: CreationRequest, storage: CloudStorage, - documents: List[Document], + docs: list[Document], + project_id: int, + ) -> None: + for doc in docs: + if self.get_existing_file_id(doc): + continue + + try: + content = storage.get(doc.object_store_url) + if doc.file_size_kb is None: + doc.file_size_kb = round(len(content) / 1024, 2) + f_obj = BytesIO(content) + f_obj.name = doc.fname + uploaded = self.client.files.create(file=f_obj, purpose="assistants") + except Exception as err: + logger.error( + "[OpenAIProvider.upload_files] Failed to upload file | doc_id=%s, error=%s", + doc.id, + str(err), + exc_info=True, + ) + raise + + doc.openai_file_id = uploaded.id + try: + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + db_doc = document_crud.read_one(doc.id) + db_doc.openai_file_id = uploaded.id + db_doc.file_size_kb = doc.file_size_kb + document_crud.update(db_doc) + except Exception as err: + logger.error( + "[OpenAIProvider.upload_files] DB persistence failed, rolling back OpenAI file | " + "doc_id=%s, openai_file_id=%s, error=%s", + doc.id, + uploaded.id, + str(err), + exc_info=True, + ) + try: + self.client.files.delete(uploaded.id) + logger.info( + "[OpenAIProvider.upload_files] Rolled back OpenAI file | " + "doc_id=%s, openai_file_id=%s", + doc.id, + uploaded.id, + ) + except Exception as delete_err: + logger.error( + "[OpenAIProvider.upload_files] Rollback failed, file is orphaned | " + "doc_id=%s, openai_file_id=%s, error=%s", + doc.id, + uploaded.id, + str(delete_err), + ) + doc.openai_file_id = None + raise + + def create( + self, + docs: List[Document], + vector_store_id: str | None = None, ) -> Collection: - """ - Create OpenAI vector store with documents and optionally an assistant. - docs_batches must be pre-fetched inside a DB session before this call. - """ try: - docs_batches = batch_documents(documents) vector_store_crud = OpenAIVectorStoreCrud(self.client) - vector_store = vector_store_crud.create() - - list(vector_store_crud.update(vector_store.id, storage, docs_batches)) - - logger.info( - "[OpenAIProvider.create] Vector store created | " - f"vector_store_id={vector_store.id}, batches={len(docs_batches)}" - ) - - # Check if we need to create an assistant (based on assistant options in request) - with_assistant = ( - collection_request.model is not None - and collection_request.instructions is not None - ) - if with_assistant: - assistant_crud = OpenAIAssistantCrud(self.client) - - assistant_options = { - "model": collection_request.model, - "instructions": collection_request.instructions, - "temperature": collection_request.temperature, - } - filtered_options = { - k: v for k, v in assistant_options.items() if v is not None - } - assistant = assistant_crud.create(vector_store.id, **filtered_options) + if vector_store_id is None: + vector_store = vector_store_crud.create() + vector_store_id = vector_store.id + if docs: + vector_store_crud.update(vector_store_id, docs) logger.info( - "[OpenAIProvider.create] Assistant created | " - f"assistant_id={assistant.id}, vector_store_id={vector_store.id}" + "[OpenAIProvider.create] Batch uploaded | vector_store_id=%s, doc_count=%d", + vector_store_id, + len(docs), ) - return Collection( - llm_service_id=assistant.id, - llm_service_name=filtered_options.get("model", "assistant"), - ) - else: - logger.info( - "[OpenAIProvider.create] Skipping assistant creation | with_assistant=False" - ) - - return Collection( - llm_service_id=vector_store.id, - llm_service_name=get_service_name("openai"), - ) + return Collection( + llm_service_id=vector_store_id, + llm_service_name=get_service_name("openai"), + ) except Exception as e: logger.error( @@ -88,26 +122,14 @@ def create( raise def delete(self, collection: Collection) -> None: - """Delete OpenAI resources (assistant or vector store). - - Determines what to delete based on llm_service_name: - - If assistant was created, delete the assistant (which also removes the vector store) - - If only vector store was created, delete the vector store - """ try: - if collection.llm_service_name != get_service_name("openai"): - OpenAIAssistantCrud(self.client).delete(collection.llm_service_id) - logger.info( - f"[OpenAIProvider.delete] Deleted assistant | assistant_id={collection.llm_service_id}" - ) - else: - OpenAIVectorStoreCrud(self.client).delete(collection.llm_service_id) - logger.info( - f"[OpenAIProvider.delete] Deleted vector store | vector_store_id={collection.llm_service_id}" - ) + OpenAIVectorStoreCrud(self.client).delete(collection.llm_service_id) + logger.info( + f"[OpenAIProvider.delete] Deleted vector store | vector_store_id={collection.llm_service_id}" + ) except Exception as e: logger.error( - f"[OpenAIProvider.delete] Failed to delete resource | " + f"[OpenAIProvider.delete] Failed to delete vector store | " f"llm_service_id={collection.llm_service_id}, error={str(e)}", exc_info=True, ) diff --git a/backend/app/tests/api/routes/collections/test_collection_delete.py b/backend/app/tests/api/routes/collections/test_collection_delete.py index f7ed400d9..441bca388 100644 --- a/backend/app/tests/api/routes/collections/test_collection_delete.py +++ b/backend/app/tests/api/routes/collections/test_collection_delete.py @@ -9,7 +9,7 @@ from app.tests.utils.auth import TestAuthContext from app.models import CollectionJobStatus from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_assistant_collection +from app.tests.utils.collection import get_vector_store_collection @patch("app.api.routes.collections.delete_service.start_job") @@ -28,7 +28,7 @@ def test_delete_collection_calls_start_job_and_returns_job( - Calls delete_service.start_job with correct arguments """ project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) resp = client.request( "DELETE", @@ -72,7 +72,7 @@ def test_delete_collection_with_callback_url_passes_it_to_start_job( into the DeletionRequest and then into delete_service.start_job. """ project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) payload = { "callback_url": "https://example.com/collections/delete-callback", diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index f41623a54..73467251b 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -6,10 +6,7 @@ from app.core.config import settings from app.tests.utils.utils import get_project, get_document -from app.tests.utils.collection import ( - get_assistant_collection, - get_vector_store_collection, -) +from app.tests.utils.collection import get_vector_store_collection from app.crud import DocumentCollectionCrud from app.models import Collection, Document from app.services.collections.helpers import get_service_name @@ -37,20 +34,20 @@ def link_document_to_collection( return document -def test_collection_info_returns_assistant_collection_with_docs( +def test_collection_info_returns_collection_with_docs( client: TestClient, db: Session, user_api_key_header: dict[str, str], ) -> None: """ Happy path: - - Assistant-style collection (get_assistant_collection) + - Vector store collection - include_docs = True (default) - At least one document linked """ project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) document = link_document_to_collection(db, collection) @@ -86,7 +83,7 @@ def test_collection_info_include_docs_false_returns_no_docs( When include_docs=false, the endpoint should not populate the documents list. """ project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) link_document_to_collection(db, collection) @@ -102,8 +99,8 @@ def test_collection_info_include_docs_false_returns_no_docs( payload = data["data"] assert payload["id"] == str(collection.id) - assert payload["llm_service_name"] == "gpt-4o" - assert payload["llm_service_id"] == collection.llm_service_id + assert payload["knowledge_base_provider"] == collection.llm_service_name + assert payload["knowledge_base_id"] == collection.llm_service_id assert payload["documents"] is None @@ -117,7 +114,7 @@ def test_collection_info_pagination_skip_and_limit( We create multiple document links and then request a paginated slice. """ project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) documents = db.exec( select(Document).where(Document.deleted_at.is_(None)).limit(2) @@ -205,7 +202,7 @@ def test_collection_info_include_docs_and_url( the endpoint returns documents with their URLs. """ project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) document = link_document_to_collection(db, collection) diff --git a/backend/app/tests/api/routes/collections/test_collection_job_info.py b/backend/app/tests/api/routes/collections/test_collection_job_info.py index ad95b6769..1c7f3855c 100644 --- a/backend/app/tests/api/routes/collections/test_collection_job_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_job_info.py @@ -5,7 +5,7 @@ from app.core.config import settings from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_assistant_collection, get_collection_job +from app.tests.utils.collection import get_vector_store_collection, get_collection_job from app.models import ( CollectionActionType, CollectionJobStatus, @@ -41,7 +41,7 @@ def test_collection_info_create_successful( headers = user_api_key_header project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) collection_job = get_collection_job( db, project, collection_id=collection.id, status=CollectionJobStatus.SUCCESSFUL @@ -61,8 +61,8 @@ def test_collection_info_create_successful( assert data["collection"] is not None col = data["collection"] assert col["id"] == str(collection.id) - assert col["llm_service_id"] == collection.llm_service_id - assert col["llm_service_name"] == "gpt-4o" + assert col["knowledge_base_id"] == collection.llm_service_id + assert col["knowledge_base_provider"] == collection.llm_service_name def test_collection_info_create_failed( @@ -101,7 +101,7 @@ def test_collection_info_delete_successful( headers = user_api_key_header project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) collection_job = get_collection_job( db, @@ -133,7 +133,7 @@ def test_collection_info_delete_failed( headers = user_api_key_header project = get_project(db, "Dalgo") - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) collection_job = get_collection_job( db, diff --git a/backend/app/tests/api/routes/collections/test_collection_list.py b/backend/app/tests/api/routes/collections/test_collection_list.py index e1cf3d589..2f746117b 100644 --- a/backend/app/tests/api/routes/collections/test_collection_list.py +++ b/backend/app/tests/api/routes/collections/test_collection_list.py @@ -3,10 +3,7 @@ from app.core.config import settings from app.tests.utils.utils import get_project -from app.tests.utils.collection import ( - get_assistant_collection, - get_vector_store_collection, -) +from app.tests.utils.collection import get_vector_store_collection from app.services.collections.helpers import get_service_name @@ -34,14 +31,13 @@ def test_list_collections_returns_api_response( assert isinstance(data["data"], list) -def test_list_collections_includes_assistant_collection( +def test_list_collections_includes_new_collection( db: Session, client: TestClient, user_api_key_header: dict[str, str], ) -> None: """ - Ensure that a newly created assistant-style collection (get_assistant_collection) - appears in the list for the current project. + Ensure that a newly created collection appears in the list for the current project. """ project = get_project(db, "Dalgo") @@ -52,7 +48,7 @@ def test_list_collections_includes_assistant_collection( ) assert response_before.status_code == 200 - collection = get_assistant_collection(db, project) + collection = get_vector_store_collection(db, project) response_after = client.get( f"{settings.API_V1_STR}/collections", diff --git a/backend/app/tests/api/routes/collections/test_create_collections.py b/backend/app/tests/api/routes/collections/test_create_collections.py index b51631939..5934394b3 100644 --- a/backend/app/tests/api/routes/collections/test_create_collections.py +++ b/backend/app/tests/api/routes/collections/test_create_collections.py @@ -11,14 +11,9 @@ from app.models.collection import CreationRequest -def _extract_metadata(body: dict) -> dict | None: - return body.get("metadata") or body.get("meta") - - def _create_test_document( db: Session, project_id: int, file_size: float = 1 ) -> Document: - """Helper to create a test document.""" doc = Document( id=uuid4(), fname="test_document.txt", @@ -33,20 +28,16 @@ def _create_test_document( @patch("app.api.routes.collections.create_service.start_job") -def test_collection_creation_with_assistant_calls_start_job_and_returns_job( +def test_collection_creation_calls_start_job_and_returns_job( mock_start_job: Any, client: TestClient, user_api_key_header: dict[str, str], user_api_key: TestAuthContext, db: Session, ) -> None: - # Create a test document in the database doc = _create_test_document(db, user_api_key.project_id, file_size=2) creation_data = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, documents=[doc.id], callback_url=None, ) @@ -65,14 +56,12 @@ def test_collection_creation_with_assistant_calls_start_job_and_returns_job( assert data["job_inserted_at"] assert data["job_updated_at"] - assert _extract_metadata(body) in (None, {}) - mock_start_job.assert_called_once() kwargs = mock_start_job.call_args.kwargs assert "db" in kwargs assert kwargs["project_id"] == user_api_key.project_id assert kwargs["organization_id"] == user_api_key.organization_id - assert kwargs["with_assistant"] is True + assert "with_assistant" not in kwargs returned_job_id = UUID(data["job_id"]) assert kwargs["collection_job_id"] == returned_job_id @@ -80,82 +69,3 @@ def test_collection_creation_with_assistant_calls_start_job_and_returns_job( assert kwargs["request"].model_dump(mode="json") == creation_data.model_dump( mode="json" ) - - -@patch("app.api.routes.collections.create_service.start_job") -def test_collection_creation_vector_only_adds_metadata_and_sets_with_assistant_false( - mock_start_job: Any, - client: TestClient, - user_api_key_header: dict[str, str], - user_api_key: TestAuthContext, - db: Session, -) -> None: - # Create a test document in the database - doc = _create_test_document(db, user_api_key.project_id, file_size=5) - - creation_data = CreationRequest( - temperature=0.000001, - documents=[doc.id], - callback_url=None, - ) - - resp = client.post( - f"{settings.API_V1_STR}/collections", - json=creation_data.model_dump(mode="json"), - headers=user_api_key_header, - ) - - assert resp.status_code == 200 - body = resp.json() - - data = body["data"] - assert data["status"] == CollectionJobStatus.PENDING - - meta = _extract_metadata(body) - assert isinstance(meta, dict) - assert "vector store only" in meta.get("note", "").lower() - - mock_start_job.assert_called_once() - kwargs = mock_start_job.call_args.kwargs - assert kwargs["project_id"] == user_api_key.project_id - assert kwargs["organization_id"] == user_api_key.organization_id - assert kwargs["with_assistant"] is False - assert kwargs["collection_job_id"] == UUID(data["job_id"]) - assert kwargs["request"].model_dump(mode="json") == creation_data.model_dump( - mode="json" - ) - - -def test_collection_creation_vector_only_request_validation_error( - client: TestClient, - user_api_key_header: dict[str, str], - user_api_key: TestAuthContext, - db: Session, -) -> None: - # Create a test document in the database - doc = _create_test_document(db, user_api_key.project_id) - - payload = { - "model": "gpt-4o", - "temperature": 0.000001, - "documents": [str(doc.id)], - "callback_url": None, - } - - resp = client.post( - f"{settings.API_V1_STR}/collections", - json=payload, - headers=user_api_key_header, - ) - - assert resp.status_code == 422 - body = resp.json() - assert body["success"] is False - assert body["data"] is None - assert body["metadata"] is None - assert body["errors"] - assert any( - "To create an Assistant, provide BOTH 'model' and 'instructions'" - in e["message"] - for e in body["errors"] - ) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py index c0751fe72..64abc4385 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py @@ -4,29 +4,17 @@ from app.crud import CollectionCrud from app.models import APIKey, Collection, ProviderType -from app.crud.rag import OpenAIAssistantCrud +from app.crud.rag import OpenAIVectorStoreCrud from app.tests.utils.utils import get_project from app.tests.utils.document import DocumentStore -def get_assistant_collection_for_delete( - db: Session, client=None, project_id: int = None -) -> Collection: - project = get_project(db) - if client is None: - client = OpenAI(api_key="test_api_key") - +def get_vector_store_collection(client: OpenAI, project_id: int) -> Collection: vector_store = client.vector_stores.create() - assistant = client.beta.assistants.create( - model="gpt-4o", - tools=[{"type": "file_search"}], - tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}}, - ) - return Collection( project_id=project_id, - llm_service_id=assistant.id, - llm_service_name="gpt-4o", + llm_service_id=vector_store.id, + llm_service_name="openai vector store", provider=ProviderType.openai, ) @@ -39,26 +27,24 @@ def test_delete_marks_deleted(self, db: Session) -> None: project = get_project(db) client = OpenAI(api_key="sk-test-key") - assistant = OpenAIAssistantCrud(client) - collection = get_assistant_collection_for_delete( - db, client, project_id=project.id - ) + v_crud = OpenAIVectorStoreCrud(client) + collection = get_vector_store_collection(client, project_id=project.id) crud = CollectionCrud(db, collection.project_id) - collection_ = crud.delete(collection, assistant) + collection_ = crud.delete(collection, v_crud) assert collection_.deleted_at is not None @openai_responses.mock() def test_delete_follows_insert(self, db: Session) -> None: + project = get_project(db) client = OpenAI(api_key="sk-test-key") - assistant = OpenAIAssistantCrud(client) - project = get_project(db) - collection = get_assistant_collection_for_delete(db, project_id=project.id) + v_crud = OpenAIVectorStoreCrud(client) + collection = get_vector_store_collection(client, project_id=project.id) crud = CollectionCrud(db, collection.project_id) - collection_ = crud.delete(collection, assistant) + collection_ = crud.delete(collection, v_crud) assert collection_.inserted_at <= collection_.deleted_at @@ -76,15 +62,13 @@ def test_delete_document_deletes_collections(self, db: Session) -> None: client = OpenAI(api_key="sk-test-key") resources = [] for _ in range(self._n_collections): - coll = get_assistant_collection_for_delete( - db, client, project_id=project.id - ) + coll = get_vector_store_collection(client, project_id=project.id) crud = CollectionCrud(db, project_id=project.id) collection = crud.create(coll, documents) resources.append((crud, collection)) ((crud, _), *_) = resources - assistant = OpenAIAssistantCrud(client) - crud.delete(documents[0], assistant) + v_crud = OpenAIVectorStoreCrud(client) + crud.delete(documents[0], v_crud) assert all(y.deleted_at for (_, y) in resources) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py index 0382b8830..56cde1ca2 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py @@ -1,29 +1,24 @@ -from openai_responses import OpenAIMock -from openai import OpenAI from sqlmodel import Session, delete from app.crud import CollectionCrud from app.models import Collection from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_assistant_collection +from app.tests.utils.collection import get_vector_store_collection -def create_collections(db: Session, n: int) -> Collection: +def create_collections(db: Session, n: int) -> int: crud = None project = get_project(db) - openai_mock = OpenAIMock() - with openai_mock.router: - client = OpenAI(api_key="sk-test-key") - for _ in range(n): - collection = get_assistant_collection(db, project=project) - store = DocumentStore(db, project_id=collection.project_id) - documents = store.fill(1) - if crud is None: - crud = CollectionCrud(db, collection.project_id) - crud.create(collection, documents) - - return crud.project_id + for _ in range(n): + collection = get_vector_store_collection(db, project=project) + store = DocumentStore(db, project_id=collection.project_id) + documents = store.fill(1) + if crud is None: + crud = CollectionCrud(db, collection.project_id) + crud.create(collection, documents) + + return crud.project_id class TestCollectionReadAll: diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py index 2fc5f7676..a85b06ccb 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py @@ -1,6 +1,4 @@ import pytest -from openai import OpenAI -from openai_responses import OpenAIMock from fastapi import HTTPException from sqlmodel import Session @@ -8,18 +6,16 @@ from app.models import Collection from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_assistant_collection +from app.tests.utils.collection import get_vector_store_collection def mk_collection(db: Session) -> Collection: - openai_mock = OpenAIMock() project = get_project(db) - with openai_mock.router: - collection = get_assistant_collection(db, project=project) - store = DocumentStore(db, project_id=collection.project_id) - documents = store.fill(1) - crud = CollectionCrud(db, collection.project_id) - return crud.create(collection, documents) + collection = get_vector_store_collection(db, project=project) + store = DocumentStore(db, project_id=collection.project_id) + documents = store.fill(1) + crud = CollectionCrud(db, collection.project_id) + return crud.create(collection, documents) class TestDatabaseReadOne: diff --git a/backend/app/tests/services/collections/providers/test_openai_provider.py b/backend/app/tests/services/collections/providers/test_openai_provider.py index b21577d49..cedf5ed2f 100644 --- a/backend/app/tests/services/collections/providers/test_openai_provider.py +++ b/backend/app/tests/services/collections/providers/test_openai_provider.py @@ -1,8 +1,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest +from openai import OpenAIError +from app.crud.rag.open_ai import OpenAIVectorStoreCrud from app.services.collections.providers.openai import OpenAIProvider from app.models.collection import Collection from app.services.collections.helpers import get_service_name @@ -12,18 +15,10 @@ ) -def test_create_openai_vector_store_only() -> None: +def test_create_openai_vector_store() -> None: client = get_mock_openai_client_with_vector_store() provider = OpenAIProvider(client=client) - collection_request = SimpleNamespace( - documents=["doc1", "doc2"], - model=None, - instructions=None, - temperature=None, - ) - - storage = MagicMock() documents = [ SimpleNamespace(file_size_kb=10), SimpleNamespace(file_size_kb=20), @@ -35,104 +30,430 @@ def test_create_openai_vector_store_only() -> None: ) as vector_store_crud_cls: vector_store_crud = vector_store_crud_cls.return_value vector_store_crud.create.return_value = MagicMock(id=vector_store_id) - vector_store_crud.update.return_value = iter([None]) + vector_store_crud.update.return_value = None - collection = provider.create( - collection_request, - storage, - documents, - ) + collection = provider.create(documents) assert isinstance(collection, Collection) assert collection.llm_service_id == vector_store_id assert collection.llm_service_name == get_service_name("openai") -def test_create_openai_with_assistant() -> None: - client = get_mock_openai_client_with_vector_store() +def test_delete_openai_vector_store() -> None: + client = MagicMock() provider = OpenAIProvider(client=client) - collection_request = SimpleNamespace( - documents=["doc1"], - model="gpt-4o", - instructions="You are helpful", - temperature=0.7, + collection = Collection( + llm_service_id=generate_openai_id("vs_"), + llm_service_name=get_service_name("openai"), ) - storage = MagicMock() - documents = [SimpleNamespace(file_size_kb=10)] - vector_store_id = generate_openai_id("vs_") - assistant_id = generate_openai_id("asst_") - with patch( "app.services.collections.providers.openai.OpenAIVectorStoreCrud" - ) as vector_store_crud_cls, patch( - "app.services.collections.providers.openai.OpenAIAssistantCrud" - ) as assistant_crud_cls: + ) as vector_store_crud_cls: vector_store_crud = vector_store_crud_cls.return_value - vector_store_crud.create.return_value = MagicMock(id=vector_store_id) - vector_store_crud.update.return_value = iter([None]) + provider.delete(collection) + + vector_store_crud.delete.assert_called_once_with(collection.llm_service_id) - assistant_crud = assistant_crud_cls.return_value - assistant_crud.create.return_value = MagicMock(id=assistant_id) - collection = provider.create( - collection_request, - storage, - documents, - ) +# --------------------------------------------------------------------------- +# upload_files +# --------------------------------------------------------------------------- - assert collection.llm_service_id == assistant_id - assert collection.llm_service_name == "gpt-4o" +def _make_doc(*, openai_file_id=None, file_size_kb=None): + return SimpleNamespace( + id=uuid4(), + fname="test.md", + object_store_url="s3://bucket/test.md", + openai_file_id=openai_file_id, + file_size_kb=file_size_kb, + ) -def test_delete_openai_assistant() -> None: + +def _patch_session_and_crud(): + """Patches Session and DocumentCrud used inside upload_files.""" + session_patcher = patch("app.services.collections.providers.openai.Session") + crud_patcher = patch("app.services.collections.providers.openai.DocumentCrud") + return session_patcher, crud_patcher + + +def test_upload_files_skips_doc_with_existing_openai_file_id() -> None: client = MagicMock() provider = OpenAIProvider(client=client) + storage = MagicMock() + doc = _make_doc(openai_file_id="file-already-exists", file_size_kb=10.0) - collection = Collection( - llm_service_id=generate_openai_id("asst_"), - llm_service_name="gpt-4o", - provider="openai", - project_id=1, + session_p, crud_p = _patch_session_and_crud() + with session_p, crud_p: + provider.upload_files(storage, [doc], project_id=1) + + storage.get.assert_not_called() + client.files.create.assert_not_called() + + +def test_upload_files_uploads_doc_and_sets_openai_file_id() -> None: + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-new-abc") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"file content" + + doc = _make_doc(file_size_kb=10.0) + + mock_crud = MagicMock() + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p as MockDocCrud: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + MockDocCrud.return_value = mock_crud + + provider.upload_files(storage, [doc], project_id=1) + + assert doc.openai_file_id == "file-new-abc" + client.files.create.assert_called_once() + _, kwargs = client.files.create.call_args + assert kwargs.get("purpose") == "assistants" + mock_crud.update.assert_called_once() + + +def test_upload_files_sets_file_size_kb_when_none() -> None: + """file_size_kb should be computed from content length if not already set.""" + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-xyz") + provider = OpenAIProvider(client=client) + + content = b"x" * 2048 # 2 KB + storage = MagicMock() + storage.get.return_value = content + + doc = _make_doc(file_size_kb=None) + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + + provider.upload_files(storage, [doc], project_id=1) + + assert doc.file_size_kb == round(len(content) / 1024, 2) + + +def test_upload_files_preserves_existing_file_size_kb() -> None: + """file_size_kb should not be overwritten if already set.""" + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-xyz") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"x" * 4096 + + doc = _make_doc(file_size_kb=99.0) + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + + provider.upload_files(storage, [doc], project_id=1) + + assert doc.file_size_kb == 99.0 + + +def test_upload_files_updates_db_with_file_id_and_size() -> None: + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-db-check") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"content" + + doc = _make_doc(file_size_kb=5.0) + mock_db_doc = MagicMock() + mock_crud = MagicMock() + mock_crud.read_one.return_value = mock_db_doc + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p as MockDocCrud: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + MockDocCrud.return_value = mock_crud + + provider.upload_files(storage, [doc], project_id=42) + + MockDocCrud.assert_called_once_with( + MockSession.return_value.__enter__.return_value, 42 ) + mock_crud.read_one.assert_called_once_with(doc.id) + assert mock_db_doc.openai_file_id == "file-db-check" + assert mock_db_doc.file_size_kb == 5.0 + mock_crud.update.assert_called_once_with(mock_db_doc) - with patch( - "app.services.collections.providers.openai.OpenAIAssistantCrud" - ) as assistant_crud_cls: - assistant_crud = assistant_crud_cls.return_value - provider.delete(collection) - assistant_crud.delete.assert_called_once_with(collection.llm_service_id) +def test_upload_files_raises_on_storage_failure() -> None: + client = MagicMock() + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.side_effect = RuntimeError("S3 error") + + doc = _make_doc() + session_p, crud_p = _patch_session_and_crud() + with session_p, crud_p: + with pytest.raises(RuntimeError, match="S3 error"): + provider.upload_files(storage, [doc], project_id=1) -def test_delete_openai_vector_store() -> None: + client.files.create.assert_not_called() + + +def test_upload_files_raises_on_openai_failure() -> None: client = MagicMock() + client.files.create.side_effect = RuntimeError("OpenAI error") provider = OpenAIProvider(client=client) - collection = Collection( - llm_service_id=generate_openai_id("vs_"), - llm_service_name=get_service_name("openai"), + storage = MagicMock() + storage.get.return_value = b"content" + + doc = _make_doc() + + session_p, crud_p = _patch_session_and_crud() + with session_p, crud_p: + with pytest.raises(RuntimeError, match="OpenAI error"): + provider.upload_files(storage, [doc], project_id=1) + + +def test_upload_files_mixed_skips_uploaded_uploads_new() -> None: + """Docs with openai_file_id are skipped; others are uploaded.""" + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-new") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"content" + + already_uploaded = _make_doc(openai_file_id="file-exists", file_size_kb=5.0) + new_doc = _make_doc(file_size_kb=5.0) + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + + provider.upload_files(storage, [already_uploaded, new_doc], project_id=1) + + assert already_uploaded.openai_file_id == "file-exists" + assert new_doc.openai_file_id == "file-new" + client.files.create.assert_called_once() + storage.get.assert_called_once_with(new_doc.object_store_url) + + +def test_upload_files_empty_docs_is_noop() -> None: + client = MagicMock() + provider = OpenAIProvider(client=client) + storage = MagicMock() + + session_p, crud_p = _patch_session_and_crud() + with session_p, crud_p: + provider.upload_files(storage, [], project_id=1) + + storage.get.assert_not_called() + client.files.create.assert_not_called() + + +def test_upload_files_file_object_name_matches_doc_fname() -> None: + """The BytesIO passed to OpenAI must carry the original filename.""" + from io import BytesIO + + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-abc") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"data" + + doc = _make_doc(file_size_kb=1.0) + doc.fname = "report.pdf" + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + provider.upload_files(storage, [doc], project_id=1) + + _, kwargs = client.files.create.call_args + f_obj = kwargs["file"] + assert isinstance(f_obj, BytesIO) + assert f_obj.name == "report.pdf" + + +def test_upload_files_raises_on_db_update_failure() -> None: + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-ok") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"content" + + doc = _make_doc(file_size_kb=1.0) + mock_crud = MagicMock() + mock_crud.read_one.return_value = MagicMock() + mock_crud.update.side_effect = RuntimeError("DB write failed") + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p as MockDocCrud: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + MockDocCrud.return_value = mock_crud + + with pytest.raises(RuntimeError, match="DB write failed"): + provider.upload_files(storage, [doc], project_id=1) + + client.files.delete.assert_called_once_with("file-ok") + assert doc.openai_file_id is None + + +def test_upload_files_db_failure_rollback_delete_error_still_raises_original() -> None: + """If both DB persistence and the rollback delete fail, the original DB error propagates.""" + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-ok") + client.files.delete.side_effect = RuntimeError("delete failed") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"content" + + doc = _make_doc(file_size_kb=1.0) + mock_crud = MagicMock() + mock_crud.read_one.return_value = MagicMock() + mock_crud.update.side_effect = RuntimeError("DB write failed") + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p as MockDocCrud: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + MockDocCrud.return_value = mock_crud + + with pytest.raises(RuntimeError, match="DB write failed"): + provider.upload_files(storage, [doc], project_id=1) + + client.files.delete.assert_called_once_with("file-ok") + assert doc.openai_file_id is None + + +def test_upload_files_first_failure_stops_remaining_docs() -> None: + """If the first doc raises, subsequent docs are never attempted.""" + client = MagicMock() + client.files.create.side_effect = RuntimeError("quota exceeded") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"content" + + doc1 = _make_doc(file_size_kb=1.0) + doc2 = _make_doc(file_size_kb=1.0) + + session_p, crud_p = _patch_session_and_crud() + with session_p, crud_p: + with pytest.raises(RuntimeError, match="quota exceeded"): + provider.upload_files(storage, [doc1, doc2], project_id=1) + + client.files.create.assert_called_once() + assert storage.get.call_count == 1 + + +# --------------------------------------------------------------------------- +# OpenAIVectorStoreCrud.update +# --------------------------------------------------------------------------- + + +def _make_batch(completed: int, failed: int) -> MagicMock: + batch = MagicMock() + batch.file_counts.completed = completed + batch.file_counts.failed = failed + return batch + + +def _make_openai_doc(file_id: str = "file-abc") -> MagicMock: + doc = MagicMock() + doc.openai_file_id = file_id + return doc + + +def test_vector_store_update_skips_when_no_docs() -> None: + client = MagicMock() + crud = OpenAIVectorStoreCrud(client) + crud.update("vs_123", []) + client.vector_stores.file_batches.upload_and_poll.assert_not_called() + + +def test_vector_store_update_succeeds_with_no_failures() -> None: + client = MagicMock() + client.vector_stores.file_batches.upload_and_poll.return_value = _make_batch( + completed=3, failed=0 ) + crud = OpenAIVectorStoreCrud(client) + crud.update("vs_123", [_make_openai_doc() for _ in range(3)]) + client.vector_stores.file_batches.upload_and_poll.assert_called_once() - with patch( - "app.services.collections.providers.openai.OpenAIVectorStoreCrud" - ) as vector_store_crud_cls: - vector_store_crud = vector_store_crud_cls.return_value - provider.delete(collection) - vector_store_crud.delete.assert_called_once_with(collection.llm_service_id) +def test_vector_store_update_raises_on_openai_error() -> None: + client = MagicMock() + client.vector_stores.file_batches.upload_and_poll.side_effect = OpenAIError( + "rate limit" + ) + crud = OpenAIVectorStoreCrud(client) + with pytest.raises(OpenAIError, match="rate limit"): + crud.update("vs_123", [_make_openai_doc()]) -def test_create_propagates_exception() -> None: - provider = OpenAIProvider(client=MagicMock()) - collection_request = SimpleNamespace( - documents=["doc1"], - model=None, - instructions=None, - temperature=None, +def test_vector_store_update_raises_on_partial_failures() -> None: + client = MagicMock() + client.vector_stores.file_batches.upload_and_poll.return_value = _make_batch( + completed=2, failed=1 + ) + crud = OpenAIVectorStoreCrud(client) + + with pytest.raises(RuntimeError, match="1 failed file"): + crud.update("vs_123", [_make_openai_doc() for _ in range(3)]) + + +def test_vector_store_update_raises_on_all_failures() -> None: + client = MagicMock() + client.vector_stores.file_batches.upload_and_poll.return_value = _make_batch( + completed=0, failed=2 + ) + crud = OpenAIVectorStoreCrud(client) + + with pytest.raises(RuntimeError, match="2 failed file"): + crud.update("vs_123", [_make_openai_doc() for _ in range(2)]) + + +def test_vector_store_update_passes_file_ids_to_openai() -> None: + client = MagicMock() + client.vector_stores.file_batches.upload_and_poll.return_value = _make_batch( + completed=2, failed=0 ) + crud = OpenAIVectorStoreCrud(client) + docs = [_make_openai_doc("file-1"), _make_openai_doc("file-2")] + + crud.update("vs_abc", docs) + + _, kwargs = client.vector_stores.file_batches.upload_and_poll.call_args + assert kwargs["vector_store_id"] == "vs_abc" + assert kwargs["file_ids"] == ["file-1", "file-2"] + + +# --------------------------------------------------------------------------- +# create (existing tests below) +# --------------------------------------------------------------------------- + + +def test_create_propagates_exception() -> None: + provider = OpenAIProvider(client=MagicMock()) with patch( "app.services.collections.providers.openai.OpenAIVectorStoreCrud" @@ -140,8 +461,4 @@ def test_create_propagates_exception() -> None: vector_store_crud_cls.return_value.create.side_effect = RuntimeError("boom") with pytest.raises(RuntimeError): - provider.create( - collection_request, - MagicMock(), - [SimpleNamespace(file_size_kb=10)], - ) + provider.create([SimpleNamespace(file_size_kb=10)]) diff --git a/backend/app/tests/services/collections/test_create_collection.py b/backend/app/tests/services/collections/test_create_collection.py index d8ca2829b..c17e794cb 100644 --- a/backend/app/tests/services/collections/test_create_collection.py +++ b/backend/app/tests/services/collections/test_create_collection.py @@ -1,26 +1,27 @@ from typing import Any import os -from pathlib import Path from unittest.mock import patch, MagicMock -from urllib.parse import urlparse import uuid from uuid import UUID, uuid4 +from celery.exceptions import SoftTimeLimitExceeded from gevent import Timeout import pytest -from moto import mock_aws from sqlmodel import Session -from app.core.cloud import AmazonCloudStorageClient from app.core.config import settings from app.crud import CollectionCrud, CollectionJobCrud, DocumentCollectionCrud from app.models import CollectionJobStatus, CollectionJob, CollectionActionType, Project from app.models.collection import CreationRequest -from app.services.collections.create_collection import start_job, execute_job +from app.services.collections.create_collection import ( + start_job, + execute_setup_job, + execute_batch_job, +) from app.tests.utils.llm_provider import get_mock_provider from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection_job, get_assistant_collection +from app.tests.utils.collection import get_collection_job from app.tests.utils.document import DocumentStore @@ -33,30 +34,33 @@ def aws_credentials() -> Any: os.environ["AWS_DEFAULT_REGION"] = settings.AWS_DEFAULT_REGION -def create_collection_job_for_create( - db: Session, - project: Project, - job_id: UUID, -) -> CollectionJob: - """Pre-create a CREATE job with the given id so start_job can update it.""" - return CollectionJobCrud(db, project.id).create( - CollectionJob( - id=job_id, - action_type=CollectionActionType.CREATE, - project_id=project.id, - collection_id=None, - status=CollectionJobStatus.PENDING, - ) - ) +def _mock_provider_with_size(llm_service_id: str, llm_service_name: str): + """Returns a mock provider whose upload_files sets file_size_kb=10.0 on each doc.""" + mock_provider = get_mock_provider(llm_service_id, llm_service_name) + + def _set_file_size(storage, docs, project_id): + for doc in docs: + doc.file_size_kb = 10.0 + + mock_provider.upload_files.side_effect = _set_file_size + return mock_provider + + +def _patch_session(db: Session): + """Context manager that routes all Session(engine) calls to the test db.""" + patcher = patch("app.services.collections.create_collection.Session") + mock_ctor = patcher.start() + mock_ctor.return_value.__enter__.return_value = db + mock_ctor.return_value.__exit__.return_value = False + return patcher + + +# --------------------------------------------------------------------------- +# start_job +# --------------------------------------------------------------------------- def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> None: - """ - start_job should: - - update an existing CollectionJob (status=PENDING, action=CREATE) - - call start_create_collection_job with the correct kwargs - - return the job UUID (same one that was passed in) - """ project = get_project(db) request = CreationRequest( documents=[UUID("f3e86a17-1e6f-41ec-b020-5b08eebef928")], @@ -65,7 +69,7 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non ) job_id = uuid4() - _ = get_collection_job( + get_collection_job( db, project, job_id=job_id, @@ -75,7 +79,7 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non ) with patch( - "app.services.collections.create_collection.start_create_collection_job" + "app.services.collections.create_collection.start_collection_setup_job" ) as mock_schedule: mock_schedule.return_value = "fake-task-id" @@ -84,425 +88,552 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non request=request, project_id=project.id, collection_job_id=job_id, - with_assistant=True, organization_id=project.organization_id, ) - assert returned_job_id == job_id + assert returned_job_id == job_id + mock_schedule.assert_called_once() + kwargs = mock_schedule.call_args.kwargs + assert kwargs["project_id"] == project.id + assert kwargs["organization_id"] == project.organization_id + assert kwargs["job_id"] == str(job_id) + assert kwargs["request"] == request.model_dump(mode="json") - job = CollectionJobCrud(db, project.id).read_one(job_id) - assert job.id == job_id - assert job.project_id == project.id - assert job.status == CollectionJobStatus.PENDING - assert job.action_type in ( - CollectionActionType.CREATE, - CollectionActionType.CREATE.value, - ) - assert job.collection_id is None - mock_schedule.assert_called_once() - kwargs = mock_schedule.call_args.kwargs - assert kwargs["project_id"] == project.id - assert kwargs["organization_id"] == project.organization_id - assert kwargs["job_id"] == str(job_id) - assert kwargs["request"] == request.model_dump(mode="json") +# --------------------------------------------------------------------------- +# execute_setup_job +# --------------------------------------------------------------------------- -@pytest.mark.usefixtures("aws_credentials") -@mock_aws +@patch("app.services.collections.create_collection.get_cloud_storage") @patch("app.services.collections.create_collection.get_llm_provider") -def test_execute_job_success_flow_updates_job_and_creates_collection( - mock_get_llm_provider: MagicMock, db: Session +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_setup_job_marks_processing_and_queues_first_batch( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + mock_get_storage: MagicMock, + db: Session, ) -> None: - """ - execute_job should: - - set task_id on the CollectionJob - - ingest documents into a vector store - - create an OpenAI assistant - - create a Collection with llm fields filled - - link the CollectionJob -> collection_id, set status=successful - - create DocumentCollection links - """ project = get_project(db) - - aws = AmazonCloudStorageClient() - aws.create() - store = DocumentStore(db=db, project_id=project.id) - document = store.put() - s3_key = Path(urlparse(document.object_store_url).path).relative_to("/") - aws.client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=str(s3_key), Body=b"test") - - sample_request = CreationRequest( - documents=[document.id], callback_url=None, provider="openai" - ) + doc = store.put() - mock_get_llm_provider.return_value = get_mock_provider( - llm_service_id="mock_vector_store_id", llm_service_name="openai vector store" + mock_get_provider.return_value = _mock_provider_with_size( + "vs_123", "openai vector store" ) - job_id = uuid4() - _ = get_collection_job( + job = get_collection_job( db, project, - job_id=job_id, action_type=CollectionActionType.CREATE, status=CollectionJobStatus.PENDING, - collection_id=None, ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) + task_id = str(uuid4()) - task_id = uuid4() - - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False - - execute_job( - request=sample_request.model_dump(), + patcher = _patch_session(db) + try: + execute_setup_job( + request=request.model_dump(mode="json"), project_id=project.id, organization_id=project.organization_id, - task_id=str(task_id), - with_assistant=True, - job_id=str(job_id), + task_id=task_id, + job_id=str(job.id), task_instance=None, ) + finally: + patcher.stop() - updated_job = CollectionJobCrud(db, project.id).read_one(job_id) - assert updated_job.task_id == str(task_id) - assert updated_job.status == CollectionJobStatus.SUCCESSFUL - assert updated_job.collection_id is not None - - created_collection = CollectionCrud(db, project.id).read_one( - updated_job.collection_id - ) - assert created_collection.llm_service_id == "mock_vector_store_id" - assert created_collection.llm_service_name == "openai vector store" - assert created_collection.updated_at is not None + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.PROCESSING + assert updated_job.task_id == task_id - docs = DocumentCollectionCrud(db).read(created_collection, skip=0, limit=10) - assert len(docs) == 1 - assert docs[0].fname == document.fname + mock_queue_batch.assert_called_once() + kw = mock_queue_batch.call_args.kwargs + assert kw["batch_number"] == 1 + assert kw["vector_store_id"] is None + assert str(doc.id) in kw["batch_doc_ids"] + assert kw["remaining_batches"] == [] -@pytest.mark.usefixtures("aws_credentials") -@mock_aws +@patch("app.services.collections.create_collection.get_cloud_storage") @patch("app.services.collections.create_collection.get_llm_provider") -def test_execute_job_assistant_create_failure_marks_failed_and_deletes_collection( - mock_get_llm_provider: MagicMock, db +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_setup_job_failure_marks_job_failed_and_raises( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + mock_get_storage: MagicMock, + db: Session, ) -> None: project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_provider = _mock_provider_with_size("vs_123", "openai vector store") + mock_provider.upload_files.side_effect = RuntimeError("S3 upload failed") + mock_get_provider.return_value = mock_provider job = get_collection_job( db, project, - job_id=uuid4(), action_type=CollectionActionType.CREATE, status=CollectionJobStatus.PENDING, - collection_id=None, ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) - req = CreationRequest(documents=[], callback_url=None, provider="openai") - - mock_provider = get_mock_provider( - llm_service_id="vs_123", llm_service_name="openai vector store" - ) - mock_get_llm_provider.return_value = mock_provider - - with patch( - "app.services.collections.create_collection.Session" - ) as SessionCtor, patch( - "app.services.collections.create_collection.CollectionCrud" - ) as MockCrud: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False - - MockCrud.return_value.create.side_effect = Exception("DB constraint violation") - - task_id = str(uuid4()) - with pytest.raises(Exception, match="DB constraint violation"): - execute_job( - request=req.model_dump(), + patcher = _patch_session(db) + try: + with pytest.raises(RuntimeError, match="S3 upload failed"): + execute_setup_job( + request=request.model_dump(mode="json"), project_id=project.id, organization_id=project.organization_id, - task_id=task_id, - with_assistant=True, + task_id=str(uuid4()), job_id=str(job.id), task_instance=None, ) + finally: + patcher.stop() - mock_provider.delete.assert_called_once() + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.FAILED + assert "S3 upload failed" in (updated_job.error_message or "") + mock_queue_batch.assert_not_called() -@pytest.mark.usefixtures("aws_credentials") -@mock_aws -@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") -def test_execute_job_success_flow_callback_job_and_creates_collection( +@patch("app.services.collections.create_collection.get_cloud_storage") +@patch("app.services.collections.create_collection.get_llm_provider") +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_setup_job_failure_sends_callback( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + mock_get_storage: MagicMock, mock_send_callback: MagicMock, - mock_get_llm_provider: MagicMock, - db, + db: Session, ) -> None: - """ - execute_job should: - - set task_id on the CollectionJob - - ingest documents into a vector store - - create an OpenAI assistant - - create a Collection with llm fields filled - - link the CollectionJob -> collection_id, set status=successful - - create DocumentCollection links - """ project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_provider = _mock_provider_with_size("vs_123", "openai vector store") + mock_provider.upload_files.side_effect = RuntimeError("upload error") + mock_get_provider.return_value = mock_provider + + callback_url = "https://example.com/callback" + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, + ) + request = CreationRequest( + documents=[doc.id], provider="openai", callback_url=callback_url + ) - aws = AmazonCloudStorageClient() - aws.create() + patcher = _patch_session(db) + try: + with pytest.raises(RuntimeError): + execute_setup_job( + request=request.model_dump(mode="json"), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + task_instance=None, + ) + finally: + patcher.stop() + + mock_send_callback.assert_called_once() + cb_url, payload = mock_send_callback.call_args.args + assert str(cb_url) == callback_url + assert payload["success"] is False + assert payload["data"]["status"] == CollectionJobStatus.FAILED + +@patch("app.services.collections.create_collection.get_cloud_storage") +@patch("app.services.collections.create_collection.get_llm_provider") +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_setup_job_timeout_marks_failed_and_reraises( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + mock_get_storage: MagicMock, + db: Session, +) -> None: + project = get_project(db) store = DocumentStore(db=db, project_id=project.id) - document = store.put() - s3_key = Path(urlparse(document.object_store_url).path).relative_to("/") - aws.client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=str(s3_key), Body=b"test") + doc = store.put() - callback_url = "https://example.com/collections/create-success" + mock_provider = _mock_provider_with_size("vs_123", "openai vector store") + mock_provider.upload_files.side_effect = Timeout(300) + mock_get_provider.return_value = mock_provider - sample_request = CreationRequest( - documents=[document.id], - callback_url=callback_url, - provider="openai", + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) - mock_get_llm_provider.return_value = get_mock_provider( - llm_service_id="mock_vector_store_id", llm_service_name="openai vector store" - ) + patcher = _patch_session(db) + try: + with pytest.raises(Timeout): + execute_setup_job( + request=request.model_dump(mode="json"), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + task_instance=None, + ) + finally: + patcher.stop() + + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.FAILED + assert "soft time limit" in (updated_job.error_message or "") + + +@patch("app.services.collections.create_collection.get_cloud_storage") +@patch("app.services.collections.create_collection.get_llm_provider") +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_setup_job_soft_time_limit_marks_failed_and_reraises( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + mock_get_storage: MagicMock, + db: Session, +) -> None: + project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() - job_id = uuid.uuid4() - _ = get_collection_job( + mock_provider = _mock_provider_with_size("vs_123", "openai vector store") + mock_provider.upload_files.side_effect = SoftTimeLimitExceeded() + mock_get_provider.return_value = mock_provider + + job = get_collection_job( db, project, - job_id=job_id, action_type=CollectionActionType.CREATE, status=CollectionJobStatus.PENDING, - collection_id=None, ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) + + patcher = _patch_session(db) + try: + with pytest.raises(SoftTimeLimitExceeded): + execute_setup_job( + request=request.model_dump(mode="json"), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + task_instance=None, + ) + finally: + patcher.stop() + + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.FAILED + assert "soft time limit" in (updated_job.error_message or "") - task_id = uuid.uuid4() - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False +# --------------------------------------------------------------------------- +# execute_batch_job +# --------------------------------------------------------------------------- - mock_send_callback.return_value = MagicMock(status_code=403) - execute_job( - request=sample_request.model_dump(), +@patch("app.services.collections.create_collection.get_llm_provider") +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_batch_job_non_final_queues_next_batch( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + db: Session, +) -> None: + project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc1 = store.put() + doc2 = store.put() + + mock_get_provider.return_value = get_mock_provider("vs_123", "openai vector store") + + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PROCESSING, + ) + request = CreationRequest( + documents=[doc1.id, doc2.id], provider="openai", callback_url=None + ) + task_id = str(uuid4()) + + patcher = _patch_session(db) + try: + execute_batch_job( + request=request.model_dump(mode="json"), project_id=project.id, organization_id=project.organization_id, - task_id=str(task_id), - with_assistant=True, - job_id=str(job_id), + task_id=task_id, + job_id=str(job.id), task_instance=None, + vector_store_id="vs_123", + batch_number=1, + batch_doc_ids=[str(doc1.id)], + remaining_batches=[[str(doc2.id)]], ) + finally: + patcher.stop() - updated_job = CollectionJobCrud(db, project.id).read_one(job_id) - collection = CollectionCrud(db, project.id).read_one(updated_job.collection_id) + mock_queue_batch.assert_called_once() + kw = mock_queue_batch.call_args.kwargs + assert kw["batch_number"] == 2 + assert kw["batch_doc_ids"] == [str(doc2.id)] + assert kw["remaining_batches"] == [] + assert kw["vector_store_id"] == "vs_123" - mock_send_callback.assert_called_once() - cb_url_arg, payload_arg = mock_send_callback.call_args.args - assert str(cb_url_arg) == callback_url - assert payload_arg["success"] is True - assert payload_arg["data"]["status"] == CollectionJobStatus.SUCCESSFUL - assert payload_arg["data"]["collection"]["id"] == str(collection.id) - assert uuid.UUID(payload_arg["data"]["job_id"]) == job_id + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.current_batch_number == 1 + assert str(doc1.id) in (updated_job.documents_uploaded or []) -@pytest.mark.usefixtures("aws_credentials") -@mock_aws @patch("app.services.collections.create_collection.get_llm_provider") -@patch("app.services.collections.create_collection.send_callback") -def test_execute_job_success_creates_collection_with_callback( - mock_send_callback: MagicMock, - mock_get_llm_provider: MagicMock, - db, +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_batch_job_final_batch_creates_collection_and_marks_successful( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + db: Session, ) -> None: - """ - execute_job should: - - set task_id on the CollectionJob - - ingest documents into a vector store - - create an OpenAI assistant - - create a Collection with llm fields filled - - link the CollectionJob -> collection_id, set status=successful - - create DocumentCollection links - """ project = get_project(db) - - aws = AmazonCloudStorageClient() - aws.create() - store = DocumentStore(db=db, project_id=project.id) - document = store.put() - s3_key = Path(urlparse(document.object_store_url).path).relative_to("/") - aws.client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=str(s3_key), Body=b"test") + doc = store.put() - callback_url = "https://example.com/collections/create-success" - - sample_request = CreationRequest( - documents=[document.id], - callback_url=callback_url, - provider="openai", - ) - - mock_get_llm_provider.return_value = get_mock_provider( - llm_service_id="mock_vector_store_id", llm_service_name="gpt-4o" + mock_get_provider.return_value = get_mock_provider( + "vs_final", "openai vector store" ) - job_id = uuid.uuid4() - _ = get_collection_job( + job = get_collection_job( db, project, - job_id=job_id, action_type=CollectionActionType.CREATE, - status=CollectionJobStatus.PENDING, - collection_id=None, + status=CollectionJobStatus.PROCESSING, ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) - task_id = uuid.uuid4() - - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False - - mock_send_callback.return_value = MagicMock(status_code=403) - - execute_job( - request=sample_request.model_dump(), + patcher = _patch_session(db) + try: + execute_batch_job( + request=request.model_dump(mode="json"), project_id=project.id, organization_id=project.organization_id, - task_id=str(task_id), - with_assistant=True, - job_id=str(job_id), + task_id=str(uuid4()), + job_id=str(job.id), task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], ) + finally: + patcher.stop() + + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.SUCCESSFUL + assert updated_job.collection_id is not None - updated_job = CollectionJobCrud(db, project.id).read_one(job_id) collection = CollectionCrud(db, project.id).read_one(updated_job.collection_id) + assert collection.llm_service_id == "vs_final" - mock_send_callback.assert_called_once() - cb_url_arg, payload_arg = mock_send_callback.call_args.args - assert str(cb_url_arg) == callback_url - assert payload_arg["success"] is True - assert payload_arg["data"]["status"] == CollectionJobStatus.SUCCESSFUL - assert payload_arg["data"]["collection"]["id"] == str(collection.id) - assert uuid.UUID(payload_arg["data"]["job_id"]) == job_id + linked_docs = DocumentCollectionCrud(db).read(collection, skip=0, limit=10) + assert len(linked_docs) == 1 + assert linked_docs[0].id == doc.id + + mock_queue_batch.assert_not_called() -@pytest.mark.usefixtures("aws_credentials") -@mock_aws -@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") -@patch("app.services.collections.create_collection.CollectionCrud") -def test_execute_job_failure_flow_callback_job_and_marks_failed( - MockCollectionCrud, +@patch("app.services.collections.create_collection.get_llm_provider") +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_batch_job_final_batch_sends_success_callback( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, mock_send_callback: MagicMock, - mock_get_llm_provider: MagicMock, db: Session, ) -> None: - """ - When creation fails, the job should be marked as FAILED, an error should be logged, - and a failure callback with the error message should be triggered. - """ project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_get_provider.return_value = get_mock_provider( + "vs_final", "openai vector store" + ) - collection = get_assistant_collection(db, project, assistant_id="asst_123") + callback_url = "https://example.com/success" job = get_collection_job( db, project, action_type=CollectionActionType.CREATE, - status=CollectionJobStatus.PENDING, - collection_id=None, + status=CollectionJobStatus.PROCESSING, + ) + request = CreationRequest( + documents=[doc.id], provider="openai", callback_url=callback_url ) - mock_get_llm_provider.return_value = MagicMock() + patcher = _patch_session(db) + try: + execute_batch_job( + request=request.model_dump(mode="json"), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], + ) + finally: + patcher.stop() - callback_url = "https://example.com/collections/create-failure" + mock_send_callback.assert_called_once() + cb_url, payload = mock_send_callback.call_args.args + assert str(cb_url) == callback_url + assert payload["success"] is True + assert payload["data"]["status"] == CollectionJobStatus.SUCCESSFUL + assert payload["data"]["collection"] is not None - collection_crud_instance = MockCollectionCrud.return_value - collection_crud_instance.read_one.return_value = collection - sample_request = CreationRequest( - documents=[uuid.uuid4()], - callback_url=callback_url, - provider="openai", - ) +@patch("app.services.collections.create_collection.get_llm_provider") +def test_execute_batch_job_provider_failure_marks_failed_and_raises( + mock_get_provider: MagicMock, + db: Session, +) -> None: + project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() - task_id = uuid.uuid4() + mock_provider = get_mock_provider("vs_123", "openai vector store") + mock_provider.create.side_effect = RuntimeError("vector store error") + mock_get_provider.return_value = mock_provider - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PROCESSING, + ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) - with pytest.raises( - ValueError, match="Requested atleast 1 document retrieved 0" - ): - execute_job( - request=sample_request.model_dump(), + patcher = _patch_session(db) + try: + with pytest.raises(RuntimeError, match="vector store error"): + execute_batch_job( + request=request.model_dump(mode="json"), project_id=project.id, organization_id=project.organization_id, - task_id=str(task_id), - with_assistant=True, + task_id=str(uuid4()), job_id=str(job.id), task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], ) + finally: + patcher.stop() updated_job = CollectionJobCrud(db, project.id).read_one(job.id) - assert updated_job.status == CollectionJobStatus.FAILED - assert "Requested atleast 1 document retrieved 0" in ( - updated_job.error_message or "" - ) - - mock_send_callback.assert_called_once() - cb_url_arg, payload_arg = mock_send_callback.call_args.args - assert str(cb_url_arg) == callback_url - assert payload_arg["success"] is False - assert "Requested atleast 1 document retrieved 0" in (payload_arg["error"] or "") - assert payload_arg["data"]["status"] == CollectionJobStatus.FAILED - assert payload_arg["data"]["collection"] is None - assert uuid.UUID(payload_arg["data"]["job_id"]) == job.id + assert "vector store error" in (updated_job.error_message or "") @patch("app.services.collections.create_collection.get_llm_provider") -def test_execute_job_timeout_marks_job_failed( - mock_get_llm_provider: MagicMock, db: Session +@patch("app.services.collections.create_collection.CollectionCrud") +def test_execute_batch_job_cleanup_called_when_provider_create_succeeds_but_db_fails( + MockCollectionCrud: MagicMock, + mock_get_provider: MagicMock, + db: Session, ) -> None: + """provider.delete should be called if create() succeeded but finalization fails.""" project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_provider = get_mock_provider("vs_123", "openai vector store") + mock_get_provider.return_value = mock_provider + + MockCollectionCrud.return_value.create.side_effect = Exception("DB write failed") job = get_collection_job( db, project, - job_id=uuid4(), action_type=CollectionActionType.CREATE, - status=CollectionJobStatus.PENDING, - collection_id=None, + status=CollectionJobStatus.PROCESSING, ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) - mock_provider = MagicMock() - mock_provider.create.side_effect = Timeout(300) - mock_get_llm_provider.return_value = mock_provider + patcher = _patch_session(db) + try: + with pytest.raises(Exception, match="DB write failed"): + execute_batch_job( + request=request.model_dump(mode="json"), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], + ) + finally: + patcher.stop() + + mock_provider.delete.assert_called_once() + + +@patch("app.services.collections.create_collection.get_llm_provider") +def test_execute_batch_job_timeout_marks_failed_and_reraises( + mock_get_provider: MagicMock, + db: Session, +) -> None: + project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() - req = CreationRequest(documents=[], callback_url=None, provider="openai") + mock_provider = get_mock_provider("vs_123", "openai vector store") + mock_provider.create.side_effect = Timeout(300) + mock_get_provider.return_value = mock_provider - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PROCESSING, + ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) + patcher = _patch_session(db) + try: with pytest.raises(Timeout): - execute_job( - request=req.model_dump(), + execute_batch_job( + request=request.model_dump(mode="json"), project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), - with_assistant=False, job_id=str(job.id), task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], ) + finally: + patcher.stop() updated_job = CollectionJobCrud(db, project.id).read_one(job.id) assert updated_job.status == CollectionJobStatus.FAILED @@ -510,50 +641,96 @@ def test_execute_job_timeout_marks_job_failed( @patch("app.services.collections.create_collection.get_llm_provider") -@patch("app.services.collections.create_collection.send_callback") -def test_execute_job_timeout_sends_failure_callback( - mock_send_callback: MagicMock, - mock_get_llm_provider: MagicMock, +def test_execute_batch_job_soft_time_limit_marks_failed_and_reraises( + mock_get_provider: MagicMock, db: Session, ) -> None: project = get_project(db) - callback_url = "https://example.com/collections/timeout" + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_provider = get_mock_provider("vs_123", "openai vector store") + mock_provider.create.side_effect = SoftTimeLimitExceeded() + mock_get_provider.return_value = mock_provider job = get_collection_job( db, project, - job_id=uuid4(), action_type=CollectionActionType.CREATE, - status=CollectionJobStatus.PENDING, - collection_id=None, + status=CollectionJobStatus.PROCESSING, ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) - mock_provider = MagicMock() - mock_provider.create.side_effect = Timeout(300) - mock_get_llm_provider.return_value = mock_provider + patcher = _patch_session(db) + try: + with pytest.raises(SoftTimeLimitExceeded): + execute_batch_job( + request=request.model_dump(mode="json"), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], + ) + finally: + patcher.stop() - req = CreationRequest(documents=[], callback_url=callback_url, provider="openai") + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.FAILED + assert "soft time limit" in (updated_job.error_message or "") - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False - with pytest.raises(Timeout): - execute_job( - request=req.model_dump(), +@patch("app.services.collections.create_collection.send_callback") +@patch("app.services.collections.create_collection.get_llm_provider") +def test_execute_batch_job_failure_sends_callback( + mock_get_provider: MagicMock, + mock_send_callback: MagicMock, + db: Session, +) -> None: + project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_provider = get_mock_provider("vs_123", "openai vector store") + mock_provider.create.side_effect = RuntimeError("batch failed") + mock_get_provider.return_value = mock_provider + + callback_url = "https://example.com/failure" + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PROCESSING, + ) + request = CreationRequest( + documents=[doc.id], provider="openai", callback_url=callback_url + ) + + patcher = _patch_session(db) + try: + with pytest.raises(RuntimeError): + execute_batch_job( + request=request.model_dump(mode="json"), project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), - with_assistant=False, job_id=str(job.id), task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], ) + finally: + patcher.stop() mock_send_callback.assert_called_once() - cb_url_arg, payload_arg = mock_send_callback.call_args.args - assert str(cb_url_arg) == callback_url - assert payload_arg["success"] is False - assert "soft time limit" in (payload_arg["error"] or "") - assert payload_arg["data"]["status"] == CollectionJobStatus.FAILED - assert payload_arg["data"]["collection"] is None - assert uuid.UUID(payload_arg["data"]["job_id"]) == job.id + cb_url, payload = mock_send_callback.call_args.args + assert str(cb_url) == callback_url + assert payload["success"] is False + assert "batch failed" in (payload["error"] or "") + assert payload["data"]["status"] == CollectionJobStatus.FAILED diff --git a/backend/app/tests/services/collections/test_helpers.py b/backend/app/tests/services/collections/test_helpers.py index 7cddaf305..8caa61901 100644 --- a/backend/app/tests/services/collections/test_helpers.py +++ b/backend/app/tests/services/collections/test_helpers.py @@ -122,14 +122,34 @@ def test_batch_documents_mixed_size_batching() -> None: assert len(batches[2]) == 1 # 15 MB total -def test_batch_documents_with_none_file_size() -> None: - """Test that documents with None file_size are treated as 0 bytes.""" - docs = create_fake_documents(10, file_size_kb=None) +def test_batch_documents_zero_size_files_batches_by_count() -> None: + """Zero-size docs contribute nothing to size, so only the 200-doc count limit applies.""" + docs = create_fake_documents(250, file_size_kb=0) + batches = helpers.batch_documents(docs) + + assert len(batches) == 2 + assert len(batches[0]) == 200 + assert len(batches[1]) == 50 + + +def test_batch_documents_doc_exactly_at_size_limit_stays_in_same_batch() -> None: + """A doc whose size exactly equals MAX_BATCH_SIZE_KB should not trigger a new batch + on its own — the split only happens when adding it would *exceed* the limit.""" + from app.services.collections.helpers import MAX_BATCH_SIZE_KB + + docs = create_fake_documents(1, file_size_kb=MAX_BATCH_SIZE_KB) batches = helpers.batch_documents(docs) - # All files with None/0 size should fit in one batch (under both limits) assert len(batches) == 1 - assert len(batches[0]) == 10 + assert len(batches[0]) == 1 + + +def test_batch_documents_with_none_file_size_raises() -> None: + """Test that documents with None file_size raise TypeError — sizes must be backfilled before batching.""" + docs = create_fake_documents(10, file_size_kb=None) + + with pytest.raises(TypeError): + helpers.batch_documents(docs) def test_batch_documents_empty_input() -> None: @@ -219,47 +239,15 @@ def test_to_collection_public_vector_store() -> None: result = to_collection_public(collection) - # For vector store, should map to knowledge_base fields assert result.id == collection.id assert result.knowledge_base_id == "vs_123" assert result.knowledge_base_provider == "openai vector store" - assert result.llm_service_id is None - assert result.llm_service_name is None assert result.project_id == 1 assert result.inserted_at == collection.inserted_at assert result.updated_at == collection.updated_at assert result.deleted_at is None -def test_to_collection_public_assistant() -> None: - """Test conversion of assistant collection to public model.""" - collection = Collection( - id=uuid4(), - project_id=2, - provider=ProviderType.openai, - llm_service_id="asst_456", - llm_service_name="gpt-4", # Does NOT match vector store name - name="Assistant Collection", - description="Assistant description", - inserted_at=now(), - updated_at=now(), - deleted_at=None, - ) - - result = to_collection_public(collection) - - # For assistant, should map to llm_service fields - assert result.id == collection.id - assert result.llm_service_id == "asst_456" - assert result.llm_service_name == "gpt-4" - assert result.knowledge_base_id is None - assert result.knowledge_base_provider is None - assert result.project_id == 2 - assert result.inserted_at == collection.inserted_at - assert result.updated_at == collection.updated_at - assert result.deleted_at is None - - def test_to_collection_public_with_deleted_at() -> None: """Test that deleted_at field is properly included when set.""" deleted_time = now() diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index f1844cb96..94231c999 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -25,32 +25,6 @@ def uuid_increment(value: UUID) -> UUID: return UUID(int=inc) -def get_assistant_collection( - db: Session, - project: Project, - *, - assistant_id: Optional[str] = None, - model: str = "gpt-4o", - collection_id: Optional[UUID] = None, -) -> Collection: - """ - Create a Collection configured for the Assistant path. - execute_job will treat this as `is_vector = False` and use assistant id. - """ - if assistant_id is None: - assistant_id = f"asst_{uuid4().hex}" - - collection = Collection( - id=collection_id or uuid4(), - project_id=project.id, - organization_id=project.organization_id, - llm_service_name=model, - llm_service_id=assistant_id, - provider=ProviderType.openai, - ) - return CollectionCrud(db, project.id).create(collection) - - def get_vector_store_collection( db: Session, project: Project,