Skip to content

Commit

Permalink
ite working
Browse files Browse the repository at this point in the history
  • Loading branch information
dainis-boumber committed Sep 16, 2018
1 parent d657bd9 commit 5d73a0b
Show file tree
Hide file tree
Showing 59 changed files with 85 additions and 1 deletion.
86 changes: 85 additions & 1 deletion complexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,15 @@

# Data pre-processing and import
# from modules import mnist
from modules import mnist
from modules import mnist
from numpy.random import rand, multivariate_normal
from numpy import arange, zeros, ones
from scipy import dot
import matplotlib.pyplot as plt

from ite.cost.x_factory import co_factory
from ite.cost.x_analytical_values import analytical_value_d_mmd
from ite.cost.x_kernel import Kernel
####################################################

'''
Expand Down Expand Up @@ -171,6 +178,83 @@ def bsda_active(datasets=[], baseline_clf=SVC(), N=100):
print(baseline_clf.predict(BSDA_X_Test))
print("Classification accuracy: ", round(baseline_clf.score(BSDA_X_Test, BSDA_y_Test), 3) * 100)

def MMD():
# !/usr/bin/env python3

""" Demo for maximum mean discrepancy (MMD) estimators.
Analytical vs estimated value is illustrated for normal random variables.
"""



def main():
# parameters:
dim = 1 # dimension of the distribution
num_of_samples_v = arange(100, 3 * 1000 + 1, 100)
cost_name = 'BDMMD_UStat' # dim >= 1
# cost_name = 'BDMMD_VStat' # dim >= 1
# cost_name = 'BDMMD_UStat_IChol' # dim >= 1
# cost_name = 'BDMMD_VStat_IChol' # dim >= 1

# initialization:
distr = 'normal' # fixed
num_of_samples_max = num_of_samples_v[-1]
length = len(num_of_samples_v)
d_hat_v = zeros(length) # vector of estimated divergence values

# RBF kernel (sigma = std / bandwith parameter):
kernel = Kernel({'name': 'RBF', 'sigma': 1})
# polynomial kernel (quadratic / cubic; c = offset parameter = 1):
# kernel = Kernel({'name': 'polynomial', 'exponent': 2, 'c': 1})
# kernel = Kernel({'name': 'polynomial', 'exponent': 3, 'c': 1})

co = co_factory(cost_name, mult=True, kernel=kernel) # cost object

# distr, dim -> samples (y1<<y2), distribution parameters (par1,par2),
# analytical value (d):
if distr == 'normal':
# mean (m1,m2):
m1 = rand(dim)
m2 = rand(dim)

# (random) linear transformation applied to the data (l1,l2) ->
# covariance matrix (c1,c2):
l2 = rand(dim, dim)
l1 = rand(dim, dim)
c1 = dot(l1, l1.T)
c2 = dot(l2, l2.T)

# generate samples (y1~N(m1,c1), y2~N(m2,c2)):
y1 = multivariate_normal(m1, c1, num_of_samples_max)
y2 = multivariate_normal(m2, c2, num_of_samples_max)

par1 = {"mean": m1, "cov": c1}
par2 = {"mean": m2, "cov": c2}

else:
raise Exception('Distribution=?')

d = analytical_value_d_mmd(distr, distr, kernel, par1, par2)

# estimation:
for (tk, num_of_samples) in enumerate(num_of_samples_v):
d_hat_v[tk] = co.estimation(y1[0:num_of_samples],
y2[0:num_of_samples]) # broadcast
print("tk={0}/{1}".format(tk + 1, length))

# plot:
plt.plot(num_of_samples_v, d_hat_v, num_of_samples_v, ones(length) * d)
plt.xlabel('Number of samples')
plt.ylabel('MMD')
plt.legend(('estimation', 'analytical value'), loc='best')
plt.title("Estimator: " + cost_name)
plt.show()

if __name__ == "__main__":
main()


def main():
#baseline_clfs = [SVC(), GaussianNB(), DecisionTreeClassifier(), MLPClassifier(hidden_layer_sizes=(10,10,10,10,10,10), solver='lbfgs', alpha=2, random_state=1, activation='relu')]
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 5d73a0b

Please sign in to comment.