From 45daa443610db8141ae049c829a5e41aaf24dffc Mon Sep 17 00:00:00 2001 From: shixiao-coder Date: Mon, 15 Jun 2026 17:29:28 -0400 Subject: [PATCH] update the embedding_type to embedding_label --- .../workflow/ingestion-helper/clients/schema.sql | 4 ++-- pipeline/workflow/ingestion-helper/config.py | 4 ++-- .../workflow/ingestion-helper/routes/embeddings.py | 8 ++++---- .../workflow/ingestion-helper/utils/embeddings.py | 12 ++++++------ .../ingestion-helper/utils/embeddings_test.py | 6 +++--- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/pipeline/workflow/ingestion-helper/clients/schema.sql b/pipeline/workflow/ingestion-helper/clients/schema.sql index 4dd6e1774..cecc1de8b 100644 --- a/pipeline/workflow/ingestion-helper/clients/schema.sql +++ b/pipeline/workflow/ingestion-helper/clients/schema.sql @@ -143,11 +143,11 @@ CREATE INDEX InEdge ON Edge(object_id, predicate, subject_id, provenance) OPTION CREATE TABLE {{ embedding_table }} ( subject_id STRING(1024) NOT NULL, - embedding_type STRING(1024) NOT NULL, + embedding_label STRING(1024) NOT NULL, embedding_content JSON, node_types ARRAY, embeddings ARRAY(vector_length=>{{ embedding_space }}) -) PRIMARY KEY(subject_id, embedding_type), +) PRIMARY KEY(subject_id, embedding_label), INTERLEAVE IN PARENT Node ON DELETE CASCADE; CREATE VECTOR INDEX {{ embedding_index }} diff --git a/pipeline/workflow/ingestion-helper/config.py b/pipeline/workflow/ingestion-helper/config.py index 15ac533a2..f1ac36958 100644 --- a/pipeline/workflow/ingestion-helper/config.py +++ b/pipeline/workflow/ingestion-helper/config.py @@ -49,7 +49,7 @@ _DEFAULT_EMBEDDING_SPECS = [ { - "embedding_type": "base_text_embedding", + "embedding_label": "base_text_embedding", "model_name": "NodeEmbeddingModel", "task_type": "RETRIEVAL_QUERY", "node_types": ["StatisticalVariable", "Topic"] @@ -60,7 +60,7 @@ if specs_env: try: parsed = json.loads(specs_env) - required_keys = {"embedding_type", "model_name", "task_type", "node_types"} + required_keys = {"embedding_label", "model_name", "task_type", "node_types"} if isinstance(parsed, list) and all(isinstance(s, dict) and required_keys.issubset(s.keys()) for s in parsed): EMBEDDING_SPECS = parsed else: diff --git a/pipeline/workflow/ingestion-helper/routes/embeddings.py b/pipeline/workflow/ingestion-helper/routes/embeddings.py index 7929867d3..946761e71 100644 --- a/pipeline/workflow/ingestion-helper/routes/embeddings.py +++ b/pipeline/workflow/ingestion-helper/routes/embeddings.py @@ -49,21 +49,21 @@ def embedding_ingestion(req: EmbeddingIngestionRequest, spanner: SpannerClient = for spec in config.EMBEDDING_SPECS: node_types = spec["node_types"] model_name = spec["model_name"] - embedding_type = spec["embedding_type"] + embedding_label = spec["embedding_label"] task_type = spec["task_type"] - logging.info(f"Job started for {embedding_type}. Fetching all nodes for types: {node_types}") + logging.info(f"Job started for {embedding_label}. Fetching all nodes for types: {node_types}") nodes = get_updated_nodes(spanner.database, timestamp, node_types, timeout=config.TIMEOUT) # materializing generator to list if necessary, but generator works since it yields converted_nodes = list(filter_and_convert_nodes(nodes)) - logging.info(f"Generating embeddings for model {model_name} (embedding_type: {embedding_type})") + logging.info(f"Generating embeddings for model {model_name} (embedding_label: {embedding_label})") affected_rows = generate_embeddings_partitioned( spanner.database, converted_nodes, model_name=model_name, embedding_table=spanner.embedding_table, - embedding_type=embedding_type, + embedding_label=embedding_label, task_type=task_type, timeout=config.TIMEOUT ) diff --git a/pipeline/workflow/ingestion-helper/utils/embeddings.py b/pipeline/workflow/ingestion-helper/utils/embeddings.py index fe04fbedc..0e58eb2e4 100644 --- a/pipeline/workflow/ingestion-helper/utils/embeddings.py +++ b/pipeline/workflow/ingestion-helper/utils/embeddings.py @@ -114,7 +114,7 @@ def filter_and_convert_nodes(nodes_generator): yield (subject_id, embedding_content, node.get("types")) -def generate_embeddings_partitioned(database, nodes_generator, model_name, embedding_table, embedding_type, task_type, timeout): +def generate_embeddings_partitioned(database, nodes_generator, model_name, embedding_table, embedding_label, task_type, timeout): """Generates embeddings in batches using standard transactions. Processes nodes in chunks of 500 to avoid transaction size limits. Accepts a generator or list to avoid loading all nodes into memory. @@ -124,7 +124,7 @@ def generate_embeddings_partitioned(database, nodes_generator, model_name, embed nodes_generator: An iterable yielding tuples containing (subject_id, embedding_content, types). model_name: Name of the remote model defined in Spanner DDL. embedding_table: Name of the embedding table. - embedding_type: Embedding type key (e.g. model ID) to insert. + embedding_label: Embedding label key (e.g. model ID) to insert. task_type: Task type parameter for ML.PREDICT (e.g. "RETRIEVAL_QUERY"). timeout: Timeout for the spanner client to execute queries. @@ -137,8 +137,8 @@ def generate_embeddings_partitioned(database, nodes_generator, model_name, embed logging.info(f"Generating embeddings in batches of {_BATCH_SIZE}.") embeddings_sql = f""" - INSERT OR UPDATE INTO {embedding_table} (subject_id, embedding_type, embedding_content, embeddings, node_types) - SELECT subject_id, @embedding_type, embedding_content, embeddings.values, node_types + INSERT OR UPDATE INTO {embedding_table} (subject_id, embedding_label, embedding_content, embeddings, node_types) + SELECT subject_id, @embedding_label, embedding_content, embeddings.values, node_types FROM ML.PREDICT( MODEL {model_name}, (SELECT subject_id, TO_JSON_STRING(embedding_content) AS content, embedding_content, node_types, @task_type AS task_type FROM UNNEST(@nodes)) @@ -162,12 +162,12 @@ def chunked(iterable, n): for batch in chunked(nodes_generator, _BATCH_SIZE): params = { "nodes": batch, - "embedding_type": embedding_type, + "embedding_label": embedding_label, "task_type": task_type } param_types = { "nodes": Array(struct_type), - "embedding_type": STRING, + "embedding_label": STRING, "task_type": STRING } diff --git a/pipeline/workflow/ingestion-helper/utils/embeddings_test.py b/pipeline/workflow/ingestion-helper/utils/embeddings_test.py index 340dcdb41..e92588aa7 100644 --- a/pipeline/workflow/ingestion-helper/utils/embeddings_test.py +++ b/pipeline/workflow/ingestion-helper/utils/embeddings_test.py @@ -158,9 +158,9 @@ def side_effect(func): def mock_execute_update(*args, **kwargs): params = kwargs.get("params", {}) self.assertIn("nodes", params) - self.assertIn("embedding_type", params) + self.assertIn("embedding_label", params) self.assertIn("task_type", params) - self.assertEqual(params["embedding_type"], "base_text_embedding") + self.assertEqual(params["embedding_label"], "base_text_embedding") self.assertEqual(params["task_type"], "RETRIEVAL_QUERY") return 2 @@ -175,7 +175,7 @@ def mock_execute_update(*args, **kwargs): nodes, model_name="NodeEmbeddingModel", embedding_table="NodeEmbedding", - embedding_type="base_text_embedding", + embedding_label="base_text_embedding", task_type="RETRIEVAL_QUERY", timeout=3600 )