package org.simantics.spreadsheet.solver.formula;

import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
import org.simantics.spreadsheet.SpreadsheetMatrix;
import org.simantics.spreadsheet.Spreadsheets;
import org.simantics.spreadsheet.solver.formula.parser.ast.AstArgList;

public class LinestFormulaFunction implements CellFormulaFunction<SpreadsheetMatrix> {

    @Override
    public SpreadsheetMatrix evaluate(CellValueVisitor visitor, AstArgList args) {
        if (args.values.size() != 4)
            throw new IllegalStateException();

        Object ys = args.values.get(0).accept(visitor);
        Object xs = args.values.get(1).accept(visitor);

        if (xs instanceof SpreadsheetMatrix) {

            OLSMultipleLinearRegression reg = new OLSMultipleLinearRegression();

            SpreadsheetMatrix xsm = (SpreadsheetMatrix) xs;
            SpreadsheetMatrix ysm = (SpreadsheetMatrix) ys;

            if (xsm.getWidth() > 1) {

                double y[] = new double[ysm.values.length];
                double x[][] = new double[ysm.values.length][xsm.getWidth()];

                for (int i = 0; i < ysm.values.length; i++) {
                    y[i] = Spreadsheets.asNumber(ysm.values[i]);
                    for (int j = 0; j < xsm.getWidth(); j++) {
                        x[i][j] = Spreadsheets.asNumber(xsm.get(i, j));
                    }
                }

                reg.newSampleData(y, x);

                int width = xsm.getWidth() + 1;
                SpreadsheetMatrix result = new SpreadsheetMatrix(width, 5);
                double[] pars = reg.estimateRegressionParameters();
                for (int i = 0; i < width; i++)
                    result.set(0, width - i - 1, pars[i]);
                double[] errs = reg.estimateRegressionParametersStandardErrors();
                for (int i = 0; i < width; i++)
                    result.set(1, width - i - 1, errs[i]);

                double sstotal = reg.calculateTotalSumOfSquares();
                double ssresid = reg.calculateResidualSumOfSquares();
                double ssreg = sstotal - ssresid;
                double F = 0;
                double r2 = ssreg / sstotal;
                double sey = reg.estimateRegressionStandardError();
                double df = width;

                result.set(2, 0, r2);
                result.set(2, 1, sey);
                result.set(3, 0, F);
                result.set(3, 1, df);
                result.set(4, 0, ssreg);
                result.set(4, 1, ssresid);
                result.set(2,2, FormulaError2.NA.getString());
                result.set(3,2, FormulaError2.NA.getString());
                result.set(4,2, FormulaError2.NA.getString());

                return result;

            }

        }

        return null;
    }

}
