package org.simantics.scl.runtime.generation;

import java.io.File;
import java.io.FileOutputStream;
import java.io.PrintStream;
import java.net.URL;

public class GenerateFunctions {
    
    public static final String PACKAGE = "org.simantics.scl.runtime.function";
    public static final int MAX_ARITY = 8;
    
    public static final String HEADER =
            "/**\r\n"
          + " * This code is generated in " + GenerateFunctions.class.getName() + ".\r\n"
          + " * Do not edit manually!\r\n"
          + " */" 
            ;
    
    public static void generateFunctionN(PrintStream p, int n) {
        p.println(HEADER);
        p.println("package " + PACKAGE + ";");
        p.println();        
        p.print("public interface Function"+n+"<");
        for(int i=0;i<n;++i)
            p.print("P" + i + ",");
        p.println("R> {");
        p.print("    R apply(");
        for(int i=0;i<n;++i) {
            if(i>0)
                p.print(", ");
            p.print("P"+i+" p" + i);
        }
        p.println(");");
        p.println("}");
    }
    
    public static void generateFunctionN(PrintStream p) {
        p.println(HEADER);
        p.println("package " + PACKAGE + ";");
        p.println();
        p.println("public interface FunctionN {");
        p.println("    Object applyArray(Object ... ps);");
        p.println("}");
    }
    
    public static void generateFunction(PrintStream p) {
        p.println(HEADER);
        p.println("package " + PACKAGE + ";");
        p.println();
        //p.println("@SuppressWarnings(\"all\")");
        p.print("public interface Function<");
        for(int k=0;k<MAX_ARITY;++k) {
            p.print("P"+k+",");
        }
        for(int k=1;k<=MAX_ARITY;++k) {
            p.print("R"+k);
            if(k < MAX_ARITY)
                p.print(",");
        }
        p.println("> extends");        
        for(int k=1;k<=MAX_ARITY;++k) {
            p.print("    Function" + k + "<");
            for(int i=0;i<k;++i) {
                p.print("P"+i+",");
            }
            p.println("R"+k+">,");
        }
        p.println("    FunctionN {");
        p.println("}");
    }
    
    public static void generateFunctionImplN(PrintStream p, int n) {
        p.println(HEADER);
        p.println("package " + PACKAGE + ";");
        p.println();
        p.println("import java.util.Arrays;");
        p.println();
        p.println("@SuppressWarnings(\"all\")");
        p.print("public abstract class FunctionImpl" + n + "<");
        for(int i=0;i<n;++i)
            p.print("P" + i + ",");
        p.print("R> implements Function<");
        for(int i=0;i<n;++i)
            p.print("P" + i + ",");
        for(int i=n;i<MAX_ARITY;++i)
            p.print("Object,");
        for(int i=1;i<=MAX_ARITY;++i) {
            if(i == n)
                p.print("R");
            else 
                p.print("Object");
            if(i < MAX_ARITY)
                p.print(",");
        }
        p.println("> {");
        for(int k=1;k<Math.min(n, MAX_ARITY+1);++k) {
            p.println("    @Override");
            p.print("    public Object apply(");
            for(int i=0;i<k;++i) {
                if(i>0)
                    p.print(", ");
                p.print("Object p" + i);
            }
            p.println(") {");
            p.print("        return new UnsaturatedFunction" + k + "(this");
            for(int i=0;i<k;++i)
                p.print(", p" + i);
            p.println(");");
            p.println("    }");
            p.println();
        }
        {
            if(n <= MAX_ARITY)
                p.println("    @Override");
            p.print("    public abstract R apply(");
            for(int i=0;i<n;++i) {
                if(i>0)
                    p.print(", ");
                p.print("P"+i+" p" + i);
            }
            p.println(");");
            p.println();
        }
        for(int k=n+1;k<=MAX_ARITY;++k) {
            p.println("    @Override");
            p.print("    public Object apply(");
            for(int i=0;i<k;++i) {
                if(i>0)
                    p.print(", ");
                p.print("Object p" + i);
            }
            p.println(") {");
            p.println("        try {");
            p.print("            return ((Function)apply(");
            for(int i=0;i<n;++i) {
                if(i>0)
                    p.print(", ");
                p.print("(P"+i+")p" + i);
            }
            p.print(")).apply(");
            for(int i=n;i<k;++i) {
                if(i>n)
                    p.print(", ");
                p.print("p" + i);
            }
            p.println(");");
            p.println("        } catch(ClassCastException e) {");
            p.println("            throw new CalledWithTooManyParameters();");
            p.println("        }");
            p.println("    }");
            p.println();
        }
        p.println("    @Override");
        p.println("    public Object applyArray(Object ... ps) {");
        p.println("        switch(ps.length) {");
        for(int k=0;k<=MAX_ARITY+n;++k) {
            p.println("        case " + k + ":");
            if(k==0)
                p.println("            return this;");
            else if(k < n) {
                p.print("            return new UnsaturatedFunction" + k + "(this");
                for(int i=0;i<k;++i)
                    p.print(", ps[" + i + "]");
                p.println(");");
            }
            else if(k == n) {
                p.print("            return apply(");
                for(int i=0;i<k;++i) {
                    if(i > 0)
                        p.print(", ");
                    p.print("(P"+i+")ps[" + i + "]");
                }
                p.println(");");
            }
            else if(k <= n + MAX_ARITY) {
                p.println("            try {");
                p.print("                return ((Function)apply(");
                for(int i=0;i<n;++i) {
                    if(i>0)
                        p.print(", ");
                    p.print("(P"+i+")ps[" + i + "]");
                }
                p.print(")).apply(");
                for(int i=n;i<k;++i) {
                    if(i>n)
                        p.print(", ");
                    p.print("ps[" + i + "]");
                }
                p.println(");");
                p.println("            } catch(ClassCastException e) {");
                p.println("                throw new CalledWithTooManyParameters();");
                p.println("            }");
            }
        }
        {
            p.println("        default:");
            p.println("            try {");
            p.print("                return ((Function)apply(");
            for(int i=0;i<n;++i) {
                if(i>0)
                    p.print(", ");
                p.print("(P"+i+")ps[" + i + "]");
            }
            p.println(")).apply(Arrays.copyOfRange(ps, "+n+", ps.length));");
            p.println("            } catch(ClassCastException e) {");
            p.println("                throw new CalledWithTooManyParameters();");
            p.println("            }");
        }
        p.println("        }");
        p.println("    }");
        p.println("}");
    }
    
