/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.msq.kernel.controller;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import it.unimi.dsi.fastutil.ints.Int2IntAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.IntIterator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.exec.QueryValidator;
import org.apache.druid.msq.indexing.error.MSQFault;
import org.apache.druid.msq.input.InputSpecSlicer;
import org.apache.druid.msq.input.InputSpecSlicerFactory;
import org.apache.druid.msq.input.stage.ReadablePartitions;
import org.apache.druid.msq.kernel.ExtraInfoHolder;
import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
import org.apache.druid.msq.kernel.controller.ControllerStagePhase;
import org.apache.druid.msq.kernel.controller.ControllerStageTracker;
import org.apache.druid.msq.kernel.controller.WorkerInputs;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation;

public class ControllerQueryKernel {
    private final QueryDefinition queryDef;
    private final Map<StageId, ControllerStageTracker> stageTracker = new HashMap<StageId, ControllerStageTracker>();
    private final ImmutableMap<StageId, Set<StageId>> inflowMap;
    private final ImmutableMap<StageId, Set<StageId>> outflowMap;
    private final Map<StageId, Set<StageId>> pendingInflowMap;
    private final Map<StageId, Set<StageId>> pendingOutflowMap;
    private final Set<StageId> readyToRunStages = new HashSet<StageId>();
    private final Set<StageId> effectivelyFinishedStages = new HashSet<StageId>();

    public ControllerQueryKernel(QueryDefinition queryDef) {
        this.queryDef = queryDef;
        this.inflowMap = ImmutableMap.copyOf(ControllerQueryKernel.computeStageInflowMap(queryDef));
        this.outflowMap = ImmutableMap.copyOf(ControllerQueryKernel.computeStageOutflowMap(queryDef));
        this.pendingInflowMap = ControllerQueryKernel.computeStageInflowMap(queryDef);
        this.pendingOutflowMap = ControllerQueryKernel.computeStageOutflowMap(queryDef);
        this.initializeReadyToRunStages();
    }

    public List<StageId> createAndGetNewStageIds(InputSpecSlicerFactory slicerFactory, WorkerAssignmentStrategy assignmentStrategy) {
        Int2IntAVLTreeMap stageWorkerCountMap = new Int2IntAVLTreeMap();
        Int2ObjectAVLTreeMap stagePartitionsMap = new Int2ObjectAVLTreeMap();
        for (ControllerStageTracker stageKernel2 : this.stageTracker.values()) {
            int stageNumber = stageKernel2.getStageDefinition().getStageNumber();
            stageWorkerCountMap.put(stageNumber, stageKernel2.getWorkerInputs().workerCount());
            if (!stageKernel2.hasResultPartitions()) continue;
            stagePartitionsMap.put(stageNumber, (Object)stageKernel2.getResultPartitions());
        }
        this.createNewKernels((Int2IntMap)stageWorkerCountMap, slicerFactory.makeSlicer((Int2ObjectMap<ReadablePartitions>)stagePartitionsMap), assignmentStrategy);
        return this.stageTracker.values().stream().filter(controllerStageTracker -> controllerStageTracker.getPhase() == ControllerStagePhase.NEW).map(stageKernel -> stageKernel.getStageDefinition().getId()).collect(Collectors.toList());
    }

    public List<StageId> getEffectivelyFinishedStageIds() {
        return ImmutableList.copyOf(this.effectivelyFinishedStages);
    }

    public List<StageId> getActiveStages() {
        return ImmutableList.copyOf(this.stageTracker.keySet());
    }

    public StageId getStageId(int stageNumber) {
        return new StageId(this.queryDef.getQueryId(), stageNumber);
    }

    public boolean isDone() {
        return Optional.ofNullable(this.stageTracker.get(this.queryDef.getFinalStageDefinition().getId())).filter(tracker -> ControllerStagePhase.isSuccessfulTerminalPhase(tracker.getPhase())).isPresent() || this.stageTracker.values().stream().anyMatch(tracker -> tracker.getPhase() == ControllerStagePhase.FAILED);
    }

    public void markSuccessfulTerminalStagesAsFinished() {
        for (StageId stageId : this.getActiveStages()) {
            ControllerStagePhase phase = this.getStagePhase(stageId);
            if (!ControllerStagePhase.isSuccessfulTerminalPhase(phase) || phase != ControllerStagePhase.RESULTS_READY) continue;
            this.finishStage(stageId, false);
        }
    }

    public boolean isSuccess() {
        return this.stageTracker.size() == this.queryDef.getStageDefinitions().size() && this.stageTracker.values().stream().allMatch(tracker -> ControllerStagePhase.isSuccessfulTerminalPhase(tracker.getPhase()));
    }

