Wrappers for Pytorch
Models saved for tiktorch must follow this specification:
A folder containing the files
model.py
: python file that defines the model.state.nn
: state dict of the model, as obtained bytorch.save(model.state_dict(), 'state.nn')
tiktorch_config.yml
: yaml file with metadata
The config must contain the following keys:
input_shape
: shape of valid network input, must be either CHW (2D) or CDHW (3D)output_shape
: shape of network output given input withinput_shape
; same formatdynamic_input_shape
: TODO explainmodel_class_name
: name of the model class inmodel.py
model_init_kwargs
: keyword arguments to build modeltorch_version
: torch version used to train this model
In addition, the config may contain the following keys:
description
: Description of the pre-trained modeldata_source
: URL of the data used for pre-training
TODO explain how to generate with tiktorch.
Possible extensions:
- specification for training set-up
- specifiying additional python modules necessary to run the model
- load model saved via
torch.save(model, path)
instead of state dict