/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.tony;

import com.google.common.annotations.VisibleForTesting;
import com.linkedin.tony.TonyConfigurationKeys;
import com.linkedin.tony.tensorflow.JobContainerRequest;
import com.linkedin.tony.tensorflow.TonySession;
import com.linkedin.tony.util.Utils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.yarn.api.records.FinalApplicationStatus;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.Priority;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.client.api.AMRMClient;
import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;

public class TaskScheduler {
    private static final Log LOG = LogFactory.getLog(TaskScheduler.class);
    private TonySession session;
    private AMRMClientAsync<AMRMClient.ContainerRequest> amRMClient;
    private FileSystem resourceFs;
    private Configuration tonyConf;
    private Map<JobContainerRequest, Map<String, Integer>> taskDependencyMap = new HashMap<JobContainerRequest, Map<String, Integer>>();
    private Map<String, LocalResource> localResources;
    private Map<String, List<AMRMClient.ContainerRequest>> jobTypeToContainerRequestsMap = new HashMap<String, List<AMRMClient.ContainerRequest>>();
    private Map<String, Map<String, LocalResource>> jobTypeToContainerResources;
    boolean dependencyCheckPassed = true;

    public TaskScheduler(TonySession session, AMRMClientAsync<AMRMClient.ContainerRequest> amRMClient, Map<String, LocalResource> localResources, FileSystem resourceFs, Configuration tonyConf, Map<String, Map<String, LocalResource>> jobTypeToContainerResources) {
        this.session = session;
        this.amRMClient = amRMClient;
        this.localResources = localResources;
        this.resourceFs = resourceFs;
        this.tonyConf = tonyConf;
        this.jobTypeToContainerResources = jobTypeToContainerResources;
    }

    public void scheduleTasks() {
        List<JobContainerRequest> requests = this.session.getContainersRequests();
        if (!TaskScheduler.isDAG(requests)) {
            LOG.error((Object)"TonY execution graph does not form a DAG, exiting.");
            this.session.setFinalStatus(FinalApplicationStatus.FAILED, "App failed due to it not being a DAG.");
            this.dependencyCheckPassed = false;
            return;
        }
        this.buildTaskDependencyGraph(requests);
        for (JobContainerRequest request : requests) {
            if (!this.checkDependencySatisfied(request)) continue;
            this.scheduleJob(request);
        }
    }

    private void buildTaskDependencyGraph(List<JobContainerRequest> requests) {
        for (JobContainerRequest request : requests) {
            for (String dependsOn : request.getDependsOn()) {
                if (dependsOn.isEmpty()) continue;
                this.taskDependencyMap.putIfAbsent(request, new HashMap());
                Map<String, Integer> dependenciesForTask = this.taskDependencyMap.get(request);
                dependenciesForTask.put(dependsOn, this.session.getContainerRequestForType(dependsOn).getNumInstances());
                this.taskDependencyMap.put(request, dependenciesForTask);
            }
        }
    }

    @VisibleForTesting
    boolean checkDependencySatisfied(JobContainerRequest request) {
        return this.taskDependencyMap.get(request) == null || this.taskDependencyMap.get(request).isEmpty();
    }

    private void scheduleJob(JobContainerRequest request) {
        AMRMClient.ContainerRequest containerAsk = this.setupContainerRequestForRM(request);
        String jobName = request.getJobName();
        if (!this.jobTypeToContainerRequestsMap.containsKey(jobName)) {
            this.jobTypeToContainerRequestsMap.put(jobName, new ArrayList());
            this.jobTypeToContainerResources.put(jobName, this.getContainerResources(jobName));
        }
        this.jobTypeToContainerRequestsMap.get(request.getJobName()).add(containerAsk);
        for (int i = 0; i < request.getNumInstances(); ++i) {
            this.amRMClient.addContainerRequest(containerAsk);
        }
        this.session.addNumExpectedTask(request.getNumInstances());
    }

    private AMRMClient.ContainerRequest setupContainerRequestForRM(JobContainerRequest request) {
        Priority priority = Priority.newInstance((int)request.getPriority());
        Resource capability = Resource.newInstance((int)((int)request.getMemory()), (int)request.getVCores());
        Utils.setCapabilityGPU(capability, request.getGPU());
        AMRMClient.ContainerRequest containerRequest = new AMRMClient.ContainerRequest(capability, null, null, priority, true, request.getNodeLabelsExpression());
        LOG.info((Object)("Requested container ask: " + containerRequest.toString()));
        return containerRequest;
    }

    private Map<String, LocalResource> getContainerResources(String jobName) {
        ConcurrentHashMap<String, LocalResource> containerResources = new ConcurrentHashMap<String, LocalResource>(this.localResources);
        String[] resources = this.tonyConf.getStrings(TonyConfigurationKeys.getResourcesKey(jobName));
        Utils.addResources(resources, containerResources, this.resourceFs);
        resources = this.tonyConf.getStrings(TonyConfigurationKeys.getContainerResourcesKey());
        Utils.addResources(resources, containerResources, this.resourceFs);
        return containerResources;
    }

    synchronized void registerDependencyCompleted(String jobName) {
        this.taskDependencyMap.forEach((k, v) -> {
            if (v.containsKey(jobName)) {
                int numContainersLeft = (Integer)v.get(jobName);
                if (--numContainersLeft == 0) {
                    v.remove(jobName);
                } else {
                    v.put(jobName, numContainersLeft);
                }
            }
        });
        Iterator<JobContainerRequest> waitingRequestItr = this.taskDependencyMap.keySet().iterator();
        while (waitingRequestItr.hasNext()) {
            JobContainerRequest waitingRequest = waitingRequestItr.next();
            if (!this.checkDependencySatisfied(waitingRequest)) continue;
            waitingRequestItr.remove();
            this.scheduleJob(waitingRequest);
        }
    }

    static boolean isDAG(List<JobContainerRequest> containersRequests) {
        HashSet<JobContainerRequest> visited = new HashSet<JobContainerRequest>();
        for (JobContainerRequest containerRequest : containersRequests) {
            if (visited.contains(containerRequest) || TaskScheduler.isSubgraphDAG(containerRequest, new ArrayList<JobContainerRequest>(), containersRequests, visited)) continue;
            return false;
        }
        return true;
    }

    static boolean isSubgraphDAG(JobContainerRequest node, List<JobContainerRequest> pathTrace, List<JobContainerRequest> containerRequests, Set<JobContainerRequest> visited) {
        if (pathTrace.contains(node)) {
            return false;
        }
        if (visited.contains(node)) {
            return true;
        }
        pathTrace.add(node);
        visited.add(node);
        List dependencies = containerRequests.stream().filter(x -> node.getDependsOn().contains(x.getJobName())).collect(Collectors.toList());
        for (JobContainerRequest dependency : dependencies) {
            if (TaskScheduler.isSubgraphDAG(dependency, pathTrace, containerRequests, visited)) continue;
            return false;
        }
        pathTrace.remove(node);
        return true;
    }
}

