MSE-CNN Implementation 1
Code database with the implementation of MSE-CNN, from the paper 'DeepQTMT: A Deep Learning Approach for Fast QTMT-based CU Partition of Intra-mode VVC'
Loading...
Searching...
No Matches
Namespaces | Functions | Variables
train_stg6.py File Reference

Namespaces

namespace  train_stg6
 

Functions

 train_stg6.train (dataloader, model, loss_fn, optimizer, device)
 If batch size equal to 1 it's a Stochastic Gradiente Descent (SGD), otherwise it's a mini-batch gradient descent.
 
 train_stg6.test (dataloader, model, loss_fn, device, loss_name)
 
 train_stg6.train_test (train_dataloader, test_dataloader, model, loss_fn, optimizer, device, epochs, lr_sch)
 
 train_stg6.main ()
 

Variables

 train_stg6.parser = argparse.ArgumentParser(description=constants.script_description)
 
 train_stg6.type
 
 train_stg6.args = parser.parse_args()
 
 train_stg6.beta = args.b
 
 train_stg6.learning_rate = args.lr
 
 train_stg6.loss_threshold = float("-inf")
 
int train_stg6.QP = 32
 
 train_stg6.batch_size = args.batch
 
 train_stg6.iterations = args.i
 
float train_stg6.decay = 0.01
 
 train_stg6.decay_controler = args.dcontr
 
 train_stg6.device = args.dev
 
 train_stg6.num_workers = args.workers
 
 train_stg6.n_mod = args.nmod
 
 train_stg6.l_path_train = args.labelsTrain
 
 train_stg6.l_path_test = args.labelsTest
 
 train_stg6.writer = SummaryWriter("runs/MSECNN_"+n_mod)
 
str train_stg6.files_mod_name_stats = "_multi_batch_iter_{ite}_batch_{batch}_QP_{QP}_beta_{be}_lr_{lr}_{n_mod}".format(ite=iterations, batch=batch_size, QP=QP, be=beta, lr=learning_rate, n_mod=n_mod)
 
int train_stg6.cnt_train = 0
 
int train_stg6.cnt_test_train = 0
 
int train_stg6.cnt_test_test = 0