Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated experiment scripts #95

Merged
merged 47 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
98f7c4a
For some reason, typing extensions was not installing bundled with fl…
jackbaker1001 Oct 31, 2023
dcc0dac
Added a new solid class to handle the full BZ sampling periodic case.…
jackbaker1001 Oct 31, 2023
4d66af8
kinetic + external is correct, but I need to reconsider array typing …
jackbaker1001 Nov 1, 2023
dde47ad
Figured out that the coulomb energy needs a repulsion_tensor for ever…
jackbaker1001 Nov 2, 2023
9d6d2d3
Disabled symmetry adaptive functions for now. Will come back to this …
jackbaker1001 Nov 7, 2023
afadeeb
Non XC conribution to total energy is correct
jackbaker1001 Nov 7, 2023
357c850
density and its gradients are now computable
jackbaker1001 Nov 8, 2023
9831b35
Refactoring of functional.py to accept Solid
jackbaker1001 Nov 8, 2023
cded621
Added NotImplementedErrors for HFX methods in solids.
jackbaker1001 Nov 8, 2023
7cf95d6
Raised B3LYP tests tolerance. We should perhaps remove this test enti…
jackbaker1001 Nov 8, 2023
2fd5d1f
Taken B3LYP tests from CI
jackbaker1001 Nov 8, 2023
7006426
re-worked train.py to take Solid instances
jackbaker1001 Nov 8, 2023
973e5a7
Fixing tests
jackbaker1001 Nov 8, 2023
8a03e5d
Fixed bug in training functiojs
jackbaker1001 Nov 8, 2023
f4ca97b
Another bug fixed
jackbaker1001 Nov 8, 2023
290cb38
Changes to allow solid computaiton in train.py
jackbaker1001 Nov 9, 2023
0e1cf90
(1) changes to how the fock matrix is computed in auto diff. We only …
jackbaker1001 Nov 10, 2023
8aff961
Fix wrong kwargs being passed
jackbaker1001 Nov 10, 2023
c680dab
another wrong kwarg problem
jackbaker1001 Nov 10, 2023
357434a
MO gradients implemented. Began modifying linear mixing SCF loop to d…
jackbaker1001 Nov 10, 2023
005af27
Forgot to commit molecule changes for mo grads
jackbaker1001 Nov 10, 2023
44690a8
Whoops, shouln't have left arguments
jackbaker1001 Nov 10, 2023
689c57b
changes to eigenproblem to allow for N_k fock matrices to be diagonal…
jackbaker1001 Nov 10, 2023
7685b3f
fixes to eigenproblem to allow vectorization along k-point dimension
jackbaker1001 Nov 15, 2023
536c80b
Crystal orbitals now are used in density calculations.
jackbaker1001 Nov 20, 2023
7200cdc
Proper handling of all ERI types in PySCF input parsing.
jackbaker1001 Nov 21, 2023
1b02768
Added some input handling
jackbaker1001 Nov 21, 2023
dce443d
minor scf refactoring
jackbaker1001 Nov 21, 2023
0ef698c
test restructure to add solids, sold training test added
jackbaker1001 Nov 22, 2023
3c2c05d
New tutorial for BZ sampling and updated the gamma only notebook
jackbaker1001 Nov 23, 2023
27d817c
non xc energy test for solid hydrogen
jackbaker1001 Nov 23, 2023
dafd90b
LDA and GGA functionals now tested for Solids
jackbaker1001 Nov 23, 2023
9821531
perhaps tf bug with linux is because pytest versions aren't consistent.
jackbaker1001 Dec 6, 2023
b533a7c
restric tensorflow verion
jackbaker1001 Dec 6, 2023
9e08c52
Changes to notebooks
jackbaker1001 Dec 6, 2023
0174903
Adding colab badges
Dec 8, 2023
7c6fb2e
New experiments: training dimers and atoms generalization
Nov 25, 2023
51a84bf
Updating the experiments
Nov 30, 2023
e73a4ce
Updated experiment scripts
Dec 3, 2023
30fe9ac
Updating train dimers non tms
Dec 4, 2023
6289906
Update train_dimers_nontransition_metals.py
Dec 4, 2023
4d8f2ba
Updating the example files
Dec 7, 2023
c5b232a
Update evaluate_noise.py
Dec 7, 2023
ca4919b
Update evaluate_noise.py
Dec 7, 2023
2218098
Update evaluate_noise.py
Dec 7, 2023
dcad380
Deleting DS_Store files
Dec 8, 2023
7138607
Revert "Deleting DS_Store files"
Dec 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions .github/workflows/install_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@ jobs:
pytest -v tests/unit/test_loss.py
- name: Run integration tests
run: |
pytest -v tests/integration/test_non_xc_energy.py
pytest -v tests/integration/test_functional_implementations.py
pytest -v tests/integration/test_Harris.py
pytest -v tests/integration/test_predict_B88.py
pytest -v tests/integration/test_predict_B3LYP.py
pytest -v tests/integration/test_training.py
pytest -v tests/integration/molecules/test_non_xc_energy.py
pytest -v tests/integration/molecules/test_functional_implementations.py
pytest -v tests/integration/molecules/test_Harris.py
pytest -v tests/integration/molecules/test_predict_B88.py
pytest -v tests/integration/molecules/test_training.py
pytest -v tests/integration/solids/test_training.py
pytest -v tests/integration/solids/test_non_xc_energy.py
pytest -v tests/integration/solids/test_functional_implementations.py


