Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pipeline/workflow/ingestion-helper/clients/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ CREATE INDEX InEdge ON Edge(object_id, predicate, subject_id, provenance) OPTION

CREATE TABLE {{ embedding_table }} (
subject_id STRING(1024) NOT NULL,
embedding_type STRING(1024) NOT NULL,
embedding_label STRING(1024) NOT NULL,
embedding_content JSON,
node_types ARRAY<STRING(1024)>,
embeddings ARRAY<FLOAT64>(vector_length=>{{ embedding_space }})
) PRIMARY KEY(subject_id, embedding_type),
) PRIMARY KEY(subject_id, embedding_label),
INTERLEAVE IN PARENT Node ON DELETE CASCADE;

CREATE VECTOR INDEX {{ embedding_index }}
Expand Down
4 changes: 2 additions & 2 deletions pipeline/workflow/ingestion-helper/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

_DEFAULT_EMBEDDING_SPECS = [
{
"embedding_type": "base_text_embedding",
"embedding_label": "base_text_embedding",
"model_name": "NodeEmbeddingModel",
"task_type": "RETRIEVAL_QUERY",
"node_types": ["StatisticalVariable", "Topic"]
Expand All @@ -60,7 +60,7 @@
if specs_env:
try:
parsed = json.loads(specs_env)
required_keys = {"embedding_type", "model_name", "task_type", "node_types"}
required_keys = {"embedding_label", "model_name", "task_type", "node_types"}
if isinstance(parsed, list) and all(isinstance(s, dict) and required_keys.issubset(s.keys()) for s in parsed):
EMBEDDING_SPECS = parsed
Comment thread
shixiao-coder marked this conversation as resolved.
else:
Expand Down
8 changes: 4 additions & 4 deletions pipeline/workflow/ingestion-helper/routes/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,21 @@ def embedding_ingestion(req: EmbeddingIngestionRequest, spanner: SpannerClient =
for spec in config.EMBEDDING_SPECS:
node_types = spec["node_types"]
model_name = spec["model_name"]
embedding_type = spec["embedding_type"]
embedding_label = spec["embedding_label"]
task_type = spec["task_type"]

logging.info(f"Job started for {embedding_type}. Fetching all nodes for types: {node_types}")
logging.info(f"Job started for {embedding_label}. Fetching all nodes for types: {node_types}")
nodes = get_updated_nodes(spanner.database, timestamp, node_types, timeout=config.TIMEOUT)
# materializing generator to list if necessary, but generator works since it yields
converted_nodes = list(filter_and_convert_nodes(nodes))

logging.info(f"Generating embeddings for model {model_name} (embedding_type: {embedding_type})")
logging.info(f"Generating embeddings for model {model_name} (embedding_label: {embedding_label})")
affected_rows = generate_embeddings_partitioned(
spanner.database,
converted_nodes,
model_name=model_name,
embedding_table=spanner.embedding_table,
embedding_type=embedding_type,
embedding_label=embedding_label,
task_type=task_type,
timeout=config.TIMEOUT
)
Expand Down
12 changes: 6 additions & 6 deletions pipeline/workflow/ingestion-helper/utils/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def filter_and_convert_nodes(nodes_generator):
yield (subject_id, embedding_content, node.get("types"))


def generate_embeddings_partitioned(database, nodes_generator, model_name, embedding_table, embedding_type, task_type, timeout):
def generate_embeddings_partitioned(database, nodes_generator, model_name, embedding_table, embedding_label, task_type, timeout):
"""Generates embeddings in batches using standard transactions.
Processes nodes in chunks of 500 to avoid transaction size limits.
Accepts a generator or list to avoid loading all nodes into memory.
Expand All @@ -124,7 +124,7 @@ def generate_embeddings_partitioned(database, nodes_generator, model_name, embed
nodes_generator: An iterable yielding tuples containing (subject_id, embedding_content, types).
model_name: Name of the remote model defined in Spanner DDL.
embedding_table: Name of the embedding table.
embedding_type: Embedding type key (e.g. model ID) to insert.
embedding_label: Embedding label key (e.g. model ID) to insert.
task_type: Task type parameter for ML.PREDICT (e.g. "RETRIEVAL_QUERY").
timeout: Timeout for the spanner client to execute queries.

Expand All @@ -137,8 +137,8 @@ def generate_embeddings_partitioned(database, nodes_generator, model_name, embed
logging.info(f"Generating embeddings in batches of {_BATCH_SIZE}.")

embeddings_sql = f"""
INSERT OR UPDATE INTO {embedding_table} (subject_id, embedding_type, embedding_content, embeddings, node_types)
SELECT subject_id, @embedding_type, embedding_content, embeddings.values, node_types
INSERT OR UPDATE INTO {embedding_table} (subject_id, embedding_label, embedding_content, embeddings, node_types)
SELECT subject_id, @embedding_label, embedding_content, embeddings.values, node_types
FROM ML.PREDICT(
MODEL {model_name},
(SELECT subject_id, TO_JSON_STRING(embedding_content) AS content, embedding_content, node_types, @task_type AS task_type FROM UNNEST(@nodes))
Expand All @@ -162,12 +162,12 @@ def chunked(iterable, n):
for batch in chunked(nodes_generator, _BATCH_SIZE):
params = {
"nodes": batch,
"embedding_type": embedding_type,
"embedding_label": embedding_label,
"task_type": task_type
}
param_types = {
"nodes": Array(struct_type),
"embedding_type": STRING,
"embedding_label": STRING,
"task_type": STRING
}

Expand Down
6 changes: 3 additions & 3 deletions pipeline/workflow/ingestion-helper/utils/embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def side_effect(func):
def mock_execute_update(*args, **kwargs):
params = kwargs.get("params", {})
self.assertIn("nodes", params)
self.assertIn("embedding_type", params)
self.assertIn("embedding_label", params)
self.assertIn("task_type", params)
self.assertEqual(params["embedding_type"], "base_text_embedding")
self.assertEqual(params["embedding_label"], "base_text_embedding")
self.assertEqual(params["task_type"], "RETRIEVAL_QUERY")
return 2

Expand All @@ -175,7 +175,7 @@ def mock_execute_update(*args, **kwargs):
nodes,
model_name="NodeEmbeddingModel",
embedding_table="NodeEmbedding",
embedding_type="base_text_embedding",
embedding_label="base_text_embedding",
task_type="RETRIEVAL_QUERY",
timeout=3600
)
Expand Down