MSE-CNN Implementation 1
Code database with the implementation of MSE-CNN, from the paper 'DeepQTMT: A Deep Learning Approach for Fast QTMT-based CU Partition of Intra-mode VVC'
Loading...
Searching...
No Matches
Namespaces | Functions
train_model_utils.py File Reference

Namespaces

namespace  msecnn_raulkviana
 
namespace  msecnn_raulkviana.train_model_utils
 

Functions

 msecnn_raulkviana.train_model_utils.model_statistics (J_history, predicted, ground_truth, pred_vector, gt_vector, f1_list, recall_list, precision_list, accuracy_list, train_or_val="train")
 Evaluates model with metrics, such as accuracy and f1_score.
 
 msecnn_raulkviana.train_model_utils.right_size (CUs)
 
 msecnn_raulkviana.train_model_utils.compute_conf_matrix (predicted, ground_truth)
 Computes the confusion matrix.
 
 msecnn_raulkviana.train_model_utils.compute_top_k_accuracy (pred_vector, gt_vector, topk)
 Computes the top k accuracy score.
 
 msecnn_raulkviana.train_model_utils.compute_num_splits_sent (pred_lst)
 Computes the num of splits that would be analyzed by the encoder.
 
 msecnn_raulkviana.train_model_utils.compute_multi_thres_performance (pred_lst, gt_lst)
 Computes multi-threshold performance.
 
 msecnn_raulkviana.train_model_utils.compute_ROC_curve (pred_vector, gt_vector, pred_num)
 Computes ROC curve.
 
 msecnn_raulkviana.train_model_utils.model_simple_metrics (predicted, ground_truth)
 Evaluates model with metrics 4 metrics, such as accuracy, f1_score, recall and precision.
 
 msecnn_raulkviana.train_model_utils.obtain_best_modes (rs, pred)
 Converts a prediction into a specific number that corresponds to the best way to split (non-split, quad tree, binary vert tree...)
 
 msecnn_raulkviana.train_model_utils.obtain_mode (pred)
 Converts a prediction into a specific number that corresponds to the best way to split (non-split, quad tree, binary vert tree...)
 
 msecnn_raulkviana.train_model_utils.one_hot_enc (tensor, num_classes=6)
 Implements one-hot encoding to a specific tensor with the set of split modes.
 
 msecnn_raulkviana.train_model_utils.print_parameters (model, optimizer)
 Prints the parameters from the state dictionaries of the model and optimizer.
 
 msecnn_raulkviana.train_model_utils.save_model_parameters (dir_name, f_name, model)
 Saves only the model parameters to a specific folder.
 
 msecnn_raulkviana.train_model_utils.save_model (dir_name, f_name, model, optimizer, loss, acc)
 Saves the parameters of the model and of the optimizer, and also the loss and the accuracy.
 
 msecnn_raulkviana.train_model_utils.load_model_parameters_stg (model, path, stg, dev)
 Loads all stages but make sure that the stage number 'stg' has the same parameters has the previous.
 
 msecnn_raulkviana.train_model_utils.load_model_parameters_eval (model, path, dev)
 Loads all stages, meant to be used with the eval_model script.
 
 msecnn_raulkviana.train_model_utils.load_model_stg_12_stg_3 (model, path, dev)
 THis function makes it possible to load parameters from the first and second stage to the third.
 
 msecnn_raulkviana.train_model_utils.load_model_stg_3_stg_4 (model, path, dev)
 This function makes it possible to load parameters from the third stage to the fourth.
 
 msecnn_raulkviana.train_model_utils.load_model_stg_4_stg_5 (model, path, dev)
 This function makes it possible to load parameters from the fourth stage to the fith.
 
 msecnn_raulkviana.train_model_utils.load_model_stg_5_stg_6 (model, path, dev)
 This function makes it possible to load parameters from the fourth stage to the fith.
 
 msecnn_raulkviana.train_model_utils.print_current_time ()
 Prints current time.