-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathCrop.lua
81 lines (73 loc) · 2.13 KB
/
Crop.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
local Crop,parent = torch.class('nn.Crop','nn.Module')
local function conv(x,k,s)
--print('c',x,k,s)
return (x-k)/s + 1
end
local function iconv(x,k,s)
--print('ic',x,k,s)
return (x-1)*s+k
end
local function getSize(m,iw,ih,func)
local name = torch.typename(m)
local ow = iw
local oh = ih
--print(name)
if name == 'nn.SpatialConvolution' or
name == 'nn.SpatialConvolutionMap' or
name == 'nn.SpatialSubSampling' or
name == 'nn.SpatialMaxPooling' or
name == 'nn.SpatialLPPooling' then
ow = func(iw,m.kW,m.dW)
oh = func(ih,m.kH,m.dH)
elseif name == 'nn.Sequential' then
if func == conv then
for i=1,#m.modules do
ow,oh = getSize(m:get(i),ow,oh,func)
end
else
for i=#m.modules,1,-1 do
ow,oh = getSize(m:get(i),ow,oh,func)
end
end
end
return ow,oh
end
local function getInputSize(m,ow,oh)
return getSize(m,ow,oh,iconv)
end
local function getOutputSize(m,iw,ih)
return getSize(m,iw,ih,conv)
end
function Crop:__init(m,minw,minh)
parent.__init(self)
self.module = m
self.ominw = minw
self.ominh = minh
self.iminw,self.iminh = getInputSize(m,self.ominw, self.ominh)
end
function Crop:validInput(input)
local iw = input:size(3)
local ih = input:size(2)
if iw < self.iminw or ih < self.iminh then
print(string.format('too small input iw=%d, ih=%d, minw=%d, mih=%d\n', iw,ih,self.iminw,self.iminh))
return false
end
return true
end
function Crop:updateOutput(input)
local iw = input:size(3)
local ih = input:size(2)
if iw < self.iminw or ih < self.iminh then
error(string.format('too small input iw=%d, ih=%d, minw=%d, mih=%d\n', iw,ih,self.iminw,self.iminh))
end
local ow,oh = getOutputSize(self.module,iw,ih)
if ow ~= math.floor(ow) or oh ~= math.floor(oh) then
local iiw,iih = getInputSize(self.module,math.floor(ow),math.floor(oh))
--print(iw,ih,iiw,iih,ow,oh)
local oo = input:narrow(3,math.floor((iw-iiw)/2)+1,iiw):narrow(2,math.floor((ih-iih)/2)+1,iih)
self.output:resizeAs(oo):copy(oo)
else
self.output:resizeAs(input):copy(input)
end
return self.output
end