/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.spark;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import org.apache.hadoop.hive.ql.exec.CommonJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.spark.SparkTask;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Dispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.PreOrderWalker;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver;
import org.apache.hadoop.hive.ql.optimizer.physical.SkewJoinResolver;
import org.apache.hadoop.hive.ql.optimizer.spark.SparkSkewJoinProcFactory;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ReduceWork;
import org.apache.hadoop.hive.ql.plan.SparkWork;

public class SparkSkewJoinResolver
implements PhysicalPlanResolver {
    @Override
    public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException {
        SparkSkewJoinProcFactory.getVisitedJoinOp().clear();
        SparkSkewJoinTaskDispatcher disp = new SparkSkewJoinTaskDispatcher(pctx);
        PreOrderWalker ogw = new PreOrderWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getRootTasks());
        ogw.startWalking(topNodes, null);
        return pctx;
    }

    public static class SparkSkewJoinProcCtx
    extends SkewJoinResolver.SkewJoinProcCtx {
        private Map<Operator<?>, ReduceWork> reducerToReduceWork = new HashMap();

        public SparkSkewJoinProcCtx(Task<? extends Serializable> task, ParseContext parseCtx) {
            super(task, parseCtx);
        }

        public Map<Operator<?>, ReduceWork> getReducerToReduceWork() {
            return this.reducerToReduceWork;
        }
    }

    class SparkSkewJoinTaskDispatcher
    implements Dispatcher {
        private PhysicalContext physicalContext;

        public SparkSkewJoinTaskDispatcher(PhysicalContext context) {
            this.physicalContext = context;
        }

        @Override
        public Object dispatch(Node nd, Stack<Node> stack, Object ... nodeOutputs) throws SemanticException {
            Task task = (Task)nd;
            if (task instanceof SparkTask) {
                SparkWork sparkWork = (SparkWork)((SparkTask)task).getWork();
                SparkSkewJoinProcCtx skewJoinProcCtx = new SparkSkewJoinProcCtx(task, this.physicalContext.getParseContext());
                LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
                opRules.put(new RuleRegExp("R1", CommonJoinOperator.getOperatorName() + "%"), SparkSkewJoinProcFactory.getJoinProc());
                DefaultRuleDispatcher disp = new DefaultRuleDispatcher(SparkSkewJoinProcFactory.getDefaultProc(), opRules, skewJoinProcCtx);
                DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
                ArrayList<Node> topNodes = new ArrayList<Node>();
                List<ReduceWork> reduceWorkList = sparkWork.getAllReduceWork();
                Collections.reverse(reduceWorkList);
                for (ReduceWork reduceWork : reduceWorkList) {
                    topNodes.add(reduceWork.getReducer());
                    skewJoinProcCtx.getReducerToReduceWork().put(reduceWork.getReducer(), reduceWork);
                }
                ogw.startWalking(topNodes, null);
            }
            return null;
        }

        public PhysicalContext getPhysicalContext() {
            return this.physicalContext;
        }

        public void setPhysicalContext(PhysicalContext physicalContext) {
            this.physicalContext = physicalContext;
        }
    }
}

