diff --git a/textaugment/mixup.py b/textaugment/mixup.py index ad9ccb8..cc21b37 100644 --- a/textaugment/mixup.py +++ b/textaugment/mixup.py @@ -76,7 +76,7 @@ def mixup_data(self, x, y=None, alpha=0.2): return np.concatenate(output_x, axis=0) mixed_y = (y.T * lam_vector).T + (y[index].T * (1.0 - lam_vector)).T output_y.append(mixed_y) - return np.concatenate(output_x, axis=0), np.concatenate(output_y, axis=0) + return np.concatenate(output_x, axis=0), np.concatenate(output_y, axis=0) def flow(self, data, labels=None, batch_size=32, shuffle=True, runs=1): """This function implements the batch iterator and specifically calls mixup