forked from hardmaru/cppn-gan-vae-tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsampler.py
178 lines (150 loc) · 5.43 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""
Implementation of Compositional Pattern Producing Networks in Tensorflow
https://en.wikipedia.org/wiki/Compositional_pattern-producing_network
@hardmaru, 2016
Sampler Class
This file is meant to be run inside an IPython session, as it is meant
to be used interacively for experimentation.
It shouldn't be that hard to take bits of this code into a normal
command line environment though if you want to use outside of IPython.
usage:
%run -i sampler.py
dataset = read_data_sets(train_dir='data/mnist')
model = CPPNVAE()
model.load_model('save/mnist')
sampler = Sampler()
"""
import math
import random
import matplotlib.pyplot as plt
import numpy as np
import PIL
import pylab
import tensorflow as tf
from PIL import Image
import dataset
import images2gif
from model import CPPNVAE
mgc = get_ipython().magic
mgc(u"matplotlib inline")
mgc(u"run -i dataset.py")
pylab.rcParams["figure.figsize"] = (10.0, 10.0)
class Sampler:
def __init__(self, model, _dataset=None): # @look
self.model = model
if _dataset:
self.dataset = _dataset
else:
self.dataset = dataset.read_data_sets()
self.z = self.generate_z()
def get_random_mnist(self, with_label=False):
if with_label == True:
data, label = self.dataset.next_batch(1, with_label)
return data[0], label[0]
return self.dataset.next_batch(1)[0]
def get_random_specific_mnist(self, label=2):
m, l = self.get_random_mnist(with_label=True)
for i in range(100):
if l == label:
break
m, l = self.get_random_mnist(with_label=True)
return m
def generate_random_label(self, label):
m = self.get_random_specific_mnist(label)
self.show_image(m)
self.show_image_from_z(self.encode(m))
def generate_z(self):
z = np.random.normal(size=self.model.z_dim).astype(np.float32)
return z
def encode(self, mnist_data):
new_shape = [1] + list(mnist_data.shape)
return self.model.encode(np.reshape(mnist_data, new_shape))
def generate(self, z=None, x_dim=512, y_dim=512, scale=8.0):
if z is None:
z = self.generate_z()
else:
z = np.reshape(z, (1, self.model.z_dim))
self.z = z
return self.model.generate(z, x_dim, y_dim, scale)[0]
def show_image(self, image_data):
"""
image_data is a tensor, in [height width depth]
image_data is NOT the PIL.Image class
"""
plt.subplot(1, 1, 1)
y_dim = image_data.shape[0]
x_dim = image_data.shape[1]
c_dim = self.model.c_dim
if c_dim > 1:
plt.imshow(image_data, interpolation="nearest")
else:
plt.imshow(
image_data.reshape(y_dim, x_dim), cmap="Greys", interpolation="nearest"
)
plt.axis("off")
plt.show()
def show_image_from_z(self, z):
self.show_image(self.generate(z))
def save_png(self, image_data, filename, specific_size=None):
img_data = np.array(1 - image_data)
y_dim = image_data.shape[0]
x_dim = image_data.shape[1]
c_dim = self.model.c_dim
if c_dim > 1:
img_data = np.array(
img_data.reshape((y_dim, x_dim, c_dim)) * 255.0, dtype=np.uint8
)
else:
img_data = np.array(
img_data.reshape((y_dim, x_dim)) * 255.0, dtype=np.uint8
)
im = Image.fromarray(img_data)
if specific_size != None:
im = im.resize(specific_size)
im.save(filename)
def to_image(self, image_data, specific_size=None):
# convert to PIL.Image format from np array (0, 1)
img_data = np.array(1 - image_data)
y_dim = image_data.shape[0]
x_dim = image_data.shape[1]
c_dim = self.model.c_dim
if c_dim > 1:
img_data = np.array(
img_data.reshape((y_dim, x_dim, c_dim)) * 255.0, dtype=np.uint8
)
else:
img_data = np.array(
img_data.reshape((y_dim, x_dim)) * 255.0, dtype=np.uint8
)
im = Image.fromarray(img_data)
if specific_size != None:
im = im.resize(specific_size)
return im
def morph(
self, z1, z2, n_total_frame=10, x_dim=512, y_dim=512, scale=8.0, sinusoid=False
):
"""
returns a list of img_data to represent morph between z1 and z2
default to linear morph, but can try sinusoid for more time near the anchor pts
n_total_frame must be >= 2, since by definition there's one frame for z1 and z2
"""
delta_z = 1.0 / (n_total_frame - 1)
diff_z = z2 - z1
img_data_array = []
for i in range(n_total_frame):
percentage = delta_z * float(i)
factor = percentage
if sinusoid == True:
factor = np.sin(percentage * np.pi / 2)
z = z1 + diff_z * factor
print("processing image ", i)
img_data_array.append(self.generate(z, x_dim, y_dim, scale))
return img_data_array
def save_anim_gif(self, img_data_array, filename, duration=0.1):
"""
this saves an animated gif given a list of img_data (numpy arrays)
"""
images = []
for i in range(len(img_data_array)):
images.append(self.to_image(img_data_array[i]))
images2gif.writeGif(filename, images, duration=duration)