MSE-CNN Implementation 1
Code database with the implementation of MSE-CNN, from the paper 'DeepQTMT: A Deep Learning Approach for Fast QTMT-based CU Partition of Intra-mode VVC'
Loading...
Searching...
No Matches
Functions | Variables
train_stg6 Namespace Reference

Functions

 train (dataloader, model, loss_fn, optimizer, device)
 If batch size equal to 1 it's a Stochastic Gradiente Descent (SGD), otherwise it's a mini-batch gradient descent.
 
 test (dataloader, model, loss_fn, device, loss_name)
 
 train_test (train_dataloader, test_dataloader, model, loss_fn, optimizer, device, epochs, lr_sch)
 
 main ()
 

Variables

 parser = argparse.ArgumentParser(description=constants.script_description)
 
 type
 
 args = parser.parse_args()
 
 beta = args.b
 
 learning_rate = args.lr
 
 loss_threshold = float("-inf")
 
int QP = 32
 
 batch_size = args.batch
 
 iterations = args.i
 
float decay = 0.01
 
 decay_controler = args.dcontr
 
 device = args.dev
 
 num_workers = args.workers
 
 n_mod = args.nmod
 
 l_path_train = args.labelsTrain
 
 l_path_test = args.labelsTest
 
 writer = SummaryWriter("runs/MSECNN_"+n_mod)
 
str files_mod_name_stats = "_multi_batch_iter_{ite}_batch_{batch}_QP_{QP}_beta_{be}_lr_{lr}_{n_mod}".format(ite=iterations, batch=batch_size, QP=QP, be=beta, lr=learning_rate, n_mod=n_mod)
 
int cnt_train = 0
 
int cnt_test_train = 0
 
int cnt_test_test = 0
 

Detailed Description

@package docstring 

@file train_stg6.py 

@brief Training script for the sixth stage of the MSE-CNN, for the luma channel.  
 
@section libraries_train_stg6 Libraries 
- sklearn.metrics
- MSECNN
- torch.utils.data
- torch
- argparse
- torch.utils.tensorboard
- datetime
- train_model_utils
- utils
- numpy
- constants
- CustomDataset
- sys
- time
- matplotlib.pyplot

@section classes_train_stg6 Classes 
- None

@section functions_train_stg6 Functions 
- train(dataloader, model, loss_fn, optimizer, device)
- test(dataloader, model, loss_fn, device, loss_name)
- train_test(train_dataloader, test_dataloader, model, loss_fn, optimizer, device, epochs, lr_sch)
- main()
 
@section global_vars_train_stg6 Global Variables 
- learning_rate
- parser 
- args 
- loss_threshold 
- batch_size 
- qp 
- device 
- n_mod 
- num_workers 
- writer 
- l_path_val 
- decay
- decay_controler
- iterations
- files_mod_name_stats 
- l_path_train
- l_path_test
- cnt_train
- cnt_test_train
- cnt_test_test

@section todo_train_stg6 TODO 
- None 

@section license License 
MIT License 
Copyright (c) 2022 Raul Kevin do Espirito Santo Viana
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@section author_train_stg6 Author(s)
- Created by Raul Kevin Viana
- Last time modified is 2023-01-29 22:23:10.689038

Function Documentation

◆ main()

train_stg6.main ( )

◆ test()

train_stg6.test (   dataloader,
  model,
  loss_fn,
  device,
  loss_name 
)

◆ train()

train_stg6.train (   dataloader,
  model,
  loss_fn,
  optimizer,
  device 
)

If batch size equal to 1 it's a Stochastic Gradiente Descent (SGD), otherwise it's a mini-batch gradient descent.

If the batch is the same as the number as the size of the dataset, it will be a Batch gradient Descent

◆ train_test()

train_stg6.train_test (   train_dataloader,
  test_dataloader,
  model,
  loss_fn,
  optimizer,
  device,
  epochs,
  lr_sch 
)

Variable Documentation

◆ args

train_stg6.args = parser.parse_args()

◆ batch_size

train_stg6.batch_size = args.batch

◆ beta

train_stg6.beta = args.b

◆ cnt_test_test

int train_stg6.cnt_test_test = 0

◆ cnt_test_train

int train_stg6.cnt_test_train = 0

◆ cnt_train

int train_stg6.cnt_train = 0

◆ decay

float train_stg6.decay = 0.01

◆ decay_controler

train_stg6.decay_controler = args.dcontr

◆ device

train_stg6.device = args.dev

◆ files_mod_name_stats

str train_stg6.files_mod_name_stats = "_multi_batch_iter_{ite}_batch_{batch}_QP_{QP}_beta_{be}_lr_{lr}_{n_mod}".format(ite=iterations, batch=batch_size, QP=QP, be=beta, lr=learning_rate, n_mod=n_mod)

◆ iterations

train_stg6.iterations = args.i

◆ l_path_test

train_stg6.l_path_test = args.labelsTest

◆ l_path_train

train_stg6.l_path_train = args.labelsTrain

◆ learning_rate

train_stg6.learning_rate = args.lr

◆ loss_threshold

train_stg6.loss_threshold = float("-inf")

◆ n_mod

train_stg6.n_mod = args.nmod

◆ num_workers

train_stg6.num_workers = args.workers

◆ parser

train_stg6.parser = argparse.ArgumentParser(description=constants.script_description)

◆ QP

int train_stg6.QP = 32

◆ type

train_stg6.type

◆ writer

train_stg6.writer = SummaryWriter("runs/MSECNN_"+n_mod)