/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.msq.exec;

import com.google.common.util.concurrent.ListenableFuture;
import it.unimi.dsi.fastutil.ints.IntSet;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.java.util.common.Either;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.exec.ClusterStatisticsMergeMode;
import org.apache.druid.msq.exec.WorkerClient;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;

public class WorkerSketchFetcher
implements AutoCloseable {
    private static final Logger log = new Logger(WorkerSketchFetcher.class);
    private static final int DEFAULT_THREAD_COUNT = 4;
    static final long BYTES_THRESHOLD = 1000000000L;
    static final long WORKER_THRESHOLD = 100L;
    private final ClusterStatisticsMergeMode clusterStatisticsMergeMode;
    private final int statisticsMaxRetainedBytes;
    private final WorkerClient workerClient;
    private final ExecutorService executorService;

    public WorkerSketchFetcher(WorkerClient workerClient, ClusterStatisticsMergeMode clusterStatisticsMergeMode, int statisticsMaxRetainedBytes) {
        this.workerClient = workerClient;
        this.clusterStatisticsMergeMode = clusterStatisticsMergeMode;
        this.executorService = Execs.multiThreaded((int)4, (String)"SketchFetcherThreadPool-%d");
        this.statisticsMaxRetainedBytes = statisticsMaxRetainedBytes;
    }

    public CompletableFuture<Either<Long, ClusterByPartitions>> submitFetcherTask(CompleteKeyStatisticsInformation completeKeyStatisticsInformation, List<String> workerTaskIds, StageDefinition stageDefinition, IntSet workersForStage) {
        ClusterBy clusterBy = stageDefinition.getClusterBy();
        switch (this.clusterStatisticsMergeMode) {
            case SEQUENTIAL: {
                return this.sequentialTimeChunkMerging(completeKeyStatisticsInformation, stageDefinition, workerTaskIds);
            }
            case PARALLEL: {
                return this.inMemoryFullSketchMerging(stageDefinition, workerTaskIds, workersForStage);
            }
            case AUTO: {
                if (clusterBy.getBucketByCount() == 0) {
                    log.info("Query[%s] stage[%d] for AUTO mode: chose PARALLEL mode to merge key statistics", new Object[]{stageDefinition.getId().getQueryId(), stageDefinition.getStageNumber()});
                    return this.inMemoryFullSketchMerging(stageDefinition, workerTaskIds, workersForStage);
                }
                if ((long)stageDefinition.getMaxWorkerCount() > 100L || completeKeyStatisticsInformation.getBytesRetained() > 1.0E9) {
                    log.info("Query[%s] stage[%d] for AUTO mode: chose SEQUENTIAL mode to merge key statistics", new Object[]{stageDefinition.getId().getQueryId(), stageDefinition.getStageNumber()});
                    return this.sequentialTimeChunkMerging(completeKeyStatisticsInformation, stageDefinition, workerTaskIds);
                }
                log.info("Query[%s] stage[%d] for AUTO mode: chose PARALLEL mode to merge key statistics", new Object[]{stageDefinition.getId().getQueryId(), stageDefinition.getStageNumber()});
                return this.inMemoryFullSketchMerging(stageDefinition, workerTaskIds, workersForStage);
            }
        }
        throw new IllegalStateException("No fetching strategy found for mode: " + (Object)((Object)this.clusterStatisticsMergeMode));
    }

    CompletableFuture<Either<Long, ClusterByPartitions>> inMemoryFullSketchMerging(StageDefinition stageDefinition, List<String> workerTaskIds, IntSet workersForStage) {
        CompletableFuture<Either<Long, ClusterByPartitions>> partitionFuture = new CompletableFuture<Either<Long, ClusterByPartitions>>();
        ClusterByStatisticsCollector mergedStatisticsCollector = stageDefinition.createResultKeyStatisticsCollector(this.statisticsMaxRetainedBytes);
        int workerCount = workersForStage.size();
        HashSet finishedWorkers = new HashSet();
        log.info("Fetching stats using %s for stage[%d] for workers[%s] ", new Object[]{ClusterStatisticsMergeMode.PARALLEL, stageDefinition.getStageNumber(), workersForStage.stream().map(Object::toString).collect(Collectors.joining(","))});
        workersForStage.forEach(workerNo -> this.executorService.submit(() -> {
            ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture = this.workerClient.fetchClusterByStatisticsSnapshot((String)workerTaskIds.get(workerNo), stageDefinition.getId().getQueryId(), stageDefinition.getStageNumber());
            try {
                ClusterByStatisticsSnapshot clusterByStatisticsSnapshot = (ClusterByStatisticsSnapshot)snapshotFuture.get();
                if (clusterByStatisticsSnapshot == null) {
                    throw new ISE("Worker %s returned null sketch, this should never happen", new Object[]{workerNo});
                }
                ClusterByStatisticsCollector clusterByStatisticsCollector = mergedStatisticsCollector;
                synchronized (clusterByStatisticsCollector) {
                    mergedStatisticsCollector.addAll(clusterByStatisticsSnapshot);
                    finishedWorkers.add(workerNo);
                    if (finishedWorkers.size() == workerCount) {
                        log.debug("Query [%s] Received all statistics, generating partitions", new Object[]{stageDefinition.getId().getQueryId()});
                        partitionFuture.complete(stageDefinition.generatePartitionsForShuffle(mergedStatisticsCollector));
                    }
                }
            }
            catch (Exception e) {
                ClusterByStatisticsCollector clusterByStatisticsCollector = mergedStatisticsCollector;
                synchronized (clusterByStatisticsCollector) {
                    if (!partitionFuture.isDone()) {
                        partitionFuture.completeExceptionally(e);
                        mergedStatisticsCollector.clear();
                    }
                }
            }
        }));
        return partitionFuture;
    }

    CompletableFuture<Either<Long, ClusterByPartitions>> sequentialTimeChunkMerging(CompleteKeyStatisticsInformation completeKeyStatisticsInformation, StageDefinition stageDefinition, List<String> workerTaskIds) {
        SequentialFetchStage sequentialFetchStage = new SequentialFetchStage(stageDefinition, workerTaskIds, completeKeyStatisticsInformation.getTimeSegmentVsWorkerMap().entrySet().iterator());
        log.info("Fetching stats using %s for stage[%d] for tasks[%s]", new Object[]{ClusterStatisticsMergeMode.SEQUENTIAL, stageDefinition.getStageNumber(), String.join((CharSequence)"", workerTaskIds)});
        sequentialFetchStage.submitFetchingTasksForNextTimeChunk();
        return sequentialFetchStage.getPartitionFuture();
    }

    private static long getPartitionCountFromEither(Either<Long, ClusterByPartitions> either) {
        if (either.isError()) {
            return (Long)either.error();
        }
        return ((ClusterByPartitions)either.valueOrThrow()).size();
    }

    @Override
    public void close() {
        this.executorService.shutdownNow();
    }

    private class SequentialFetchStage {
        private final StageDefinition stageDefinition;
        private final List<String> workerTaskIds;
        private final Iterator<Map.Entry<Long, Set<Integer>>> timeSegmentVsWorkerIdIterator;
        private final CompletableFuture<Either<Long, ClusterByPartitions>> partitionFuture;
        private final List<ClusterByPartition> finalPartitionBoundries = new ArrayList<ClusterByPartition>();

        public SequentialFetchStage(StageDefinition stageDefinition, List<String> workerTaskIds, Iterator<Map.Entry<Long, Set<Integer>>> timeSegmentVsWorkerIdIterator) {
            this.stageDefinition = stageDefinition;
            this.workerTaskIds = workerTaskIds;
            this.timeSegmentVsWorkerIdIterator = timeSegmentVsWorkerIdIterator;
            this.partitionFuture = new CompletableFuture();
        }

        public void submitFetchingTasksForNextTimeChunk() {
            if (!this.timeSegmentVsWorkerIdIterator.hasNext()) {
                this.partitionFuture.complete((Either<Long, ClusterByPartitions>)Either.value((Object)new ClusterByPartitions(this.finalPartitionBoundries)));
            } else {
                Map.Entry<Long, Set<Integer>> entry = this.timeSegmentVsWorkerIdIterator.next();
                Long timeChunk = entry.getKey();
                Set<Integer> workerIdsWithTimeChunk = entry.getValue();
                ClusterByStatisticsCollector mergedStatisticsCollector = this.stageDefinition.createResultKeyStatisticsCollector(WorkerSketchFetcher.this.statisticsMaxRetainedBytes);
                HashSet finishedWorkers = new HashSet();
                log.debug("Query [%s]. Submitting request for statistics for time chunk %s to %s workers", new Object[]{this.stageDefinition.getId().getQueryId(), timeChunk, workerIdsWithTimeChunk.size()});
                for (int workerNo : workerIdsWithTimeChunk) {
                    WorkerSketchFetcher.this.executorService.submit(() -> {
                        ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture = WorkerSketchFetcher.this.workerClient.fetchClusterByStatisticsSnapshotForTimeChunk(this.workerTaskIds.get(workerNo), this.stageDefinition.getId().getQueryId(), this.stageDefinition.getStageNumber(), timeChunk);
                        try {
                            ClusterByStatisticsSnapshot snapshotForTimeChunk = (ClusterByStatisticsSnapshot)snapshotFuture.get();
                            if (snapshotForTimeChunk == null) {
                                throw new ISE("Worker %s returned null sketch for %s, this should never happen", new Object[]{workerNo, timeChunk});
                            }
                            ClusterByStatisticsCollector clusterByStatisticsCollector = mergedStatisticsCollector;
                            synchronized (clusterByStatisticsCollector) {
                                mergedStatisticsCollector.addAll(snapshotForTimeChunk);
                                finishedWorkers.add(workerNo);
                                if (finishedWorkers.size() == workerIdsWithTimeChunk.size()) {
                                    Either<Long, ClusterByPartitions> longClusterByPartitionsEither = this.stageDefinition.generatePartitionsForShuffle(mergedStatisticsCollector);
                                    log.debug("Query [%s]. Received all statistics for time chunk %s, generating partitions", new Object[]{this.stageDefinition.getId().getQueryId(), timeChunk});
                                    long totalPartitionCount = (long)this.finalPartitionBoundries.size() + WorkerSketchFetcher.getPartitionCountFromEither((Either<Long, ClusterByPartitions>)longClusterByPartitionsEither);
                                    if (totalPartitionCount > (long)this.stageDefinition.getMaxPartitionCount()) {
                                        this.partitionFuture.complete((Either<Long, ClusterByPartitions>)Either.error((Object)totalPartitionCount));
                                        mergedStatisticsCollector.clear();
                                    } else {
                                        List timeSketchPartitions = ((ClusterByPartitions)longClusterByPartitionsEither.valueOrThrow()).ranges();
                                        this.abutAndAppendPartitionBoundries(this.finalPartitionBoundries, timeSketchPartitions);
                                        log.debug("Query [%s]. Finished generating partitions for time chunk %s, total count so far %s", new Object[]{this.stageDefinition.getId().getQueryId(), timeChunk, this.finalPartitionBoundries.size()});
                                        this.submitFetchingTasksForNextTimeChunk();
                                    }
                                }
                            }
                        }
                        catch (Exception e) {
                            ClusterByStatisticsCollector clusterByStatisticsCollector = mergedStatisticsCollector;
                            synchronized (clusterByStatisticsCollector) {
                                if (!this.partitionFuture.isDone()) {
                                    this.partitionFuture.completeExceptionally(e);
                                    mergedStatisticsCollector.clear();
                                }
                            }
                        }
                    });
                }
            }
        }

        private void abutAndAppendPartitionBoundries(List<ClusterByPartition> finalPartitionBoundries, List<ClusterByPartition> timeSketchPartitions) {
            if (!finalPartitionBoundries.isEmpty()) {
                ClusterByPartition clusterByPartition = finalPartitionBoundries.remove(finalPartitionBoundries.size() - 1);
                finalPartitionBoundries.add(new ClusterByPartition(clusterByPartition.getStart(), timeSketchPartitions.get(0).getStart()));
            }
            finalPartitionBoundries.addAll(timeSketchPartitions);
        }

        public CompletableFuture<Either<Long, ClusterByPartitions>> getPartitionFuture() {
            return this.partitionFuture;
        }
    }
}

