/*
 * Decompiled with CFR 0.152.
 */
package com.modelengineers.MoRe_elk.alg.layered.p3order.mes;

import com.google.common.collect.Lists;
import com.google.ortools.sat.BoolVar;
import com.google.ortools.sat.Literal;
import com.modelengineers.MoRe_elk.alg.layered.graph.LGraph;
import com.modelengineers.MoRe_elk.alg.layered.graph.LNode;
import com.modelengineers.MoRe_elk.alg.layered.graph.Layer;
import com.modelengineers.MoRe_elk.alg.layered.ortools.CpSolverWrapper;
import com.modelengineers.MoRe_elk.alg.layered.p3order.mes.CEdge;
import com.modelengineers.MoRe_elk.alg.layered.p3order.mes.ConstraintLikeObjective;
import com.modelengineers.MoRe_elk.alg.layered.p3order.mes.IObjective;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public final class BinaryOrderProblem {
    private static final int COMPLEXITY_THRESHOLD = 1500;
    private static final int NUM_SOLVER_SEARCH_WORKERS = 8;
    private static final int NUM_NODES_BELOW_WHICH_TO_USE_ONE_SOLVER_SEARCH_WORKER = 20;
    private LGraph graph;
    private CpSolverWrapper solverWrapper = new CpSolverWrapper();
    private List<BoolVar[][]> isBelow = Lists.newArrayList();
    private List<IObjective> objectives = Lists.newArrayList();
    private int complexity = 0;

    public BinaryOrderProblem(LGraph graph) {
        this.graph = graph;
        graph.getLayers().stream().forEach(layer -> this.addIsBelowVariables((Layer)layer));
    }

    public CpSolverWrapper getSolverWrapper() {
        return this.solverWrapper;
    }

    private void addIsBelowVariables(Layer layer) {
        BoolVar[][] isBelowInLayer = this.initIsBelowInLayer(layer);
        this.ensureTransitivity(isBelowInLayer);
        this.isBelow.add(isBelowInLayer);
    }

    private BoolVar[][] initIsBelowInLayer(Layer layer) {
        int numOfNodesInLayer = layer.getNodes().size();
        BoolVar[][] isBelowInLayer = new BoolVar[numOfNodesInLayer][numOfNodesInLayer];
        int i = 0;
        while (i < numOfNodesInLayer - 1) {
            LNode nodeI = layer.getNodes().get(i);
            int j = i + 1;
            while (j < numOfNodesInLayer) {
                LNode nodeJ = layer.getNodes().get(j);
                isBelowInLayer[i][j] = this.initIsBelow(nodeI, nodeJ);
                ++j;
            }
            ++i;
        }
        return isBelowInLayer;
    }

    private BoolVar initIsBelow(LNode nodeI, LNode nodeJ) {
        return this.solverWrapper.newBoolVar(String.valueOf(nodeI.toString()) + "_isBelow_" + nodeJ.toString());
    }

    private void ensureTransitivity(BoolVar[][] isBelowInLayer) {
        int numOfNodes = isBelowInLayer.length;
        int i = 0;
        while (i < numOfNodes) {
            int j = i + 1;
            while (j < numOfNodes) {
                int k = j + 1;
                while (k < numOfNodes) {
                    this.ensureTransitivity(isBelowInLayer, i, j, k);
                    ++k;
                }
                ++j;
            }
            ++i;
        }
    }

    private void ensureTransitivity(BoolVar[][] isBelowInLayer, int i, int j, int k) {
        this.solverWrapper.addOr(new Literal[]{isBelowInLayer[i][j].not(), isBelowInLayer[j][k].not(), isBelowInLayer[i][k]});
        this.solverWrapper.addOr(new Literal[]{isBelowInLayer[i][j], isBelowInLayer[j][k], isBelowInLayer[i][k].not()});
    }

    public void setFirstNodeDirectlyAboveSecondNode(LNode upperNode, LNode lowerNode) {
        this.setFirstNodeAboveSecondNode(upperNode, lowerNode);
        this.ensureNoOtherNodeIsBetween(upperNode, lowerNode);
    }

    public void setFirstNodeAboveSecondNode(LNode upperNode, LNode lowerNode) {
        Literal lowerBelowUpperNode = this.isBelow(lowerNode, upperNode);
        this.solverWrapper.addEquality(lowerBelowUpperNode, 1L);
    }

    public void setFirstNodeAboveSecondNodeInfeasibleSafe(LNode upperNode, LNode lowerNode) {
        Literal lowerBelowUpperNode = this.isBelow(upperNode, lowerNode);
        this.addObjective(new ConstraintLikeObjective(lowerBelowUpperNode));
    }

    public Literal isBelow(LNode nodeA, LNode nodeB) {
        BoolVar[][] isBelowInLayer = this.getIsBelowInLayer(this.getLayer(nodeA));
        return this.isBelow(nodeA, nodeB, isBelowInLayer);
    }

    private BoolVar[][] getIsBelowInLayer(Layer layer) {
        return this.isBelow.get(layer.getIndex());
    }

    private Layer getLayer(LNode node) {
        return this.graph.getLayers().stream().filter(l -> l.getNodes().contains(node)).findFirst().orElse(null);
    }

    private Literal isBelow(LNode nodeA, LNode nodeB, BoolVar[][] isBelowInLayer) {
        int ixB;
        int ixA = this.getNodeIndex(nodeA);
        return ixA < (ixB = this.getNodeIndex(nodeB)) ? isBelowInLayer[ixA][ixB] : isBelowInLayer[ixB][ixA].not();
    }

    private int getNodeIndex(LNode node) {
        return this.getLayer(node).getNodes().indexOf(node);
    }

    private void ensureNoOtherNodeIsBetween(LNode upperNode, LNode lowerNode) {
        List<LNode> nodesInLayer = this.getNodesInLayer(upperNode);
        for (LNode node : nodesInLayer) {
            if (node == upperNode || node == lowerNode) continue;
            Literal nodeBelowUpperNode = this.isBelow(node, upperNode);
            Literal nodeBelowLowerNode = this.isBelow(node, lowerNode);
            this.solverWrapper.addEquality(nodeBelowLowerNode, nodeBelowUpperNode);
        }
    }

    private List<LNode> getNodesInLayer(LNode node) {
        return this.getLayer(node).getNodes();
    }

    public Literal sourceOfFirstEdgeIsBelowSourceOfSecondEdge(CEdge a, CEdge b) {
        if (a.getSourceNode() == b.getSourceNode()) {
            return a.getSource().isAbove(b.getSource()) ? this.solverWrapper.getFalse() : this.solverWrapper.getTrue();
        }
        return this.firstNodeIsBelowSecondNode(a.getSourceNode(), b.getSourceNode());
    }

    public Literal sourceOfFirstEdgeIsBelowTargetOfSecondEdge(CEdge a, CEdge b) {
        if (a.getSourceNode() == b.getTargetNode()) {
            return a.getSource().isAbove(b.getTarget()) ? this.solverWrapper.getFalse() : this.solverWrapper.getTrue();
        }
        return this.firstNodeIsBelowSecondNode(a.getSourceNode(), b.getTargetNode());
    }

    public Literal targetOfFirstEdgeIsBelowSourceOfSecondEdge(CEdge a, CEdge b) {
        if (a.getTargetNode() == b.getSourceNode()) {
            return a.getTarget().isAbove(b.getSource()) ? this.solverWrapper.getFalse() : this.solverWrapper.getTrue();
        }
        return this.firstNodeIsBelowSecondNode(a.getTargetNode(), b.getSourceNode());
    }

    public Literal targetOfFirstEdgeIsBelowTargetOfSecondEdge(CEdge a, CEdge b) {
        if (a.getTargetNode() == b.getTargetNode()) {
            return a.getTarget().isAbove(b.getTarget()) ? this.solverWrapper.getFalse() : this.solverWrapper.getTrue();
        }
        return this.firstNodeIsBelowSecondNode(a.getTargetNode(), b.getTargetNode());
    }

    public Literal firstNodeIsBelowSecondNode(LNode firstNode, LNode secondNode) {
        assert (firstNode.getLayer().equals(secondNode.getLayer()));
        return this.isBelow(firstNode, secondNode);
    }

    public void addObjectives(List<IObjective> objectivesToAdd) {
        objectivesToAdd.forEach(objective -> this.addObjective((IObjective)objective));
    }

    public void addObjective(IObjective objective) {
        if (this.isOptimizableObjective(objective)) {
            this.addOptimizableObjective(objective);
        }
    }

    private boolean isOptimizableObjective(IObjective objective) {
        return objective != null && objective.getLiteral() != null;
    }

    private void addOptimizableObjective(IObjective objective) {
        this.objectives.add(objective);
        if (this.problemBecomesTooComplex(objective)) {
            throw new OrderProblemComplexityThresholdExceededException();
        }
    }

    private boolean problemBecomesTooComplex(IObjective objective) {
        if (objective.contributesToComplexity()) {
            ++this.complexity;
        }
        return this.complexity > 1500;
    }

    public void findBestOrder() throws InterruptedException {
        this.setMinimizationToWeightedSumOfAllObjectives();
        this.setSolverParameters();
        this.solverWrapper.minimize();
    }

    private void setMinimizationToWeightedSumOfAllObjectives() {
        List<Literal> variables = this.objectives.stream().map(o -> o.getLiteral()).collect(Collectors.toList());
        List<Long> coeffs = this.objectives.stream().map(o -> o.getCoeff()).collect(Collectors.toList());
        this.solverWrapper.addObjectiveFactors(variables, coeffs);
    }

    private void setSolverParameters() {
        boolean useOneWorker = this.graph.getNumTrueNodes() < 20L;
        this.solverWrapper.getParameters().setNumSearchWorkers(useOneWorker ? 1 : 8);
    }

    public void transferOrderOnNodeIndices() {
        this.graph.getLayers().forEach(l -> this.transferOrderOnNodeIndices((Layer)l));
    }

    public void transferOrderOnNodeIndices(Layer layer) {
        Map<LNode, Integer> newIndices = layer.getNodes().stream().collect(Collectors.toMap(n -> n, n -> this.getNumNodesAboveInLayer((LNode)n)));
        newIndices.keySet().forEach(n -> {
            LNode lNode = layer.getNodes().set((Integer)newIndices.get(n), (LNode)n);
        });
    }

    private int getNumNodesAboveInLayer(LNode node) {
        return (int)this.getNodesInLayer(node).stream().filter(n -> n != node).filter(n -> this.solverWrapper.getBooleanValue(this.isBelow(node, (LNode)n))).count();
    }

    public class OrderProblemComplexityThresholdExceededException
    extends RuntimeException {
        public OrderProblemComplexityThresholdExceededException() {
            super("Threshold for complexity of order problem exceeded.");
        }
    }
}

