Skip to content

Commit

Permalink
Merge pull request #44 from BrikerMan/develop
Browse files Browse the repository at this point in the history
release v0.2.0
  • Loading branch information
BrikerMan authored Mar 5, 2019
2 parents 24cb4ac + c219639 commit 647c8c4
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 63 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ install:
script:
- nosetests --with-coverage --cover-html --cover-html-dir=htmlcov --cover-package="kashgari" $TEST_FILE
after_success:
- coveralls
- coveralls
3 changes: 2 additions & 1 deletion kashgari/embeddings/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import kashgari.macros as k
from kashgari.type_hints import *
from kashgari.utils import helper
from kashgari.layers import NonMaskingLayer

EMBEDDINGS_PATH = os.path.join(k.DATA_PATH, 'embedding')

Expand Down Expand Up @@ -298,7 +299,7 @@ def build(self):
model = keras_bert.load_trained_model_from_checkpoint(config_path,
check_point_path,
seq_len=self.sequence_length)
output_layer = helper.NonMaskingLayer()(model.output)
output_layer = NonMaskingLayer()(model.output)
self._model = Model(model.inputs, output_layer)

self.embedding_size = self.model.output_shape[-1]
Expand Down
24 changes: 24 additions & 0 deletions kashgari/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,29 @@ def get_config(self):
return dict(list(base_config.items()) + list(config.items()))


class NonMaskingLayer(Layer):
"""
fix convolutional 1D can't receive masked input, detail: https://github.com/keras-team/keras/issues/4978
thanks for https://github.com/jacoxu
"""

def __init__(self, **kwargs):
self.supports_masking = True
super(NonMaskingLayer, self).__init__(**kwargs)

def build(self, input_shape):
pass

def compute_mask(self, input, input_mask=None):
# do not pass the mask to the next layers
return None

def call(self, x, mask=None):
return x

def get_output_shape_for(self, input_shape):
return input_shape


if __name__ == '__main__':
print("hello, world")
10 changes: 5 additions & 5 deletions kashgari/tasks/base/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from kashgari.embeddings import CustomEmbedding, BaseEmbedding
from kashgari.utils.crf import CRF, crf_loss, crf_accuracy
from keras_bert.bert import get_custom_objects as get_bert_custom_objects
from kashgari.layers import AttentionWeightedAverage, KMaxPooling
from kashgari.layers import AttentionWeightedAverage, KMaxPooling, NonMaskingLayer


class BaseModel(object):
Expand Down Expand Up @@ -114,7 +114,7 @@ def create_custom_objects(model_info):
embedding = model_info.get('embedding')

if embedding and embedding['embedding_type'] == 'bert':
custom_objects['NonMaskingLayer'] = helper.NonMaskingLayer
custom_objects['NonMaskingLayer'] = NonMaskingLayer
custom_objects.update(get_bert_custom_objects())
custom_objects['AttentionWeightedAverage'] = AttentionWeightedAverage
custom_objects['KMaxPooling'] = KMaxPooling
Expand All @@ -132,15 +132,15 @@ def load_model(cls, model_path: str):
model_info = json.load(f)
agent = cls()
custom_objects = cls.create_custom_objects(model_info)

agent.model_info = model_info['model_info']
if custom_objects:
logger.debug('prepared custom objects: {}'.format(custom_objects))

try:
agent.model = keras.models.load_model(os.path.join(model_path, 'model.model'),
custom_objects=custom_objects)
except Exception as e:
logger.warn('Error `{}` occured trying directly model loading. Try to rebuild.'.format(e))
logger.warning('Error `{}` occured trying directly model loading. Try to rebuild.'.format(e))
logger.debug('Load model structure from json.')
with open(os.path.join(model_path, 'struct.json'), 'r', encoding='utf-8') as f:
model_struct = f.read()
Expand Down Expand Up @@ -198,4 +198,4 @@ def load_model(cls, model_path: str):


if __name__ == "__main__":
print("Hello world")
print("Hello world")
111 changes: 91 additions & 20 deletions kashgari/tasks/classification/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras.utils import to_categorical
from sklearn import metrics
from sklearn.utils import class_weight as class_weight_calculte
from sklearn.preprocessing import MultiLabelBinarizer

