-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjax_utils.py
413 lines (297 loc) · 18.7 KB
/
jax_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
from typing import Callable, Dict, Tuple, List
import numpy as np
from ase.atoms import Atoms
from ase.build import bulk
import ase.io
from ase.calculators.lj import LennardJones
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary
from jax import vmap, random
from jax.api import grad, jacfwd, jit
from jax.interpreters.xla import DeviceArray
from jax_md import energy
from jax_md.energy import DisplacementFn, NeighborList
from jax_md.simulate import NVEState
from os import environ
from enum import Enum
import warnings
from jax_md import space, quantity
import jax.numpy as jnp
EnergyFn = Callable[[space.Array, energy.NeighborList], space.Array]
PotentialFn = Callable[[space.Array], Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, None, None]]
PotentialProperties = Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
def strain_nl_energy_fn(total_energy_fn: EnergyFn, box: jnp.array):
# in order to differentiate w.r.t. the deformation, does it have to be an argument?
# or can it be wrapped at creation time?
transform_box_fn = lambda deformation: space.transform(jnp.eye(3) + (deformation + deformation.T) * 0.5, box)
strained_total_energy_fn = lambda R, deformation, *args, **kwargs: total_energy_fn(R, *args, box=transform_box_fn(deformation), **kwargs)
stress_fn = lambda R, deformation, *args, **kwargs: grad(strained_total_energy_fn, argnums=1)(R, deformation, *args, **kwargs) / jnp.linalg.det(box)
return strained_total_energy_fn, stress_fn
# TODO for all NL potentials: neighbors are only wrapped at potential creation time, but never passed during evaluations!
def strained_neighbor_list_potential(energy_fn, neighbors, box: jnp.ndarray) -> PotentialFn:
def potential(R: space.Array) -> PotentialProperties:
# 1) Set the box under strain using a symmetrized deformation tensor
# 2) Override the box in the energy function
# 3) Derive forces, stress and stresses as gradients of the deformed energy function
deformation = jnp.zeros_like(box)
# a function to symmetrize the deformation tensor and apply it to the box
transform_box_fn = lambda deformation: space.transform(jnp.eye(3) + (deformation + deformation.T) * 0.5, box)
# atomwise and total energy functions that act on the transformed box. same for force, stress and stresses.
deformation_energy_fn = lambda deformation, R, *args, **kwargs: energy_fn(R, box=transform_box_fn(deformation),
neighbor=neighbors)
total_energy_fn = lambda deformation, R, *args, **kwargs: jnp.sum(deformation_energy_fn(deformation, R))
force_fn = lambda deformation, R, *args, **kwargs: grad(total_energy_fn, argnums=1)(deformation, R) * -1
stress_fn = lambda deformation, R, *args, **kwargs: grad(total_energy_fn, argnums=0)(deformation,
R) / jnp.linalg.det(box)
stress = stress_fn(deformation, R, neighbor=neighbors)
total_energy = total_energy_fn(deformation, R, neighbor=neighbors)
atomwise_energies = deformation_energy_fn(deformation, R, neighbor=neighbors)
forces = force_fn(deformation, R, neighbor=neighbors)
return total_energy, atomwise_energies, forces, stress
return potential
def unstrained_neighbor_list_potential(energy_fn, neighbors) -> PotentialFn:
def potential(R: space.Array) -> PotentialProperties:
total_energy_fn = lambda R, *args, **kwargs: jnp.sum(energy_fn(R, *args, **kwargs))
forces_fn = quantity.force(total_energy_fn)
total_energy = total_energy_fn(R, neighbor=neighbors)
atomwise_energies = energy_fn(R, neighbor=neighbors)
forces = forces_fn(R, neighbor=neighbors)
stress, stresses = None, None
return total_energy, atomwise_energies, forces, stress
return potential
def block_and_dispatch(properties: Tuple[DeviceArray, ...]):
for p in properties:
if p is None:
continue
p.block_until_ready()
return [None if p is None else np.array(p) for p in properties]
def get_initial_nve_state(atoms: Atoms) -> NVEState:
if atoms.calc is None:
raise RuntimeError("Atoms must have a calculator")
R = jnp.float32(atoms.get_positions())
V = jnp.float32(atoms.get_velocities()) # Å / fs
forces = jnp.float32(atoms.get_forces())
masses = jnp.float32(atoms.get_masses()[0])
return NVEState(R, V, forces, masses)
def get_argon_lennard_jones_parameters():
return {
'sigma': 3.40,
'epsilon': 0.01042,
'rc': 10.54,
'ro': 6.0
}
def initialize_cubic_argon(multiplier=5, temperature_K=30) -> Atoms:
atoms = bulk("Ar", cubic=True) * [multiplier, multiplier, multiplier]
MaxwellBoltzmannDistribution(atoms, temperature_K=temperature_K)
Stationary(atoms)
lj_parameters = get_argon_lennard_jones_parameters()
atoms.calc = LennardJones(sigma=lj_parameters['sigma'], epsilon=lj_parameters['epsilon'], rc=lj_parameters['rc'], ro=lj_parameters['ro'], smooth=True)
return atoms
def read_cubic_argon(file_name="geometry.in"):
atoms = ase.io.read(file_name, format="aims")
lj_parameters = get_argon_lennard_jones_parameters()
atoms.calc = LennardJones(sigma=lj_parameters['sigma'], epsilon=lj_parameters['epsilon'], rc=lj_parameters['rc'], ro=lj_parameters['ro'], smooth=True)
return atoms
def write_cubic_argon(file_name="geometry.in", multiplier=5):
atoms = initialize_cubic_argon(multiplier)
ase.io.write(file_name, atoms, velocities=True, format="aims")
## old stuff here
class XlaMemoryFlag(Enum):
XLA_PYTHON_CLIENT_PREALLOCATE = "XLA_PYTHON_CLIENT_PREALLOCATE"
XLA_PYTHON_CLIENT_MEM_FRACTION = "XLA_PYTHON_CLIENT_MEM_FRACTION"
XLA_PYTHON_CLIENT_ALLOCATOR = "XLA_PYTHON_CLIENT_ALLOCATOR"
DEFAULT = "DEFAULT"
def get_memory_allocation_mode() -> XlaMemoryFlag:
active_flags = []
for f in XlaMemoryFlag:
try:
environ[f.name]
active_flags.append(f)
except KeyError:
continue
if len(active_flags) > 1:
raise SystemError("Multiple memory allocation modes enabled simultaneously.")
if not active_flags:
return XlaMemoryFlag.DEFAULT
return active_flags[0]
def compute_pairwise_distances(displacement_fn: space.DisplacementFn, R: jnp.ndarray):
# displacement_fn takes two vectors Ra and Rb
# space.map_product() vmaps it twice along rows and columns such that we can input matrices
dR_dimensionwise_fn = space.map_product(displacement_fn)
dR_dimensionwise = dR_dimensionwise_fn(R, R) # ... resulting in 4 dimension-wise distance matrices shaped (n, n, 3)
# Computing the vector magnitude for every row vector:
# First, map along the first axis of the initial (n, n, 3) matrix. the "output" will be (n, 3)
# Secondly, within the mapped (n, 3) matrix, map along the zero-th axis again (one atom).
# Here, apply the magnitude function for the atom's displacement row vector.
magnitude_fn = lambda x: jnp.sqrt(jnp.sum(x**2))
vectorized_fn = vmap(vmap(magnitude_fn, in_axes=0), in_axes=0)
return vectorized_fn(dR_dimensionwise)
def generate_R(n: int, scaling_factor: float) -> jnp.ndarray:
# TODO: Build a global service to manage and demand PRNGKeys for JAX-based simulations. if necessary for MD later.
key = random.PRNGKey(0)
key, subkey = random.split(key)
return random.uniform(subkey, shape=(n, 3)) * scaling_factor
def get_displacement(atoms: Atoms):
if not all(atoms.get_pbc()):
warnings.warn("Atoms object without periodic boundary conditions passed!")
return space.free()
box = atoms.get_cell().array
return space.periodic_general(box, fractional_coordinates=False)
# return get_real_displacement(atoms)
def get_real_displacement(atoms):
if not all(atoms.get_pbc()):
raise ValueError("Atoms object without periodic boundary conditions passed!")
cell = atoms.get_cell().array
inverse_cell = space.inverse(cell)
displacement_in_scaled_coordinates, _ = space.periodic_general(cell)
# **kwargs are now used to feed through the box information
def displacement(Ra: space.Array, Rb: space.Array, **kwargs) -> space.Array:
Ra_scaled = space.transform(inverse_cell, Ra)
Rb_scaled = space.transform(inverse_cell, Rb)
return displacement_in_scaled_coordinates(Ra_scaled, Rb_scaled, **kwargs)
return displacement, None
def jit_if_wanted(do_jit: bool, *args) -> Tuple:
if not all([callable(a) for a in args]):
raise ValueError("Expected a list of callables.")
if not do_jit:
return args
return tuple([jit(f) for f in args])
PotentialFn = Callable[[space.Array], Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, None, None]]
PotentialProperties = Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
def get_strained_pair_potential(box: jnp.ndarray, displacement_fn: DisplacementFn, sigma: float, epsilon: float, r_cutoff: float, r_onset: float, compute_stress: bool, compute_stresses: bool) -> PotentialFn:
def strained_potential_fn(R: space.Array) -> PotentialProperties:
# 1) Set the box under strain using a symmetrized deformation tensor
# 2) Override the box in the energy function
# 3) Derive forces, stress and stresses as gradients of the deformed energy function
# define a default energy function, an infinitesimal deformation and a function to apply the transformation to the box
energy_fn = energy.lennard_jones_pair(displacement_fn, sigma=sigma, epsilon=epsilon, r_cutoff=r_cutoff, r_onset=r_onset, per_particle=True)
deformation = jnp.zeros_like(box)
# a function to symmetrize the deformation tensor and apply it to the box
transform_box_fn = lambda deformation: space.transform(jnp.eye(3) + (deformation + deformation.T) * 0.5, box)
# atomwise and total energy functions that act on the transformed box. same for force, stress and stresses.
deformation_energy_fn = lambda deformation, R: energy_fn(R, box=transform_box_fn(deformation))
total_energy_fn = lambda deformation, R: jnp.sum(deformation_energy_fn(deformation, R))
force_fn = lambda deformation, R: grad(total_energy_fn, argnums=1)(deformation, R) * -1
stress = None
if compute_stress:
stress_fn = lambda deformation, R: grad(total_energy_fn, argnums=0)(deformation, R) / jnp.linalg.det(box)
stress = stress_fn(deformation, R)
stresses = None
if compute_stresses:
stresses_fn = lambda deformation, R: jacfwd(deformation_energy_fn, argnums=0)(deformation, R) / jnp.linalg.det(box)
stresses = stresses_fn(deformation, R)
total_energy = total_energy_fn(deformation, R)
atomwise_energies = deformation_energy_fn(deformation, R)
forces = force_fn(deformation, R)
return total_energy, atomwise_energies, forces, stress, stresses
return strained_potential_fn
def get_unstrained_pair_potential(displacement_fn: DisplacementFn, sigma: float, epsilon: float, r_cutoff: float, r_onset: float) -> PotentialFn:
def unstrained_potential_fn(R: space.Array) -> PotentialProperties:
energy_fn = energy.lennard_jones_pair(displacement_fn, sigma=sigma, epsilon=epsilon, r_onset=r_onset, r_cutoff=r_cutoff, per_particle=True)
total_energy_fn = lambda R: jnp.sum(energy_fn(R))
forces_fn = quantity.force(total_energy_fn)
total_energy = total_energy_fn(R)
atomwise_energies = energy_fn(R)
forces = forces_fn(R)
stress, stresses = None, None
return total_energy, atomwise_energies, forces, stress, stresses
return unstrained_potential_fn
def get_strained_neighbor_list_potential(energy_fn, neighbors, box: jnp.ndarray, compute_stress: bool, compute_stresses: bool) -> PotentialFn:
def strained_potential_fn(R: space.Array) -> PotentialProperties:
# 1) Set the box under strain using a symmetrized deformation tensor
# 2) Override the box in the energy function
# 3) Derive forces, stress and stresses as gradients of the deformed energy function
# define a default energy function, an infinitesimal deformation and a function to apply the transformation to the box
# energy_fn = energy.lennard_jones_pair(displacement_fn, sigma=sigma, epsilon=epsilon, r_cutoff=r_cutoff, r_onset=r_onset, per_particle=True)
deformation = jnp.zeros_like(box)
# a function to symmetrize the deformation tensor and apply it to the box
transform_box_fn = lambda deformation: space.transform(jnp.eye(3) + (deformation + deformation.T) * 0.5, box)
# atomwise and total energy functions that act on the transformed box. same for force, stress and stresses.
deformation_energy_fn = lambda deformation, R, *args, **kwargs: energy_fn(R, box=transform_box_fn(deformation), neighbor=neighbors)
total_energy_fn = lambda deformation, R, *args, **kwargs: jnp.sum(deformation_energy_fn(deformation, R))
force_fn = lambda deformation, R, *args, **kwargs: grad(total_energy_fn, argnums=1)(deformation, R) * -1
stress = None
if compute_stress:
stress_fn = lambda deformation, R, *args, **kwargs: grad(total_energy_fn, argnums=0)(deformation, R) / jnp.linalg.det(box)
stress = stress_fn(deformation, R, neighbor=neighbors)
stresses = None
if compute_stresses:
stresses_fn = lambda deformation, R, *args, **kwargs: jacfwd(deformation_energy_fn, argnums=0)(deformation, R) / jnp.linalg.det(box)
stresses = stresses_fn(deformation, R, neighbor=neighbors)
total_energy = total_energy_fn(deformation, R, neighbor=neighbors)
atomwise_energies = deformation_energy_fn(deformation, R, neighbor=neighbors)
forces = force_fn(deformation, R, neighbor=neighbors)
return total_energy, atomwise_energies, forces, stress, stresses
return strained_potential_fn
def get_unstrained_neighbor_list_potential(energy_fn, neighbors) -> PotentialFn:
def unstrained_potential(R: space.Array) -> PotentialProperties:
total_energy_fn = lambda R, *args, **kwargs: jnp.sum(energy_fn(R, *args, **kwargs))
forces_fn = quantity.force(total_energy_fn)
total_energy = total_energy_fn(R, neighbor=neighbors)
atomwise_energies = energy_fn(R, neighbor=neighbors)
forces = forces_fn(R, neighbor=neighbors)
stress, stresses = None, None
return total_energy, atomwise_energies, forces, stress, stresses
return unstrained_potential
def get_strained_gnn_potential(energy_fn, neighbors, params, box: jnp.ndarray, compute_stress: bool, compute_stresses: bool) -> PotentialFn:
def strained_potential_fn(R: space.Array) -> PotentialProperties:
deformation = jnp.zeros_like(box)
transform_box_fn = lambda deformation: space.transform(jnp.eye(3) + (deformation + deformation.T) * 0.5, box)
total_deformation_energy_fn = lambda params, R, deformation, neighbors: energy_fn(params, R, neighbors, box=transform_box_fn(deformation))
force_fn = lambda params, R, deformation, neighbors: grad(total_deformation_energy_fn, argnums=1)(params, R, deformation, neighbors) * -1
# TODO: atom-wise energies + stresses with GNN?
# fake atomwise energy function from which we can take the jacobian
atomwise_energy_fn = lambda params, R, deformation, neighbors: jnp.ones((R.shape[0],1)) / total_deformation_energy_fn(params, R, deformation, neighbors)
total_energy = total_deformation_energy_fn(params, R, deformation, neighbors)
atomwise_energies = atomwise_energy_fn(params, R, deformation, neighbors)
forces = force_fn(params, R, deformation, neighbors)
stress = None
if compute_stress:
stress_fn = lambda params, R, deformation, neighbors: grad(total_deformation_energy_fn, argnums=2)(params, R, deformation, neighbors) / jnp.linalg.det(box)
stress = stress_fn(params, R, deformation, neighbors)
stresses = None
if compute_stresses:
stresses_fn = lambda params, R, deformation, neighbors: jacfwd(atomwise_energy_fn, argnums=2)(params, R, deformation, neighbors) / jnp.linalg.det(box)
stresses = stresses_fn(params, R, deformation, neighbors)
return total_energy, atomwise_energies, forces, stress, stresses
return strained_potential_fn
def get_unstrained_gnn_potential(energy_fn, neighbors, params) -> PotentialFn:
def unstrained_potential_fn(R: space.Array) -> PotentialProperties:
total_energy = energy_fn(params, R, neighbors)
# TODO: atom-wise energies with GNN?
# fake atomwise energy function as in strained potential
atomwise_energy_fn = lambda params, R, neighbors: jnp.ones((R.shape[0],1)) / energy_fn(params, R, neighbors)
atomwise_energies = atomwise_energy_fn(params, R, neighbors)
force_fn = lambda params, R, neighbors, *args, **kwargs: grad(energy_fn, argnums=1)(params, R, neighbors) * -1
forces = force_fn(params, R, neighbors)
return total_energy, atomwise_energies, forces, None, None
return unstrained_potential_fn
# TODO: JaxCalculator
def get_state(calculator) -> Dict:
# Copy the object's state from self.__dict__ which contains
# all our instance attributes. Always use the dict.copy()
# method to avoid modifying the original state.
state = calculator.__dict__.copy()
# Remove the unpicklable entries.
if '_displacement_fn' in state: del state['_displacement_fn']
if '_potential_fn' in state: del state['_potential_fn']
if '_R' in state: del state['_R']
# neighbor list calculator
if '_energy_fn' in state: del state['_energy_fn']
if '_neighbor_fn' in state: del state['_neighbor_fn']
if '_neighbors' in state: del state['_neighbors']
# GNN calculator
if '_init_fn' in state: del state['_init_fn']
return state
def set_state(calculator, state: Dict):
# Restore instance attributes (i.e., filename and lineno).
calculator.__dict__.update(state)
# Restore the previously opened file's state. To do so, we need to
# reopen it and read from it until the line count is restored.
error_fn = lambda *args, **kwargs: print("Pickled instance cannot compute new data")
calculator._displacement_fn = error_fn
calculator._potential_fn = error_fn
calculator._R = error_fn
calculator._energy_fn = error_fn
calculator._neighbor_fn = error_fn
calculator._neighbors = error_fn
calculator._init_fn = error_fn