/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.msq.exec;

import com.google.common.primitives.Ints;
import com.google.inject.Injector;
import it.unimi.dsi.fastutil.ints.IntSet;
import java.util.Objects;
import org.apache.druid.frame.processor.Bouncer;
import org.apache.druid.indexing.worker.config.WorkerConfig;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.msq.indexing.error.NotEnoughMemoryFault;
import org.apache.druid.msq.indexing.error.TooManyWorkersFault;
import org.apache.druid.msq.input.InputSpecs;
import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.query.lookup.LookupExtractor;
import org.apache.druid.query.lookup.LookupExtractorFactoryContainer;
import org.apache.druid.query.lookup.LookupReferencesManager;
import org.apache.druid.segment.realtime.appenderator.AppenderatorsManager;
import org.apache.druid.segment.realtime.appenderator.UnifiedIndexerAppenderatorsManager;

public class WorkerMemoryParameters {
    private static final Logger log = new Logger(WorkerMemoryParameters.class);
    static final double USABLE_MEMORY_FRACTION = 0.75;
    static final double APPENDERATOR_MEMORY_FRACTION = 0.67;
    private static final int STANDARD_FRAME_SIZE = 1000000;
    private static final int LARGE_FRAME_SIZE = 8000000;
    public static final long PROCESSING_MINIMUM_BYTES = 25000000L;
    private static final int MAX_SUPER_SORTER_PROCESSORS = 4;
    private static final int MIN_SUPER_SORTER_FRAMES = 3;
    private static final int APPENDERATOR_MERGE_ROUGH_MEMORY_PER_COLUMN = 3000;
    private static final double PARTITION_STATS_MEMORY_MAX_FRACTION = 0.1;
    private static final long PARTITION_STATS_MEMORY_MAX_BYTES = 300000000L;
    static final double BROADCAST_JOIN_MEMORY_FRACTION = 0.3;
    private final int superSorterMaxActiveProcessors;
    private final int superSorterMaxChannelsPerProcessor;
    private final long appenderatorMemory;
    private final long broadcastJoinMemory;
    private final int partitionStatisticsMaxRetainedBytes;

    WorkerMemoryParameters(int superSorterMaxActiveProcessors, int superSorterMaxChannelsPerProcessor, long appenderatorMemory, long broadcastJoinMemory, int partitionStatisticsMaxRetainedBytes) {
        this.superSorterMaxActiveProcessors = superSorterMaxActiveProcessors;
        this.superSorterMaxChannelsPerProcessor = superSorterMaxChannelsPerProcessor;
        this.appenderatorMemory = appenderatorMemory;
        this.broadcastJoinMemory = broadcastJoinMemory;
        this.partitionStatisticsMaxRetainedBytes = partitionStatisticsMaxRetainedBytes;
    }

    public static WorkerMemoryParameters createProductionInstanceForController(Injector injector) {
        return WorkerMemoryParameters.createInstance(Runtime.getRuntime().maxMemory(), WorkerMemoryParameters.computeUsableMemoryInJvm(injector), WorkerMemoryParameters.computeNumWorkersInJvm(injector), WorkerMemoryParameters.computeNumProcessorsInJvm(injector), 0);
    }

    public static WorkerMemoryParameters createProductionInstanceForWorker(Injector injector, QueryDefinition queryDef, int stageNumber) {
        IntSet inputStageNumbers = InputSpecs.getStageNumbers(queryDef.getStageDefinition(stageNumber).getInputSpecs());
        int numInputWorkers = inputStageNumbers.intStream().map(inputStageNumber -> queryDef.getStageDefinition(inputStageNumber).getMaxWorkerCount()).sum();
        return WorkerMemoryParameters.createInstance(Runtime.getRuntime().maxMemory(), WorkerMemoryParameters.computeUsableMemoryInJvm(injector), WorkerMemoryParameters.computeNumWorkersInJvm(injector), WorkerMemoryParameters.computeNumProcessorsInJvm(injector), numInputWorkers);
    }

    public static WorkerMemoryParameters createInstance(long maxMemoryInJvm, long usableMemoryInJvm, int numWorkersInJvm, int numProcessingThreadsInJvm, int numInputWorkers) {
        long bundleMemoryForInputChannels;
        long workerMemory = WorkerMemoryParameters.memoryPerWorker(usableMemoryInJvm, numWorkersInJvm);
        long bundleMemory = WorkerMemoryParameters.memoryPerBundle(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm);
        long bundleMemoryForProcessing = bundleMemory - (bundleMemoryForInputChannels = WorkerMemoryParameters.memoryNeededForInputChannels(numInputWorkers));
        if (bundleMemoryForProcessing < 25000000L) {
            int maxWorkers = WorkerMemoryParameters.computeMaxWorkers(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm);
            if (maxWorkers > 0) {
                throw new MSQException(new TooManyWorkersFault(numInputWorkers, Math.min(1000, maxWorkers)));
            }
            throw new MSQException(new NotEnoughMemoryFault(maxMemoryInJvm, usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm));
        }
        int maxNumFramesForSuperSorter = Ints.checkedCast((long)(bundleMemory / 8000000L));
        if (maxNumFramesForSuperSorter < 3) {
            throw new MSQException(new NotEnoughMemoryFault(maxMemoryInJvm, usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm));
        }
        int superSorterMaxActiveProcessors = Math.min(numProcessingThreadsInJvm, Math.min(maxNumFramesForSuperSorter / 3, 4));
        int superSorterMaxChannelsPerProcessor = maxNumFramesForSuperSorter / superSorterMaxActiveProcessors - 1;
        return new WorkerMemoryParameters(superSorterMaxActiveProcessors, superSorterMaxChannelsPerProcessor, (long)((double)bundleMemoryForProcessing * 0.67), (long)((double)bundleMemoryForProcessing * 0.3), Ints.checkedCast((long)workerMemory));
    }

