Create UTSR_ModS2.cpp

This commit is contained in:
mathur04
2024-07-10 23:09:21 +02:00
committed by GitHub
parent 9fd93fa7bd
commit 3e84d6437e

307
UTSR/UTSR_ModS2.cpp Normal file
View File

@ -0,0 +1,307 @@
#include <iostream>
#include <cmath>
#include <vector>
struct ScanData {
std::vector<std::vector<double>> AscanValues;
std::vector<double> timeScale;
struct CscanData {
std::string WaveType;
double Cl;
double Cs;
double Frequency;
double ProbeDiameter;
std::vector<double> X;
std::vector<double> Y;
} CscanData;
};
struct UTSRResult {
std::vector<std::vector<double>> output;
std::vector<double> roiX;
std::vector<double> roiZ;
std::string methodMsg;
std::string idTitle;
double Wavelength;
double alpha;
double UTSR_tol;
double timeElapsed;
int numIter;
std::vector<int> cg_iter;
std::vector<double> cg_tol_iter;
std::vector<double> errUTSR;
std::vector<double> difUTSR;
std::vector<double> Jf;
std::vector<double> beta_iter;
std::vector<double> errOrig;
};
UTSRResult UTSR_ModS2(ScanData scanData, std::vector<std::string> varargin) {
UTSRResult structOut;
// UTSR image reconstruction algorithm
auto Model = [&](std::vector<double> x, std::vector<double> W, double dt, double dx, double ERMv, double alpha, std::vector<int> szData, double tau0, std::vector<std::vector<double>> Mod, std::vector<std::vector<double>> Filt, int Rx) {
// Applies direct operator
std::vector<std::vector<double>> g = ModS2(reshape(x, szData), dt, dx, ERMv, tau0, Mod, Rx);
// Applies adjoint operator
std::vector<double> y = ModS2T(g , dt, dx, ERMv, tau0, Filt, Rx);
// Applies regularization
if(alpha != 0) {
// L1 norm if W != I and L = I,
// Tikhonov if W = I and L = I.
y = y + alpha*alpha*W*x;
} else {
// No regularization, least squares solution
y = y;
}
return y;
};
// Parser verification functions
auto scanDataValidationFcn = [&](ScanData x) {
return (x.AscanValues.size() > 0) && (x.AscanValues[0].size() > 0) && (x.timeScale.size() > 0) && (x.CscanData.X.size() > 0) && (x.CscanData.Y.size() > 0);
};
std::vector<std::string> validOutput = {"raw", "normalized", "dB"};
auto checkOutput = [&](std::string x) {
for (auto& s : validOutput) {
if (s == x) {
return true;
}
}
return false;
};
std::vector<std::string> validPostproc = {"none","rectified","env"};
auto checkPostproc = [&](std::string x) {
for (auto& s : validPostproc) {
if (s == x) {
return true;
}
}
return false;
};
std::vector<std::string> validPreproc = {"none", "env"};
auto checkPreproc = [&](std::string x) {
for (auto& s : validPreproc) {
if (s == x) {
return true;
}
}
return false;
};
std::vector<std::string> validTrFreqSp = {"gaussian","cossquare", "610060"};
auto checkTrFreqSp = [&](std::string x) {
for (auto& s : validTrFreqSp) {
if (s == x) {
return true;
}
}
return false;
};
auto checkSysModel = [&](std::vector<std::vector<double>> x) {
return (x.size() > 0) && (x[0].size() > 0);
};
auto checkfOrig = [&](std::vector<std::vector<double>> x) {
return (x.size() > 0) && (x[0].size() > 0);
};
// Verify scanData parameter
assert(scanDataValidationFcn(scanData));
// Get scan data informations
double c;
if (scanData.CscanData.WaveType == "L") {
c = scanData.CscanData.Cl;
} else {
c = scanData.CscanData.Cs;
}
double ERMv = c/2;
std::vector<double> scanX = scanData.CscanData.X;
std::vector<double> scanY = scanData.CscanData.Y;
int tamX = scanX.size();
int tamY = scanY.size();
double deltaScanX = scanX[1] - scanX[0];
std::vector<double> scanZ = scanData.timeScale;
double deltaScanZ = scanZ[1] - scanZ[0];
double apertAngle = asin(0.51*c/(scanData.CscanData.Frequency*1e6*scanData.CscanData.ProbeDiameter*1e-3));
double tanApertAngle = tan(apertAngle);
// Create parser to process arguments
// TODO: Implement parser
// Define ROI
// TODO: Implement ROI definition
// Start timer
std::cout << "$$$$$$$ Starting UTSR algorithm (alpha = " << inpP.Results.alpha << ") $$$$$$$" << std::endl;
std::cout << "---- Title: " << inpP.Results.title << " ----" << std::endl;
auto tAlg = std::chrono::steady_clock::now();
// Starting UTSR algorithm
int ci = ((yi - 1)*tamX);
std::vector<std::vector<double>> S_t_u = scanData.AscanValues[idxMinZ:idxMaxZ, (colIniX:colFinX)+ci];
std::vector<std::vector<double>> fOrig;
if (scanData.AscanValues.size() == inpP.Results.fOrig.size()) {
fOrig = inpP.Results.fOrig[idxMinZ:idxMaxZ, (colIniX:colFinX)+ci];
} else {
fOrig = std::vector<std::vector<double>>();
}
double tau0 = scanData.timeScale[idxMinZ]*1e-6;
switch (inpP.Results.Preproc) {
case "env":
S_t_u = abs(hilbert(S_t_u));
break;
}
// Parameters needes for Stolt Transform
std::vector<int> szData = {M, N};
double dt = deltaScanZ/ERMv;
double dx = deltaRoiX;
// Generating model and match filter
scanData.AscanValues = S_t_u;
scanData.CscanData.X = scanData.CscanData.X[(colIniX:colFinX)+ci];
scanData.timeScale = scanData.timeScale[idxMinZ:idxMaxZ];
std::vector<std::vector<double>> Mod;
std::vector<std::vector<double>> Filt;
if (inpP.Results.SysModel.empty()) {
std::tie(Mod, std::ignore) = GenerateModelFilter(scanData, "TrFreqSp", inpP.Results.TrFreqSp, Rx);
} else {
Mod = inpP.Results.SysModel;
}
Filt = conj(Mod);
// Reconstruction parameters
double alpha = inpP.Results.alpha;
double beta = 1;
double tol = inpP.Results.tol;
double cg_tol = inpP.Results.cg_tol;
// Initializing algorithm
std::vector<double> HTg = ModS2T(S_t_u, dt, dx, ERMv, tau0, Filt, Rx);
std::vector<double> fTemp = std::vector<double>(HTg.size(), 0);
int numIter = 1;
int stagCount = 0;
int maxIter = inpP.Results.maxIter;
while (true) {
if (numIter == 1) {
std::vector<double> W = std::vector<double>(fTemp.size(), 1);
} else {
if (!std::isinf(beta)) {
while (true) {
std::vector<double> W = 1./(abs(fTemp) + beta);
double n1fTemp = norm(fTemp, 1);
double n1fTempAppr = norm(sqrt(W).*fTemp)^2;
double n1Err = (n1fTemp-n1fTempAppr)/n1fTemp;
if (abs(n1Err-cg_tol) > cg_tol/10) {
beta = beta / (n1Err/cg_tol);
} else {
break;
}
}
}
}
beta_iter[numIter] = beta;
std::cout << "cg_tol = " << cg_tol << " | ";
cg_tol_iter[numIter] = cg_tol;
std::vector<double> f;
int flag;
int iter;
std::tie(f, flag, std::ignore, iter) = pcg(
[&](std::vector<double> x) { return Model(x, W, dt, dx, ERMv, alpha, szData, tau0, Mod, Filt, Rx); },
HTg, cg_tol, M*N, std::vector<double>(), std::vector<double>(), fTemp);
cg_iter[numIter] = iter;
if (inpP.Results.NegCutOff) {
for (int i = 0; i < f.size(); i++) {
if (f[i] < 0) {
f[i] = 0;
}
}
}
std::vector<double> Hf = ModS2(reshape(f, szData), dt, dx, ERMv, tau0, Mod, Rx);
tL2[numIter] = (1/2)*(norm(S_t_u(:) - Hf(:),2))^2;
tL1[numIter] = (alpha^2)*norm(f(:),1);
J_f[numIter] = tL2[numIter] + tL1[numIter];
errUTSR[numIter] = norm(f - fTemp);
if (numIter > 1) {
difUTSR[numIter] = abs(J_f[numIter] - J_f[numIter-1])/J_f[numIter-1];
} else {
difUTSR[numIter] = 1;
}
if (fOrig.size() > 0) {
errOrig[numIter] = norm(f - fOrig(:));
} else {
errOrig[numIter] = errUTSR[numIter];
}
std::cout << "errUTSR = " << errUTSR[numIter] << " | difUTSR = " << difUTSR[numIter] << " | J_f = " << J_f[numIter] << " | iter = " << iter << " | flag = " << flag << std::endl;
if (alpha == 0) {
break;
}
if (std::isinf(beta)) {
break;
}
if (errUTSR[numIter] < tol) {
break;
}
// Tolerância UTSR foi maior que na iteração anterior
if (numIter > 1 && (errUTSR[numIter] >= errUTSR[numIter-1])) {
stagCount = stagCount + 1;
if (stagCount >= inpP.Results.maxStagCount) {
break;
}
}
if ((maxIter > 0) && (numIter > maxIter)) {
break;
}
fTemp = f;
numIter = numIter + 1;
}
// Reshape found solution as a matrix
structOut.output = reshape(f, szData);
// Stop timer
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - tAlg).count();
structOut.output(isnan(structOut.output)) = 0;
// Creates output structures
structOut.roiX = roiX;
structOut.roiZ = roiZ;
structOut.methodMsg = "UTSR Algorithm";
structOut.idTitle = inpP.Results.title;
structOut.Wavelength = scanData.CscanData.Wavelength;
// Method parameters
structOut.alpha = alpha;
structOut.UTSR_tol = tol;
// Method results
structOut.timeElapsed = timeElapsed;
structOut.numIter = numIter;
structOut.cg_iter = cg_iter;
structOut.cg_tol_iter = cg_tol_iter;
structOut.errUTSR = errUTSR;
structOut.difUTSR = difUTSR;
structOut.Jf = J_f;
structOut.beta_iter = beta_iter;
structOut.errOrig = errOrig;
return structOut;
}