Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added iOS Text Searcher and Basic Objective C Test #898

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions tensorflow_lite_support/ios/task/text/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package(
default_visibility = ["//tensorflow_lite_support:internal"],
licenses = ["notice"], # Apache 2.0
)

objc_library(
name = "TFLTextSearcher",
srcs = [
"sources/TFLTextSearcher.mm",
],
hdrs = [
"sources/TFLTextSearcher.h",
],
copts = [
"-ObjC++",
"-std=c++17",
],
features = ["-layering_check"],
module_name = "TFLTextSearcher",
deps = [
"//tensorflow_lite_support/cc/task/text:text_searcher",
"//tensorflow_lite_support/ios:TFLCommonUtils",
"//tensorflow_lite_support/ios/task/core:TFLBaseOptionsCppHelpers",
"//tensorflow_lite_support/ios/task/processor:TFLEmbeddingOptionsHelpers",
"//tensorflow_lite_support/ios/task/processor:TFLSearchOptionsHelpers",
"//tensorflow_lite_support/ios/task/processor:TFLSearchResultHelpers",
],
)
104 changes: 104 additions & 0 deletions tensorflow_lite_support/ios/task/text/sources/TFLTextSearcher.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import <Foundation/Foundation.h>

#import "tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h"
#import "tensorflow_lite_support/ios/task/processor/sources/TFLEmbeddingOptions.h"
#import "tensorflow_lite_support/ios/task/processor/sources/TFLSearchOptions.h"
#import "tensorflow_lite_support/ios/task/processor/sources/TFLSearchResult.h"

NS_ASSUME_NONNULL_BEGIN

/**
* Options to configure TFLTextSearcher.
*/
NS_SWIFT_NAME(TextSearcherOptions)
@interface TFLTextSearcherOptions : NSObject

/**
* Base options for configuring the TextSearcher. This specifies the TFLite
* model to use for embedding extraction, as well as hardware acceleration
* options to use as inference time.
*/
@property(nonatomic, copy) TFLBaseOptions *baseOptions;

/**
* Options controlling the behavior of the embedding model specified in the
* base options.
*/
@property(nonatomic, copy) TFLEmbeddingOptions *embeddingOptions;

/**
* Options specifying the index to search into and controlling the search behavior.
*/
@property(nonatomic, copy) TFLSearchOptions *searchOptions;

/**
* Initializes a new `TFLTextSearcherOptions` with the absolute path to the model file
* stored locally on the device, set to the given the model path.
*
* @discussion The external model file must be a single standalone TFLite file. It could be packed
* with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the
* necessary metadata and associated files might result in errors. Check the [documentation]
* (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
*
* @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
*
* @return An instance of `TFLTextSearcherOptions` initialized to the given model path.
*/
- (instancetype)initWithModelPath:(NSString *)modelPath;

@end

/**
* A TensorFlow Lite Task Text Searcher.
*/
NS_SWIFT_NAME(TextSearcher)
@interface TFLTextSearcher : NSObject

/**
* Creates a new instance of `TFLTextSearcher` from the given `TFLTextSearcherOptions`.
*
* @param options The options to use for configuring the `TFLTextSearcher`.
* @param error An optional error parameter populated when there is an error in initializing
* the text searcher.
*
* @return A new instance of `TextSearcher` with the given options. `nil` if there is an error
* in initializing the text searcher.
*/
+ (nullable instancetype)textSearcherWithOptions:(TFLTextSearcherOptions *)options
error:(NSError **)error
NS_SWIFT_NAME(searcher(options:));

+ (instancetype)new NS_UNAVAILABLE;

/**
* Performs embedding extraction on the given text, followed by nearest-neighbor search in the
* index.
*
* @param text An string on which embedding extraction is to be performed, followed by
* nearest-neighbor search in the index.
*
* @return A `TFLSearchResult`. `nil` if there is an error encountered during embedding extraction
* and nearest neighbor search. Please see `TFLSearchResult` for more details.
*/
- (nullable TFLSearchResult *)searchWithText:(NSString *)text
error:(NSError **)error NS_SWIFT_NAME(search(text:));

