Example

This example demonstrates the whole process from initial atomic structure to training, evaluation and prediction. It includes:

  1. Read input atomic structures (saved as extxyz files) and create descriptors and their derivatives.

  2. Read inputs and outputs into a Data object.

  3. Create tensorflow dataset for training.

  4. Train the potential and apply it for prediction.

  5. Save the trained model and then load it for retraining or prediction.

The code has been tested on Tensorflow 2.5 and 2.6.

[1]:
import atomdnn

# 'float64' is used for reading data and train by default
atomdnn.data_type = 'float64'

# force and stress are evaluated by default,
# if one only need to compute potential energy, then set compute_force to false
atomdnn.compute_force = True

# default value is for converting ev/A^3 to GPa
# note that: the predicted positive stress means tension and negative stress means compression
stress_unit_convert = 160.2176

import numpy as np
import pickle
import tensorflow as tf
from atomdnn import data
from atomdnn.data import Data
from atomdnn.data import *
from atomdnn.descriptor import *
from atomdnn import network
from atomdnn.network import Network

Create descriptors

Read input atomic structures (saved as extxyz files) and create descriptors and their derivatives

[2]:
descriptor = {'name': 'acsf',
              'cutoff': 6.5,
              'etaG2':[0.01,0.1,1,5,10],
              'etaG4': [0.01],
              'zeta': [0.08,1.0,10.0,100.0],
              'lambda': [1.0, -1.0]}


# define lammps excutable (serial or mpi)
# LAMMPS has to be compiled with the added compute and dump_local subrutines (inside atomdnn/lammps)
lmpexe = 'lmp_serial'
#lmpexe = 'mpirun -np 2 lmp_mpi'  # can be mpi version


elements = ['C']
xyzfile_path='extxyz'
xyzfile_name = 'example_extxyz.*' # a serials of files like example_extxyz.1, example_extxyz.2, ...example_extxyz.n
descriptors_path = './descriptors'

descriptor_filename = 'dump_fp.*' # a serials of dump_fp.* files will be created
der_filename ='dump_der.*'

print('total number of fingerprints = %i'%get_num_fingerprints(descriptor,elements))
total number of fingerprints = 14
[3]:
# this will create a serials of files for descriptors and their derivatives inside descriptors_path
create_descriptors(lmpexe,
                   elements,
                   xyzfile_path,
                   xyzfile_name,
                   descriptors_path,
                   descriptor,
                   descriptor_filename,
                   der_filename)
Start creating fingerprints and derivatives for 50 files ...
  so far finished for 10 images ...
  so far finished for 20 images ...
  so far finished for 30 images ...
  so far finished for 40 images ...
  so far finished for 50 images ...
Finish creating descriptors and derivatives for total 50 images.
It took 10.55 seconds.

Read inputs&outputs

Read inputs and outputs into a Data object

[4]:
# create a Data object
mydata = Data()
[5]:
# read inputs: descriptors and their derivatives
mydata.read_inputdata(descriptors_path, descriptor_filename, der_filename)
Start reading fingerprints from 'dump_fp.*' for total 50 files ...
  so far read 50 images ...
  Finish reading fingerprints from total 50 images.


Start reading derivatives from 'dump_der.*' for total 50 files ...
  This may take a while for large data set ...
  so far read 50 images ...
  Finish reading dGdr derivatives from total 50 images.

  It took 1.02 seconds to read the derivatives data.

---------- input dataset information ----------
total images = 50
max number of atoms = 4
number of fingerprints = 14
number of atom types = 1
max number of derivative pairs = 200
------------------------------------------------
[6]:
# read outputs: potential energy, force and stress from extxyz files
mydata.read_outputdata(xyzfile_path, xyzfile_name)
Reading outputs from extxyz files ...
  so far read 50 images ...
  Finish reading outputs from total 50 images.


---------- output dataset information ------------
total images = 50
max number of atoms = 4
read_force = True
read_stress = True
---------------------------------------------------

Create TFdataset

Create tensorflow dataset for training

[7]:
# convert data to tensors
mydata.convert_data_to_tensor()
Conversion may take a while for large datasets...
It took 0.0046 second.
[8]:
# create tensorflow dataset
tf_dataset = tf.data.Dataset.from_tensor_slices((mydata.input_dict,mydata.output_dict))
[9]:
dataset_path = './example_tfdataset'

# save the dataset
tf.data.experimental.save(tf_dataset, dataset_path)

# save the element_spec to disk for future loading, this is only needed for tensorflow lower than 2.6
with open(dataset_path + '/element_spec', 'wb') as out_:
    pickle.dump(tf_dataset.element_spec, out_)

Note: The above three steps just need to be done once for one data set, the training only uses the saved tensorflow dataset.

Training

Load the dataset and train the model

[10]:
# load tensorflow dataset, for Tensorflow version lower than 2.6, need to specify element_spec.

with open(dataset_path + '/element_spec', 'rb') as in_:
    element_spec = pickle.load(in_)

dataset = tf.data.experimental.load(dataset_path,element_spec=element_spec)
[11]:
# split the data to training, validation and testing sets

train_dataset, val_dataset, test_dataset = split_dataset(dataset,0.7,0.2,0.1,shuffle=True)
Traning data: 35 images
Validation data: 10 images
Test data: 5 images
[12]:
# Build the network
# See section 'Training' for detailed description on Network object.

act_fun = 'relu' # activation function
nfp = get_num_fingerprints(descriptor,elements) # number of fingerprints (or descriptors) from dataset
arch = [10,10] # NN layers

model = Network(elements, nfp, arch, act_fun)
[13]:
# Train the model

opt = 'Adam' # optimizer
loss_fun = 'mae' # loss function
scaling = 'std' # scaling the traning data with standardization
lr = 0.01 # learning rate
loss_weights = {'pe' : 1, 'force' : 1, 'stress': 0.1} # the weights in loss function

