RTM-Predictions¶
Project for working with simulated data.
Warning
The Pipeline is still work in progress and may has unexpected behaviour and/or some code flaws
Tutorial for training models¶
- Most of the work is done by the ModelTrainer class (defined in Trainer.ModelTrainer.py)
- ModelTrainer is the generic base class which takes all necessary parameters for any kind of training
- The training process is configured and started from an own dedicated script
- Examples for these dedicated scripts can be found in the model_trainer_*.py files in the root directory
Basic principles of ModelTrainer:
- Data processing is currently done by the ‘LoopingDataGenerator’ (defined in Pipeline.torch_datagenerator.py)
- LoopingDataGenerator takes a list of file paths as base paths
- The base paths are searched for .erfh5 files using the data_gather_function passed to the ModelTrainer
- After gathering the .erfh5 files, the data from these is processed using the data_processing_function passed to the ModelTrainer. An example for processing is the extraction of all pressure sensor values
- Additional work such as creating batches and shuffling data is done automatically
- The training process is implemented in ModelTrainer. This includes validation steps during training and testing on a dedicated test set after training
Steps for using the ModelTrainer in your script:
For data processing you need two functions:
- data_gather_function, a function for collecting the paths to the files from a root directory
- data_processing_function, a function that extracts the data from the collected filepaths. Must return data in following format: [(instance_1, label_1), … , (instance_n, label_n)] (examples for both functions can be found in the Pipeline.data_loader_*.py and Pipeline.data_gather.py files)
Define a PyTorch model
Instantiate the ModelTrainer: mt = ModelTrainer( … ). Pass all necessary arguments for your task. Important: You have to pass your model using lambda: YourModel()
Train your model using mt.start_training(). No additional parameters need to be passed if you have configured the ModelTrainer correctly
Testing using a dedicated test set can be done using mt.inference_on_test_set( … )
There are many more optional arguments for Master_Trainer and ERFH5_DataGenerator