/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.tony;

import com.google.common.annotations.VisibleForTesting;
import com.linkedin.tony.TaskMonitor;
import com.linkedin.tony.TonyConfigurationKeys;
import com.linkedin.tony.rpc.MetricsRpc;
import com.linkedin.tony.rpc.impl.ApplicationRpcClient;
import com.linkedin.tony.util.Utils;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.yarn.api.ApplicationConstants;
import org.apache.hadoop.yarn.api.records.ContainerId;

public class TaskExecutor {
    private static final Log LOG = LogFactory.getLog(TaskExecutor.class);
    private static final int MAX_NUM_FAILED_HB_ATTEMPTS = 5;
    @VisibleForTesting
    protected Configuration tonyConf = new Configuration(false);
    private ServerSocket rpcSocket;
    private int rpcPort;
    private ServerSocket tbSocket;
    private int tbPort;
    private int timeOut;
    private String amHost;
    private int amPort;
    private MetricsRpc metricsProxy;
    private int metricsRPCPort;
    private int metricsIntervalMs;
    private String taskCommand;
    private String clusterSpec;
    private String jobName;
    private int taskIndex;
    private String taskId;
    private int numTasks;
    private boolean isChief;
    private Configuration yarnConf = new Configuration(false);
    private Configuration hdfsConf = new Configuration(false);
    private ApplicationRpcClient proxy;
    private Map<String, String> shellEnv = new HashMap<String, String>();
    private int hbInterval;
    private final ScheduledExecutorService scheduledThreadPool = Executors.newScheduledThreadPool(2);
    private int numFailedHBAttempts = 0;
    private TonyConfigurationKeys.MLFramework framework;

    protected TaskExecutor() {
    }

    private void setupPorts() throws IOException {
        this.rpcSocket = new ServerSocket(0);
        this.rpcPort = this.rpcSocket.getLocalPort();
        LOG.info((Object)("Reserved rpcPort: " + this.rpcPort));
        if (this.isChief) {
            this.tbSocket = new ServerSocket(0);
            this.tbPort = this.tbSocket.getLocalPort();
            this.registerTensorBoardUrl();
            this.shellEnv.put("TB_PORT", String.valueOf(this.tbPort));
            LOG.info((Object)("Reserved tbPort: " + this.tbPort));
        }
    }

    private void releasePorts() throws IOException {
        try {
            if (this.rpcSocket != null) {
                this.rpcSocket.close();
            }
        }
        finally {
            if (this.tbSocket != null) {
                this.tbSocket.close();
            }
        }
    }