model.train(train_dataset,val_dataset,
            optimizer=opt,
            loss_fun = loss_fun,
            batch_size=30,
            lr=lr,
            epochs=50,
            scaling=scaling,
            loss_weights=loss_weights,
            compute_all_loss=True,
            shuffle=True,
            append_loss=True)
Forces are used for training.
Stresses are used for training.
Scaling factors are computed using training dataset.
Training dataset are standardized.
Validation dataset are standardized.
Training dataset will be shuffled during training.

===> Epoch 1/50 - 0.287s/epoch
     training_loss    - pe_loss: 31.190 - force_loss: 356.384 - stress_loss: 39990.078 - total_loss: 4386.582
     validation_loss  - pe_loss: 32.375 - force_loss: 327.956 - stress_loss: 38602.905 - total_loss: 4220.621

===> Epoch 2/50 - 0.255s/epoch
     training_loss    - pe_loss: 29.309 - force_loss: 331.239 - stress_loss: 33570.114 - total_loss: 3717.558
     validation_loss  - pe_loss: 28.926 - force_loss: 277.891 - stress_loss: 31481.089 - total_loss: 3454.926

===> Epoch 3/50 - 0.243s/epoch
     training_loss    - pe_loss: 19.546 - force_loss: 299.680 - stress_loss: 23171.993 - total_loss: 2636.426
     validation_loss  - pe_loss: 26.623 - force_loss: 241.708 - stress_loss: 25293.141 - total_loss: 2797.644

===> Epoch 4/50 - 0.267s/epoch
     training_loss    - pe_loss: 17.683 - force_loss: 256.099 - stress_loss: 22112.295 - total_loss: 2485.011
     validation_loss  - pe_loss: 24.879 - force_loss: 226.670 - stress_loss: 20505.446 - total_loss: 2302.093

===> Epoch 5/50 - 0.206s/epoch
     training_loss    - pe_loss: 19.918 - force_loss: 292.258 - stress_loss: 19392.847 - total_loss: 2251.461
     validation_loss  - pe_loss: 23.761 - force_loss: 227.353 - stress_loss: 14071.244 - total_loss: 1658.238

===> Epoch 6/50 - 0.217s/epoch
     training_loss    - pe_loss: 18.440 - force_loss: 294.655 - stress_loss: 18934.791 - total_loss: 2206.574
     validation_loss  - pe_loss: 23.181 - force_loss: 230.540 - stress_loss: 10633.239 - total_loss: 1317.044

===> Epoch 7/50 - 0.250s/epoch
     training_loss    - pe_loss: 27.071 - force_loss: 287.616 - stress_loss: 17824.723 - total_loss: 2097.160
     validation_loss  - pe_loss: 22.483 - force_loss: 233.838 - stress_loss: 10535.882 - total_loss: 1309.910

===> Epoch 8/50 - 0.213s/epoch
     training_loss    - pe_loss: 19.139 - force_loss: 262.362 - stress_loss: 17680.789 - total_loss: 2049.580
     validation_loss  - pe_loss: 21.706 - force_loss: 232.758 - stress_loss: 10097.616 - total_loss: 1264.227

===> Epoch 9/50 - 0.247s/epoch
     training_loss    - pe_loss: 16.393 - force_loss: 290.298 - stress_loss: 19667.135 - total_loss: 2273.404
     validation_loss  - pe_loss: 20.813 - force_loss: 250.224 - stress_loss: 12001.393 - total_loss: 1471.177

===> Epoch 10/50 - 0.211s/epoch
     training_loss    - pe_loss: 19.249 - force_loss: 265.476 - stress_loss: 15017.614 - total_loss: 1786.486
     validation_loss  - pe_loss: 19.828 - force_loss: 247.004 - stress_loss: 13313.849 - total_loss: 1598.217

===> Epoch 11/50 - 0.247s/epoch
     training_loss    - pe_loss: 17.911 - force_loss: 275.333 - stress_loss: 15653.270 - total_loss: 1858.571
     validation_loss  - pe_loss: 18.833 - force_loss: 240.327 - stress_loss: 13014.488 - total_loss: 1560.609

===> Epoch 12/50 - 0.239s/epoch
     training_loss    - pe_loss: 17.082 - force_loss: 217.358 - stress_loss: 11507.945 - total_loss: 1385.234
     validation_loss  - pe_loss: 17.622 - force_loss: 221.982 - stress_loss: 11865.776 - total_loss: 1426.182

===> Epoch 13/50 - 0.211s/epoch
     training_loss    - pe_loss: 15.717 - force_loss: 214.068 - stress_loss: 13056.794 - total_loss: 1535.465
     validation_loss  - pe_loss: 16.475 - force_loss: 230.185 - stress_loss: 12524.101 - total_loss: 1499.071

===> Epoch 14/50 - 0.212s/epoch
     training_loss    - pe_loss: 14.213 - force_loss: 265.868 - stress_loss: 18321.999 - total_loss: 2112.281
     validation_loss  - pe_loss: 15.383 - force_loss: 223.565 - stress_loss: 12250.728 - total_loss: 1464.020

===> Epoch 15/50 - 0.256s/epoch
     training_loss    - pe_loss: 15.886 - force_loss: 223.352 - stress_loss: 12262.024 - total_loss: 1465.440
     validation_loss  - pe_loss: 14.432 - force_loss: 212.618 - stress_loss: 13039.163 - total_loss: 1530.967

===> Epoch 16/50 - 0.255s/epoch
     training_loss    - pe_loss: 10.931 - force_loss: 230.472 - stress_loss: 14679.957 - total_loss: 1709.399
     validation_loss  - pe_loss: 14.192 - force_loss: 206.057 - stress_loss: 13255.603 - total_loss: 1545.809