    public Int2ObjectMap<WorkOrder> createWorkOrders(int stageNumber, @Nullable Int2ObjectMap<Object> extraInfos) {
        Int2ObjectAVLTreeMap retVal = new Int2ObjectAVLTreeMap();
        ControllerStageTracker stageKernel = this.getStageKernelOrThrow(this.getStageId(stageNumber));
        WorkerInputs workerInputs = stageKernel.getWorkerInputs();
        IntIterator intIterator = workerInputs.workers().iterator();
        while (intIterator.hasNext()) {
            int workerNumber = (Integer)intIterator.next();
            Object extraInfo = extraInfos != null ? extraInfos.get(workerNumber) : null;
            ExtraInfoHolder extraInfoHolder = stageKernel.getStageDefinition().getProcessorFactory().makeExtraInfoHolder(extraInfo);
            WorkOrder workOrder = new WorkOrder(this.queryDef, stageNumber, workerNumber, workerInputs.inputsForWorker(workerNumber), extraInfoHolder);
            QueryValidator.validateWorkOrder(workOrder);
            retVal.put(workerNumber, (Object)workOrder);
        }
        return retVal;
    }

    private void createNewKernels(Int2IntMap stageWorkerCountMap, InputSpecSlicer slicer, WorkerAssignmentStrategy assignmentStrategy) {
        for (StageId nextStage : this.readyToRunStages) {
            StageDefinition stageDef = this.queryDef.getStageDefinition(nextStage);
            ControllerStageTracker stageKernel = ControllerStageTracker.create(stageDef, stageWorkerCountMap, slicer, assignmentStrategy);
            this.stageTracker.put(nextStage, stageKernel);
        }
        this.readyToRunStages.clear();
    }

    private void initializeReadyToRunStages() {
        Iterator<Map.Entry<StageId, Set<StageId>>> pendingInflowIterator = this.pendingInflowMap.entrySet().iterator();
        while (pendingInflowIterator.hasNext()) {
            Map.Entry<StageId, Set<StageId>> stageToInflowStages = pendingInflowIterator.next();
            if (stageToInflowStages.getValue().size() != 0) continue;
            this.readyToRunStages.add(stageToInflowStages.getKey());
            pendingInflowIterator.remove();
        }
    }

    public StageDefinition getStageDefinition(StageId stageId) {
        return this.getStageKernelOrThrow(stageId).getStageDefinition();
    }

    public ControllerStagePhase getStagePhase(StageId stageId) {
        return this.getStageKernelOrThrow(stageId).getPhase();
    }

    public boolean doesStageHaveResultPartitions(StageId stageId) {
        return this.getStageKernelOrThrow(stageId).hasResultPartitions();
    }

    public ReadablePartitions getResultPartitionsForStage(StageId stageId) {
        return this.getStageKernelOrThrow(stageId).getResultPartitions();
    }

    public ClusterByPartitions getResultPartitionBoundariesForStage(StageId stageId) {
        return this.getStageKernelOrThrow(stageId).getResultPartitionBoundaries();
    }

    public CompleteKeyStatisticsInformation getCompleteKeyStatisticsInformation(StageId stageId) {
        return this.getStageKernelOrThrow(stageId).getCompleteKeyStatisticsInformation();
    }

    public void setClusterByPartitionBoundaries(StageId stageId, ClusterByPartitions clusterByPartitions) {
        this.getStageKernelOrThrow(stageId).setClusterByPartitionBoundaries(clusterByPartitions);
    }

    public boolean hasStageCollectorEncounteredAnyMultiValueField(StageId stageId) {
        return this.getStageKernelOrThrow(stageId).collectorEncounteredAnyMultiValueField();
    }

    public Object getResultObjectForStage(StageId stageId) {
        return this.getStageKernelOrThrow(stageId).getResultObject();
    }

    public void startStage(StageId stageId) {
        ControllerStageTracker stageKernel = this.getStageKernelOrThrow(stageId);
        if (stageKernel.getPhase() != ControllerStagePhase.NEW) {
            throw new ISE("Cannot start the stage: [%s]", new Object[]{stageId});
        }
        stageKernel.start();
        this.transitionStageKernel(stageId, ControllerStagePhase.READING_INPUT);
    }

    public void finishStage(StageId stageId, boolean strict) {
        if (strict && !this.effectivelyFinishedStages.contains(stageId)) {
            throw new IAE("Cannot mark the stage: [%s] finished", new Object[]{stageId});
        }
        this.getStageKernelOrThrow(stageId).finish();
        this.effectivelyFinishedStages.remove(stageId);
        this.transitionStageKernel(stageId, ControllerStagePhase.FINISHED);
    }

    public WorkerInputs getWorkerInputsForStage(StageId stageId) {
        return this.getStageKernelOrThrow(stageId).getWorkerInputs();
    }

