package org.simantics.scl.compiler.module.coverage;

import java.util.Collection;
import java.util.Map;

import org.simantics.scl.compiler.module.Module;
import org.simantics.scl.runtime.profiling.BranchPoint;

import gnu.trove.map.hash.THashMap;

public class CoverageUtils {

    public static ModuleCoverage getCoverage(String moduleName, THashMap<String, BranchPoint[]> branchPoints) {
        THashMap<String, FunctionCoverage> methodCoverages = new THashMap<String, FunctionCoverage>();
        int totalCodeSize = 0;
        int coveredCodeSize = 0;
        int totalFunctionCount = 0;
        int coveredFunctionCount = 0;
        for(Map.Entry<String, BranchPoint[]> entry : branchPoints.entrySet()) {
            int totalFunctionCodeSize = 0;
            int uncoveredFunctionCodeSize = 0;
            for(BranchPoint branchPoint : entry.getValue()) {
                totalFunctionCodeSize += branchPoint.getCodeSize();
                uncoveredFunctionCodeSize += uncoveredCodeSize(branchPoint);
            }
            int coveredFunctionCodeSize = totalFunctionCodeSize - uncoveredFunctionCodeSize;
            String functionName = entry.getKey();
            methodCoverages.put(functionName,
                    new FunctionCoverage(functionName, totalFunctionCodeSize, coveredFunctionCodeSize));
            totalCodeSize += totalFunctionCodeSize;
            coveredCodeSize += coveredFunctionCodeSize;
            ++totalFunctionCount;
            if(coveredFunctionCodeSize > 0)
                ++coveredFunctionCount;
        }
        
        return new ModuleCoverage(moduleName, methodCoverages,
                totalCodeSize, coveredCodeSize,
                totalFunctionCount, coveredFunctionCount);
    }
    
    public static ModuleCoverage getCoverage(Module module) {
        THashMap<String, BranchPoint[]> branchPoints = module.getBranchPoints();
        if(branchPoints == null)
            return null;
        else
            return getCoverage(module.getName(), branchPoints);
    }
    
    public static CombinedCoverage combineCoverages(THashMap<String,ModuleCoverage> moduleCoverages) {
        int totalCodeSize = 0;
        int coveredCodeSize = 0;
        int totalFunctionCount = 0;
        int coveredFunctionCount = 0;
        for(ModuleCoverage mCov : moduleCoverages.values()) {
            totalCodeSize += mCov.getTotalCodeSize();
            coveredCodeSize += mCov.getCoveredCodeSize();
            totalFunctionCount += mCov.totalFunctionCount;
            coveredFunctionCount += mCov.coveredFunctionCount;
        }
        return new CombinedCoverage(moduleCoverages,
                totalCodeSize, coveredCodeSize,
                totalFunctionCount, coveredFunctionCount);
    }
    
    public static CombinedCoverage getCoverage(Collection<Module> modules) {
        THashMap<String,ModuleCoverage> moduleCoverages =
                new THashMap<String,ModuleCoverage>();
        for(Module module : modules) {
            ModuleCoverage coverage = getCoverage(module);
            if(coverage != null)
                moduleCoverages.put(module.getName(), coverage);
        }
        return combineCoverages(moduleCoverages);
    }

    public static void resetCoverage(Collection<Module> modules) {
        modules.forEach(CoverageUtils::resetCoverage);
    }

    public static void resetCoverage(Module module) {
        THashMap<String, BranchPoint[]> branches = module.getBranchPoints();
        if (branches != null) {
            for (BranchPoint[] points : branches.values()) {
                if (points != null) {
                    for (BranchPoint point : points) {
                        point.resetVisitCountersRecursively();
                    }
                }
            }
        }
    }
    
    static double safeDiv(int a, int b) {
        if(b == 0)
            return 1.0;
        else
            return ((double)a) / b;
    }

    private static int uncoveredCodeSize(BranchPoint branchPoint) {
        if(branchPoint.getVisitCounter() == 0)
            return branchPoint.getCodeSize();
        else {
            int sum = 0;
            for(BranchPoint child : branchPoint.getChildren())
                sum += uncoveredCodeSize(child);
            return sum;
        }
    }
}