===> Epoch 17/50 - 0.254s/epoch
     training_loss    - pe_loss: 14.032 - force_loss: 214.895 - stress_loss: 12529.994 - total_loss: 1481.926
     validation_loss  - pe_loss: 14.426 - force_loss: 217.929 - stress_loss: 11939.140 - total_loss: 1426.269

===> Epoch 18/50 - 0.258s/epoch
     training_loss    - pe_loss: 13.977 - force_loss: 199.781 - stress_loss: 10151.721 - total_loss: 1228.930
     validation_loss  - pe_loss: 14.610 - force_loss: 208.643 - stress_loss: 9057.785 - total_loss: 1129.031

===> Epoch 19/50 - 0.237s/epoch
     training_loss    - pe_loss: 14.160 - force_loss: 190.762 - stress_loss: 9122.232 - total_loss: 1117.146
     validation_loss  - pe_loss: 15.200 - force_loss: 214.031 - stress_loss: 7970.364 - total_loss: 1026.267

===> Epoch 20/50 - 0.567s/epoch
     training_loss    - pe_loss: 12.521 - force_loss: 187.481 - stress_loss: 8241.700 - total_loss: 1024.172
     validation_loss  - pe_loss: 15.501 - force_loss: 207.533 - stress_loss: 7092.256 - total_loss: 932.260

===> Epoch 21/50 - 0.235s/epoch
     training_loss    - pe_loss: 12.524 - force_loss: 218.186 - stress_loss: 11744.765 - total_loss: 1405.186
     validation_loss  - pe_loss: 15.301 - force_loss: 203.474 - stress_loss: 7162.949 - total_loss: 935.070

===> Epoch 22/50 - 0.240s/epoch
     training_loss    - pe_loss: 14.047 - force_loss: 213.436 - stress_loss: 10236.759 - total_loss: 1251.159
     validation_loss  - pe_loss: 14.761 - force_loss: 199.951 - stress_loss: 7072.908 - total_loss: 922.003

===> Epoch 23/50 - 0.208s/epoch
     training_loss    - pe_loss: 9.837 - force_loss: 248.349 - stress_loss: 11092.987 - total_loss: 1367.484
     validation_loss  - pe_loss: 14.310 - force_loss: 194.913 - stress_loss: 6726.697 - total_loss: 881.892

===> Epoch 24/50 - 0.244s/epoch
     training_loss    - pe_loss: 11.016 - force_loss: 184.585 - stress_loss: 8991.769 - total_loss: 1094.778
     validation_loss  - pe_loss: 14.106 - force_loss: 187.860 - stress_loss: 6918.478 - total_loss: 893.814

===> Epoch 25/50 - 0.211s/epoch
     training_loss    - pe_loss: 15.097 - force_loss: 184.636 - stress_loss: 8829.165 - total_loss: 1082.650
     validation_loss  - pe_loss: 13.765 - force_loss: 186.991 - stress_loss: 7435.635 - total_loss: 944.320

===> Epoch 26/50 - 0.242s/epoch
     training_loss    - pe_loss: 12.930 - force_loss: 203.051 - stress_loss: 9268.753 - total_loss: 1142.856
     validation_loss  - pe_loss: 13.263 - force_loss: 178.193 - stress_loss: 6942.370 - total_loss: 885.693

===> Epoch 27/50 - 0.266s/epoch
     training_loss    - pe_loss: 10.075 - force_loss: 180.792 - stress_loss: 9319.836 - total_loss: 1122.851
     validation_loss  - pe_loss: 12.740 - force_loss: 172.928 - stress_loss: 6968.925 - total_loss: 882.560

===> Epoch 28/50 - 0.303s/epoch
     training_loss    - pe_loss: 9.762 - force_loss: 177.376 - stress_loss: 9268.017 - total_loss: 1113.940
     validation_loss  - pe_loss: 12.235 - force_loss: 166.857 - stress_loss: 7379.211 - total_loss: 917.013

===> Epoch 29/50 - 0.252s/epoch
     training_loss    - pe_loss: 10.697 - force_loss: 178.297 - stress_loss: 6721.237 - total_loss: 861.117
     validation_loss  - pe_loss: 11.725 - force_loss: 160.584 - stress_loss: 6881.283 - total_loss: 860.437

===> Epoch 30/50 - 0.197s/epoch
     training_loss    - pe_loss: 12.150 - force_loss: 168.624 - stress_loss: 8436.945 - total_loss: 1024.468
     validation_loss  - pe_loss: 11.199 - force_loss: 166.258 - stress_loss: 7778.549 - total_loss: 955.312

===> Epoch 31/50 - 0.232s/epoch
     training_loss    - pe_loss: 13.997 - force_loss: 164.493 - stress_loss: 6873.374 - total_loss: 865.827
     validation_loss  - pe_loss: 10.699 - force_loss: 160.313 - stress_loss: 7252.115 - total_loss: 896.223

===> Epoch 32/50 - 0.218s/epoch
     training_loss    - pe_loss: 11.401 - force_loss: 160.173 - stress_loss: 7200.905 - total_loss: 891.664
     validation_loss  - pe_loss: 10.229 - force_loss: 154.167 - stress_loss: 7178.661 - total_loss: 882.263

===> Epoch 33/50 - 0.246s/epoch
     training_loss    - pe_loss: 9.099 - force_loss: 184.592 - stress_loss: 5416.165 - total_loss: 735.307
     validation_loss  - pe_loss: 9.814 - force_loss: 150.473 - stress_loss: 6569.242 - total_loss: 817.212

===> Epoch 34/50 - 0.228s/epoch
     training_loss    - pe_loss: 8.762 - force_loss: 160.241 - stress_loss: 5852.148 - total_loss: 754.218
     validation_loss  - pe_loss: 9.374 - force_loss: 145.594 - stress_loss: 6108.906 - total_loss: 765.858

===> Epoch 35/50 - 0.240s/epoch
     training_loss    - pe_loss: 12.406 - force_loss: 159.299 - stress_loss: 8944.486 - total_loss: 1066.153
     validation_loss  - pe_loss: 8.900 - force_loss: 141.091 - stress_loss: 5858.747 - total_loss: 735.865