    private static TaskExecutor createExecutor() throws Exception {
        TaskExecutor executor = new TaskExecutor();
        executor.initConfigs();
        Utils.extractResources();
        LOG.info((Object)("Setting up application RPC client, connecting to: " + executor.amHost + ":" + executor.amPort));
        executor.proxy = ApplicationRpcClient.getInstance(executor.amHost, executor.amPort, executor.yarnConf);
        LOG.info((Object)("Setting up metrics RPC client, connecting to: " + executor.amHost + ":" + executor.metricsRPCPort));
        executor.metricsProxy = (MetricsRpc)RPC.getProxy(MetricsRpc.class, (long)RPC.getProtocolVersion(MetricsRpc.class), (InetSocketAddress)new InetSocketAddress(executor.amHost, executor.metricsRPCPort), (Configuration)executor.yarnConf);
        executor.scheduledThreadPool.scheduleAtFixedRate(new TaskMonitor(executor.jobName, executor.taskIndex, executor.yarnConf, executor.tonyConf, executor.metricsProxy), 0L, executor.metricsIntervalMs, TimeUnit.MILLISECONDS);
        executor.setupPorts();
        executor.clusterSpec = executor.registerAndGetClusterSpec();
        if (executor.clusterSpec == null) {
            LOG.error((Object)"Failed to register worker with AM.");
            throw new Exception("Failed to register worker with AM.");
        }
        LOG.info((Object)("Successfully registered and got cluster spec: " + executor.clusterSpec));
        switch (executor.framework) {
            case TENSORFLOW: {
                LOG.info((Object)"Setting up TensorFlow job...");
                executor.shellEnv.put("JOB_NAME", String.valueOf(executor.jobName));
                executor.shellEnv.put("TASK_INDEX", String.valueOf(executor.taskIndex));
                executor.shellEnv.put("CLUSTER_SPEC", String.valueOf(executor.clusterSpec));
                executor.shellEnv.put("TF_CONFIG", Utils.constructTFConfig(executor.clusterSpec, executor.jobName, executor.taskIndex));
                break;
            }
            case PYTORCH: {
                LOG.info((Object)"Setting up PyTorch job...");
                String initMethod = Utils.parseClusterSpecForPytorch(executor.clusterSpec);
                if (initMethod == null) {
                    System.exit(-1);
                }
                LOG.info((Object)("Init method is: " + initMethod));
                executor.shellEnv.put("INIT_METHOD", initMethod);
                executor.shellEnv.put("RANK", String.valueOf(executor.taskIndex));
                executor.shellEnv.put("WORLD", String.valueOf(executor.numTasks));
                break;
            }
            case MXNET: {
                LOG.info((Object)"Setting up MXNet job...");
                String[] dmlcServer = Utils.parseClusterSpecForMXNet(executor.clusterSpec);
                if (dmlcServer == null) {
                    System.exit(-1);
                }
                int numServer = executor.tonyConf.getInt(TonyConfigurationKeys.getInstancesKey("server"), 0);
                int numWorker = executor.tonyConf.getInt(TonyConfigurationKeys.getInstancesKey("worker"), 0);
                LOG.info((Object)("init DMLC is: " + dmlcServer[0] + " port: " + dmlcServer[1]));
                LOG.info((Object)("init DMLC ROLE: " + executor.jobName));
                LOG.info((Object)("init DMLC NUM_PS: " + numServer));
                LOG.info((Object)("init DMLC NUM_WORKER: " + numWorker));
                executor.shellEnv.put("DMLC_ROLE", executor.jobName);
                executor.shellEnv.put("DMLC_PS_ROOT_URI", dmlcServer[0]);
                executor.shellEnv.put("DMLC_PS_ROOT_PORT", dmlcServer[1]);
                executor.shellEnv.put("DMLC_LOCAL", "0");
                executor.shellEnv.put("DMLC_NUM_SERVER", String.valueOf(numServer));
                executor.shellEnv.put("DMLC_NUM_WORKER", String.valueOf(numWorker));
                break;
            }
            case HOROVOD: {
                break;
            }
            default: {
                throw new RuntimeException("Unsupported executor framework: " + (Object)((Object)executor.framework));
            }
        }
        return executor;
    }

    public static void main(String[] unused) throws Exception {
        LOG.info((Object)"TaskExecutor is running..");
        TaskExecutor executor = null;
        try {
            executor = Objects.requireNonNull(TaskExecutor.createExecutor());
        }
        finally {
            if (executor != null) {
                executor.releasePorts();
            }
        }
        int exitCode = Utils.executeShell(executor.taskCommand, executor.timeOut, executor.shellEnv);
        executor.skewAndHangIfTesting();
        executor.registerExecutionResult(exitCode, executor.jobName, String.valueOf(executor.taskIndex));
        LOG.info((Object)("Child process exited with exit code " + exitCode));
        System.exit(exitCode);
    }

    protected void initConfigs() {
        this.jobName = System.getenv("JOB_NAME");
        this.taskIndex = Integer.parseInt(System.getenv("TASK_INDEX"));
        this.numTasks = Integer.parseInt(System.getenv("TASK_NUM"));
        this.taskId = this.jobName + ":" + this.taskIndex;
        LOG.info((Object)("Executor is running task " + this.taskId));
        String isChiefEnvValue = System.getenv("IS_CHIEF");
        this.isChief = Boolean.parseBoolean(isChiefEnvValue);
        this.amHost = System.getenv("AM_HOST");
        this.amPort = Integer.parseInt(System.getenv("AM_PORT"));
        this.tonyConf.addResource(new Path("tony-final.xml"));
        this.timeOut = this.tonyConf.getInt("tony.worker.timeout", 0);
        this.hbInterval = this.tonyConf.getInt("tony.task.heartbeat-interval-ms", 1000);
        String[] shellEnvs = this.tonyConf.getStrings("tony.execution.envs");
        this.shellEnv = Utils.parseKeyValue(shellEnvs);
        this.taskCommand = this.tonyConf.get(TonyConfigurationKeys.getExecuteCommandKey(this.jobName), this.tonyConf.get(TonyConfigurationKeys.getContainerExecuteCommandKey()));
        if (this.taskCommand == null) {
            LOG.fatal((Object)"Task command is empty. Please set tony.[jobtype].command or pass --executes in command line");
            throw new IllegalArgumentException("Task command is empty.");
        }
        LOG.info((Object)("Task command: " + this.taskCommand));
        this.framework = TonyConfigurationKeys.MLFramework.valueOf(this.tonyConf.get("tony.application.framework", "tensorflow").toUpperCase());
        this.metricsRPCPort = Integer.parseInt(System.getenv("METRICS_RPC_PORT"));
        this.metricsIntervalMs = this.tonyConf.getInt("tony.task.metrics-interval-ms", 5000);
        Utils.initYarnConf(this.yarnConf);
        Utils.initHdfsConf(this.hdfsConf);
    }