- (instancetype)init NS_UNAVAILABLE;

@end

NS_ASSUME_NONNULL_END
114 changes: 114 additions & 0 deletions tensorflow_lite_support/ios/task/text/sources/TFLTextSearcher.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import "tensorflow_lite_support/ios/task/text/sources/TFLTextSearcher.h"
#import "tensorflow_lite_support/ios/sources/TFLCommon.h"
#import "tensorflow_lite_support/ios/sources/TFLCommonUtils.h"
#import "tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+CppHelpers.h"
#import "tensorflow_lite_support/ios/task/processor/sources/TFLEmbeddingOptions+Helpers.h"
#import "tensorflow_lite_support/ios/task/processor/sources/TFLSearchOptions+Helpers.h"
#import "tensorflow_lite_support/ios/task/processor/sources/TFLSearchResult+Helpers.h"

#include "tensorflow_lite_support/cc/task/text/text_searcher.h"

namespace {
using TextSearcherCpp = ::tflite::task::text::TextSearcher;
using TextSearcherOptionsCpp = ::tflite::task::text::TextSearcherOptions;
using SearchResultCpp = ::tflite::task::processor::SearchResult;
using ::tflite::support::StatusOr;
} // namespace

@interface TFLTextSearcher () {
/** TextSearcher backed by C++ API */
std::unique_ptr<TextSearcherCpp> _cppTextSearcher;
}
@end

@implementation TFLTextSearcherOptions

- (instancetype)init {
self = [super init];
if (self) {
_baseOptions = [[TFLBaseOptions alloc] init];
_embeddingOptions = [[TFLEmbeddingOptions alloc] init];
_searchOptions = [[TFLSearchOptions alloc] init];
}
return self;
}

- (instancetype)initWithModelPath:(NSString *)modelPath {
self = [self init];
if (self) {
_baseOptions.modelFile.filePath = modelPath;
}
return self;
}

- (TextSearcherOptionsCpp)cppOptions {
TextSearcherOptionsCpp cppOptions = {};
[self.baseOptions copyToCppOptions:cppOptions.mutable_base_options()];
[self.embeddingOptions copyToCppOptions:cppOptions.mutable_embedding_options()];
[self.searchOptions copyToCppOptions:cppOptions.mutable_search_options()];

return cppOptions;
}

@end

@implementation TFLTextSearcher

- (nullable instancetype)initWithCppTextSearcherOptions:(TextSearcherOptionsCpp)cppOptions {
self = [super init];
if (self) {
StatusOr<std::unique_ptr<TextSearcherCpp>> cppTextSearcher =
TextSearcherCpp::CreateFromOptions(cppOptions);
if (cppTextSearcher.ok()) {
_cppTextSearcher = std::move(cppTextSearcher.value());
} else {
return nil;
}
}
return self;
}

+ (nullable instancetype)textSearcherWithOptions:(TFLTextSearcherOptions *)options
error:(NSError **)error {
if (!options) {
[TFLCommonUtils createCustomError:error
withCode:TFLSupportErrorCodeInvalidArgumentError
description:@"TFLTextSearcherOptions argument cannot be nil."];
return nil;
}

TextSearcherOptionsCpp cppOptions = [options cppOptions];

return [[TFLTextSearcher alloc] initWithCppTextSearcherOptions:cppOptions];
}

- (nullable TFLSearchResult *)searchWithText:(NSString *)text error:(NSError **)error {
if (!text) {
[TFLCommonUtils createCustomError:error
withCode:TFLSupportErrorCodeInvalidArgumentError
description:@"GMLImage argument cannot be nil."];
return nil;
}

std::string cppTextToBeSearched = std::string(text.UTF8String, [text lengthOfBytesUsingEncoding:NSUTF8StringEncoding]);
StatusOr<SearchResultCpp> cppSearchResultStatus = _cppTextSearcher->Search(
cppTextToBeSearched);

return [TFLSearchResult searchResultWithCppResult:cppSearchResultStatus error:error];
}