308 changes: 308 additions & 0 deletions examples/article_experiments/evaluate_atoms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from jax.random import split, PRNGKey
from jax import numpy as jnp, value_and_grad
from jax.nn import gelu
import numpy as np
from optax import adam
from tqdm import tqdm
import os
from orbax.checkpoint import PyTreeCheckpointer

from grad_dft import (
train_kernel,
energy_predictor,
NeuralFunctional,
canonicalize_inputs,
dm21_coefficient_inputs,
dm21_densities,
loader
)

from torch.utils.tensorboard import SummaryWriter
import jax
from jax import config
config.update("jax_enable_x64", True)

# In this example we explain how to evaluate the experiments that train
# the functional in some points of the dissociation curve of H2 or H2^+.

dirpath = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))

# todo: Select here the file to evaluate

# Select here the file you would like to evaluate your model on
test_files = ["atoms.h5"]

# Select here the folder where the checkpoints are stored
ckpt_folder = "checkpoints/atoms/"
training_data_dirpath = os.path.join(dirpath, ckpt_folder)
import json
def convert(o):
if isinstance(o, np.float32):
return float(o)
return o

# In this example we explain how to replicate the experiments that train
# the functional in some points of the dissociation curve of H2 or H2^+.

dirpath = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
training_data_dirpath = os.path.normpath(dirpath + "/data/training/atoms/")
training_files = ["atoms_training.h5"]

####### Model definition #######

# Then we define the Functional, via an function whose output we will integrate.
n_layers = 10
width_layers = 1024
squash_offset = 1e-4
layer_widths = [width_layers] * n_layers
out_features = 4
sigmoid_scale_factor = 2.0
activation = gelu
loadcheckpoint = True #todo: change this


def nn_coefficients(instance, rhoinputs, *_, **__):
x = canonicalize_inputs(rhoinputs) # Making sure dimensions are correct

# Initial layer: log -> dense -> tanh
x = jnp.log(jnp.abs(x) + squash_offset) # squash_offset = 1e-4
instance.sow("intermediates", "log", x)
x = instance.dense(features=layer_widths[0])(x) # features = 256
instance.sow("intermediates", "initial_dense", x)
x = jnp.tanh(x)
instance.sow("intermediates", "tanh", x)

# 6 Residual blocks with 256-features dense layer and layer norm
for features, i in zip(layer_widths, range(len(layer_widths))): # layer_widths = [256]*6
res = x
x = instance.dense(features=features)(x)
instance.sow("intermediates", "residual_dense_" + str(i), x)
x = x + res # nn.Dense + Residual connection
instance.sow("intermediates", "residual_residual_" + str(i), x)
x = instance.layer_norm()(x) # + res # nn.LayerNorm
instance.sow("intermediates", "residual_layernorm_" + str(i), x)
x = activation(x) # activation = jax.nn.gelu
instance.sow("intermediates", "residual_elu_" + str(i), x)

return instance.head(x, out_features, sigmoid_scale_factor)


