diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index 87d026f51ff..009691b9f1c 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -406,58 +406,143 @@ class IsFullResult(BaseModel): # region Util -def create_graph_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str, str, list[dict]], None, None]: +def create_session_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str, str, str], None, None]: """ - Create all graph permutations from the given batch data and graph. Yields tuples - of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems - that was applied to the graph. + Given a batch and a maximum number of sessions to create, generate a tuple of session_id, session_json, and + field_values_json for each session. + + The batch has a "source" graph and a data property. The data property is a list of lists of BatchDatum objects. + Each BatchDatum has a field identifier (e.g. a node id and field name), and a list of values to substitute into + the field. + + This structure allows us to create a new graph for every possible permutation of BatchDatum objects: + - Each BatchDatum can be "expanded" into a dict of node-field-value tuples - one for each item in the BatchDatum. + - Zip each inner list of expanded BatchDatum objects together. Call this a "batch_data_list". + - Take the cartesian product of all zipped batch_data_lists, resulting in a list of permutations of BatchDatum + - Take the cartesian product of all zipped batch_data_lists, resulting in a list of lists of BatchDatum objects. + Each inner list now represents the substitution values for a single permutation (session). + - For each permutation, substitute the values into the graph + + This function is optimized for performance, as it is used to generate a large number of sessions at once. + + Args: + batch: The batch to generate sessions from + maximum: The maximum number of sessions to generate + + Returns: + A generator that yields tuples of session_id, session_json, and field_values_json for each session. The + generator will stop early if the maximum number of sessions is reached. """ # TODO: Should this be a class method on Batch? data: list[list[tuple[dict]]] = [] batch_data_collection = batch.data if batch.data is not None else [] - graph_as_dict = batch.graph.model_dump(warnings=False, exclude_none=True) - session_dict = GraphExecutionState(graph=Graph()).model_dump(warnings=False, exclude_none=True) for batch_datum_list in batch_data_collection: - # each batch_datum_list needs to be convered to NodeFieldValues and then zipped - node_field_values_to_zip: list[list[dict]] = [] + # Expand each BatchDatum into a list of dicts - one for each item in the BatchDatum for batch_datum in batch_datum_list: node_field_values = [ + # Note: A tuple here is slightly faster than a dict, but we need the object in dict form to be inserted + # in the session_queue table anyways. So, overall creating NFVs as dicts is faster. {"node_path": batch_datum.node_path, "field_name": batch_datum.field_name, "value": item} for item in batch_datum.items ] node_field_values_to_zip.append(node_field_values) + # Zip the dicts together to create a list of dicts for each permutation data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type] - # create generator to yield session,nfv tuples + # We serialize the graph and session once, then mutate the graph dict in place for each session. + # + # This sounds scary, but it's actually fine. + # + # The batch prep logic injects field values into the same fields for each generated session. + # + # For example, after the product operation, we'll end up with a list of node-field-value tuples like this: + # [ + # ( + # {"node_path": "1", "field_name": "a", "value": 1}, + # {"node_path": "2", "field_name": "b", "value": 2}, + # {"node_path": "3", "field_name": "c", "value": 3}, + # ), + # ( + # {"node_path": "1", "field_name": "a", "value": 4}, + # {"node_path": "2", "field_name": "b", "value": 5}, + # {"node_path": "3", "field_name": "c", "value": 6}, + # ) + # ] + # + # Note that each tuple has the same length, and each tuple substitutes values in for exactly the same node fields. + # No matter the complexity of the batch, this property holds true. + # + # This means each permutation's substitution can be done in-place on the same graph dict, because it overwrites the + # previous mutation. We only need to serialize the graph once, and then we can mutate it in place for each session. + # + # Previously, we had created new Graph objects for each session, but this was very slow for large (1k+ session + # batches). We then tried dumping the graph to dict and using deep-copy to create a new dict for each session, + # but this was also slow. + # + # Overall, we achieved a 100x speedup by mutating the graph dict in place for each session over creating new Graph + # objects for each session. + # + # We will also mutate the session dict in place, setting a new ID for each session and setting the mutated graph + # dict as the session's graph. + + # Dump the batch's graph to a dict once + graph_as_dict = batch.graph.model_dump(warnings=False, exclude_none=True) + + # We must provide a Graph object when creating the "dummy" session dict, but we don't actually use it. It will be + # overwritten for each session by the mutated graph_as_dict. + session_dict = GraphExecutionState(graph=Graph()).model_dump(warnings=False, exclude_none=True) + + # Now we can create a generator that yields the session_id, session_json, and field_values_json for each session. count = 0 + + # Each batch may have multiple runs, so we need to generate the same number of sessions for each run. The total is + # still limited by the maximum number of sessions. for _ in range(batch.runs): for d in product(*data): if count >= maximum: + # We've reached the maximum number of sessions we may generate return + + # Flatten the list of lists of dicts into a single list of dicts + # TODO(psyche): Is the a more efficient way to do this? flat_node_field_values = list(chain.from_iterable(d)) - # The fields that are injected for each the same for all graphs. Therefore, we can mutate the graph dict - # in place and then serialize it to json for each session. It's functionally the same as creating a new - # graph dict for each session, but is more efficient. + # Need a fresh ID for each session session_id = uuid_string() + + # Mutate the session dict in place session_dict["id"] = session_id - for item in flat_node_field_values: - graph_as_dict["nodes"][item["node_path"]][item["field_name"]] = item["value"] + # Substitute the values into the graph + for nfv in flat_node_field_values: + graph_as_dict["nodes"][nfv["node_path"]][nfv["field_name"]] = nfv["value"] + # Mutate the session dict in place session_dict["graph"] = graph_as_dict - yield (session_id, json.dumps(session_dict, default=to_jsonable_python), flat_node_field_values) + + # Serialize the session and field values + # Note the use of pydantic's to_jsonable_python to handle serialization of any python object, including sets. + session_json = json.dumps(session_dict, default=to_jsonable_python) + field_values_json = json.dumps(flat_node_field_values, default=to_jsonable_python) + + # Yield the session_id, session_json, and field_values_json + yield (session_id, session_json, field_values_json) + + # Increment the count so we know when to stop count += 1 def calc_session_count(batch: Batch) -> int: """ - Calculates the number of sessions that would be created by the batch, without incurring - the overhead of actually generating them. Adapted from `create_sessions(). + Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually + creating them, as is done in `create_session_nfv_tuples()`. + + The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how + many were _actually_ created (which may be less due to the maximum number of sessions). """ # TODO: Should this be a class method on Batch? if not batch.data: @@ -473,20 +558,63 @@ def calc_session_count(batch: Batch) -> int: return len(data_product) * batch.runs -def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> list[tuple]: - values_to_insert: list[tuple] = [] +ValueToInsertTuple: TypeAlias = tuple[ + str, # queue_id + str, # session (as stringified JSON) + str, # session_id + str, # batch_id + str | None, # field_values (optional, as stringified JSON) + int, # priority + str | None, # workflow (optional, as stringified JSON) + str | None, # origin (optional) + str | None, # destination (optional) + str | None, # retried_from_item_id (optional, this is always None for new items) +] +"""A type alias for the tuple of values to insert into the session queue table.""" + + +def prepare_values_to_insert( + queue_id: str, batch: Batch, priority: int, max_new_queue_items: int +) -> list[ValueToInsertTuple]: + """ + Given a batch, prepare the values to insert into the session queue table. The list of tuples can be used with an + `executemany` statement to insert multiple rows at once. + + Args: + queue_id: The ID of the queue to insert the items into + batch: The batch to prepare the values for + priority: The priority of the queue items + max_new_queue_items: The maximum number of queue items to insert + + Returns: + A list of tuples to insert into the session queue table. Each tuple contains the following values: + - queue_id + - session (as stringified JSON) + - session_id + - batch_id + - field_values (optional, as stringified JSON) + - priority + - workflow (optional, as stringified JSON) + - origin (optional) + - destination (optional) + - retried_from_item_id (optional, this is always None for new items) + """ + + # A tuple is a fast and memory-efficient way to store the values to insert. Previously, we used a NamedTuple, but + # measured a ~5% performance improvement by using a normal tuple instead. For very large batches (10k+ items), the + # this difference becomes noticeable. + # + # So, despite the inferior DX with normal tuples, we use one here for performance reasons. + + values_to_insert: list[ValueToInsertTuple] = [] + # pydantic's to_jsonable_python handles serialization of any python object, including sets, which json.dumps does # not support by default. Apparently there are sets somewhere in the graph. # The same workflow is used for all sessions in the batch - serialize it once workflow_json = json.dumps(batch.workflow, default=to_jsonable_python) if batch.workflow else None - for session_id, session_json, field_values in create_graph_nfv_tuples(batch, max_new_queue_items): - # As a perf optimization, we can mutate the session_dict in place. This is safe because we dump it to json - # as part of the tuple construction - - field_values_json = json.dumps(field_values, default=to_jsonable_python) if field_values else None - + for session_id, session_json, field_values_json in create_session_nfv_tuples(batch, max_new_queue_items): values_to_insert.append( ( queue_id, diff --git a/tests/test_session_queue.py b/tests/test_session_queue.py index 13c5634a192..a3f0b8c803a 100644 --- a/tests/test_session_queue.py +++ b/tests/test_session_queue.py @@ -9,7 +9,7 @@ BatchDatum, NodeFieldValue, calc_session_count, - create_graph_nfv_tuples, + create_session_nfv_tuples, prepare_values_to_insert, ) from invokeai.app.services.shared.graph import Graph, GraphExecutionState @@ -42,7 +42,7 @@ def batch_graph() -> Graph: def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph): b = Batch(graph=batch_graph, data=batch_data_collection, runs=2) - t = list(create_graph_nfv_tuples(batch=b, maximum=1000)) + t = list(create_session_nfv_tuples(batch=b, maximum=1000)) # 2 list[BatchDatum] * length 2 * 2 runs = 8 assert len(t) == 8 @@ -90,28 +90,28 @@ def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph def test_create_sessions_from_batch_without_runs(batch_data_collection, batch_graph): b = Batch(graph=batch_graph, data=batch_data_collection) - t = list(create_graph_nfv_tuples(batch=b, maximum=1000)) + t = list(create_session_nfv_tuples(batch=b, maximum=1000)) # 2 list[BatchDatum] * length 2 * 1 runs = 8 assert len(t) == 4 def test_create_sessions_from_batch_without_batch(batch_graph): b = Batch(graph=batch_graph, runs=2) - t = list(create_graph_nfv_tuples(batch=b, maximum=1000)) + t = list(create_session_nfv_tuples(batch=b, maximum=1000)) # 2 runs assert len(t) == 2 def test_create_sessions_from_batch_without_batch_or_runs(batch_graph): b = Batch(graph=batch_graph) - t = list(create_graph_nfv_tuples(batch=b, maximum=1000)) + t = list(create_session_nfv_tuples(batch=b, maximum=1000)) # 1 run assert len(t) == 1 def test_create_sessions_from_batch_with_runs_and_max(batch_data_collection, batch_graph): b = Batch(graph=batch_graph, data=batch_data_collection, runs=2) - t = list(create_graph_nfv_tuples(batch=b, maximum=5)) + t = list(create_session_nfv_tuples(batch=b, maximum=5)) # 2 list[BatchDatum] * length 2 * 2 runs = 8, but max is 5 assert len(t) == 5