Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-16319 Implement KNN API #16474

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions h2o-algos/src/main/java/hex/api/RegisterAlgos.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ public void registerEndPoints(RestApiContext context) {
new hex.isotonic .IsotonicRegression(true),
new hex.tree.dt .DT (true),
new hex.hglm .HGLM (true),
new hex.adaboost. AdaBoost (true)
//new hex.knn .KNN (true) will be implement in different PR
new hex.adaboost. AdaBoost (true),
new hex.knn .KNN (true)
};

// "Word2Vec", "Example", "Grep"
Expand Down
15 changes: 7 additions & 8 deletions h2o-algos/src/main/java/hex/knn/KNN.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class KNNDriver extends Driver {
public void computeImpl() {
KNNModel model = null;
Frame result = new Frame(Key.make("KNN_distances"));
Frame tmpResult = null;
try {
init(true); // Initialize parameters
if (error_count() > 0) {
Expand All @@ -71,13 +70,16 @@ public void computeImpl() {
for (int j = 0; j < nCols; j++) {
query[j] = train.vec(j).chunkForChunkIdx(i).deepCopy();
}
KNNDistanceTask task = new KNNDistanceTask(_parms._k, query, _parms._distance, idColumnIndex, idColumn, idType, responseColumnIndex, responseColumn);
tmpResult = task.doAll(train).outputFrame();
KNNDistanceTask task = new KNNDistanceTask(_parms._k, query, KNNDistanceFactory.createDistance(_parms._distance), idColumnIndex, idColumn, idType, responseColumnIndex, responseColumn);
Frame tmpResult = task.doAll(train).outputFrame();
Scope.untrack(tmpResult);

// merge result from a chunk
result = result.add(tmpResult);
}
DKV.put(result._key, result);
model._output.setDistancesKey(result._key);
Key<Frame> key = result._key;
DKV.put(key, result);
model._output.setDistancesKey(key);
Scope.untrack(result);

model.update(_job);
Expand All @@ -90,9 +92,6 @@ public void computeImpl() {
if (model != null) {
model.unlock(_job);
}
if (tmpResult != null) {
tmpResult.remove();
}
}
}
}
Expand Down
21 changes: 21 additions & 0 deletions h2o-algos/src/main/java/hex/knn/KNNDistanceFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package hex.knn;

import hex.DistanceType;
import water.H2O;

public class KNNDistanceFactory {

public static KNNDistance createDistance(DistanceType type) {
switch (type) {
case EUCLIDEAN:
return new EuclideanDistance();
case MANHATTAN:
return new ManhattanDistance();
case COSINE:
return new CosineDistance();
default:
throw H2O.unimpl("Try to get "+type+" which is not supported.");
}
}

}
8 changes: 6 additions & 2 deletions h2o-algos/src/main/java/hex/knn/KNNModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public String javaName() {
}

public int _k = 3;
public KNNDistance _distance;
public DistanceType _distance;
public boolean _compute_metrics;

@Override
Expand Down Expand Up @@ -53,6 +53,10 @@ public void setDistancesKey(Key<Frame> _distances_key) {
public Frame getDistances(){
return DKV.get(_distances_key).get();
}

public Key<Frame> getDistancesKey() {
return _distances_key;
}
}

public KNNModel(Key<KNNModel> selfKey, KNNModel.KNNParameters parms, KNNModel.KNNOutput output) {
Expand All @@ -76,7 +80,7 @@ protected double[] score0(double[] data, double[] preds) {
int idIndex = train.find(_parms._id_column);
int responseIndex = train.find(_parms._response_column);
byte idType = train.types()[idIndex];
preds = new KNNScoringTask(data, _parms._k, _output.nclasses(), _parms._distance, idIndex, idType,
preds = new KNNScoringTask(data, _parms._k, _output.nclasses(), KNNDistanceFactory.createDistance(_parms._distance), idIndex, idType,
responseIndex).doAll(train).score();
Scope.untrack(train);
return preds;
Expand Down
28 changes: 28 additions & 0 deletions h2o-algos/src/main/java/hex/schemas/KNNModelV3.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package hex.schemas;

import hex.knn.KNNModel;
import water.api.API;
import water.api.schemas3.ModelOutputSchemaV3;
import water.api.schemas3.ModelSchemaV3;

public class KNNModelV3 extends ModelSchemaV3<KNNModel, KNNModelV3, KNNModel.KNNParameters, KNNV3.KNNParametersV3, KNNModel.KNNOutput, KNNModelV3.KNNModelOutputV3> {

public static final class KNNModelOutputV3 extends ModelOutputSchemaV3<KNNModel.KNNOutput, KNNModelOutputV3> {
@API(help="Key of frame with calculated distances.")
public String distances;

@Override public KNNModelOutputV3 fillFromImpl(KNNModel.KNNOutput impl) {
KNNModelOutputV3 knnv3 = super.fillFromImpl(impl);
knnv3.distances = impl.getDistancesKey().toString();
return knnv3;
}
}

public KNNV3.KNNParametersV3 createParametersSchema() { return new KNNV3.KNNParametersV3(); }
public KNNModelOutputV3 createOutputSchema() { return new KNNModelOutputV3(); }

@Override public KNNModel createImpl() {
KNNModel.KNNParameters parms = parameters.createImpl();
return new KNNModel( model_id.key(), parms, null );
}
}
46 changes: 46 additions & 0 deletions h2o-algos/src/main/java/hex/schemas/KNNV3.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package hex.schemas;

import hex.DistanceType;
import hex.knn.KNN;
import hex.knn.KNNModel;
import water.api.API;
import water.api.schemas3.FrameV3;
import water.api.schemas3.ModelParametersSchemaV3;


public class KNNV3 extends ModelBuilderSchema<KNN, KNNV3, KNNV3.KNNParametersV3> {
public static final class KNNParametersV3 extends ModelParametersSchemaV3<KNNModel.KNNParameters, KNNParametersV3> {
static public String[] fields = new String[]{
"model_id",
"training_frame",
"response_column",
"id_column",
"ignored_columns",
"ignore_const_cols",
"seed",
"max_runtime_secs",
"categorical_encoding",
"distribution",
"custom_metric_func",
"gainslift_bins",
"auc_type",
"k",
"distance"
};

@API(level = API.Level.critical, direction = API.Direction.INOUT, gridable = true,
is_member_of_frames = {"training_frame"},
is_mutually_exclusive_with = {"ignored_columns"},
help = "Identify each record column.")
public FrameV3.ColSpecifierV3 id_column;

@API(help = "RNG Seed", level = API.Level.secondary, gridable = true)
public long seed;

@API(help = "Number of nearest neighbours", level = API.Level.secondary, gridable = true)
public int k;

@API(help = "Distance type", level = API.Level.secondary, gridable = true, values = { "AUTO", "euclidean", "manhattan", "cosine"})
public DistanceType distance;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,7 @@ hex.schemas.AdaBoostModelV3
hex.schemas.AdaBoostModelV3$AdaBoostModelOutputV3
hex.schemas.AdaBoostV3
hex.schemas.AdaBoostV3$AdaBoostParametersV3
hex.schemas.KNNModelV3
hex.schemas.KNNModelV3$KNNModelOutputV3
hex.schemas.KNNV3
hex.schemas.KNNV3$KNNParametersV3
14 changes: 6 additions & 8 deletions h2o-algos/src/test/java/hex/knn/KNNTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void testIris() {
KNNModel.KNNParameters parms = new KNNModel.KNNParameters();
parms._train = fr._key;
parms._k = 3;
parms._distance = new EuclideanDistance();
parms._distance = DistanceType.EUCLIDEAN;
parms._response_column = response;
parms._id_column = idColumn;
parms._auc_type = MultinomialAucType.MACRO_OVR;
Expand All @@ -55,9 +55,7 @@ public void testIris() {
ModelMetricsMultinomial mm1 = (ModelMetricsMultinomial) knn._output._training_metrics;
Assert.assertEquals(mm.auc(), mm1.auc(), 0);

// test after KNN API will be ready
//knn.testJavaScoring(fr, preds, 0);

knn.testJavaScoring(fr, preds, 0);
} finally {
if (knn != null){
knn.delete();
Expand Down Expand Up @@ -90,7 +88,7 @@ public void testSimpleFrameEuclidean() {
KNNModel.KNNParameters parms = new KNNModel.KNNParameters();
parms._train = fr._key;
parms._k = 2;
parms._distance = new EuclideanDistance();
parms._distance = DistanceType.EUCLIDEAN;
parms._response_column = response;
parms._id_column = idColumn;
parms._auc_type = MultinomialAucType.MACRO_OVR;
Expand Down Expand Up @@ -165,7 +163,7 @@ public void testSimpleFrameManhattan() {
KNNModel.KNNParameters parms = new KNNModel.KNNParameters();
parms._train = fr._key;
parms._k = 2;
parms._distance = new ManhattanDistance();
parms._distance = DistanceType.MANHATTAN;
parms._response_column = response;
parms._id_column = idColumn;
parms._auc_type = MultinomialAucType.MACRO_OVR;
Expand Down Expand Up @@ -240,7 +238,7 @@ public void testSimpleFrameCosine() {
KNNModel.KNNParameters parms = new KNNModel.KNNParameters();
parms._train = fr._key;
parms._k = 2;
parms._distance = new CosineDistance();
parms._distance = DistanceType.COSINE;
parms._response_column = response;
parms._id_column = idColumn;
parms._auc_type = MultinomialAucType.MACRO_OVR;
Expand Down Expand Up @@ -332,7 +330,7 @@ public void testIdColumnIsNotDefined() {
KNNModel.KNNParameters parms = new KNNModel.KNNParameters();
parms._train = fr._key;
parms._k = 2;
parms._distance = new EuclideanDistance();
parms._distance = DistanceType.EUCLIDEAN;
parms._response_column = "class";
parms._id_column = null;

Expand Down
37 changes: 37 additions & 0 deletions h2o-bindings/bin/custom/R/gen_knn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
extensions = dict(
extra_params=[('verbose', 'FALSE')],
required_params=['x', 'y', 'training_frame', 'id_column', 'response_column'],
skip_default_set_params_for=['training_frame', 'ignored_columns', 'response_column', 'offset_column'],
set_required_params="""
parms$training_frame <- training_frame
args <- .verify_dataxy(training_frame, x, y)
if (!missing(id_column)) {
parms$id_column <- id_column
} else {
stop("ID column is required.")
}
parms$ignored_columns <- args$x_ignore
parms$response_column <- args$y
"""
)


doc = dict(
preamble="""
Build a KNN model

Builds a K-nearest neighbour model on an H2OFrame.
""",
params=dict(
verbose="""
\code{Logical}. Print scoring history to the console. Defaults to FALSE.
"""
),
returns="""
Creates a \linkS4class{H2OModel} object of the right type.
""",
seealso="""
\code{\link{predict.H2OModel}} for prediction
""",
examples=""""""
)
7 changes: 7 additions & 0 deletions h2o-core/src/main/java/hex/DistanceType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package hex;

public enum DistanceType {
EUCLIDEAN,
MANHATTAN,
COSINE
}
6 changes: 6 additions & 0 deletions h2o-py/docs/modeling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ Supervised
:show-inheritance:
:members:

:mod:`H2OKnnEstimator`
----------------------
.. autoclass:: h2o.estimators.knn.H2OKnnEstimator
:show-inheritance:
:members:

:mod:`H2OModelSelectionEstimator`
---------------------------------
.. autoclass:: h2o.estimators.model_selection.H2OModelSelectionEstimator
Expand Down
3 changes: 2 additions & 1 deletion h2o-py/h2o/estimators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .isolation_forest import H2OIsolationForestEstimator
from .isotonicregression import H2OIsotonicRegressionEstimator
from .kmeans import H2OKMeansEstimator
from .knn import H2OKnnEstimator
from .model_selection import H2OModelSelectionEstimator
from .naive_bayes import H2ONaiveBayesEstimator
from .pca import H2OPrincipalComponentAnalysisEstimator
Expand Down Expand Up @@ -67,7 +68,7 @@ def create_estimator(algo, **params):
"H2OExtendedIsolationForestEstimator", "H2OGeneralizedAdditiveEstimator", "H2OGradientBoostingEstimator",
"H2OGenericEstimator", "H2OGeneralizedLinearEstimator", "H2OGeneralizedLowRankEstimator", "H2OHGLMEstimator",
"H2OInfogram", "H2OIsolationForestEstimator", "H2OIsotonicRegressionEstimator", "H2OKMeansEstimator",
"H2OModelSelectionEstimator", "H2ONaiveBayesEstimator", "H2OPrincipalComponentAnalysisEstimator",
"H2OKnnEstimator", "H2OModelSelectionEstimator", "H2ONaiveBayesEstimator", "H2OPrincipalComponentAnalysisEstimator",
"H2OSupportVectorMachineEstimator", "H2ORandomForestEstimator", "H2ORuleFitEstimator",
"H2OStackedEnsembleEstimator", "H2OSingularValueDecompositionEstimator", "H2OTargetEncoderEstimator",
"H2OUpliftRandomForestEstimator", "H2OWord2vecEstimator", "H2OXGBoostEstimator"
Expand Down
Loading
Loading