functional = NeuralFunctional(
coefficients=nn_coefficients,
coefficient_inputs=dm21_coefficient_inputs,
energy_densities=partial(dm21_densities, functional_type="MGGA"),
)

####### Initializing the functional and some parameters #######

key = PRNGKey(1) # Jax-style random seed #todo: select this

# We generate the features from the molecule we created before, to initialize the parameters
(key,) = split(key, 1)
rhoinputs = jax.random.normal(key, shape=[2, 7])
params = functional.init(key, rhoinputs)

checkpoint_step = 441 #todo: change this
learning_rate = 1e-7
momentum = 0.9
tx = adam(learning_rate=learning_rate, b1=momentum)
opt_state = tx.init(params)
cost_val = jnp.inf

orbax_checkpointer = PyTreeCheckpointer()

ckpt_dir = os.path.join(dirpath, "checkpoints/ckpts_atoms/", "checkpoint_" + str(checkpoint_step) + "/")
if loadcheckpoint:
train_state = functional.load_checkpoint(
tx=tx, step=checkpoint_step, orbax_checkpointer=orbax_checkpointer, ckpt_dir = ckpt_dir
)
params = train_state.params
tx = train_state.tx
opt_state = tx.init(params)
epoch = train_state.step

########### Definition of the molecule energy prediction function #####################

# Here we use one of the following. We will use the second here.
compute_energy = jax.jit(energy_predictor(functional))


######## Predict function ########


def predict(state, test_files, training_data_dirpath):
"""Predict molecules in file."""
energies = {}
true_energies = {}
params, _, _ = state
for file in tqdm(test_files, "Files"):
fpath = os.path.join(training_data_dirpath, file)
print("Training on file: ", fpath, "\n")
load = loader(fname=fpath, randomize=True, training=True, config_omegas=[])
for _, system in tqdm(load, "Molecules/reactions per file"):
true_energies["".join(chr(num) for num in list(system.name))] = float(system.energy)
predicted_energy, _ = compute_energy(params, system)
energies["".join(chr(num) for num in list(system.name))] = float(predicted_energy)
del system
return energies, true_energies

def load_energies(test_files, data_dirpath):
"""Predict molecules in file."""
true_energies = {}
for file in tqdm(test_files, "Files"):
fpath = os.path.join(data_dirpath, file)
print("Training on file: ", fpath, "\n")
load = loader(fname=fpath, randomize=True, training=True, config_omegas=[])
for _, system in tqdm(load, "Molecules/reactions per file"):
true_energies["".join(chr(num) for num in list(system.name))] = float(system.energy)
del system
return true_energies


######## Plotting the evaluation results ########

# Predictions
state = params, opt_state, cost_val

# If there is no predictions.json file, we generate it
if not os.path.isfile(os.path.join(training_data_dirpath, "predictions.json")):
predictions, targets = predict(state, test_files, training_data_dirpath)
else:
with open(os.path.join(training_data_dirpath, "predictions.json"), 'r') as fp:
predictions = json.load(fp)
targets = load_energies(test_files, training_data_dirpath)

# Each key in dictionary predictions has the form "b'{}". Remove the b' and ' characters
clean_targets = {}
for k in targets.keys():
clean_targets[k[2:-1]] = targets[k]
targets = clean_targets

for k in predictions.keys():
print(k, predictions[k], targets[k])

# save predictions
with open(os.path.join(training_data_dirpath, "predictions.json"), 'w') as fp:
json.dump(predictions, fp, default=convert)


from pyscf.data.elements import ELEMENTS, CONFIGURATION

import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator


transition_metals = ['Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn']
training_atoms = ELEMENTS[1:19] + ["Ca", "Ge", "Se", "Kr"] + transition_metals[::2]
test_atoms = ["K", "Ga", "As", "Br"] + transition_metals[1::2]

atoms = np.array(ELEMENTS[1:37])
training_mask = np.array([atom in training_atoms for atom in atoms])
test_mask = np.array([atom in test_atoms for atom in atoms])
print(atoms[training_mask])

# Two subplots
fig = plt.figure(figsize=(12, 7))
ax, ax2 = fig.subplots(2, 1, sharex=True)

# Plot 1
#ax.set_xlabel('Atoms', fontsize=14)
ax.tick_params(axis='y', which='major', labelsize=14, direction = 'in')
ax.tick_params(axis='y', which='minor', labelsize=14, direction = 'in')
ax.tick_params(axis='x', which='major', labelsize=14)

