Building a CNN from scratch in C++
#include <TMB.hpp>
template<class Type>
Type objective_function<Type>::operator() ()
{
DATA_INTEGER(n_imgs);
DATA_SCALAR(nrow_img);
DATA_SCALAR(nrow_rep_1);
DATA_SCALAR(nrow_rep_2);
DATA_SCALAR(nrow_rep_3);
DATA_SCALAR(n_class);
DATA_MATRIX(true_class);
DATA_ARRAY(inp_img);
PARAMETER_ARRAY(conv_1);
PARAMETER_ARRAY(conv_2);
PARAMETER_ARRAY(conv_3);
PARAMETER_ARRAY(dense_1);
vector<Type> enum_softm(n_imgs);
matrix<Type> pred_pre(n_imgs,2);
matrix<Type> pred(n_imgs,2);
Type loss = 0;
array<Type> int_rep_1(n_imgs,5,27-2,27-2);
array<Type> int_rep_2(n_imgs,5,27-2-4,27-2-4);
array<Type> int_rep_3(n_imgs,5,27-2-4-6,27-2-4-6);
for(int g = 0; g < 12; ++g){
for(int i = 0; i < nrow_img-2; ++i){
for(int j = 0; j < nrow_img-2; ++j){
for(int m = 0; m < 5; ++m){
for(int k = 0; k < 3; ++k){
for(int l = 0; l < 3; ++l){
int_rep_1(g,m,i,j) += conv_1(m,k,l) * inp_img(g,i+k,j+l);
}
}
}
}
}
for(int i = 0; i < nrow_rep_1-4; ++i){
for(int j = 0; j < nrow_rep_1-4; ++j){
for(int m = 0; m < 5; ++m){
for(int k = 0; k < 5; ++k){
for(int l = 0; l < 5; ++l){
for(int n = 0; n < 5; ++n){
if(int_rep_1(g,n,i+k,j+l) > 0){
int_rep_2(g,m,i,j) += conv_2(m,k,l) * int_rep_1(g,n,i+k,j+l);
}
}
}
}
}
}
}
for(int i = 0; i < nrow_rep_2-6; ++i){
for(int j = 0; j < nrow_rep_2-6; ++j){
for(int m = 0; m < 5; ++m){
for(int k = 0; k < 7; ++k){
for(int l = 0; l < 7; ++l){
for(int n = 0; n < 5; ++n){
if(int_rep_2(g,n,i+k,j+l) > 0){
int_rep_3(g,m,i,j) += conv_3(m,k,l) * int_rep_2(g,n,i+k,j+l);
}
}
}
}
}
}
}
for(int h = 0; h < n_class; ++h){
for(int i = 0; i < nrow_rep_2-6; ++i){
for(int j = 0; j < nrow_rep_2-6; ++j){
for(int m = 0; m < 5; ++m){
if(int_rep_3(g,m,i,j) > 0){
pred_pre(g,h) += dense_1(h,m,i,j) * int_rep_3(g,m,i,j);
}
}
}
}
}
for(int h = 0; h < n_class; ++h){
enum_softm(g) += exp(pred_pre(g,h));
}
for(int h = 0; h < n_class; ++h){
pred(g,h) += exp(pred_pre(g,h)) / enum_softm(g);
}
for(int h = 0; h < n_class; ++h){
loss += -(true_class(g,h) * log(pred(g,h)) + (1-true_class(g,h)) * log(1-pred(g,h)));
}
}
ADREPORT(int_rep_1);
ADREPORT(int_rep_2);
ADREPORT(int_rep_3);
ADREPORT(pred);
return loss;
}