===> Epoch 36/50 - 0.213s/epoch
     training_loss    - pe_loss: 9.840 - force_loss: 149.606 - stress_loss: 7718.608 - total_loss: 931.307
     validation_loss  - pe_loss: 8.535 - force_loss: 131.647 - stress_loss: 5924.420 - total_loss: 732.624

===> Epoch 37/50 - 0.245s/epoch
     training_loss    - pe_loss: 8.595 - force_loss: 147.061 - stress_loss: 6082.042 - total_loss: 763.860
     validation_loss  - pe_loss: 8.216 - force_loss: 132.321 - stress_loss: 5114.304 - total_loss: 651.967

===> Epoch 38/50 - 0.223s/epoch
     training_loss    - pe_loss: 6.672 - force_loss: 129.401 - stress_loss: 5902.543 - total_loss: 726.327
     validation_loss  - pe_loss: 7.922 - force_loss: 139.757 - stress_loss: 6516.377 - total_loss: 799.317

===> Epoch 39/50 - 0.263s/epoch
     training_loss    - pe_loss: 7.650 - force_loss: 133.333 - stress_loss: 6576.271 - total_loss: 798.610
     validation_loss  - pe_loss: 7.689 - force_loss: 137.120 - stress_loss: 6384.103 - total_loss: 783.219

===> Epoch 40/50 - 0.237s/epoch
     training_loss    - pe_loss: 6.882 - force_loss: 129.057 - stress_loss: 5120.016 - total_loss: 647.940
     validation_loss  - pe_loss: 7.396 - force_loss: 136.088 - stress_loss: 7042.040 - total_loss: 847.688

===> Epoch 41/50 - 0.206s/epoch
     training_loss    - pe_loss: 6.412 - force_loss: 109.322 - stress_loss: 6859.636 - total_loss: 801.697
     validation_loss  - pe_loss: 7.084 - force_loss: 130.286 - stress_loss: 5301.032 - total_loss: 667.474

===> Epoch 42/50 - 0.233s/epoch
     training_loss    - pe_loss: 8.292 - force_loss: 124.199 - stress_loss: 5566.544 - total_loss: 689.145
     validation_loss  - pe_loss: 6.817 - force_loss: 125.865 - stress_loss: 5427.751 - total_loss: 675.457

===> Epoch 43/50 - 0.241s/epoch
     training_loss    - pe_loss: 7.371 - force_loss: 111.575 - stress_loss: 4294.890 - total_loss: 548.434
     validation_loss  - pe_loss: 6.633 - force_loss: 117.875 - stress_loss: 5009.373 - total_loss: 625.446

===> Epoch 44/50 - 0.222s/epoch
     training_loss    - pe_loss: 6.747 - force_loss: 115.083 - stress_loss: 5000.162 - total_loss: 621.846
     validation_loss  - pe_loss: 6.580 - force_loss: 115.144 - stress_loss: 4611.756 - total_loss: 582.899

===> Epoch 45/50 - 0.239s/epoch
     training_loss    - pe_loss: 6.703 - force_loss: 162.029 - stress_loss: 6026.879 - total_loss: 771.421
     validation_loss  - pe_loss: 6.608 - force_loss: 113.569 - stress_loss: 4733.484 - total_loss: 593.526

===> Epoch 46/50 - 0.236s/epoch
     training_loss    - pe_loss: 6.483 - force_loss: 147.566 - stress_loss: 7066.115 - total_loss: 860.661
     validation_loss  - pe_loss: 6.658 - force_loss: 114.772 - stress_loss: 3948.175 - total_loss: 516.247

===> Epoch 47/50 - 0.223s/epoch
     training_loss    - pe_loss: 9.509 - force_loss: 158.753 - stress_loss: 5859.098 - total_loss: 754.172
     validation_loss  - pe_loss: 6.646 - force_loss: 108.553 - stress_loss: 3693.647 - total_loss: 484.564

===> Epoch 48/50 - 0.221s/epoch
     training_loss    - pe_loss: 8.387 - force_loss: 109.751 - stress_loss: 5225.576 - total_loss: 640.695
     validation_loss  - pe_loss: 6.606 - force_loss: 100.762 - stress_loss: 3702.446 - total_loss: 477.613

===> Epoch 49/50 - 0.251s/epoch
     training_loss    - pe_loss: 6.458 - force_loss: 109.626 - stress_loss: 5447.562 - total_loss: 660.840
     validation_loss  - pe_loss: 6.442 - force_loss: 100.152 - stress_loss: 3516.438 - total_loss: 458.238

===> Epoch 50/50 - 0.231s/epoch
     training_loss    - pe_loss: 8.750 - force_loss: 105.299 - stress_loss: 3383.525 - total_loss: 452.402
     validation_loss  - pe_loss: 6.207 - force_loss: 94.140 - stress_loss: 4186.339 - total_loss: 518.981

End of training, elapsed time:  00:00:12
[14]:
# plot the training loss

model.plot_loss(start_epoch=1)
../_images/getstarted_example_24_0.png
../_images/getstarted_example_24_1.png
../_images/getstarted_example_24_2.png
../_images/getstarted_example_24_3.png
[15]:
# Evaluate using the first 5 data in test dataset

model.evaluate(test_dataset.take(5),return_prediction=False)
Evaluation loss is:
        pe_loss:       8.3364e+00
     force_loss:       9.8733e+01
    stress_loss:       4.5703e+03
     total_loss:       5.6410e+02
The total loss is computed using the loss weights - pe: 1.00 - force: 1.00 - stress: 0.10
[16]:
# prediction using the first 5 data in test dataset