# Plot difference between predictions and targets, ordered according to atoms, in log y scale
diffs = (np.array([predictions[atom] for atom in atoms]) - np.array([targets[atom] for atom in atoms]))

std = np.std(diffs)
training_mean = np.mean(abs(diffs[training_mask]))
test_mean = np.mean(abs(diffs[test_mask]))

# Now we print them also in the plot
for a, d in zip(atoms, diffs):
if a in training_atoms: label, color = 'Training', '#192a56'
elif a in test_atoms: label, color = 'Test', '#00a8ff'
ax.scatter(a, d, label = label, color = color)
#ax.scatter(atoms[training_mask], diffs[training_mask], label='Training MAE')
#ax.scatter(atoms[test_mask], diffs[test_mask], label='Test MAE')
ax.set_ylabel('Error (Ha)', fontsize=14)

ax.text(0.03, 0.9, '(a) Error', transform=ax.transAxes, fontsize=14)

#plot line at 0
ax.plot(atoms, np.zeros(len(atoms)), 'k--')


handles, labels = ax.get_legend_handles_labels()
by_label = dict(zip(labels, handles))
legend1 = ax.legend(by_label.values(), by_label.keys(), fontsize=14, loc ='lower left')
ax.add_artist(legend1)


# Plot 2
# Plot absolute errors as
for a, d in zip(atoms, diffs):
if a in training_atoms: label, color = 'Training', '#192a56'
elif a in test_atoms: label, color = 'Test', '#00a8ff'
ax2.bar(a, abs(d), label = label, color = color)
#ax2.bar(atoms[training_mask], abs(diffs[training_mask]), label='Training MAE')
#ax2.bar(atoms[test_mask], abs(diffs[test_mask]), label='Test MAE')
ax2.plot(atoms, np.ones(len(atoms))*training_mean, '-', color = '#192a56', label='Training MAE')
ax2.plot(atoms, np.ones(len(atoms))*test_mean, '--', color = '#00a8ff', label='Test MAE')
# Now we print them also in the plot
ax.text(0.4, 2., 'Training MAE: {:.1e} Ha'.format(training_mean), horizontalalignment='right', verticalalignment='center', transform=ax2.transAxes, color = '#192a56', fontsize=14)
ax.text(0.4, 1.9, 'Test MAE: {:.1e} Ha'.format(test_mean), horizontalalignment='right', verticalalignment='center', transform=ax2.transAxes, color = '#00a8ff', fontsize=14)
ax2.text(0.03, 0.9, '(b) Absolute error', transform=ax2.transAxes, fontsize = 14)
#ax2.text(0.5, 0.85, 'Std: {:.4f} kcal/mol'.format(std), horizontalalignment='right', verticalalignment='center', transform=ax2.transAxes)

handles, labels = ax2.get_legend_handles_labels()
by_label = dict(zip(labels, handles))
legend2 = ax.legend(by_label.values(), by_label.keys(), fontsize=14, loc = (0.18, 0.025))
ax.add_artist(legend2)

ax2.set_ylabel('Absolute error (Ha)', fontsize=14)
ax2.tick_params(axis='x', which='major', labelsize=14, rotation=90)
ax2.tick_params(axis='x', which='minor', labelsize=14, rotation=90)
ax2.tick_params(axis='y', which='major', labelsize=14, direction = 'in')
ax2.tick_params(axis='y', which='minor', labelsize=14, direction = 'in')

#fig.suptitle('Absolute errors in electronic energies, no noise', fontsize=16)

ref_mean_training = training_mean
ref_mean_test = test_mean
ref_std = std

# set log scale
ax2.set_yscale('log')

from matplotlib.ticker import MultipleLocator
#xminorLocator = MultipleLocator(0.1)
#ax.xaxis.set_minor_locator(xminorLocator)
yminorLocator = MultipleLocator(0.25)
ax.yaxis.set_minor_locator(yminorLocator)
yminorLocator = MultipleLocator(0.25)
#ax2.yaxis.set_minor_locator(yminorLocator)

plt.show()


#save
#tight layout
plt.tight_layout()
fig.savefig('checkpoints/ckpts_atoms/atoms_generalization.pdf', dpi=100)

Loading