diff --git a/mii/legacy/client.py b/mii/legacy/client.py index 3965f3b2..d90df050 100644 --- a/mii/legacy/client.py +++ b/mii/legacy/client.py @@ -134,6 +134,13 @@ def query(self, request_dict, **query_kwargs): elif self.task == TaskType.TEXT2IMG: args = (request_dict["prompt"], request_dict.get("negative_prompt", None)) kwargs = query_kwargs + elif self.task == TaskType.INPAINTING: + negative_prompt = request_dict.get("negative_prompt", None) + args = (request_dict["prompt"], + request_dict["image"], + request_dict["mask_image"], + negative_prompt) + kwargs = query_kwargs else: args = (request_dict["query"], ) kwargs = query_kwargs diff --git a/mii/legacy/constants.py b/mii/legacy/constants.py index d305a036..254adff8 100644 --- a/mii/legacy/constants.py +++ b/mii/legacy/constants.py @@ -20,6 +20,7 @@ class TaskType(str, Enum): CONVERSATIONAL = "conversational" TEXT2IMG = "text-to-image" ZERO_SHOT_IMAGE_CLASSIFICATION = "zero-shot-image-classification" + INPAINTING = "text-to-image-inpainting" class ModelProvider(str, Enum): @@ -60,6 +61,11 @@ class ModelProvider(str, Enum): TaskType.TEXT2IMG: ["prompt"], TaskType.ZERO_SHOT_IMAGE_CLASSIFICATION: ["image", "candidate_labels"], + TaskType.INPAINTING: [ + "prompt", + "image", + "mask_image", + ] } MII_CACHE_PATH = "MII_CACHE_PATH" diff --git a/mii/legacy/grpc_related/modelresponse_server.py b/mii/legacy/grpc_related/modelresponse_server.py index b9cfdeed..c6b3df4d 100644 --- a/mii/legacy/grpc_related/modelresponse_server.py +++ b/mii/legacy/grpc_related/modelresponse_server.py @@ -126,6 +126,9 @@ def ConversationalReply(self, request, context): def ZeroShotImgClassificationReply(self, request, context): return self._run_inference("ZeroShotImgClassificationReply", request) + def InpaintingReply(self, request, context): + return self._run_inference("InpaintingReply", request) + class AtomicCounter: def __init__(self, initial_value=0): diff --git a/mii/legacy/grpc_related/proto/legacymodelresponse.proto b/mii/legacy/grpc_related/proto/legacymodelresponse.proto index 5ed801bc..9f8c060c 100644 --- a/mii/legacy/grpc_related/proto/legacymodelresponse.proto +++ b/mii/legacy/grpc_related/proto/legacymodelresponse.proto @@ -35,6 +35,7 @@ service ModelResponse { rpc ConversationalReply(ConversationRequest) returns (ConversationReply) {} rpc Txt2ImgReply(Text2ImageRequest) returns (ImageReply) {} rpc ZeroShotImgClassificationReply (ZeroShotImgClassificationRequest) returns (SingleStringReply) {} + rpc InpaintingReply(InpaintingRequest) returns (ImageReply) {} } message Value { @@ -114,3 +115,11 @@ message ZeroShotImgClassificationRequest { repeated string candidate_labels = 2; map query_kwargs = 3; } + +message InpaintingRequest { + repeated string prompt = 1; + repeated bytes image = 2; + repeated bytes mask_image = 3; + repeated string negative_prompt = 4; + map query_kwargs = 5; +} diff --git a/mii/legacy/grpc_related/proto/legacymodelresponse_pb2.py b/mii/legacy/grpc_related/proto/legacymodelresponse_pb2.py index 2d153cde..591ef64d 100644 --- a/mii/legacy/grpc_related/proto/legacymodelresponse_pb2.py +++ b/mii/legacy/grpc_related/proto/legacymodelresponse_pb2.py @@ -4,7 +4,6 @@ # DeepSpeed Team # Generated by the protocol buffer compiler. DO NOT EDIT! # source: legacymodelresponse.proto -# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -17,7 +16,7 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x19legacymodelresponse.proto\x12\x13legacymodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xc7\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12O\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x39.legacymodelresponse.SingleStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xc5\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12N\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x38.legacymodelresponse.MultiStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xc5\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12\x45\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32/.legacymodelresponse.QARequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x94\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12O\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x39.legacymodelresponse.ConversationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\"\xdb\x01\n\x11Text2ImageRequest\x12\x0e\n\x06prompt\x18\x01 \x03(\t\x12\x17\n\x0fnegative_prompt\x18\x02 \x03(\t\x12M\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32\x37.legacymodelresponse.Text2ImageRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xf9\x01\n ZeroShotImgClassificationRequest\x12\r\n\x05image\x18\x01 \x01(\t\x12\x18\n\x10\x63\x61ndidate_labels\x18\x02 \x03(\t\x12\\\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32\x46.legacymodelresponse.ZeroShotImgClassificationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\x32\xb7\x08\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12I\n\rCreateSession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12J\n\x0e\x44\x65stroySession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x62\n\x0eGeneratorReply\x12\'.legacymodelresponse.MultiStringRequest\x1a%.legacymodelresponse.MultiStringReply\"\x00\x12i\n\x13\x43lassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x62\n\x16QuestionAndAnswerReply\x12\x1e.legacymodelresponse.QARequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x63\n\rFillMaskReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12n\n\x18TokenClassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12i\n\x13\x43onversationalReply\x12(.legacymodelresponse.ConversationRequest\x1a&.legacymodelresponse.ConversationReply\"\x00\x12Y\n\x0cTxt2ImgReply\x12&.legacymodelresponse.Text2ImageRequest\x1a\x1f.legacymodelresponse.ImageReply\"\x00\x12\x81\x01\n\x1eZeroShotImgClassificationReply\x12\x35.legacymodelresponse.ZeroShotImgClassificationRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x62\x06proto3' + b'\n\x19legacymodelresponse.proto\x12\x13legacymodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xc7\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12O\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x39.legacymodelresponse.SingleStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xc5\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12N\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x38.legacymodelresponse.MultiStringRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xc5\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12\x45\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32/.legacymodelresponse.QARequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x94\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12O\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x39.legacymodelresponse.ConversationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\"\xdb\x01\n\x11Text2ImageRequest\x12\x0e\n\x06prompt\x18\x01 \x03(\t\x12\x17\n\x0fnegative_prompt\x18\x02 \x03(\t\x12M\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32\x37.legacymodelresponse.Text2ImageRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xf9\x01\n ZeroShotImgClassificationRequest\x12\r\n\x05image\x18\x01 \x01(\t\x12\x18\n\x10\x63\x61ndidate_labels\x18\x02 \x03(\t\x12\\\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32\x46.legacymodelresponse.ZeroShotImgClassificationRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\"\xfe\x01\n\x11InpaintingRequest\x12\x0e\n\x06prompt\x18\x01 \x03(\t\x12\r\n\x05image\x18\x02 \x03(\x0c\x12\x12\n\nmask_image\x18\x03 \x03(\x0c\x12\x17\n\x0fnegative_prompt\x18\x04 \x03(\t\x12M\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x37.legacymodelresponse.InpaintingRequest.QueryKwargsEntry\x1aN\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.legacymodelresponse.Value:\x02\x38\x01\x32\x95\t\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12I\n\rCreateSession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12J\n\x0e\x44\x65stroySession\x12\x1e.legacymodelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x62\n\x0eGeneratorReply\x12\'.legacymodelresponse.MultiStringRequest\x1a%.legacymodelresponse.MultiStringReply\"\x00\x12i\n\x13\x43lassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x62\n\x16QuestionAndAnswerReply\x12\x1e.legacymodelresponse.QARequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\x63\n\rFillMaskReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12n\n\x18TokenClassificationReply\x12(.legacymodelresponse.SingleStringRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12i\n\x13\x43onversationalReply\x12(.legacymodelresponse.ConversationRequest\x1a&.legacymodelresponse.ConversationReply\"\x00\x12Y\n\x0cTxt2ImgReply\x12&.legacymodelresponse.Text2ImageRequest\x1a\x1f.legacymodelresponse.ImageReply\"\x00\x12\x81\x01\n\x1eZeroShotImgClassificationReply\x12\x35.legacymodelresponse.ZeroShotImgClassificationRequest\x1a&.legacymodelresponse.SingleStringReply\"\x00\x12\\\n\x0fInpaintingReply\x12&.legacymodelresponse.InpaintingRequest\x1a\x1f.legacymodelresponse.ImageReply\"\x00\x62\x06proto3' ) _globals = globals() @@ -38,6 +37,8 @@ _globals['_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY']._options = None _globals[ '_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY']._serialized_options = b'8\001' + _globals['_INPAINTINGREQUEST_QUERYKWARGSENTRY']._options = None + _globals['_INPAINTINGREQUEST_QUERYKWARGSENTRY']._serialized_options = b'8\001' _globals['_VALUE']._serialized_start = 79 _globals['_VALUE']._serialized_end = 174 _globals['_SESSIONID']._serialized_start = 176 @@ -75,6 +76,10 @@ _globals[ '_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY']._serialized_start = 331 _globals['_ZEROSHOTIMGCLASSIFICATIONREQUEST_QUERYKWARGSENTRY']._serialized_end = 409 - _globals['_MODELRESPONSE']._serialized_start = 2009 - _globals['_MODELRESPONSE']._serialized_end = 3088 + _globals['_INPAINTINGREQUEST']._serialized_start = 2009 + _globals['_INPAINTINGREQUEST']._serialized_end = 2263 + _globals['_INPAINTINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 331 + _globals['_INPAINTINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 409 + _globals['_MODELRESPONSE']._serialized_start = 2266 + _globals['_MODELRESPONSE']._serialized_end = 3439 # @@protoc_insertion_point(module_scope) diff --git a/mii/legacy/grpc_related/proto/legacymodelresponse_pb2_grpc.py b/mii/legacy/grpc_related/proto/legacymodelresponse_pb2_grpc.py index 4807d59f..f3747b78 100644 --- a/mii/legacy/grpc_related/proto/legacymodelresponse_pb2_grpc.py +++ b/mii/legacy/grpc_related/proto/legacymodelresponse_pb2_grpc.py @@ -81,6 +81,12 @@ def __init__(self, channel): SerializeToString, response_deserializer=legacymodelresponse__pb2.SingleStringReply.FromString, ) + self.InpaintingReply = channel.unary_unary( + '/legacymodelresponse.ModelResponse/InpaintingReply', + request_serializer=legacymodelresponse__pb2.InpaintingRequest. + SerializeToString, + response_deserializer=legacymodelresponse__pb2.ImageReply.FromString, + ) class ModelResponseServicer(object): @@ -151,6 +157,12 @@ def ZeroShotImgClassificationReply(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def InpaintingReply(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_ModelResponseServicer_to_server(servicer, server): rpc_method_handlers = { @@ -231,6 +243,12 @@ def add_ModelResponseServicer_to_server(servicer, server): response_serializer=legacymodelresponse__pb2.SingleStringReply. SerializeToString, ), + 'InpaintingReply': + grpc.unary_unary_rpc_method_handler( + servicer.InpaintingReply, + request_deserializer=legacymodelresponse__pb2.InpaintingRequest.FromString, + response_serializer=legacymodelresponse__pb2.ImageReply.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'legacymodelresponse.ModelResponse', @@ -526,3 +544,29 @@ def ZeroShotImgClassificationReply(request, wait_for_ready, timeout, metadata) + + @staticmethod + def InpaintingReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/legacymodelresponse.ModelResponse/InpaintingReply', + legacymodelresponse__pb2.InpaintingRequest.SerializeToString, + legacymodelresponse__pb2.ImageReply.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) diff --git a/mii/legacy/method_table.py b/mii/legacy/method_table.py index 0367efe0..145d6092 100644 --- a/mii/legacy/method_table.py +++ b/mii/legacy/method_table.py @@ -9,7 +9,7 @@ from mii.legacy.constants import TaskType from mii.legacy.grpc_related.proto import legacymodelresponse_pb2 as modelresponse_pb2 from mii.legacy.utils import kwarg_dict_to_proto, unpack_proto_query_kwargs -from mii.legacy.models.utils import ImageResponse +from mii.legacy.models.utils import ImageResponse, convert_bytes_to_pil_image def single_string_request_to_proto(self, request_dict, **query_kwargs): @@ -312,6 +312,51 @@ def run_inference(self, inference_pipeline, args, kwargs): return inference_pipeline(image, candidate_labels=candidate_labels, **kwargs) +class InpaintingMethods(Text2ImgMethods): + @property + def method(self): + return "InpaintingReply" + + def run_inference(self, inference_pipeline, args, kwargs): + prompt, image, mask_image, negative_prompt = args + return inference_pipeline(prompt=prompt, + image=image, + mask_image=mask_image, + negative_prompt=negative_prompt, + **kwargs) + + def pack_request_to_proto(self, request_dict, **query_kwargs): + prompt = request_dict["prompt"] + prompt = prompt if isinstance(prompt, list) else [prompt] + negative_prompt = request_dict.get("negative_prompt", [""] * len(prompt)) + negative_prompt = negative_prompt if isinstance(negative_prompt, + list) else [negative_prompt] + image = request_dict["image"] if isinstance(request_dict["image"], + list) else [request_dict["image"]] + mask_image = request_dict["mask_image"] if isinstance( + request_dict["mask_image"], + list) else [request_dict["mask_image"]] + + return modelresponse_pb2.InpaintingRequest( + prompt=prompt, + image=image, + mask_image=mask_image, + negative_prompt=negative_prompt, + query_kwargs=kwarg_dict_to_proto(query_kwargs), + ) + + def unpack_request_from_proto(self, request): + kwargs = unpack_proto_query_kwargs(request.query_kwargs) + + image = [convert_bytes_to_pil_image(img) for img in request.image] + mask_image = [ + convert_bytes_to_pil_image(mask_image) for mask_image in request.mask_image + ] + + args = (list(request.prompt), image, mask_image, list(request.negative_prompt)) + return args, kwargs + + GRPC_METHOD_TABLE = { TaskType.TEXT_GENERATION: TextGenerationMethods(), TaskType.TEXT_CLASSIFICATION: TextClassificationMethods(), @@ -321,4 +366,5 @@ def run_inference(self, inference_pipeline, args, kwargs): TaskType.CONVERSATIONAL: ConversationalMethods(), TaskType.TEXT2IMG: Text2ImgMethods(), TaskType.ZERO_SHOT_IMAGE_CLASSIFICATION: ZeroShotImgClassificationMethods(), + TaskType.INPAINTING: InpaintingMethods(), } diff --git a/mii/legacy/models/utils.py b/mii/legacy/models/utils.py index 5298745e..f9b10769 100644 --- a/mii/legacy/models/utils.py +++ b/mii/legacy/models/utils.py @@ -3,6 +3,7 @@ # DeepSpeed Team import os +import io from mii.legacy.utils import mii_cache_path @@ -59,3 +60,13 @@ def images(self): images.append(img) self._deserialized_images = images return self._deserialized_images + + +def convert_bytes_to_pil_image(image_bytes: bytes): + """Converts bytes to a PIL Image object.""" + if not isinstance(image_bytes, bytes): + return image_bytes + + from PIL import Image + image = Image.open(io.BytesIO(image_bytes)) + return image diff --git a/mii/legacy/utils.py b/mii/legacy/utils.py index 64b7f16c..062a9524 100644 --- a/mii/legacy/utils.py +++ b/mii/legacy/utils.py @@ -190,7 +190,7 @@ def get_num_gpus(mii_config): def get_provider(model_name, task): if model_name == "gpt-neox": provider = ModelProvider.ELEUTHER_AI - elif task == TaskType.TEXT2IMG: + elif task in [TaskType.TEXT2IMG, TaskType.INPAINTING]: provider = ModelProvider.DIFFUSERS else: provider = ModelProvider.HUGGING_FACE