-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy path2_1_KLDCriterion.lua
34 lines (23 loc) · 916 Bytes
/
2_1_KLDCriterion.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
local KLDCriterion, parent = torch.class('nn.KLDCriterion', 'nn.Criterion')
function KLDCriterion:__init(gradCoefficient)
parent.__init(self)
self.gradCoefficient = gradCoefficient and gradCoefficient or 1
end
function KLDCriterion:updateOutput(mean, log_var)
-- Appendix B from VAE paper: 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
local mean_sq = torch.pow(mean, 2)
local KLDelements = log_var:clone()
KLDelements:exp():mul(-1)
KLDelements:add(-1, mean_sq)
KLDelements:add(1)
KLDelements:add(log_var)
self.output = -0.5 * torch.sum(KLDelements)
return self.output
end
function KLDCriterion:updateGradInput(mean, log_var)
self.gradInput = {}
self.gradInput[1] = mean:clone():mul(self.gradCoefficient)
-- Fix this to be nicer
self.gradInput[2] = torch.exp(log_var):mul(-1):add(1):mul(-0.5):mul(self.gradCoefficient)
return self.gradInput
end