    public static void generateFunctionImplN(PrintStream p) {
        p.println(HEADER);
        p.println("package " + PACKAGE + ";");
        p.println();
        p.println("import java.util.Arrays;");
        p.println();
        p.println("@SuppressWarnings(\"all\")");
        p.println("public abstract class FunctionImplN implements Function {");
        p.println("    int arity;");
        p.println();
        p.println("    public FunctionImplN(int arity) {");
        p.println("        if(arity < 1)");
        p.println("            throw new IllegalArgumentException();");
        p.println("        this.arity = arity;");
        p.println("    }");
        p.println();
        for(int k=1;k<=MAX_ARITY;++k) {
            p.println("    @Override");
            p.print("    public Object apply(");
            for(int i=0;i<k;++i) {
                if(i>0)
                    p.print(", ");
                p.print("Object p" + i);
            }
            p.println(") {");
            p.println("        try {");
            p.println("            switch(arity) {");
            for(int i=1;i<k;++i) {
                p.print("            case " + i + ": return ((Function)doApply(");
                for(int j=0;j<i;++j) {
                    if(j>0)
                        p.print(", ");
                    p.print("p" + j);
                }
                p.print(")).apply(");
                for(int j=i;j<k;++j) {
                    if(j>i)
                        p.print(", ");
                    p.print("p" + j);
                }
                p.println(");");
            }
            p.print("            case " + k + ": return doApply(");
            for(int i=0;i<k;++i) {
                if(i>0)
                    p.print(", ");
                p.print("p" + i);
            }
            p.println(");");
            p.print("            default: return new UnsaturatedFunction" + k + "(this");
            for(int i=0;i<k;++i)
                p.print(", p" + i);
            p.println(");");
            p.println("            }");
            p.println("        } catch(ClassCastException e) {");
            p.println("            throw new CalledWithTooManyParameters();");
            p.println("        }");
            p.println("    }");
            p.println();
        }
        p.println("    public abstract Object doApply(Object ... ps);");
        p.println();
        p.println("    @Override");
        p.println("    public Object applyArray(Object ... ps) {");
        p.println("        if(ps.length == arity)");
        p.println("            return doApply(ps);");
        p.println("        else if(ps.length < arity)");
        p.println("            return new UnsaturatedFunctionN(this, ps);");
        p.println("        else /* ps.length > arity */ {");
        p.println("            try {");
        p.println("                return ((Function)doApply(Arrays.copyOf(ps, arity))).applyArray(Arrays.copyOfRange(ps, arity, ps.length));");
        p.println("            } catch(ClassCastException e) {");
        p.println("                throw new CalledWithTooManyParameters();");
        p.println("            }");
        p.println("        }");
        p.println("    }");
        p.println("}");
    }
    