input_dict = get_input_dict(dataset.take(5))
model.predict(input_dict)
[16]:
{'pe': array([  4.98966932, -15.53297727, -21.9046552 , -29.27170519,
        -26.450315  ]),
 'force': array([[[ -75.06978466, -172.92211148,  229.67395629],
         [  74.41928761,  -38.89317527, -129.60283767],
         [ 291.48027163,   79.89212325,  108.91411508],
         [-290.82977488,  131.92316363, -208.98523383]],

        [[  -1.12284797,  -46.38003397,  -56.78587395],
         [  37.89161011,   60.28494093,  -68.44166021],
         [ -45.4923701 ,  -90.1904247 ,  132.87016702],
         [   8.7236079 ,   76.2855178 ,   -7.64263283]],

        [[-106.47649583,  -12.34103851,  377.1129094 ],
         [  80.46425249,  217.20247428, -316.56428174],
         [ 150.11888887, -261.81965841,  112.99040383],
         [-124.10664501,   56.95822197, -173.53903144]],

        [[-263.93606985,   84.63016038, -102.77425547],
         [-179.60252309,   15.28498485,  -69.12516193],
         [ 317.26096084,  -41.53343411,  -15.18248703],
         [ 126.27763224,  -58.38171139,  187.08190441]],

        [[ 254.3949484 ,  107.85178033,  304.67510746],
         [ -52.12582621,   17.22681417, -255.89791842],
         [-182.44827854,  -27.31303965,   94.44380832],
         [ -19.82084327,  -97.76555481, -143.22099734]]]),
 'stress': array([[-6.20371789e+03,  1.13627470e+01, -2.37123066e+01,
          1.13627484e+01, -6.39380530e+03, -1.62881455e+00,
         -2.37123045e+01, -1.62881573e+00,  1.20551161e+00],
        [-8.50114487e+03,  1.84839986e+00,  2.08877457e-01,
          1.84840042e+00, -8.65504055e+03, -4.07797059e-01,
          2.08877844e-01, -4.07797449e-01, -4.92153189e+00],
        [ 3.63854203e+04, -1.97266425e+01, -9.44855164e+00,
         -1.97266432e+01,  3.63329050e+04,  2.73247795e+00,
         -9.44855559e+00,  2.73248313e+00,  1.26945233e+01],
        [-6.59333928e+03,  5.05246045e+01,  1.00709374e+01,
          5.05246048e+01, -6.35512008e+03, -2.80925787e+00,
          1.00709363e+01, -2.80925581e+00,  7.70519751e+00],
        [ 4.09253355e+03,  3.87355845e+00, -4.18005947e+00,
          3.87355820e+00,  4.34530492e+03, -1.30450751e+00,
         -4.18006231e+00, -1.30450772e+00, -8.47212142e+00]])}

Save/load model

save the trained model

[17]:
# we re-write the descriptor here to empasize that it should be the same one defined above
descriptor = {'name': 'acsf',
              'cutoff': 6.5,
              'etaG2':[0.01,0.05,0.1,0.5,1,5,10],
              'etaG4': [0.01],
              'zeta': [0.08,0.2,1.0,5.0,10.0,50.0,100.0],
              'lambda': [1.0, -1.0]}

save_dir = 'example.tfdnn'
network.save(model,save_dir,descriptor=descriptor)
INFO:tensorflow:Assets written to: example.tfdnn/assets
Network signatures and descriptor are written to example.tfdnn/parameters for LAMMPS simulation.

Load the trained model for continuous training and prediction

[18]:
imported_model = network.load(save_dir)

# Re-train the model
loss_weights = {'pe' : 1, 'force' : 1, 'stress': 0.1}

opt = 'Adam'
loss_fun = 'mae'
scaling = 'std'

model.train(train_dataset, val_dataset,
            optimizer=opt,
            loss_fun = loss_fun,
            batch_size=30,
            lr=0.02,
            epochs=50,
            scaling=scaling,
            loss_weights=loss_weights,
            compute_all_loss=True,
            shuffle=True,
            append_loss=True)
Network has been inflated! self.built: True
Forces are used for training.
Stresses are used for training.
Scaling factors are computed using training dataset.
Training dataset are standardized.
Validation dataset are standardized.
Training dataset will be shuffled during training.

===> Epoch 1/50 - 0.286s/epoch
     training_loss    - pe_loss: 6.225 - force_loss: 107.097 - stress_loss: 5896.093 - total_loss: 702.930
     validation_loss  - pe_loss: 5.153 - force_loss: 81.355 - stress_loss: 3708.718 - total_loss: 457.380

===> Epoch 2/50 - 0.291s/epoch
     training_loss    - pe_loss: 6.683 - force_loss: 85.014 - stress_loss: 3550.454 - total_loss: 446.742
     validation_loss  - pe_loss: 6.016 - force_loss: 72.757 - stress_loss: 3547.104 - total_loss: 433.484

===> Epoch 3/50 - 0.208s/epoch
     training_loss    - pe_loss: 4.510 - force_loss: 71.721 - stress_loss: 3633.668 - total_loss: 439.598
     validation_loss  - pe_loss: 6.587 - force_loss: 69.578 - stress_loss: 3321.606 - total_loss: 408.326

===> Epoch 4/50 - 0.313s/epoch
     training_loss    - pe_loss: 4.918 - force_loss: 72.456 - stress_loss: 3999.166 - total_loss: 477.291
     validation_loss  - pe_loss: 7.014 - force_loss: 61.233 - stress_loss: 3738.862 - total_loss: 442.133

===> Epoch 5/50 - 0.269s/epoch
     training_loss    - pe_loss: 4.612 - force_loss: 68.645 - stress_loss: 3438.338 - total_loss: 417.091
     validation_loss  - pe_loss: 7.047 - force_loss: 63.952 - stress_loss: 2774.278 - total_loss: 348.426