@end
31 changes: 31 additions & 0 deletions tensorflow_lite_support/ios/test/task/text/text_searcher/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
load("@org_tensorflow//tensorflow/lite/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION")
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner")

package(
default_visibility = ["//visibility:private"],
licenses = ["notice"], # Apache 2.0
)

objc_library(
name = "TFLTextSearcherObjcTestLibrary",
testonly = 1,
srcs = ["TFLTextSearcherTests.m"],
data = [
"//tensorflow_lite_support/cc/test/testdata/task/text:test_searchers",
],
tags = TFL_DEFAULT_TAGS,
deps = [
"//tensorflow_lite_support/ios/task/text:TFLTextSearcher",
],
)

ios_unit_test(
name = "TFLTextSearcherObjcTest",
minimum_os_version = TFL_MINIMUM_OS_VERSION,
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
deps = [
":TFLTextSearcherObjcTestLibrary",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import <XCTest/XCTest.h>

#import "tensorflow_lite_support/ios/task/text/sources/TFLTextSearcher.h"

NS_ASSUME_NONNULL_BEGIN

#define ValidateSearchResultCount(searchResult, expectedNearestNeighborsCount) \
XCTAssertEqual(searchResult.nearestNeighbors.count, expectedNearestNeighborsCount);

#define ValidateNearestNeighbor(nearestNeighbor, expectedMetadata, expectedDistance) \
XCTAssertEqualObjects(nearestNeighbor.metadata, expectedMetadata); \
XCTAssertEqualWithAccuracy(nearestNeighbor.distance, expectedDistance, 1e-6);

@interface TFLTextSearcherTests : XCTestCase
@property(nonatomic, nullable) NSString *modelPath;
@end

@implementation TFLTextSearcherTests

- (void)setUp {
[super setUp];
self.modelPath =
[[NSBundle bundleForClass:self.class] pathForResource:@"regex_searcher"
ofType:@"tflite"];
XCTAssertNotNil(self.modelPath);
}

- (TFLTextSearcher *)textSearcherWithSearcherModelPath:(NSString *)modelPath {
TFLTextSearcherOptions *textSearcherOptions =
[[TFLTextSearcherOptions alloc] initWithModelPath:self.modelPath];

TFLTextSearcher *textSearcher = [TFLTextSearcher textSearcherWithOptions:textSearcherOptions
error:nil];
XCTAssertNotNil(textSearcher);

return textSearcher;
}

- (void)validateSearchResultForInferenceWithSearchContent:(TFLSearchResult *)searchResult {
ValidateSearchResultCount(searchResult,
5 // expectedNearestNeighborsCount
);

ValidateNearestNeighbor(searchResult.nearestNeighbors[0],
@"The weather was excellent.", // expectedMetadata
0.889664649963 // expectedDistance
);
ValidateNearestNeighbor(searchResult.nearestNeighbors[1],
@"The sun was shining on that day.", // expectedMetadata
0.889667928219 // expectedDistance
);
ValidateNearestNeighbor(searchResult.nearestNeighbors[2],
@"The cat is chasing after the mouse.", // expectedMetadata
0.889669716358 // expectedDistance
);
ValidateNearestNeighbor(searchResult.nearestNeighbors[3],
@"It was a sunny day.", // expectedMetadata
0.889671087265 // expectedDistance
);
ValidateNearestNeighbor(searchResult.nearestNeighbors[4],
@"He was very happy with his newly bought car.", // expectedMetadata
0.889671683311 // expectedDistance
);
}

- (void)testSuccessfullInferenceWithSearchContentOnText {
TFLTextSearcher *textSearcher =
[self textSearcherWithSearcherModelPath:self.modelPath];

TFLSearchResult *searchResult = [textSearcher searchWithText:@"The weather was excellent." error:nil];
[self validateSearchResultForInferenceWithSearchContent:searchResult];
}

@end

NS_ASSUME_NONNULL_END
Original file line number Diff line number Diff line change
Expand Up @@ -294,4 +294,4 @@ - (void)testImageSearcherWithEmbedderModelAndInvalidIndexFileFails {

@end

NS_ASSUME_NONNULL_END
NS_ASSUME_NONNULL_END