Ex Data, Scientia

Home Contact

Building a CNN from scratch in C++

#include <TMB.hpp>
template<class Type>
Type objective_function<Type>::operator() ()
{
  DATA_INTEGER(n_imgs);
  DATA_SCALAR(nrow_img);
  DATA_SCALAR(nrow_rep_1);
  DATA_SCALAR(nrow_rep_2);
  DATA_SCALAR(nrow_rep_3);
  DATA_SCALAR(n_class);
  DATA_MATRIX(true_class);
  DATA_ARRAY(inp_img);
  
  PARAMETER_ARRAY(conv_1);
  PARAMETER_ARRAY(conv_2);
  PARAMETER_ARRAY(conv_3);
  PARAMETER_ARRAY(dense_1);
  
  vector<Type> enum_softm(n_imgs);
  matrix<Type> pred_pre(n_imgs,2);
  matrix<Type> pred(n_imgs,2);
  Type loss = 0;
  
  array<Type> int_rep_1(n_imgs,5,27-2,27-2);
  array<Type> int_rep_2(n_imgs,5,27-2-4,27-2-4);
  array<Type> int_rep_3(n_imgs,5,27-2-4-6,27-2-4-6);
  
  for(int g = 0; g < 12; ++g){
    for(int i = 0; i < nrow_img-2; ++i){
      for(int j = 0; j < nrow_img-2; ++j){
        for(int m = 0; m < 5; ++m){
          for(int k = 0; k < 3; ++k){
            for(int l = 0; l < 3; ++l){
              int_rep_1(g,m,i,j) += conv_1(m,k,l) * inp_img(g,i+k,j+l);
            }
          }
        }
      }
    }
    
    for(int i = 0; i < nrow_rep_1-4; ++i){
      for(int j = 0; j < nrow_rep_1-4; ++j){
        for(int m = 0; m < 5; ++m){
          for(int k = 0; k < 5; ++k){
            for(int l = 0; l < 5; ++l){
              for(int n = 0; n < 5; ++n){
                if(int_rep_1(g,n,i+k,j+l) > 0){
                  int_rep_2(g,m,i,j) += conv_2(m,k,l) * int_rep_1(g,n,i+k,j+l);
                }
              }
            }
          }
        }
      }
    }

    for(int i = 0; i < nrow_rep_2-6; ++i){
      for(int j = 0; j < nrow_rep_2-6; ++j){
        for(int m = 0; m < 5; ++m){
          for(int k = 0; k < 7; ++k){
            for(int l = 0; l < 7; ++l){
              for(int n = 0; n < 5; ++n){
                if(int_rep_2(g,n,i+k,j+l) > 0){
                  int_rep_3(g,m,i,j) += conv_3(m,k,l) * int_rep_2(g,n,i+k,j+l);
                }
              }
            }
          }
        }
      }
    }

    for(int h = 0; h < n_class; ++h){
      for(int i = 0; i < nrow_rep_2-6; ++i){
        for(int j = 0; j < nrow_rep_2-6; ++j){
          for(int m = 0; m < 5; ++m){
            if(int_rep_3(g,m,i,j) > 0){
              pred_pre(g,h) += dense_1(h,m,i,j) * int_rep_3(g,m,i,j);
            }
          }
        }
      }
    }

    for(int h = 0; h < n_class; ++h){
      enum_softm(g) += exp(pred_pre(g,h));
    }

    for(int h = 0; h < n_class; ++h){
      pred(g,h) += exp(pred_pre(g,h)) / enum_softm(g);
    }

    for(int h = 0; h < n_class; ++h){
      loss += -(true_class(g,h) * log(pred(g,h)) + (1-true_class(g,h)) * log(1-pred(g,h)));
    }
  }
  
  ADREPORT(int_rep_1);
  ADREPORT(int_rep_2);
  ADREPORT(int_rep_3);
  ADREPORT(pred);
  
  return loss;
}