===> Epoch 6/50 - 0.219s/epoch
     training_loss    - pe_loss: 6.011 - force_loss: 61.255 - stress_loss: 3458.065 - total_loss: 413.073
     validation_loss  - pe_loss: 6.830 - force_loss: 63.454 - stress_loss: 3492.039 - total_loss: 419.488

===> Epoch 7/50 - 0.216s/epoch
     training_loss    - pe_loss: 5.938 - force_loss: 64.808 - stress_loss: 3637.062 - total_loss: 434.452
     validation_loss  - pe_loss: 6.210 - force_loss: 57.891 - stress_loss: 3011.305 - total_loss: 365.231

===> Epoch 8/50 - 0.225s/epoch
     training_loss    - pe_loss: 4.245 - force_loss: 76.674 - stress_loss: 2692.130 - total_loss: 350.132
     validation_loss  - pe_loss: 5.141 - force_loss: 56.055 - stress_loss: 2189.114 - total_loss: 280.108

===> Epoch 9/50 - 0.227s/epoch
     training_loss    - pe_loss: 3.599 - force_loss: 51.576 - stress_loss: 2436.548 - total_loss: 298.829
     validation_loss  - pe_loss: 4.396 - force_loss: 58.601 - stress_loss: 2538.165 - total_loss: 316.814

===> Epoch 10/50 - 0.258s/epoch
     training_loss    - pe_loss: 3.344 - force_loss: 46.560 - stress_loss: 2197.354 - total_loss: 269.640
     validation_loss  - pe_loss: 4.293 - force_loss: 47.936 - stress_loss: 2566.444 - total_loss: 308.873

===> Epoch 11/50 - 0.239s/epoch
     training_loss    - pe_loss: 3.545 - force_loss: 47.598 - stress_loss: 2464.916 - total_loss: 297.635
     validation_loss  - pe_loss: 4.367 - force_loss: 45.184 - stress_loss: 2279.156 - total_loss: 277.467

===> Epoch 12/50 - 0.245s/epoch
     training_loss    - pe_loss: 3.830 - force_loss: 45.659 - stress_loss: 2423.599 - total_loss: 291.849
     validation_loss  - pe_loss: 4.177 - force_loss: 40.014 - stress_loss: 1761.722 - total_loss: 220.364

===> Epoch 13/50 - 0.258s/epoch
     training_loss    - pe_loss: 3.439 - force_loss: 55.449 - stress_loss: 2754.962 - total_loss: 334.384
     validation_loss  - pe_loss: 3.599 - force_loss: 38.619 - stress_loss: 1264.196 - total_loss: 168.638

===> Epoch 14/50 - 0.244s/epoch
     training_loss    - pe_loss: 3.371 - force_loss: 43.619 - stress_loss: 1720.694 - total_loss: 219.059
     validation_loss  - pe_loss: 3.033 - force_loss: 37.956 - stress_loss: 964.631 - total_loss: 137.452

===> Epoch 15/50 - 0.216s/epoch
     training_loss    - pe_loss: 2.820 - force_loss: 43.306 - stress_loss: 1832.361 - total_loss: 229.362
     validation_loss  - pe_loss: 2.787 - force_loss: 41.718 - stress_loss: 1062.309 - total_loss: 150.736

===> Epoch 16/50 - 0.260s/epoch
     training_loss    - pe_loss: 2.301 - force_loss: 37.596 - stress_loss: 2066.742 - total_loss: 246.571
     validation_loss  - pe_loss: 2.631 - force_loss: 34.549 - stress_loss: 942.406 - total_loss: 131.421

===> Epoch 17/50 - 0.205s/epoch
     training_loss    - pe_loss: 2.490 - force_loss: 33.295 - stress_loss: 1390.026 - total_loss: 174.788
     validation_loss  - pe_loss: 2.603 - force_loss: 28.231 - stress_loss: 1015.503 - total_loss: 132.384

===> Epoch 18/50 - 0.223s/epoch
     training_loss    - pe_loss: 2.607 - force_loss: 40.722 - stress_loss: 1264.683 - total_loss: 169.797
     validation_loss  - pe_loss: 2.482 - force_loss: 28.160 - stress_loss: 1063.177 - total_loss: 136.959

===> Epoch 19/50 - 0.243s/epoch
     training_loss    - pe_loss: 2.484 - force_loss: 36.703 - stress_loss: 1246.367 - total_loss: 163.824
     validation_loss  - pe_loss: 2.171 - force_loss: 25.673 - stress_loss: 646.808 - total_loss: 92.525

===> Epoch 20/50 - 0.245s/epoch
     training_loss    - pe_loss: 2.196 - force_loss: 37.554 - stress_loss: 1010.980 - total_loss: 140.848
     validation_loss  - pe_loss: 2.050 - force_loss: 27.034 - stress_loss: 989.504 - total_loss: 128.035

===> Epoch 21/50 - 0.313s/epoch
     training_loss    - pe_loss: 1.950 - force_loss: 40.177 - stress_loss: 1302.699 - total_loss: 172.397
     validation_loss  - pe_loss: 2.259 - force_loss: 28.655 - stress_loss: 923.259 - total_loss: 123.240

===> Epoch 22/50 - 0.273s/epoch
     training_loss    - pe_loss: 1.939 - force_loss: 33.098 - stress_loss: 920.506 - total_loss: 127.088
     validation_loss  - pe_loss: 2.598 - force_loss: 27.771 - stress_loss: 920.784 - total_loss: 122.447

===> Epoch 23/50 - 0.273s/epoch
     training_loss    - pe_loss: 2.114 - force_loss: 30.801 - stress_loss: 1053.435 - total_loss: 138.258
     validation_loss  - pe_loss: 2.808 - force_loss: 23.787 - stress_loss: 796.039 - total_loss: 106.198

===> Epoch 24/50 - 0.284s/epoch
     training_loss    - pe_loss: 2.508 - force_loss: 28.795 - stress_loss: 798.401 - total_loss: 111.143
     validation_loss  - pe_loss: 2.946 - force_loss: 23.228 - stress_loss: 654.580 - total_loss: 91.632