from kashgari import macros as k
from kashgari.tasks.base import BaseModel
Expand All @@ -28,8 +29,45 @@

class ClassificationModel(BaseModel):

def __init__(self, embedding: BaseEmbedding = None, hyper_parameters: Dict = None, **kwargs):
def __init__(self,
embedding: BaseEmbedding = None,
hyper_parameters: Dict = None,
multi_label: bool = False,
**kwargs):
"""
:param embedding:
:param hyper_parameters:
:param multi_label:
:param kwargs:
"""
super(ClassificationModel, self).__init__(embedding, hyper_parameters, **kwargs)
self.multi_label = multi_label
self.multi_label_binarizer: MultiLabelBinarizer = None

if self.multi_label:
if not hyper_parameters or \
hyper_parameters.get('compile_params', {}).get('loss') is None:
self.hyper_parameters['compile_params']['loss'] = 'binary_crossentropy'
else:
logging.warning('recommend to use binary_crossentropy loss for multi_label task')

if not hyper_parameters or \
hyper_parameters.get('compile_params', {}).get('metrics') is None:
self.hyper_parameters['compile_params']['metrics'] = ['categorical_accuracy']
else:
logging.warning('recommend to use categorical_accuracy metrivs for multi_label task')

if not hyper_parameters or \
hyper_parameters.get('activation_layer', {}).get('sigmoid') is None:
self.hyper_parameters['activation_layer']['activation'] = 'sigmoid'
else:
logging.warning('recommend to use sigmoid activation for multi_label task')

def info(self):
info = super(ClassificationModel, self).info()
info['model_info']['multi_label'] = self.multi_label
return info

