/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.sql.calcite.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Stack;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexSlot;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.query.LookupDataSource;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.DruidJoinQueryRel;
import org.apache.druid.sql.calcite.rel.DruidQueryRel;
import org.apache.druid.sql.calcite.rel.DruidRel;
import org.apache.druid.sql.calcite.rel.PartialDruidQuery;

public class DruidJoinRule
extends RelOptRule {
    private final boolean enableLeftScanDirect;
    private final PlannerContext plannerContext;

    private DruidJoinRule(PlannerContext plannerContext) {
        super(DruidJoinRule.operand(Join.class, (RelOptRuleOperand)DruidJoinRule.operand(DruidRel.class, (RelOptRuleOperandChildren)DruidJoinRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[]{DruidJoinRule.operand(DruidRel.class, (RelOptRuleOperandChildren)DruidJoinRule.any())}));
        this.enableLeftScanDirect = plannerContext.queryContext().getEnableJoinLeftScanDirect();
        this.plannerContext = plannerContext;
    }

    public static DruidJoinRule instance(PlannerContext plannerContext) {
        return new DruidJoinRule(plannerContext);
    }

    public boolean matches(RelOptRuleCall call) {
        Join join = (Join)call.rel(0);
        DruidRel left = (DruidRel)call.rel(1);
        DruidRel right = (DruidRel)call.rel(2);
        return this.canHandleCondition(join.getCondition(), join.getLeft().getRowType(), right) && left.getPartialDruidQuery() != null && right.getPartialDruidQuery() != null;
    }

    public void onMatch(RelOptRuleCall call) {
        DruidRel newRight;
        DruidRel newLeft;
        Filter leftFilter;
        boolean isLeftDirectAccessPossible;
        Join join = (Join)call.rel(0);
        DruidRel left = (DruidRel)call.rel(1);
        DruidRel right = (DruidRel)call.rel(2);
        RexBuilder rexBuilder = join.getCluster().getRexBuilder();
        ArrayList<Object> newProjectExprs = new ArrayList<Object>();
        ConditionAnalysis conditionAnalysis = this.analyzeCondition(join.getCondition(), join.getLeft().getRowType(), right).get();
        boolean bl = isLeftDirectAccessPossible = this.enableLeftScanDirect && left instanceof DruidQueryRel;
        if (left.getPartialDruidQuery().stage() == PartialDruidQuery.Stage.SELECT_PROJECT && (isLeftDirectAccessPossible || left.getPartialDruidQuery().getWhereFilter() == null)) {
            RelNode leftScan = left.getPartialDruidQuery().getScan();
            Project leftProject = left.getPartialDruidQuery().getSelectProject();
            leftFilter = left.getPartialDruidQuery().getWhereFilter();
            newProjectExprs.addAll(leftProject.getProjects());
            newLeft = left.withPartialQuery(PartialDruidQuery.create(leftScan));
            conditionAnalysis = conditionAnalysis.pushThroughLeftProject(leftProject);
        } else {
            for (int i = 0; i < left.getRowType().getFieldCount(); ++i) {
                newProjectExprs.add(rexBuilder.makeInputRef(((RelDataTypeField)join.getRowType().getFieldList().get(i)).getType(), i));
            }
            newLeft = left;
            leftFilter = null;
        }
        if (right.getPartialDruidQuery().stage() == PartialDruidQuery.Stage.SELECT_PROJECT && right.getPartialDruidQuery().getWhereFilter() == null && !right.getPartialDruidQuery().getSelectProject().isMapping() && conditionAnalysis.onlyUsesMappingsFromRightProject(right.getPartialDruidQuery().getSelectProject())) {
            RelNode rightScan = right.getPartialDruidQuery().getScan();
            Project rightProject = right.getPartialDruidQuery().getSelectProject();
            for (RexNode rexNode : RexUtil.shift((Iterable)rightProject.getProjects(), (int)newLeft.getRowType().getFieldCount())) {
                if (join.getJoinType().generatesNullsOnRight()) {
                    newProjectExprs.add(DruidJoinRule.makeNullableIfLiteral(rexNode, rexBuilder));
                    continue;
                }
                newProjectExprs.add(rexNode);
            }
            newRight = right.withPartialQuery(PartialDruidQuery.create(rightScan));
            conditionAnalysis = conditionAnalysis.pushThroughRightProject(rightProject);
        } else {
            for (int i = 0; i < right.getRowType().getFieldCount(); ++i) {
                newProjectExprs.add(rexBuilder.makeInputRef(((RelDataTypeField)join.getRowType().getFieldList().get(left.getRowType().getFieldCount() + i)).getType(), newLeft.getRowType().getFieldCount() + i));
            }
            newRight = right;
        }
        DruidJoinQueryRel druidJoin = DruidJoinQueryRel.create(join.copy(join.getTraitSet(), conditionAnalysis.getCondition(rexBuilder), (RelNode)newLeft, (RelNode)newRight, join.getJoinType(), join.isSemiJoinDone()), leftFilter, left.getPlannerContext());
        RelBuilder relBuilder = call.builder().push((RelNode)druidJoin).project((Iterable)RexUtil.fixUp((RexBuilder)rexBuilder, newProjectExprs, (List)RelOptUtil.getFieldTypeList((RelDataType)druidJoin.getRowType())));
        call.transformTo(relBuilder.build());
    }

    private static RexNode makeNullableIfLiteral(RexNode rexNode, RexBuilder rexBuilder) {
        if (rexNode.isA(SqlKind.LITERAL)) {
            return rexBuilder.makeLiteral((Object)RexLiteral.value((RexNode)rexNode), rexBuilder.getTypeFactory().createTypeWithNullability(rexNode.getType(), true), false);
        }
        return rexNode;
    }

    @VisibleForTesting
    boolean canHandleCondition(RexNode condition, RelDataType leftRowType, DruidRel<?> right) {
        return this.analyzeCondition(condition, leftRowType, right).isPresent();
    }

    private Optional<ConditionAnalysis> analyzeCondition(RexNode condition, RelDataType leftRowType, DruidRel<?> right) {
        long distinctRightColumns;
        DruidQueryRel druidQueryRel;
        List<RexNode> subConditions = DruidJoinRule.decomposeAnd(condition);
        ArrayList<Pair<RexNode, RexInputRef>> equalitySubConditions = new ArrayList<Pair<RexNode, RexInputRef>>();
        ArrayList<RexLiteral> literalSubConditions = new ArrayList<RexLiteral>();
        int numLeftFields = leftRowType.getFieldCount();
        HashSet<RexInputRef> rightColumns = new HashSet<RexInputRef>();
        for (RexNode subCondition : subConditions) {
            if (RexUtil.isLiteral((RexNode)subCondition, (boolean)true)) {
                if (subCondition.isA(SqlKind.CAST)) {
                    RexCall call = (RexCall)subCondition;
                    if (call.getType().getSqlTypeName().equals((Object)((RexNode)call.getOperands().get(0)).getType().getSqlTypeName())) {
                        literalSubConditions.add((RexLiteral)call.getOperands().get(0));
                        continue;
                    }
                    return Optional.empty();
                }
                literalSubConditions.add((RexLiteral)subCondition);
                continue;
            }
            if (!subCondition.isA(SqlKind.EQUALS)) {
                this.plannerContext.setPlanningError("SQL requires a join with '%s' condition that is not supported.", subCondition.getKind());
                return Optional.empty();
            }
            List operands = ((RexCall)subCondition).getOperands();
            Preconditions.checkState((operands.size() == 2 ? 1 : 0) != 0, (String)"Expected 2 operands, got[%,d]", (Object[])new Object[]{operands.size()});
            if (DruidJoinRule.isLeftExpression((RexNode)operands.get(0), numLeftFields) && DruidJoinRule.isRightInputRef((RexNode)operands.get(1), numLeftFields)) {
                equalitySubConditions.add((Pair<RexNode, RexInputRef>)Pair.of(operands.get(0), (Object)((RexInputRef)operands.get(1))));
                rightColumns.add((RexInputRef)operands.get(1));
                continue;
            }
            if (DruidJoinRule.isRightInputRef((RexNode)operands.get(0), numLeftFields) && DruidJoinRule.isLeftExpression((RexNode)operands.get(1), numLeftFields)) {
                equalitySubConditions.add((Pair<RexNode, RexInputRef>)Pair.of(operands.get(1), (Object)((RexInputRef)operands.get(0))));
                rightColumns.add((RexInputRef)operands.get(0));
                continue;
            }
            this.plannerContext.setPlanningError("SQL is resulting in a join that has unsupported operand types.", new Object[0]);
            return Optional.empty();
        }
        if (right != null && !DruidJoinQueryRel.computeRightRequiresSubquery(DruidJoinQueryRel.getSomeDruidChild(right)) && right instanceof DruidQueryRel && (druidQueryRel = (DruidQueryRel)right).getDruidTable().getDataSource() instanceof LookupDataSource && (distinctRightColumns = rightColumns.stream().map(RexSlot::getIndex).distinct().count()) > 1L) {
            this.plannerContext.setPlanningError("SQL is resulting in a join involving lookup where value column is used in the condition.", new Object[0]);
            return Optional.empty();
        }
        return Optional.of(new ConditionAnalysis(numLeftFields, equalitySubConditions, literalSubConditions));
    }

    @VisibleForTesting
    static List<RexNode> decomposeAnd(RexNode condition) {
        ArrayList<RexNode> retVal = new ArrayList<RexNode>();
        Stack<Object> stack = new Stack<Object>();
        stack.push(condition);
        while (!stack.empty()) {
            RexNode current = (RexNode)stack.pop();
            if (current.isA(SqlKind.AND)) {
                List operands = ((RexCall)current).getOperands();
                for (int i = operands.size() - 1; i >= 0; --i) {
                    stack.push(operands.get(i));
                }
                continue;
            }
            retVal.add(current);
        }
        return retVal;
    }

    private static boolean isLeftExpression(RexNode rexNode, int numLeftFields) {
        return ImmutableBitSet.range((int)numLeftFields).contains(RelOptUtil.InputFinder.bits((RexNode)rexNode));
    }

    private static boolean isRightInputRef(RexNode rexNode, int numLeftFields) {
        return rexNode.isA(SqlKind.INPUT_REF) && ((RexInputRef)rexNode).getIndex() >= numLeftFields;
    }

    @VisibleForTesting
    static class ConditionAnalysis {
        private final int numLeftFields;
        private final List<Pair<RexNode, RexInputRef>> equalitySubConditions;
        private final List<RexLiteral> literalSubConditions;

        ConditionAnalysis(int numLeftFields, List<Pair<RexNode, RexInputRef>> equalitySubConditions, List<RexLiteral> literalSubConditions) {
            this.numLeftFields = numLeftFields;
            this.equalitySubConditions = equalitySubConditions;
            this.literalSubConditions = literalSubConditions;
        }

        public ConditionAnalysis pushThroughLeftProject(Project leftProject) {
            int rhsShift = leftProject.getInput().getRowType().getFieldCount() - leftProject.getRowType().getFieldCount();
            return new ConditionAnalysis(leftProject.getInput().getRowType().getFieldCount(), this.equalitySubConditions.stream().map(equality -> Pair.of((Object)RelOptUtil.pushPastProject((RexNode)((RexNode)equality.lhs), (Project)leftProject), (Object)((RexInputRef)RexUtil.shift((RexNode)((RexNode)equality.rhs), (int)rhsShift)))).collect(Collectors.toList()), this.literalSubConditions);
        }

        public ConditionAnalysis pushThroughRightProject(Project rightProject) {
            Preconditions.checkArgument((boolean)this.onlyUsesMappingsFromRightProject(rightProject), (Object)"Cannot push through");
            return new ConditionAnalysis(this.numLeftFields, this.equalitySubConditions.stream().map(equality -> Pair.of((Object)equality.lhs, (Object)((RexInputRef)RexUtil.shift((RexNode)RelOptUtil.pushPastProject((RexNode)RexUtil.shift((RexNode)((RexNode)equality.rhs), (int)(-this.numLeftFields)), (Project)rightProject), (int)this.numLeftFields)))).collect(Collectors.toList()), this.literalSubConditions);
        }

        public boolean onlyUsesMappingsFromRightProject(Project rightProject) {
            for (Pair<RexNode, RexInputRef> equality : this.equalitySubConditions) {
                int rightIndex = ((RexInputRef)equality.rhs).getIndex() - this.numLeftFields;
                if (((RexNode)rightProject.getProjects().get(rightIndex)).isA(SqlKind.INPUT_REF)) continue;
                return false;
            }
            return true;
        }

        public RexNode getCondition(RexBuilder rexBuilder) {
            return RexUtil.composeConjunction((RexBuilder)rexBuilder, (Iterable)Iterables.concat(this.literalSubConditions, (Iterable)this.equalitySubConditions.stream().map(equality -> rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, new RexNode[]{(RexNode)equality.lhs, (RexNode)equality.rhs})).collect(Collectors.toList())), (boolean)false);
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            ConditionAnalysis that = (ConditionAnalysis)o;
            return Objects.equals(this.equalitySubConditions, that.equalitySubConditions) && Objects.equals(this.literalSubConditions, that.literalSubConditions);
        }

        public int hashCode() {
            return Objects.hash(this.equalitySubConditions, this.literalSubConditions);
        }

        public String toString() {
            return "ConditionAnalysis{equalitySubConditions=" + this.equalitySubConditions + ", literalSubConditions=" + this.literalSubConditions + '}';
        }
    }
}