===> Epoch 25/50 - 0.224s/epoch
     training_loss    - pe_loss: 2.822 - force_loss: 28.372 - stress_loss: 639.680 - total_loss: 95.162
     validation_loss  - pe_loss: 2.937 - force_loss: 22.818 - stress_loss: 562.909 - total_loss: 82.047

===> Epoch 26/50 - 0.268s/epoch
     training_loss    - pe_loss: 2.591 - force_loss: 28.649 - stress_loss: 561.209 - total_loss: 87.362
     validation_loss  - pe_loss: 2.668 - force_loss: 23.794 - stress_loss: 591.417 - total_loss: 85.604

===> Epoch 27/50 - 0.254s/epoch
     training_loss    - pe_loss: 2.355 - force_loss: 23.397 - stress_loss: 483.879 - total_loss: 74.140
     validation_loss  - pe_loss: 2.109 - force_loss: 24.381 - stress_loss: 475.684 - total_loss: 74.059

===> Epoch 28/50 - 0.254s/epoch
     training_loss    - pe_loss: 2.215 - force_loss: 24.191 - stress_loss: 349.811 - total_loss: 61.387
     validation_loss  - pe_loss: 1.648 - force_loss: 23.037 - stress_loss: 360.378 - total_loss: 60.722

===> Epoch 29/50 - 0.313s/epoch
     training_loss    - pe_loss: 1.653 - force_loss: 26.394 - stress_loss: 387.579 - total_loss: 66.805
     validation_loss  - pe_loss: 1.442 - force_loss: 18.368 - stress_loss: 334.823 - total_loss: 53.291

===> Epoch 30/50 - 0.303s/epoch
     training_loss    - pe_loss: 1.154 - force_loss: 20.671 - stress_loss: 327.221 - total_loss: 54.548
     validation_loss  - pe_loss: 1.302 - force_loss: 16.569 - stress_loss: 340.470 - total_loss: 51.917

===> Epoch 31/50 - 0.226s/epoch
     training_loss    - pe_loss: 1.310 - force_loss: 22.966 - stress_loss: 326.334 - total_loss: 56.910
     validation_loss  - pe_loss: 1.202 - force_loss: 16.213 - stress_loss: 318.304 - total_loss: 49.245

===> Epoch 32/50 - 0.233s/epoch
     training_loss    - pe_loss: 1.443 - force_loss: 20.132 - stress_loss: 277.640 - total_loss: 49.338
     validation_loss  - pe_loss: 1.136 - force_loss: 16.502 - stress_loss: 280.441 - total_loss: 45.682

===> Epoch 33/50 - 0.254s/epoch
     training_loss    - pe_loss: 1.371 - force_loss: 20.072 - stress_loss: 294.634 - total_loss: 50.907
     validation_loss  - pe_loss: 1.115 - force_loss: 15.626 - stress_loss: 218.226 - total_loss: 38.564

===> Epoch 34/50 - 0.234s/epoch
     training_loss    - pe_loss: 1.158 - force_loss: 18.625 - stress_loss: 217.164 - total_loss: 41.499
     validation_loss  - pe_loss: 1.145 - force_loss: 15.217 - stress_loss: 211.943 - total_loss: 37.557

===> Epoch 35/50 - 0.262s/epoch
     training_loss    - pe_loss: 0.966 - force_loss: 16.210 - stress_loss: 171.115 - total_loss: 34.287
     validation_loss  - pe_loss: 1.496 - force_loss: 16.229 - stress_loss: 217.624 - total_loss: 39.487

===> Epoch 36/50 - 0.207s/epoch
     training_loss    - pe_loss: 1.585 - force_loss: 15.893 - stress_loss: 223.773 - total_loss: 39.856
     validation_loss  - pe_loss: 1.950 - force_loss: 15.536 - stress_loss: 227.235 - total_loss: 40.210

===> Epoch 37/50 - 0.226s/epoch
     training_loss    - pe_loss: 2.034 - force_loss: 15.639 - stress_loss: 162.887 - total_loss: 33.961
     validation_loss  - pe_loss: 1.884 - force_loss: 14.791 - stress_loss: 193.579 - total_loss: 36.033

===> Epoch 38/50 - 0.251s/epoch
     training_loss    - pe_loss: 1.547 - force_loss: 16.658 - stress_loss: 135.980 - total_loss: 31.803
     validation_loss  - pe_loss: 1.340 - force_loss: 14.072 - stress_loss: 157.810 - total_loss: 31.193

===> Epoch 39/50 - 0.244s/epoch
     training_loss    - pe_loss: 0.891 - force_loss: 13.698 - stress_loss: 172.006 - total_loss: 31.789
     validation_loss  - pe_loss: 1.055 - force_loss: 13.231 - stress_loss: 140.377 - total_loss: 28.324

===> Epoch 40/50 - 0.219s/epoch
     training_loss    - pe_loss: 1.238 - force_loss: 15.531 - stress_loss: 200.949 - total_loss: 36.864
     validation_loss  - pe_loss: 1.002 - force_loss: 12.882 - stress_loss: 101.479 - total_loss: 24.033

===> Epoch 41/50 - 0.223s/epoch
     training_loss    - pe_loss: 0.755 - force_loss: 13.338 - stress_loss: 130.370 - total_loss: 27.130
     validation_loss  - pe_loss: 0.930 - force_loss: 12.227 - stress_loss: 97.618 - total_loss: 22.919

===> Epoch 42/50 - 0.226s/epoch
     training_loss    - pe_loss: 0.833 - force_loss: 14.639 - stress_loss: 129.654 - total_loss: 28.437
     validation_loss  - pe_loss: 0.882 - force_loss: 12.062 - stress_loss: 80.547 - total_loss: 20.998

