Skip to content

Commit

Permalink
tidy(app): document & clean up batch prep logic
Browse files Browse the repository at this point in the history
  • Loading branch information
psychedelicious committed Feb 26, 2025
1 parent d1e03aa commit 047c643
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 31 deletions.
178 changes: 153 additions & 25 deletions invokeai/app/services/session_queue/session_queue_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,58 +406,143 @@ class IsFullResult(BaseModel):
# region Util


def create_graph_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str, str, list[dict]], None, None]:
def create_session_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str, str, str], None, None]:
"""
Create all graph permutations from the given batch data and graph. Yields tuples
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
that was applied to the graph.
Given a batch and a maximum number of sessions to create, generate a tuple of session_id, session_json, and
field_values_json for each session.
The batch has a "source" graph and a data property. The data property is a list of lists of BatchDatum objects.
Each BatchDatum has a field identifier (e.g. a node id and field name), and a list of values to substitute into
the field.
This structure allows us to create a new graph for every possible permutation of BatchDatum objects:
- Each BatchDatum can be "expanded" into a dict of node-field-value tuples - one for each item in the BatchDatum.
- Zip each inner list of expanded BatchDatum objects together. Call this a "batch_data_list".
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of permutations of BatchDatum
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of lists of BatchDatum objects.
Each inner list now represents the substitution values for a single permutation (session).
- For each permutation, substitute the values into the graph
This function is optimized for performance, as it is used to generate a large number of sessions at once.
Args:
batch: The batch to generate sessions from
maximum: The maximum number of sessions to generate
Returns:
A generator that yields tuples of session_id, session_json, and field_values_json for each session. The
generator will stop early if the maximum number of sessions is reached.
"""

# TODO: Should this be a class method on Batch?

data: list[list[tuple[dict]]] = []
batch_data_collection = batch.data if batch.data is not None else []
graph_as_dict = batch.graph.model_dump(warnings=False, exclude_none=True)
session_dict = GraphExecutionState(graph=Graph()).model_dump(warnings=False, exclude_none=True)

for batch_datum_list in batch_data_collection:
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped

node_field_values_to_zip: list[list[dict]] = []
# Expand each BatchDatum into a list of dicts - one for each item in the BatchDatum
for batch_datum in batch_datum_list:
node_field_values = [
# Note: A tuple here is slightly faster than a dict, but we need the object in dict form to be inserted
# in the session_queue table anyways. So, overall creating NFVs as dicts is faster.
{"node_path": batch_datum.node_path, "field_name": batch_datum.field_name, "value": item}
for item in batch_datum.items
]
node_field_values_to_zip.append(node_field_values)
# Zip the dicts together to create a list of dicts for each permutation
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]

# create generator to yield session,nfv tuples
# We serialize the graph and session once, then mutate the graph dict in place for each session.
#
# This sounds scary, but it's actually fine.
#
# The batch prep logic injects field values into the same fields for each generated session.
#
# For example, after the product operation, we'll end up with a list of node-field-value tuples like this:
# [
# (
# {"node_path": "1", "field_name": "a", "value": 1},
# {"node_path": "2", "field_name": "b", "value": 2},
# {"node_path": "3", "field_name": "c", "value": 3},
# ),
# (
# {"node_path": "1", "field_name": "a", "value": 4},
# {"node_path": "2", "field_name": "b", "value": 5},
# {"node_path": "3", "field_name": "c", "value": 6},
# )
# ]
#
# Note that each tuple has the same length, and each tuple substitutes values in for exactly the same node fields.
# No matter the complexity of the batch, this property holds true.
#
# This means each permutation's substitution can be done in-place on the same graph dict, because it overwrites the
# previous mutation. We only need to serialize the graph once, and then we can mutate it in place for each session.
#
# Previously, we had created new Graph objects for each session, but this was very slow for large (1k+ session
# batches). We then tried dumping the graph to dict and using deep-copy to create a new dict for each session,
# but this was also slow.
#
# Overall, we achieved a 100x speedup by mutating the graph dict in place for each session over creating new Graph
# objects for each session.
#
# We will also mutate the session dict in place, setting a new ID for each session and setting the mutated graph
# dict as the session's graph.

# Dump the batch's graph to a dict once
graph_as_dict = batch.graph.model_dump(warnings=False, exclude_none=True)

# We must provide a Graph object when creating the "dummy" session dict, but we don't actually use it. It will be
# overwritten for each session by the mutated graph_as_dict.
session_dict = GraphExecutionState(graph=Graph()).model_dump(warnings=False, exclude_none=True)

# Now we can create a generator that yields the session_id, session_json, and field_values_json for each session.
count = 0

# Each batch may have multiple runs, so we need to generate the same number of sessions for each run. The total is
# still limited by the maximum number of sessions.
for _ in range(batch.runs):
for d in product(*data):
if count >= maximum:
# We've reached the maximum number of sessions we may generate
return

# Flatten the list of lists of dicts into a single list of dicts
# TODO(psyche): Is the a more efficient way to do this?
flat_node_field_values = list(chain.from_iterable(d))

# The fields that are injected for each the same for all graphs. Therefore, we can mutate the graph dict
# in place and then serialize it to json for each session. It's functionally the same as creating a new
# graph dict for each session, but is more efficient.
# Need a fresh ID for each session
session_id = uuid_string()

# Mutate the session dict in place
session_dict["id"] = session_id

for item in flat_node_field_values:
graph_as_dict["nodes"][item["node_path"]][item["field_name"]] = item["value"]
# Substitute the values into the graph
for nfv in flat_node_field_values:
graph_as_dict["nodes"][nfv["node_path"]][nfv["field_name"]] = nfv["value"]

