-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmcnoise_boost.py
98 lines (76 loc) · 2.53 KB
/
mcnoise_boost.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
Boost the Monte Carlo noise simulation (specifically the staging time)
on general error with circuit layerwise slicing
"""
import time
import sys
sys.path.insert(0, "../")
import tensorcircuit as tc
tc.set_backend("jax")
n = 3 # 10
nlayer = 2 # 4
def f1(key, param, n, nlayer):
if key is not None:
tc.backend.set_random_state(key)
c = tc.Circuit(n)
for i in range(n):
c.H(i)
for j in range(nlayer):
for i in range(n - 1):
c.cnot(i, i + 1)
c.apply_general_kraus(tc.channels.phasedampingchannel(0.15), i)
c.apply_general_kraus(tc.channels.phasedampingchannel(0.15), i + 1)
for i in range(n):
c.rx(i, theta=param[j, i])
return tc.backend.real(c.expectation((tc.gates.z(), [int(n / 2)])))
def templatecnot(s, param, i):
c = tc.Circuit(n, inputs=s)
c.cnot(i, i + 1)
return c.state()
def templatenoise(key, s, param, i):
c = tc.Circuit(n, inputs=s)
status = tc.backend.stateful_randu(key)[0]
c.apply_general_kraus(tc.channels.phasedampingchannel(0.15), i, status=status)
return c.state()
def templaterz(s, param, j):
c = tc.Circuit(n, inputs=s)
for i in range(n):
c.rx(i, theta=param[j, i])
return c.state()
def f2(key, param, n, nlayer):
c = tc.Circuit(n)
for i in range(n):
c.H(i)
s = c.state()
for j in range(nlayer):
for i in range(n - 1):
s = templatecnot(s, param, i)
key, subkey = tc.backend.random_split(key)
s = templatenoise(subkey, s, param, i)
key, subkey = tc.backend.random_split(key)
s = templatenoise(subkey, s, param, i + 1)
s = templaterz(s, param, j)
return tc.backend.real(tc.expectation((tc.gates.z(), [int(n / 2)]), ket=s))
vagf1 = tc.backend.jit(tc.backend.value_and_grad(f1, argnums=1), static_argnums=(2, 3))
vagf2 = tc.backend.jit(tc.backend.value_and_grad(f2, argnums=1), static_argnums=(2, 3))
param = tc.backend.ones([nlayer, n])
def benchmark(f, tries=3):
time0 = time.time()
key = tc.backend.get_random_state(42)
print(f(key, param, n, nlayer)[0])
time1 = time.time()
for _ in range(tries):
print(f(key, param, n, nlayer)[0])
time2 = time.time()
print(
"staging time: ",
time1 - time0,
"running time: ",
(time2 - time1) / tries,
)
print("without layerwise slicing jit")
benchmark(vagf1)
print("=============================")
print("with layerwise slicing jit")
benchmark(vagf2)
# 10*4: jax*T4: 235/0.36 vs. 26/0.04