/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.confignode.procedure.impl.pipe.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId;
import org.apache.iotdb.common.rpc.thrift.TConsensusGroupType;
import org.apache.iotdb.commons.cluster.NodeStatus;
import org.apache.iotdb.commons.pipe.agent.task.meta.PipeStaticMeta;
import org.apache.iotdb.confignode.manager.ConfigManager;
import org.apache.iotdb.pipe.api.exception.PipeException;

public class PipeExternalSourceLoadBalancer {
    private final BalanceStrategy strategy;

    public PipeExternalSourceLoadBalancer(String balanceStrategy) {
        switch (balanceStrategy) {
            case "proportion": {
                this.strategy = new ProportionalBalanceStrategy();
                break;
            }
            default: {
                throw new IllegalArgumentException("Unknown load balance strategy: " + balanceStrategy);
            }
        }
    }

    public Map<Integer, Integer> balance(int parallelCount, PipeStaticMeta pipeStaticMeta, ConfigManager configManager) {
        return this.strategy.balance(parallelCount, pipeStaticMeta, configManager);
    }

    public static class ProportionalBalanceStrategy
    implements BalanceStrategy {
        @Override
        public Map<Integer, Integer> balance(int parallelCount, PipeStaticMeta pipeStaticMeta, ConfigManager configManager) {
            Map<TConsensusGroupId, Integer> regionLeaderMap = configManager.getLoadManager().getRegionLeaderMap();
            HashMap<Integer, Integer> taskId2DataNodeId = new HashMap<Integer, Integer>();
            if (pipeStaticMeta.getExtractorParameters().getBooleanOrDefault(Arrays.asList("extractor.single-mode", "source.single-mode"), true)) {
                List runningDataNodes = configManager.getLoadManager().filterDataNodeThroughStatus(NodeStatus.Running).stream().sorted().collect(Collectors.toList());
                if (runningDataNodes.isEmpty()) {
                    throw new PipeException("No available datanode to assign tasks");
                }
                int numNodes = runningDataNodes.size();
                for (int i = 1; i <= Math.min(numNodes, parallelCount); ++i) {
                    int datanodeId = (Integer)runningDataNodes.get(i - 1);
                    taskId2DataNodeId.put(-i, datanodeId);
                }
                return taskId2DataNodeId;
            }
            HashMap leaderRegionId2DataRegionCountMap = new HashMap();
            regionLeaderMap.entrySet().stream().filter(e -> ((TConsensusGroupId)e.getKey()).getType() == TConsensusGroupType.DataRegion && (Integer)e.getValue() != -1).forEach(e -> {
                int leaderRegionDataNodeId = (Integer)e.getValue();
                leaderRegionId2DataRegionCountMap.put(leaderRegionDataNodeId, leaderRegionId2DataRegionCountMap.getOrDefault(leaderRegionDataNodeId, 0) + 1);
            });
            if (leaderRegionId2DataRegionCountMap.isEmpty()) {
                List runningDataNodes = configManager.getLoadManager().filterDataNodeThroughStatus(NodeStatus.Running).stream().sorted().collect(Collectors.toList());
                if (runningDataNodes.isEmpty()) {
                    throw new PipeException("No available datanode to assign tasks");
                }
                int numNodes = runningDataNodes.size();
                int quotient = parallelCount / numNodes;
                int remainder = parallelCount % numNodes;
                int taskIndex = 1;
                for (int i = 0; i < numNodes; ++i) {
                    int tasksForNode = quotient + (i < remainder ? 1 : 0);
                    int datanodeId = (Integer)runningDataNodes.get(i);
                    for (int j = 0; j < tasksForNode; ++j) {
                        taskId2DataNodeId.put(-taskIndex, datanodeId);
                        ++taskIndex;
                    }
                }
                return taskId2DataNodeId;
            }
            int totalRegions = leaderRegionId2DataRegionCountMap.values().stream().mapToInt(Integer::intValue).sum();
            HashMap<Integer, Double> leaderRegionId2ExactShareMap = new HashMap<Integer, Double>();
            HashMap<Integer, Integer> leaderRegionId2AssignedCountMap = new HashMap<Integer, Integer>();
            for (Map.Entry entry : leaderRegionId2DataRegionCountMap.entrySet()) {
                double share = (double)(parallelCount * (Integer)entry.getValue()) / (double)totalRegions;
                leaderRegionId2ExactShareMap.put((Integer)entry.getKey(), share);
                leaderRegionId2AssignedCountMap.put((Integer)entry.getKey(), (int)Math.floor(share));
            }
            int remainder = parallelCount - leaderRegionId2AssignedCountMap.values().stream().mapToInt(Integer::intValue).sum();
            List sortedLeaders = leaderRegionId2ExactShareMap.keySet().stream().sorted((l1, l2) -> {
                double diff = (Double)leaderRegionId2ExactShareMap.get(l2) - Math.floor((Double)leaderRegionId2ExactShareMap.get(l2)) - ((Double)leaderRegionId2ExactShareMap.get(l1) - Math.floor((Double)leaderRegionId2ExactShareMap.get(l1)));
                return diff > 0.0 ? 1 : (diff < 0.0 ? -1 : Integer.compare(l1, l2));
            }).collect(Collectors.toList());
            for (int i = 0; i < remainder; ++i) {
                int leaderId = (Integer)sortedLeaders.get(i % sortedLeaders.size());
                leaderRegionId2AssignedCountMap.put(leaderId, (Integer)leaderRegionId2AssignedCountMap.get(leaderId) + 1);
            }
            ArrayList stableLeaders = new ArrayList(leaderRegionId2AssignedCountMap.keySet());
            Collections.sort(stableLeaders);
            int taskIndex = 1;
            for (Integer leader : stableLeaders) {
                int count = (Integer)leaderRegionId2AssignedCountMap.get(leader);
                for (int i = 0; i < count; ++i) {
                    taskId2DataNodeId.put(-taskIndex, leader);
                    ++taskIndex;
                }
            }
            return taskId2DataNodeId;
        }
    }

    private static interface BalanceStrategy {
        public Map<Integer, Integer> balance(int var1, PipeStaticMeta var2, ConfigManager var3);
    }
}

