diff --git a/UTSR/UTSR_ModS2.cpp b/UTSR/UTSR_ModS2.cpp new file mode 100644 index 0000000..6576043 --- /dev/null +++ b/UTSR/UTSR_ModS2.cpp @@ -0,0 +1,307 @@ +#include +#include +#include + +struct ScanData { + std::vector> AscanValues; + std::vector timeScale; + struct CscanData { + std::string WaveType; + double Cl; + double Cs; + double Frequency; + double ProbeDiameter; + std::vector X; + std::vector Y; + } CscanData; +}; + +struct UTSRResult { + std::vector> output; + std::vector roiX; + std::vector roiZ; + std::string methodMsg; + std::string idTitle; + double Wavelength; + double alpha; + double UTSR_tol; + double timeElapsed; + int numIter; + std::vector cg_iter; + std::vector cg_tol_iter; + std::vector errUTSR; + std::vector difUTSR; + std::vector Jf; + std::vector beta_iter; + std::vector errOrig; +}; + +UTSRResult UTSR_ModS2(ScanData scanData, std::vector varargin) { + UTSRResult structOut; + // UTSR image reconstruction algorithm + auto Model = [&](std::vector x, std::vector W, double dt, double dx, double ERMv, double alpha, std::vector szData, double tau0, std::vector> Mod, std::vector> Filt, int Rx) { + // Applies direct operator + std::vector> g = ModS2(reshape(x, szData), dt, dx, ERMv, tau0, Mod, Rx); + + // Applies adjoint operator + std::vector y = ModS2T(g , dt, dx, ERMv, tau0, Filt, Rx); + + // Applies regularization + if(alpha != 0) { + // L1 norm if W != I and L = I, + // Tikhonov if W = I and L = I. + y = y + alpha*alpha*W*x; + } else { + // No regularization, least squares solution + y = y; + } + return y; + }; + + // Parser verification functions + auto scanDataValidationFcn = [&](ScanData x) { + return (x.AscanValues.size() > 0) && (x.AscanValues[0].size() > 0) && (x.timeScale.size() > 0) && (x.CscanData.X.size() > 0) && (x.CscanData.Y.size() > 0); + }; + std::vector validOutput = {"raw", "normalized", "dB"}; + auto checkOutput = [&](std::string x) { + for (auto& s : validOutput) { + if (s == x) { + return true; + } + } + return false; + }; + std::vector validPostproc = {"none","rectified","env"}; + auto checkPostproc = [&](std::string x) { + for (auto& s : validPostproc) { + if (s == x) { + return true; + } + } + return false; + }; + std::vector validPreproc = {"none", "env"}; + auto checkPreproc = [&](std::string x) { + for (auto& s : validPreproc) { + if (s == x) { + return true; + } + } + return false; + }; + std::vector validTrFreqSp = {"gaussian","cossquare", "610060"}; + auto checkTrFreqSp = [&](std::string x) { + for (auto& s : validTrFreqSp) { + if (s == x) { + return true; + } + } + return false; + }; + auto checkSysModel = [&](std::vector> x) { + return (x.size() > 0) && (x[0].size() > 0); + }; + auto checkfOrig = [&](std::vector> x) { + return (x.size() > 0) && (x[0].size() > 0); + }; + + // Verify scanData parameter + assert(scanDataValidationFcn(scanData)); + + // Get scan data informations + double c; + if (scanData.CscanData.WaveType == "L") { + c = scanData.CscanData.Cl; + } else { + c = scanData.CscanData.Cs; + } + double ERMv = c/2; + + std::vector scanX = scanData.CscanData.X; + std::vector scanY = scanData.CscanData.Y; + int tamX = scanX.size(); + int tamY = scanY.size(); + double deltaScanX = scanX[1] - scanX[0]; + + std::vector scanZ = scanData.timeScale; + double deltaScanZ = scanZ[1] - scanZ[0]; + + double apertAngle = asin(0.51*c/(scanData.CscanData.Frequency*1e6*scanData.CscanData.ProbeDiameter*1e-3)); + double tanApertAngle = tan(apertAngle); + + // Create parser to process arguments + // TODO: Implement parser + + // Define ROI + // TODO: Implement ROI definition + + // Start timer + std::cout << "$$$$$$$ Starting UTSR algorithm (alpha = " << inpP.Results.alpha << ") $$$$$$$" << std::endl; + std::cout << "---- Title: " << inpP.Results.title << " ----" << std::endl; + auto tAlg = std::chrono::steady_clock::now(); + + // Starting UTSR algorithm + int ci = ((yi - 1)*tamX); + std::vector> S_t_u = scanData.AscanValues[idxMinZ:idxMaxZ, (colIniX:colFinX)+ci]; + std::vector> fOrig; + if (scanData.AscanValues.size() == inpP.Results.fOrig.size()) { + fOrig = inpP.Results.fOrig[idxMinZ:idxMaxZ, (colIniX:colFinX)+ci]; + } else { + fOrig = std::vector>(); + } + double tau0 = scanData.timeScale[idxMinZ]*1e-6; + switch (inpP.Results.Preproc) { + case "env": + S_t_u = abs(hilbert(S_t_u)); + break; + } + + // Parameters needes for Stolt Transform + std::vector szData = {M, N}; + double dt = deltaScanZ/ERMv; + double dx = deltaRoiX; + + // Generating model and match filter + scanData.AscanValues = S_t_u; + scanData.CscanData.X = scanData.CscanData.X[(colIniX:colFinX)+ci]; + scanData.timeScale = scanData.timeScale[idxMinZ:idxMaxZ]; + std::vector> Mod; + std::vector> Filt; + if (inpP.Results.SysModel.empty()) { + std::tie(Mod, std::ignore) = GenerateModelFilter(scanData, "TrFreqSp", inpP.Results.TrFreqSp, Rx); + } else { + Mod = inpP.Results.SysModel; + } + Filt = conj(Mod); + + // Reconstruction parameters + double alpha = inpP.Results.alpha; + double beta = 1; + double tol = inpP.Results.tol; + double cg_tol = inpP.Results.cg_tol; + + // Initializing algorithm + std::vector HTg = ModS2T(S_t_u, dt, dx, ERMv, tau0, Filt, Rx); + std::vector fTemp = std::vector(HTg.size(), 0); + int numIter = 1; + int stagCount = 0; + int maxIter = inpP.Results.maxIter; + + while (true) { + if (numIter == 1) { + std::vector W = std::vector(fTemp.size(), 1); + } else { + if (!std::isinf(beta)) { + while (true) { + std::vector W = 1./(abs(fTemp) + beta); + double n1fTemp = norm(fTemp, 1); + double n1fTempAppr = norm(sqrt(W).*fTemp)^2; + double n1Err = (n1fTemp-n1fTempAppr)/n1fTemp; + if (abs(n1Err-cg_tol) > cg_tol/10) { + beta = beta / (n1Err/cg_tol); + } else { + break; + } + } + } + } + + beta_iter[numIter] = beta; + std::cout << "cg_tol = " << cg_tol << " | "; + cg_tol_iter[numIter] = cg_tol; + std::vector f; + int flag; + int iter; + std::tie(f, flag, std::ignore, iter) = pcg( + [&](std::vector x) { return Model(x, W, dt, dx, ERMv, alpha, szData, tau0, Mod, Filt, Rx); }, + HTg, cg_tol, M*N, std::vector(), std::vector(), fTemp); + cg_iter[numIter] = iter; + + if (inpP.Results.NegCutOff) { + for (int i = 0; i < f.size(); i++) { + if (f[i] < 0) { + f[i] = 0; + } + } + } + + std::vector Hf = ModS2(reshape(f, szData), dt, dx, ERMv, tau0, Mod, Rx); + tL2[numIter] = (1/2)*(norm(S_t_u(:) - Hf(:),2))^2; + tL1[numIter] = (alpha^2)*norm(f(:),1); + J_f[numIter] = tL2[numIter] + tL1[numIter]; + + errUTSR[numIter] = norm(f - fTemp); + + if (numIter > 1) { + difUTSR[numIter] = abs(J_f[numIter] - J_f[numIter-1])/J_f[numIter-1]; + } else { + difUTSR[numIter] = 1; + } + + if (fOrig.size() > 0) { + errOrig[numIter] = norm(f - fOrig(:)); + } else { + errOrig[numIter] = errUTSR[numIter]; + } + + std::cout << "errUTSR = " << errUTSR[numIter] << " | difUTSR = " << difUTSR[numIter] << " | J_f = " << J_f[numIter] << " | iter = " << iter << " | flag = " << flag << std::endl; + + if (alpha == 0) { + break; + } + + if (std::isinf(beta)) { + break; + } + + if (errUTSR[numIter] < tol) { + break; + } + + // Tolerância UTSR foi maior que na iteração anterior + if (numIter > 1 && (errUTSR[numIter] >= errUTSR[numIter-1])) { + stagCount = stagCount + 1; + if (stagCount >= inpP.Results.maxStagCount) { + break; + } + } + + if ((maxIter > 0) && (numIter > maxIter)) { + break; + } + + fTemp = f; + numIter = numIter + 1; + } + + // Reshape found solution as a matrix + structOut.output = reshape(f, szData); + + // Stop timer + auto timeElapsed = std::chrono::duration_cast(std::chrono::steady_clock::now() - tAlg).count(); + structOut.output(isnan(structOut.output)) = 0; + + // Creates output structures + structOut.roiX = roiX; + structOut.roiZ = roiZ; + structOut.methodMsg = "UTSR Algorithm"; + structOut.idTitle = inpP.Results.title; + structOut.Wavelength = scanData.CscanData.Wavelength; + // Method parameters + structOut.alpha = alpha; + structOut.UTSR_tol = tol; + // Method results + structOut.timeElapsed = timeElapsed; + structOut.numIter = numIter; + structOut.cg_iter = cg_iter; + structOut.cg_tol_iter = cg_tol_iter; + structOut.errUTSR = errUTSR; + structOut.difUTSR = difUTSR; + structOut.Jf = J_f; + structOut.beta_iter = beta_iter; + structOut.errOrig = errOrig; + + return structOut; +} + +