-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathirls.m
88 lines (67 loc) · 2.42 KB
/
irls.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
function [] = irls()
%% Iterative Reweighted Least Squares (IRLS) for Logistic Regression
% MAP estimation of hyperplane w with L1 penalty
%
% min NLL(w) = min -sum_{i=1}^{n} [y_i log mu_i + (1-y_i) log(1-mu_i)]
% where mu_i = sigm(w'x_i)
%
% XSw=Sz, w=(X'SX)^{-1}X'Sz, where z = Xw_k+S_k^{-1}(y-mu_k)
% X is n x d and S_k=diag(mu_ik(1-mu_ik))
%
% Newton updates: w_{k+1}= w_k - H^{-1} g_k
% with g_k = X(mu_k-y) + lambda*w and H = X'SX + lambda*I
close all; clear all;
%% synthetic dataset
n = 1000; d = 4; mu0 = zeros(d,1); Sigma0=diag(ones(d,1)); X = randn(n,d);
w0 = mvnrnd(mu0, Sigma0)'; %ground truth hyperplane
z0 = randn(n,1) + X*w0; %ground truth z_i
y = sign(z0); %ground truth labels \in {+1,-1}
%% IRLS
max_num_iter=1e2; w=zeros(d,max_num_iter);
%sigmoid function
sigm = @(X,y,w) 1./(1+exp(-y.*(X*w)));
%init w
w(:,1)=zeros(d,1);
%w(:,1)=randn(d,1); %may oscillate
%w(:,1)=mvnrnd(w0,0.1*Sigma0)'; %ground truth + noise
%prior / regularization for numerical stability
lambda=1e-4; vInv = 2*lambda*eye(d);
for k=1:max_num_iter
mu_k = sigm(X,y,w(:,k)); %bernoulli probability
Sk = mu_k.*(1-mu_k) + eps; %weight matrix
z_k = X*w(:,k)+(1-mu_k).*y./Sk; %response update, y \in {+1,-1}
%w(:,k+1)=inv(X'*Sk*X + vInv)*(X'*Sk*z_k); %w update
Xd=X'*sparse(diag(Sk)); R=chol(Xd*X+vInv);
w(:,k+1)=R\(R'\Xd*z_k);
%check convergence
fprintf('iter: %d, ||w_{k+1}-w_{k}||= %.6f\n', k, norm(w(:,k+1)-w(:,k),2));
if (norm(w(:,k+1)-w(:,k),2) < 1e-6), break; end
end
w=w(:,1:k);
%% compute MSE
MSE = (cummean(w,2)-repmat(w0,1,size(w,2))).^2;
%% plot MSE
figure; legendInfo={};
for dim=1:d
legendInfo{dim} = ['dim = ', num2str(dim)];
plot(1:size(w,2),MSE(dim,:),'color',rand(1,3),'linewidth',2.0); hold on; grid on;
end
xlabel('iterations'); ylabel('MSE'); legend(legendInfo);
title('MSE vs iterations for IRLS Logistic Regression');
end
function y = cummean(x,dim)
if nargin==1,
% Determine which dimension CUMSUM./[1:N] will use
dim = min(find(size(x) ~= 1));
if isempty(dim), dim = 1; end
end
siz = [size(x) ones(1, dim-ndims(x))];
n = size(x, dim);
% Permute and reshape so that DIM becomes the row dimension of a 2-D array
perm = [dim:max(length(size(x)), dim) 1:dim-1];
x = reshape(permute(x, perm), n, prod(siz)/n);
% Calculate cummulative mean
y = cumsum(x, 1)./repmat([1:n]', 1, prod(siz)/n);
% Permute and reshape back
y = ipermute(reshape(y, siz(perm)), perm);
end