===> Epoch 43/50 - 0.246s/epoch
     training_loss    - pe_loss: 1.131 - force_loss: 13.162 - stress_loss: 110.646 - total_loss: 25.357
     validation_loss  - pe_loss: 0.868 - force_loss: 13.379 - stress_loss: 71.995 - total_loss: 21.447

===> Epoch 44/50 - 0.250s/epoch
     training_loss    - pe_loss: 0.952 - force_loss: 13.554 - stress_loss: 79.879 - total_loss: 22.494
     validation_loss  - pe_loss: 0.932 - force_loss: 14.400 - stress_loss: 78.787 - total_loss: 23.211

===> Epoch 45/50 - 0.242s/epoch
     training_loss    - pe_loss: 0.758 - force_loss: 12.906 - stress_loss: 117.809 - total_loss: 25.445
     validation_loss  - pe_loss: 0.988 - force_loss: 14.545 - stress_loss: 90.461 - total_loss: 24.579

===> Epoch 46/50 - 0.186s/epoch
     training_loss    - pe_loss: 0.701 - force_loss: 13.074 - stress_loss: 74.397 - total_loss: 21.215
     validation_loss  - pe_loss: 0.977 - force_loss: 12.150 - stress_loss: 55.447 - total_loss: 18.672

===> Epoch 47/50 - 0.242s/epoch
     training_loss    - pe_loss: 0.743 - force_loss: 14.006 - stress_loss: 51.716 - total_loss: 19.921
     validation_loss  - pe_loss: 0.940 - force_loss: 11.309 - stress_loss: 76.686 - total_loss: 19.918

===> Epoch 48/50 - 0.227s/epoch
     training_loss    - pe_loss: 0.777 - force_loss: 12.847 - stress_loss: 64.845 - total_loss: 20.109
     validation_loss  - pe_loss: 0.931 - force_loss: 10.905 - stress_loss: 75.985 - total_loss: 19.434

===> Epoch 49/50 - 0.238s/epoch
     training_loss    - pe_loss: 0.679 - force_loss: 12.772 - stress_loss: 54.894 - total_loss: 18.940
     validation_loss  - pe_loss: 0.953 - force_loss: 10.538 - stress_loss: 63.520 - total_loss: 17.843

===> Epoch 50/50 - 0.252s/epoch
     training_loss    - pe_loss: 0.789 - force_loss: 11.463 - stress_loss: 58.099 - total_loss: 18.062
     validation_loss  - pe_loss: 0.978 - force_loss: 10.639 - stress_loss: 46.868 - total_loss: 16.304

End of training, elapsed time:  00:00:12
[19]:
model.plot_loss(start_epoch=1)
../_images/getstarted_example_32_0.png
../_images/getstarted_example_32_1.png
../_images/getstarted_example_32_2.png
../_images/getstarted_example_32_3.png
[20]:
imported_model.evaluate(test_dataset.take(5),return_prediction=False)
Evaluation loss is:
        pe_loss:       8.3364e+00
     force_loss:       9.8733e+01
    stress_loss:       4.5703e+03
     total_loss:       5.6410e+02
The total loss is computed using the loss weights - pe: 1.00 - force: 1.00 - stress: 0.10
[21]:
input_dict = get_input_dict(test_dataset.take(5))
imported_model.predict(input_dict)
[21]:
{'pe': array([-35.19343363, -30.5018231 , -13.10211931, -40.24163573,
        -34.76370812]),
 'force': array([[[  62.99615578,  -13.02417413,   35.15568535],
         [ 125.17564353,   56.29076835,    4.00421797],
         [-120.00388   ,  -31.82880756,    4.69251285],
         [ -68.16791941,  -11.43778665,  -43.85241614]],

        [[ -10.02454801,  -75.68597036,  -84.857218  ],
         [  76.18432679,   28.68981993,  116.42437951],
         [ 104.90041358,   37.58853954,  216.42067892],
         [-171.06019292,    9.40761135, -247.98784042]],

        [[-187.60105157,   72.76025473,   95.32019806],
         [ 115.33638883,  -71.88966868,  -41.80044277],
         [  69.78280937,  152.38949361,   73.49249946],
         [   2.48185339, -153.26007932, -127.01225473]],

        [[ -78.6992566 , -152.26434996,  -34.44163215],
         [ 217.72729506,   42.42075491,  330.17770106],
         [  61.01884481, -136.88389852,   -9.8914536 ],
         [-200.04688344,  246.72749333, -285.84461536]],

        [[ 343.49811374,   -6.76913846,   74.74235097],
         [-171.71298205,   96.87952106,  -14.96776701],
         [-198.50906892, -136.20298223,   16.80680101],
         [  26.7239375 ,   46.09259955,  -76.581385  ]]]),
 'stress': array([[-1.96553956e+03, -3.84383326e+01,  5.27824788e+00,
         -3.84383324e+01, -2.24581615e+03,  1.04064708e+00,
          5.27824870e+00,  1.04064694e+00,  1.57429504e+00],
        [-1.04472838e+04, -3.39653871e+01, -1.10222099e+01,
         -3.39653872e+01, -1.05141466e+04,  3.61939603e-01,
         -1.10222056e+01,  3.61936112e-01, -1.71367644e+01],
        [ 4.39121460e+03,  4.09626905e+01, -1.07343792e+01,
          4.09626917e+01,  4.26569843e+03, -5.34677434e+00,
         -1.07343794e+01, -5.34677692e+00, -5.61152547e-01],
        [ 3.66647618e+04,  3.74010913e+01,  1.02510058e+01,
          3.74010908e+01,  3.65224082e+04,  9.64309282e-01,
          1.02510072e+01,  9.64311133e-01,  1.47602297e+01],
        [-1.49339234e+04, -1.26911479e+02, -8.38920551e+00,
         -1.26911479e+02, -1.47875442e+04,  1.52792284e+00,
         -8.38920762e+00,  1.52792340e+00, -2.73513556e+00]])}