@property
def label2idx(self) -> Dict[str, int]:
Expand All @@ -51,26 +89,42 @@ def build_model(self):
"""
raise NotImplementedError()

@classmethod
def load_model(cls, model_path: str):
agent: ClassificationModel = super(ClassificationModel, cls).load_model(model_path)
agent.multi_label = agent.model_info.get('multi_label', False)
if agent.multi_label:
keys = list(agent.label2idx.keys())
agent.multi_label_binarizer = MultiLabelBinarizer(classes=keys)
agent.multi_label_binarizer.fit(keys[0])
return agent

def build_token2id_label2id_dict(self,
x_train: List[List[str]],
y_train: List[str],
x_validate: List[List[str]] = None,
y_validate: List[str] = None):
x_data = x_train
y_data = y_train
if x_validate:
x_data += x_validate
y_data += y_validate
x_data = x_train + x_validate
y_data = y_train + y_validate
else:
x_data = x_train
y_data = y_train
self.embedding.build_token2idx_dict(x_data, 3)

label_set = set(y_data)
label2idx = {
k.PAD: 0,
}
for label in label_set:
label2idx[label] = len(label2idx)
if self.multi_label:
label_set = set()
for i in y_data:
label_set = label_set.union(list(i))
else:
label_set = set(y_data)

label2idx = {}
for idx, label in enumerate(label_set):
label2idx[label] = idx
self._label2idx = label2idx
self._idx2label = dict([(val, key) for (key, val) in label2idx.items()])
self.multi_label_binarizer = MultiLabelBinarizer(classes=list(self.label2idx.keys()))

def convert_label_to_idx(self, label: Union[List[str], str]) -> Union[List[int], int]:
if isinstance(label, str):
Expand Down Expand Up @@ -102,14 +156,18 @@ def get_data_generator(self,
target_y = y_data[0: batch_size]

tokenized_x = self.embedding.tokenize(target_x)
tokenized_y = self.convert_label_to_idx(target_y)

padded_x = sequence.pad_sequences(tokenized_x,
maxlen=self.embedding.sequence_length,
padding='post')
padded_y = to_categorical(tokenized_y,
num_classes=len(self.label2idx),
dtype=np.int)

if self.multi_label:
padded_y = self.multi_label_binarizer.fit_transform(target_y)
else:
tokenized_y = self.convert_label_to_idx(target_y)
padded_y = to_categorical(tokenized_y,
num_classes=len(self.label2idx),
dtype=np.int)
if is_bert:
padded_x_seg = np.zeros(shape=(len(padded_x), self.embedding.sequence_length))
x_input_data = [padded_x, padded_x_seg]
Expand All @@ -119,9 +177,9 @@ def get_data_generator(self,

def fit(self,
x_train: List[List[str]],
y_train: List[str],
y_train: Union[List[str], List[List[str]], List[Tuple[str]]],
x_validate: List[List[str]] = None,
y_validate: List[str] = None,
y_validate: Union[List[str], List[List[str]], List[Tuple[str]]] = None,
batch_size: int = 64,
epochs: int = 5,
class_weight: bool = False,
Expand Down Expand Up @@ -203,12 +261,14 @@ def predict(self,
sentence: Union[List[str], List[List[str]]],
batch_size=None,
output_dict=False,
multi_label_threshold=0.6,
debug_info=False) -> Union[List[str], str, List[Dict], Dict]:
"""
predict with model
:param sentence: single sentence as List[str] or list of sentence as List[List[str]]
:param batch_size: predict batch_size
:param output_dict: return dict with result with confidence
:param multi_label_threshold:
:param debug_info: print debug info using logging.debug when True
:return:
"""
Expand All @@ -227,7 +287,15 @@ def predict(self,
else:
x = padded_tokens
res = self.model.predict(x, batch_size=batch_size)
predict_result = res.argmax(-1)

if self.multi_label:
if debug_info:
logging.info('raw output: {}'.format(res))
res[res >= multi_label_threshold] = 1
res[res < multi_label_threshold] = 0
predict_result = res
else:
predict_result = res.argmax(-1)

if debug_info:
logging.info('input: {}'.format(x))
Expand All @@ -247,7 +315,10 @@ def predict(self,
else:
return results[0]
else:
results = self.convert_idx_to_label(predict_result)
if self.multi_label:
results = self.multi_label_binarizer.inverse_transform(predict_result)
else:
results = self.convert_idx_to_label(predict_result)
if is_list:
return results
else:
Expand All @@ -263,4 +334,4 @@ def evaluate(self, x_data, y_data, batch_size=None, digits=4, debug_info=False)
logging.debug('x : {}'.format(x_data[index]))
logging.debug('y : {}'.format(y_data[index]))
logging.debug('y_pred : {}'.format(y_pred[index]))
return report
return report
11 changes: 3 additions & 8 deletions kashgari/tasks/classification/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,15 @@
"""
from __future__ import absolute_import, division

import logging

import keras
#from keras import optimizers

from keras.models import Model
from keras.layers import Dense, Lambda, Flatten, Reshape
from keras.layers import Bidirectional, Conv1D
from keras.layers import Dense, Lambda, Flatten
from keras.layers import Dropout, SpatialDropout1D
from keras.layers import GlobalAveragePooling1D, GlobalMaxPooling1D, MaxPooling1D
from keras.layers import Bidirectional, Conv1D
from keras.layers import concatenate
from keras.models import Model

from kashgari.layers import AttentionWeightedAverage, KMaxPooling, LSTMLayer, GRULayer

from kashgari.tasks.classification.base_model import ClassificationModel


Expand Down
24 changes: 0 additions & 24 deletions kashgari/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,30 +59,6 @@ def unison_shuffled_copies(a, b):
return list(a), list(b)


class NonMaskingLayer(Layer):
"""
fix convolutional 1D can't receive masked input, detail: https://github.com/keras-team/keras/issues/4978
thanks for https://github.com/jacoxu
"""

def __init__(self, **kwargs):
self.supports_masking = True
super(NonMaskingLayer, self).__init__(**kwargs)

def build(self, input_shape):
input_shape = input_shape

def compute_mask(self, input, input_mask=None):
# do not pass the mask to the next layers
return None

def call(self, x, mask=None):
return x

def get_output_shape_for(self, input_shape):
return input_shape


def weighted_categorical_crossentropy(weights):
"""
A weighted version of keras.objectives.categorical_crossentropy
Expand Down
Loading

0 comments on commit 647c8c4

Please sign in to comment.