/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.msq.statistics;

import com.google.common.math.LongMath;
import com.google.common.primitives.Ints;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.NoSuchElementException;
import javax.annotation.Nullable;
import org.apache.datasketches.quantiles.ItemsSketch;
import org.apache.datasketches.quantiles.ItemsUnion;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.frame.key.RowKey;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.msq.statistics.KeyCollector;

public class QuantilesSketchKeyCollector
implements KeyCollector<QuantilesSketchKeyCollector> {
    private final Comparator<byte[]> comparator;
    private ItemsSketch<byte[]> sketch;
    private double averageKeyLength;

    QuantilesSketchKeyCollector(Comparator<byte[]> comparator, @Nullable ItemsSketch<byte[]> sketch, double averageKeyLength) {
        this.comparator = comparator;
        this.sketch = sketch;
        this.averageKeyLength = averageKeyLength;
    }

    @Override
    public void add(RowKey key, long weight) {
        double estimatedTotalSketchSizeInBytes = this.averageKeyLength * (double)this.sketch.getN();
        estimatedTotalSketchSizeInBytes += (double)((long)key.estimatedObjectSizeBytes() * weight);
        int i = 0;
        while ((long)i < weight) {
            this.sketch.update((Object)key.array());
            ++i;
        }
        this.averageKeyLength = estimatedTotalSketchSizeInBytes / (double)this.sketch.getN();
    }

    @Override
    public void addAll(QuantilesSketchKeyCollector other) {
        ItemsUnion union = ItemsUnion.getInstance((int)Math.max(this.sketch.getK(), other.sketch.getK()), this.comparator);
        double sketchBytesCount = this.averageKeyLength * (double)this.sketch.getN();
        double otherBytesCount = other.averageKeyLength * (double)other.getSketch().getN();
        this.averageKeyLength = (sketchBytesCount + otherBytesCount) / (double)(this.sketch.getN() + other.sketch.getN());
        union.update(this.sketch);
        union.update(other.sketch);
        this.sketch = union.getResultAndReset();
    }

    @Override
    public boolean isEmpty() {
        return this.sketch.isEmpty();
    }

    @Override
    public long estimatedTotalWeight() {
        return this.sketch.getN();
    }

    @Override
    public long estimatedRetainedBytes() {
        return Math.round(this.averageKeyLength * (double)this.estimatedRetainedKeys());
    }

    @Override
    public int estimatedRetainedKeys() {
        return this.sketch.getRetainedItems();
    }

    @Override
    public boolean downSample() {
        if (this.sketch.getN() <= 1L) {
            return true;
        }
        if (this.sketch.getK() == 2) {
            return false;
        }
        this.sketch = this.sketch.downSample(this.sketch.getK() / 2);
        return true;
    }

    @Override
    public RowKey minKey() {
        byte[] minValue = (byte[])this.sketch.getMinValue();
        if (minValue != null) {
            return RowKey.wrap((byte[])minValue);
        }
        throw new NoSuchElementException();
    }

    @Override
    public ClusterByPartitions generatePartitionsWithTargetWeight(long targetWeight) {
        if (targetWeight <= 0L) {
            throw new IAE("targetPartitionWeight must be positive, but was [%d]", new Object[]{targetWeight});
        }
        if (this.sketch.getN() == 0L) {
            return ClusterByPartitions.oneUniversalPartition();
        }
        int numPartitions = Ints.checkedCast((long)LongMath.divide((long)this.sketch.getN(), (long)targetWeight, (RoundingMode)RoundingMode.CEILING));
        byte[][] quantiles = (byte[][])this.sketch.getQuantiles(numPartitions + 1);
        ArrayList<ClusterByPartition> partitions = new ArrayList<ClusterByPartition>();
        for (int i = 0; i < numPartitions; ++i) {
            boolean isFinalPartition;
            boolean bl = isFinalPartition = i == numPartitions - 1;
            if (isFinalPartition) {
                partitions.add(new ClusterByPartition(RowKey.wrap((byte[])quantiles[i]), null));
                continue;
            }
            int cmp = this.comparator.compare(quantiles[i], quantiles[i + 1]);
            if (cmp >= 0) continue;
            ClusterByPartition partition = new ClusterByPartition(RowKey.wrap((byte[])quantiles[i]), RowKey.wrap((byte[])quantiles[i + 1]));
            partitions.add(partition);
        }
        return new ClusterByPartitions(partitions);
    }

    ItemsSketch<byte[]> getSketch() {
        return this.sketch;
    }

    double getAverageKeyLength() {
        return this.averageKeyLength;
    }
}