    public int getSuperSorterMaxActiveProcessors() {
        return this.superSorterMaxActiveProcessors;
    }

    public int getSuperSorterMaxChannelsPerProcessor() {
        return this.superSorterMaxChannelsPerProcessor;
    }

    public long getAppenderatorMaxBytesInMemory() {
        return Math.max(1L, this.appenderatorMemory / 2L);
    }

    public int getAppenderatorMaxColumnsToMerge() {
        return Ints.checkedCast((long)Math.max(2L, this.appenderatorMemory / 2L / 3000L));
    }

    public int getStandardFrameSize() {
        return 1000000;
    }

    public int getLargeFrameSize() {
        return 8000000;
    }

    public long getBroadcastJoinMemory() {
        return this.broadcastJoinMemory;
    }

    public int getPartitionStatisticsMaxRetainedBytes() {
        return this.partitionStatisticsMaxRetainedBytes;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        WorkerMemoryParameters that = (WorkerMemoryParameters)o;
        return this.superSorterMaxActiveProcessors == that.superSorterMaxActiveProcessors && this.superSorterMaxChannelsPerProcessor == that.superSorterMaxChannelsPerProcessor && this.appenderatorMemory == that.appenderatorMemory && this.broadcastJoinMemory == that.broadcastJoinMemory && this.partitionStatisticsMaxRetainedBytes == that.partitionStatisticsMaxRetainedBytes;
    }

    public int hashCode() {
        return Objects.hash(this.superSorterMaxActiveProcessors, this.superSorterMaxChannelsPerProcessor, this.appenderatorMemory, this.broadcastJoinMemory, this.partitionStatisticsMaxRetainedBytes);
    }

    public String toString() {
        return "WorkerMemoryParameters{superSorterMaxActiveProcessors=" + this.superSorterMaxActiveProcessors + ", superSorterMaxChannelsPerProcessor=" + this.superSorterMaxChannelsPerProcessor + ", appenderatorMemory=" + this.appenderatorMemory + ", broadcastJoinMemory=" + this.broadcastJoinMemory + ", partitionStatisticsMaxRetainedBytes=" + this.partitionStatisticsMaxRetainedBytes + '}';
    }

    static int computeMaxWorkers(long usableMemoryInJvm, int numWorkersInJvm, int numProcessingThreadsInJvm) {
        long bundleMemory = WorkerMemoryParameters.memoryPerBundle(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm);
        return Math.max(0, Ints.checkedCast((long)((bundleMemory - 25000000L) / 1000000L - 1L)));
    }

    private static int computeNumWorkersInJvm(Injector injector) {
        AppenderatorsManager appenderatorsManager = (AppenderatorsManager)injector.getInstance(AppenderatorsManager.class);
        if (appenderatorsManager instanceof UnifiedIndexerAppenderatorsManager) {
            return ((WorkerConfig)injector.getInstance(WorkerConfig.class)).getCapacity();
        }
        return 1;
    }

    private static int computeNumProcessorsInJvm(Injector injector) {
        return ((Bouncer)injector.getInstance(Bouncer.class)).getMaxCount();
    }

    private static long memoryPerWorker(long usableMemoryInJvm, int numWorkersInJvm) {
        long memoryForWorkers = (long)Math.min((double)usableMemoryInJvm * 0.1, (double)((long)numWorkersInJvm * 300000000L));
        return memoryForWorkers / (long)numWorkersInJvm;
    }

    private static long memoryPerBundle(long usableMemoryInJvm, int numWorkersInJvm, int numProcessingThreadsInJvm) {
        int bundleCount = numWorkersInJvm + numProcessingThreadsInJvm;
        long memoryForWorkers = (long)numWorkersInJvm * WorkerMemoryParameters.memoryPerWorker(usableMemoryInJvm, numWorkersInJvm);
        long memoryForBundles = usableMemoryInJvm - memoryForWorkers;
        return memoryForBundles / (long)bundleCount;
    }

    private static long memoryNeededForInputChannels(int numInputWorkers) {
        return 1000000L * (long)(numInputWorkers + 1);
    }

    private static long computeUsableMemoryInJvm(Injector injector) {
        return (long)((double)(Runtime.getRuntime().maxMemory() - WorkerMemoryParameters.computeTotalLookupFootprint(injector)) * 0.75);
    }

    private static long computeTotalLookupFootprint(Injector injector) {
        LookupReferencesManager lookupManager = (LookupReferencesManager)injector.getInstance(LookupReferencesManager.class);
        int lookupCount = 0;
        long lookupFootprint = 0L;
        for (String lookupName : lookupManager.getAllLookupNames()) {
            LookupExtractorFactoryContainer container = lookupManager.get(lookupName).orElse(null);
            if (container == null) continue;
            try {
                LookupExtractor extractor = (LookupExtractor)container.getLookupExtractorFactory().get();
                lookupFootprint += extractor.estimateHeapFootprint();
                ++lookupCount;
            }
            catch (Exception e) {
                log.noStackTrace().warn((Throwable)e, "Failed to load lookup [%s] for size estimation. Skipping.", new Object[]{lookupName});
            }
        }
        log.debug("Lookup footprint: %d lookups with %,d total bytes.", new Object[]{lookupCount, lookupFootprint});
        return lookupFootprint;
    }
}