    public static void generateUnsaturatedFunctionN(PrintStream p, int n) {
        p.println(HEADER);
        p.println("package " + PACKAGE + ";");
        p.println();
        p.println("@SuppressWarnings(\"all\")");
        p.println("public class UnsaturatedFunction" + n + " implements Function {");
        p.println("    private final Function f;");
        for(int i=0;i<n;++i)
            p.println("    private final Object p" + i + ";");
        p.println();
        p.print("    public UnsaturatedFunction" + n + "(Function f");
        for(int i=0;i<n;++i)
            p.print(", Object p" + i);
        p.println(") {");
        p.println("        this.f = f;");
        for(int i=0;i<n;++i)
            p.println("        this.p" + i + " = p" + i + ";");
        p.println("    }");
        p.println();
        for(int k=1;k<=MAX_ARITY;++k) {
            p.println("    @Override");
            p.print("    public Object apply(");
            for(int i=0;i<k;++i) {
                if(i>0)
                    p.print(", ");
                p.print("Object p" + (i + n));
            }
            p.println(") {");
            if(n + k <= MAX_ARITY)
                p.print("        return f.apply(");
            else
                p.print("        return f.applyArray(");
            for(int i=0;i<k+n;++i) {
                if(i>0)
                    p.print(", ");
                p.print("p" + i);
            }
            p.println(");");
            p.println("    }");
            p.println();
        }
        {
            p.println("    @Override");
            p.println("    public Object applyArray(Object ... ps) {");         
            p.println("        Object[] nps = new Object[ps.length + " + n + "];");
            for(int i=0;i<n;++i)
                p.println("        nps[" + i + "] = p" + i + ";");
            p.println("        for(int i=0;i<ps.length;++i)");
            p.println("            nps[i + " + n + "] = ps[i];");
            p.println("        return f.applyArray(nps);");
            p.println("    }");
            p.println();
        }
        {
            p.println("    @Override");
            p.println("    public String toString() {");
            p.println("        StringBuilder sb = new StringBuilder();");
            p.println("        sb.append(\"(\").append(f);");
            for(int i=0;i<n;++i)
                p.println("        sb.append(\" \").append(p"+i+");");
            p.println("        sb.append(\")\");");
            p.println("        return sb.toString();");
            p.println("    }");
            p.println();
        }
        {
            p.println("    @Override");
            p.println("    public int hashCode() {");
            p.println("        int result = f.hashCode();");
            for(int i=0;i<n;++i)
                p.println("        result = 31 * result + (p"+i+" == null ? 0 : p"+i+".hashCode());");
            p.println("        return result;");
            p.println("    }");
            p.println();
        }
        {
            p.println("    @Override");
            p.println("    public boolean equals(Object obj) {");
            p.println("        if (this == obj)");
            p.println("            return true;");
            p.println("        if (obj == null)");
            p.println("            return false;");
            p.println("        if (getClass() != obj.getClass())");
            p.println("            return false;");
            p.println("        UnsaturatedFunction"+n+" other = (UnsaturatedFunction"+n+") obj;");
            p.println("        if(!f.equals(other.f))");
            p.println("            return false;");
            for(int i=0;i<n;++i) {
                p.println("        if(p"+i+" == null) {");
                p.println("            if (other.p"+i+" != null)");
                p.println("                return false;");
                p.println("        } else if (!p"+i+".equals(other.p"+i+"))");
                p.println("            return false;");
            }
            p.println("        return true;");
            p.println("    }");
            p.println();
        }
        p.println("}");
    }
    
