MSE-CNN Implementation 1
Code database with the implementation of MSE-CNN, from the paper 'DeepQTMT: A Deep Learning Approach for Fast QTMT-based CU Partition of Intra-mode VVC'
|
Namespaces | |
namespace | train_stg6 |
Functions | |
train_stg6.train (dataloader, model, loss_fn, optimizer, device) | |
If batch size equal to 1 it's a Stochastic Gradiente Descent (SGD), otherwise it's a mini-batch gradient descent. | |
train_stg6.test (dataloader, model, loss_fn, device, loss_name) | |
train_stg6.train_test (train_dataloader, test_dataloader, model, loss_fn, optimizer, device, epochs, lr_sch) | |
train_stg6.main () | |
Variables | |
train_stg6.parser = argparse.ArgumentParser(description=constants.script_description) | |
train_stg6.type | |
train_stg6.args = parser.parse_args() | |
train_stg6.beta = args.b | |
train_stg6.learning_rate = args.lr | |
train_stg6.loss_threshold = float("-inf") | |
int | train_stg6.QP = 32 |
train_stg6.batch_size = args.batch | |
train_stg6.iterations = args.i | |
float | train_stg6.decay = 0.01 |
train_stg6.decay_controler = args.dcontr | |
train_stg6.device = args.dev | |
train_stg6.num_workers = args.workers | |
train_stg6.n_mod = args.nmod | |
train_stg6.l_path_train = args.labelsTrain | |
train_stg6.l_path_test = args.labelsTest | |
train_stg6.writer = SummaryWriter("runs/MSECNN_"+n_mod) | |
str | train_stg6.files_mod_name_stats = "_multi_batch_iter_{ite}_batch_{batch}_QP_{QP}_beta_{be}_lr_{lr}_{n_mod}".format(ite=iterations, batch=batch_size, QP=QP, be=beta, lr=learning_rate, n_mod=n_mod) |
int | train_stg6.cnt_train = 0 |
int | train_stg6.cnt_test_train = 0 |
int | train_stg6.cnt_test_test = 0 |