    public void addPartialKeyStatisticsForStageAndWorker(StageId stageId, int workerNumber, PartialKeyStatisticsInformation partialKeyStatisticsInformation) {
        ControllerStageTracker stageKernel = this.getStageKernelOrThrow(stageId);
        ControllerStagePhase newPhase = stageKernel.addPartialKeyStatisticsForWorker(workerNumber, partialKeyStatisticsInformation);
        switch (newPhase) {
            case MERGING_STATISTICS: 
            case POST_READING: 
            case FAILED: {
                this.transitionStageKernel(stageId, newPhase);
            }
        }
    }

    public void setResultsCompleteForStageAndWorker(StageId stageId, int workerNumber, Object resultObject) {
        if (this.getStageKernelOrThrow(stageId).setResultsCompleteForWorker(workerNumber, resultObject)) {
            this.transitionStageKernel(stageId, ControllerStagePhase.RESULTS_READY);
        }
    }

    public MSQFault getFailureReasonForStage(StageId stageId) {
        return this.getStageKernelOrThrow(stageId).getFailureReason();
    }

    public void failStageForReason(StageId stageId, MSQFault fault) {
        this.getStageKernelOrThrow(stageId).failForReason(fault);
        this.transitionStageKernel(stageId, ControllerStagePhase.FAILED);
    }

    public void failStage(StageId stageId) {
        this.getStageKernelOrThrow(stageId).fail();
        this.transitionStageKernel(stageId, ControllerStagePhase.FAILED);
    }

    private ControllerStageTracker getStageKernelOrThrow(StageId stageId) {
        ControllerStageTracker stageKernel = this.stageTracker.get(stageId);
        if (stageKernel == null) {
            throw new IAE("Cannot find kernel corresponding to stage [%s] in query [%s]", new Object[]{stageId, this.queryDef.getQueryId()});
        }
        return stageKernel;
    }

    public void transitionStageKernel(StageId stageId, ControllerStagePhase newPhase) {
        Preconditions.checkArgument((boolean)this.stageTracker.containsKey(stageId), (Object)"Attempting to modify an unknown stageKernel");
        if (newPhase == ControllerStagePhase.RESULTS_READY) {
            for (StageId dependentStageId : (Set)this.outflowMap.get((Object)stageId)) {
                if (!this.pendingInflowMap.containsKey(dependentStageId)) continue;
                this.pendingInflowMap.get(dependentStageId).remove(stageId);
                if (this.pendingInflowMap.get(dependentStageId).size() != 0) continue;
                this.readyToRunStages.add(dependentStageId);
                this.pendingInflowMap.remove(dependentStageId);
            }
        }
        if (ControllerStagePhase.isPostReadingPhase(newPhase)) {
            for (StageId inputStage : (Set)this.inflowMap.get((Object)stageId)) {
                if (!this.pendingOutflowMap.containsKey(inputStage)) continue;
                this.pendingOutflowMap.get(inputStage).remove(stageId);
                if (this.pendingOutflowMap.get(inputStage).size() != 0) continue;
                this.effectivelyFinishedStages.add(inputStage);
                this.pendingOutflowMap.remove(inputStage);
            }
        }
    }

    @VisibleForTesting
    ControllerStageTracker getControllerStageKernel(int stageNumber) {
        return this.stageTracker.get(new StageId(this.queryDef.getQueryId(), stageNumber));
    }

    private static Map<StageId, Set<StageId>> computeStageInflowMap(QueryDefinition queryDefinition) {
        HashMap<StageId, Set<StageId>> retVal = new HashMap<StageId, Set<StageId>>();
        for (StageDefinition stageDef : queryDefinition.getStageDefinitions()) {
            StageId stageId = stageDef.getId();
            retVal.computeIfAbsent(stageId, ignored -> new HashSet());
            IntIterator intIterator = queryDefinition.getStageDefinition(stageId).getInputStageNumbers().iterator();
            while (intIterator.hasNext()) {
                int inputStageNumber = (Integer)intIterator.next();
                StageId inputStageId = new StageId(queryDefinition.getQueryId(), inputStageNumber);
                retVal.computeIfAbsent(stageId, ignored -> new HashSet()).add(inputStageId);
            }
        }
        return retVal;
    }

    private static Map<StageId, Set<StageId>> computeStageOutflowMap(QueryDefinition queryDefinition) {
        HashMap<StageId, Set<StageId>> retVal = new HashMap<StageId, Set<StageId>>();
        for (StageDefinition stageDef : queryDefinition.getStageDefinitions()) {
            StageId stageId = stageDef.getId();
            retVal.computeIfAbsent(stageId, ignored -> new HashSet());
            IntIterator intIterator = queryDefinition.getStageDefinition(stageId).getInputStageNumbers().iterator();
            while (intIterator.hasNext()) {
                int inputStageNumber = (Integer)intIterator.next();
                StageId inputStageId = new StageId(queryDefinition.getQueryId(), inputStageNumber);
                retVal.computeIfAbsent(inputStageId, ignored -> new HashSet()).add(stageId);
            }
        }
        return retVal;
    }
}