    public static void generateUnsaturatedFunctionN(PrintStream p) {
        p.println(HEADER);
        p.println("package " + PACKAGE + ";");
        p.println();
        p.println("import java.util.Arrays;");
        p.println();
        p.println("@SuppressWarnings(\"all\")");
        p.println("public class UnsaturatedFunctionN implements Function {");
        p.println("    private final Function f;");
        p.println("    private final Object[] ps;");
        p.println();
        p.println("    public UnsaturatedFunctionN(Function f, Object ... ps) {");
        p.println("        this.f = f;");
        p.println("        this.ps = ps;");
        p.println("    }");
        p.println();
        for(int k=1;k<=MAX_ARITY;++k) {
            p.println("    @Override");
            p.print("    public Object apply(");
            for(int i=1;i<=k;++i) {
                if(i>1)
                    p.print(", ");
                p.print("Object p" + i);
            }
            p.println(") {");
            p.println("        Object[] nps = new Object[ps.length + " + k + "];");
            p.println("        System.arraycopy(ps, 0, nps, 0, ps.length);");
            for(int i=1;i<=k;++i)
                p.println("        nps[ps.length+" + (i-1) + "] = p" + i+ ";");
            p.println("        return f.applyArray(nps);");
            p.println("    }");
            p.println();
        }
        {
            p.println("    @Override");
            p.println("    public Object applyArray(Object ... ops) {");
            p.println("        Object[] nps = new Object[ps.length + ops.length];");
            p.println("        System.arraycopy(ps, 0, nps, 0, ps.length);");
            p.println("        System.arraycopy(ops, 0, nps, ps.length, ops.length);");
            p.println("        return f.applyArray(nps);");
            p.println("    }");
            p.println();
        }
        {
            p.println("    @Override");
            p.println("    public String toString() {");
            p.println("        StringBuilder sb = new StringBuilder();");
            p.println("        sb.append(\"(\").append(f);");
            p.println("        for (Object p : ps)");
            p.println("            sb.append(\" \").append(p);");
            p.println("        sb.append(\")\");");
            p.println("        return sb.toString();");
            p.println("    }");
            p.println();
        }
        {
            p.println("    @Override");
            p.println("    public int hashCode() {");
            p.println("        return f.hashCode() + 31 * Arrays.hashCode(ps);");
            p.println("    }");
            p.println();
        }
        {
            p.println("    @Override");
            p.println("    public boolean equals(Object obj) {");
            p.println("        if (this == obj)");
            p.println("            return true;");
            p.println("        if (obj == null)");
            p.println("            return false;");
            p.println("        if (getClass() != obj.getClass())");
            p.println("            return false;");
            p.println("        UnsaturatedFunctionN other = (UnsaturatedFunctionN) obj;");
            p.println("        return f.equals(other.f) && Arrays.equals(ps, other.ps);");
            p.println("    }");
            p.println();
        }
        p.println("}");
    }
    
    public static void main(String[] args) throws Exception {
        URL url = GenerateFunctions.class.getResource(".");
        File dir = new File(url.getPath());
        while(!new File(dir, "src").exists())
            dir = dir.getParentFile();        
        dir = new File(dir, "src");
        dir = new File(dir, PACKAGE.replace('.', '/'));
        dir.mkdirs();
        
        for(int n=1;n<=MAX_ARITY;++n) {
            PrintStream ps = 
                new PrintStream(new FileOutputStream(new File(dir, "Function"+n+".java")));
            generateFunctionN(ps, n);
        }
        {
            PrintStream ps = 
                new PrintStream(new FileOutputStream(new File(dir, "FunctionN.java")));
            generateFunctionN(ps);
        }
        {
            PrintStream ps = 
                new PrintStream(new FileOutputStream(new File(dir, "Function.java")));
            generateFunction(ps);
        }
        
        for(int n=1;n<=MAX_ARITY;++n) {
            PrintStream ps = 
                new PrintStream(new FileOutputStream(new File(dir, "FunctionImpl"+n+".java")));
            generateFunctionImplN(ps, n);
        }
        {
            PrintStream ps = 
                new PrintStream(new FileOutputStream(new File(dir, "FunctionImplN.java")));
            generateFunctionImplN(ps);
        }
        for(int n=1;n<=MAX_ARITY;++n) {
            PrintStream ps = 
                new PrintStream(new FileOutputStream(new File(dir, "UnsaturatedFunction"+n+".java")));
            generateUnsaturatedFunctionN(ps, n);
        }
        {
            PrintStream ps = 
                new PrintStream(new FileOutputStream(new File(dir, "UnsaturatedFunctionN.java")));
            generateUnsaturatedFunctionN(ps);
        }
    }    
}