# Mutate the session dict in place
session_dict["graph"] = graph_as_dict
yield (session_id, json.dumps(session_dict, default=to_jsonable_python), flat_node_field_values)

# Serialize the session and field values
# Note the use of pydantic's to_jsonable_python to handle serialization of any python object, including sets.
session_json = json.dumps(session_dict, default=to_jsonable_python)
field_values_json = json.dumps(flat_node_field_values, default=to_jsonable_python)

# Yield the session_id, session_json, and field_values_json
yield (session_id, session_json, field_values_json)

# Increment the count so we know when to stop
count += 1


def calc_session_count(batch: Batch) -> int:
"""
Calculates the number of sessions that would be created by the batch, without incurring
the overhead of actually generating them. Adapted from `create_sessions().
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
creating them, as is done in `create_session_nfv_tuples()`.
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
many were _actually_ created (which may be less due to the maximum number of sessions).
"""
# TODO: Should this be a class method on Batch?
if not batch.data:
Expand All @@ -473,20 +558,63 @@ def calc_session_count(batch: Batch) -> int:
return len(data_product) * batch.runs


def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> list[tuple]:
values_to_insert: list[tuple] = []
ValueToInsertTuple: TypeAlias = tuple[
str, # queue_id
str, # session (as stringified JSON)
str, # session_id
str, # batch_id
str | None, # field_values (optional, as stringified JSON)
int, # priority
str | None, # workflow (optional, as stringified JSON)
str | None, # origin (optional)
str | None, # destination (optional)
str | None, # retried_from_item_id (optional, this is always None for new items)
]
"""A type alias for the tuple of values to insert into the session queue table."""


def prepare_values_to_insert(
queue_id: str, batch: Batch, priority: int, max_new_queue_items: int
) -> list[ValueToInsertTuple]:
"""
Given a batch, prepare the values to insert into the session queue table. The list of tuples can be used with an
`executemany` statement to insert multiple rows at once.
Args:
queue_id: The ID of the queue to insert the items into
batch: The batch to prepare the values for
priority: The priority of the queue items
max_new_queue_items: The maximum number of queue items to insert
Returns:
A list of tuples to insert into the session queue table. Each tuple contains the following values:
- queue_id
- session (as stringified JSON)
- session_id
- batch_id
- field_values (optional, as stringified JSON)
- priority
- workflow (optional, as stringified JSON)
- origin (optional)
- destination (optional)
- retried_from_item_id (optional, this is always None for new items)
"""

# A tuple is a fast and memory-efficient way to store the values to insert. Previously, we used a NamedTuple, but
# measured a ~5% performance improvement by using a normal tuple instead. For very large batches (10k+ items), the
# this difference becomes noticeable.
#
# So, despite the inferior DX with normal tuples, we use one here for performance reasons.

values_to_insert: list[ValueToInsertTuple] = []

# pydantic's to_jsonable_python handles serialization of any python object, including sets, which json.dumps does
# not support by default. Apparently there are sets somewhere in the graph.

# The same workflow is used for all sessions in the batch - serialize it once
workflow_json = json.dumps(batch.workflow, default=to_jsonable_python) if batch.workflow else None

for session_id, session_json, field_values in create_graph_nfv_tuples(batch, max_new_queue_items):
# As a perf optimization, we can mutate the session_dict in place. This is safe because we dump it to json
# as part of the tuple construction

field_values_json = json.dumps(field_values, default=to_jsonable_python) if field_values else None

for session_id, session_json, field_values_json in create_session_nfv_tuples(batch, max_new_queue_items):
values_to_insert.append(
(
queue_id,
Expand Down
12 changes: 6 additions & 6 deletions tests/test_session_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
BatchDatum,
NodeFieldValue,
calc_session_count,
create_graph_nfv_tuples,
create_session_nfv_tuples,
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
Expand Down Expand Up @@ -42,7 +42,7 @@ def batch_graph() -> Graph:

def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph):
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
t = list(create_graph_nfv_tuples(batch=b, maximum=1000))
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
# 2 list[BatchDatum] * length 2 * 2 runs = 8
assert len(t) == 8

Expand Down Expand Up @@ -90,28 +90,28 @@ def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph

def test_create_sessions_from_batch_without_runs(batch_data_collection, batch_graph):
b = Batch(graph=batch_graph, data=batch_data_collection)
t = list(create_graph_nfv_tuples(batch=b, maximum=1000))
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
# 2 list[BatchDatum] * length 2 * 1 runs = 8
assert len(t) == 4


def test_create_sessions_from_batch_without_batch(batch_graph):
b = Batch(graph=batch_graph, runs=2)
t = list(create_graph_nfv_tuples(batch=b, maximum=1000))
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
# 2 runs
assert len(t) == 2


def test_create_sessions_from_batch_without_batch_or_runs(batch_graph):
b = Batch(graph=batch_graph)
t = list(create_graph_nfv_tuples(batch=b, maximum=1000))
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
# 1 run
assert len(t) == 1


def test_create_sessions_from_batch_with_runs_and_max(batch_data_collection, batch_graph):
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
t = list(create_graph_nfv_tuples(batch=b, maximum=5))
t = list(create_session_nfv_tuples(batch=b, maximum=5))
# 2 list[BatchDatum] * length 2 * 2 runs = 8, but max is 5
assert len(t) == 5

Expand Down

0 comments on commit 047c643

Please sign in to comment.