    private String registerAndGetClusterSpec() {
        ContainerId containerId = ContainerId.fromString((String)System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()));
        String hostName = Utils.getCurrentHostName();
        LOG.info((Object)("ContainerId is: " + containerId + " HostName is: " + hostName));
        this.scheduledThreadPool.scheduleAtFixedRate(new Heartbeater(), 0L, this.hbInterval, TimeUnit.MILLISECONDS);
        LOG.info((Object)("Connecting to " + this.amHost + ":" + this.amPort + " to register worker spec: " + this.jobName + " " + this.taskIndex + " " + hostName + ":" + this.rpcPort));
        return Utils.pollTillNonNull(() -> this.proxy.registerWorkerSpec(this.jobName + ":" + this.taskIndex, hostName + ":" + this.rpcPort), 3, 0);
    }

    private void registerTensorBoardUrl() {
        String hostName = Utils.getCurrentHostName();
        String tbUrl = hostName + ":" + this.tbPort;
        LOG.info((Object)("TensorBoard address : " + tbUrl));
        String response = Utils.pollTillNonNull(() -> this.proxy.registerTensorBoardUrl(tbUrl), 1, 60);
        if (response != null) {
            LOG.info((Object)("Register TensorBoard response: " + response));
        }
    }

    private void registerExecutionResult(int exitCode, String jobName, String jobIndex) {
        String sessionId = System.getenv("SESSION_ID");
        String response = Utils.pollTillNonNull(() -> this.proxy.registerExecutionResult(exitCode, jobName, jobIndex, sessionId), 1, 60);
        if (response != null) {
            LOG.info((Object)("AM response for result execution run: " + response));
        }
    }

    private void skewAndHangIfTesting() {
        String skewInstr = System.getenv("TEST_TASK_EXECUTOR_SKEW");
        if (skewInstr != null) {
            String[] instr = skewInstr.split("#");
            try {
                if (instr.length == 3 && instr[0].equals(this.jobName) && Integer.parseInt(instr[1]) == this.taskIndex) {
                    int waitTime = Integer.parseInt(instr[2]);
                    LOG.info((Object)("Will sleep for [" + waitTime + "] as instructed to simulate skew"));
                    try {
                        Thread.sleep(waitTime);
                    }
                    catch (InterruptedException e) {
                        LOG.error((Object)"Thread interrupted while hanging..", (Throwable)e);
                    }
                }
            }
            catch (Exception e) {
                LOG.error((Object)"Got Exception while parsing skew instruction", (Throwable)e);
            }
        }
    }

    private class Heartbeater
    implements Runnable {
        int hbMissCounter = 0;
        int numHbToMiss;

        private Heartbeater() {
            String hbMissStr = System.getenv("TEST_TASK_EXECUTOR_NUM_HB_MISS");
            try {
                int numMisses = Integer.parseInt(hbMissStr);
                if (numMisses > 0) {
                    this.numHbToMiss = numMisses;
                }
            }
            catch (Exception e) {
                this.numHbToMiss = 0;
            }
        }

        @Override
        public void run() {
            try {
                if (this.hbMissCounter == 0) {
                    LOG.debug((Object)("[" + TaskExecutor.this.taskId + "] Sending Ping !!"));
                    TaskExecutor.this.proxy.taskExecutorHeartbeat(TaskExecutor.this.taskId);
                    TaskExecutor.this.numFailedHBAttempts = 0;
                    this.hbMissCounter = this.numHbToMiss;
                } else {
                    LOG.debug((Object)("[" + TaskExecutor.this.taskId + "] Skipping heartbeat for Testing !!"));
                    --this.hbMissCounter;
                }
            }
            catch (Exception e) {
                LOG.error((Object)("[" + TaskExecutor.this.taskId + "] Failed to send Heart Beat."), (Throwable)e);
                if (++TaskExecutor.this.numFailedHBAttempts > 5) {
                    LOG.error((Object)("[" + TaskExecutor.this.taskId + "] Exceeded max number of allowed failed heart beat send attempts. Going to stop heartbeating!"));
                    e.printStackTrace();
                    throw new RuntimeException(e);
                }
                LOG.warn((Object)"Will retry heartbeat..");
            }
        }
    }
}

