package org.apache.druid.server.coordinator.loading;

import com.google.common.collect.Sets;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.druid.client.DruidServer;
import org.apache.druid.server.coordinator.DruidCluster;
import org.apache.druid.server.coordinator.ServerHolder;
import org.apache.druid.server.coordinator.balancer.BalancerStrategy;
import org.apache.druid.server.coordinator.rules.SegmentActionHandler;
import org.apache.druid.server.coordinator.stats.CoordinatorRunStats;
import org.apache.druid.server.coordinator.stats.CoordinatorStat;
import org.apache.druid.server.coordinator.stats.Dimension;
import org.apache.druid.server.coordinator.stats.RowKey;
import org.apache.druid.server.coordinator.stats.Stats;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.SegmentId;

@NotThreadSafe
/* loaded from: input_file:org/apache/druid/server/coordinator/loading/StrategicSegmentAssigner.class */
public class StrategicSegmentAssigner implements SegmentActionHandler {
    private final SegmentLoadQueueManager loadQueueManager;
    private final DruidCluster cluster;
    private final CoordinatorRunStats stats;
    private final SegmentReplicaCountMap replicaCountMap;
    private final ReplicationThrottler replicationThrottler;
    private final RoundRobinServerSelector serverSelector;
    private final BalancerStrategy strategy;
    private final boolean useRoundRobinAssignment;
    private final Map<String, Set<String>> datasourceToInvalidLoadTiers = new HashMap();
    private final Map<String, Integer> tierToHistoricalCount = new HashMap();
    private final Map<String, Set<SegmentId>> segmentsToDelete = new HashMap();
    private final Map<String, Set<DataSegment>> segmentsWithZeroRequiredReplicas = new HashMap();
    private final Set<DataSegment> broadcastSegments = new HashSet();

    public StrategicSegmentAssigner(SegmentLoadQueueManager segmentLoadQueueManager, DruidCluster druidCluster, BalancerStrategy balancerStrategy, SegmentLoadingConfig segmentLoadingConfig, CoordinatorRunStats coordinatorRunStats) {
        this.stats = coordinatorRunStats;
        this.cluster = druidCluster;
        this.strategy = balancerStrategy;
        this.loadQueueManager = segmentLoadQueueManager;
        this.replicaCountMap = SegmentReplicaCountMap.create(druidCluster);
        this.replicationThrottler = createReplicationThrottler(druidCluster, segmentLoadingConfig);
        this.useRoundRobinAssignment = segmentLoadingConfig.isUseRoundRobinSegmentAssignment();
        this.serverSelector = this.useRoundRobinAssignment ? new RoundRobinServerSelector(druidCluster) : null;
        druidCluster.getHistoricals().forEach((str, navigableSet) -> {
            this.tierToHistoricalCount.put(str, Integer.valueOf(navigableSet.size()));
        });
    }

    public SegmentReplicationStatus getReplicationStatus() {
        return this.replicaCountMap.toReplicationStatus();
    }

    public Map<String, Set<SegmentId>> getSegmentsToDelete() {
        return this.segmentsToDelete;
    }

    public Map<String, Set<DataSegment>> getSegmentsWithZeroRequiredReplicas() {
        return this.segmentsWithZeroRequiredReplicas;
    }

    public Map<String, Set<String>> getDatasourceToInvalidLoadTiers() {
        return this.datasourceToInvalidLoadTiers;
    }

    public boolean moveSegment(DataSegment dataSegment, ServerHolder serverHolder, List<ServerHolder> list) {
        String tier = serverHolder.getServer().getTier();
        List<ServerHolder> list2 = (List) list.stream().filter(serverHolder2 -> {
            return serverHolder2.getServer().getTier().equals(tier);
        }).filter(serverHolder3 -> {
            return serverHolder3.canLoadSegment(dataSegment);
        }).collect(Collectors.toList());
        if (list2.isEmpty()) {
            incrementSkipStat(Stats.Segments.MOVE_SKIPPED, "No eligible server", dataSegment, tier);
            return false;
        }
        if (!serverHolder.isDecommissioning()) {
            list2.add(serverHolder);
        }
        ServerHolder findDestinationServerToMoveSegment = this.strategy.findDestinationServerToMoveSegment(dataSegment, serverHolder, list2);
        if (findDestinationServerToMoveSegment == null || findDestinationServerToMoveSegment.getServer().equals(serverHolder.getServer())) {
            incrementSkipStat(Stats.Segments.MOVE_SKIPPED, "Optimally placed", dataSegment, tier);
            return false;
        }
        if (moveSegment(dataSegment, serverHolder, findDestinationServerToMoveSegment)) {
            incrementStat(Stats.Segments.MOVED, dataSegment, tier, 1L);
            return true;
        }
        incrementSkipStat(Stats.Segments.MOVE_SKIPPED, "Encountered error", dataSegment, tier);
        return false;
    }

    private boolean moveSegment(DataSegment dataSegment, ServerHolder serverHolder, ServerHolder serverHolder2) {
        String tier = serverHolder.getServer().getTier();
        if (serverHolder.isLoadingSegment(dataSegment)) {
            if (serverHolder.cancelOperation(SegmentAction.LOAD, dataSegment)) {
                return this.replicaCountMap.get(dataSegment.getId(), tier).loadedNotDropping() >= 1 ? replicateSegment(dataSegment, serverHolder2) : loadSegment(dataSegment, serverHolder2);
            }
            return false;
        }
        if (serverHolder.isServingSegment(dataSegment)) {
            return this.loadQueueManager.moveSegment(dataSegment, serverHolder, serverHolder2);
        }
        return false;
    }

    @Override // org.apache.druid.server.coordinator.rules.SegmentActionHandler
    public void replicateSegment(DataSegment dataSegment, Map<String, Integer> map) {
        HashSet<String> newHashSet = Sets.newHashSet(this.cluster.getTierNames());
        if (map.isEmpty()) {
            this.replicaCountMap.computeIfAbsent(dataSegment.getId(), DruidServer.DEFAULT_TIER);
        } else {
            map.forEach((str, num) -> {
                reportTierCapacityStats(dataSegment, num.intValue(), str);
                this.replicaCountMap.computeIfAbsent(dataSegment.getId(), str).setRequired(num.intValue(), this.tierToHistoricalCount.getOrDefault(str, 0).intValue());
                if (newHashSet.contains(str)) {
                    return;
                }
                this.datasourceToInvalidLoadTiers.computeIfAbsent(dataSegment.getDataSource(), str -> {
                    return new HashSet();
                }).add(str);
            });
        }
        SegmentReplicaCount total = this.replicaCountMap.getTotal(dataSegment.getId());
        if (total.required() <= 0) {
            this.segmentsWithZeroRequiredReplicas.computeIfAbsent(dataSegment.getDataSource(), str2 -> {
                return new HashSet();
            }).add(dataSegment);
        }
        int loadedNotDropping = total.loadedNotDropping() - total.requiredAndLoadable();
        int i = 0;
        for (String str3 : newHashSet) {
            i += updateReplicasInTier(dataSegment, str3, map.getOrDefault(str3, 0).intValue(), loadedNotDropping - i);
        }
    }

    private int updateReplicasInTier(DataSegment dataSegment, String str, int i, int i2) {
        int i3;
        int cancelOperations;
        int cancelOperations2;
        SegmentReplicaCount segmentReplicaCount = this.replicaCountMap.get(dataSegment.getId(), str);
        int loadedNotDropping = segmentReplicaCount.loadedNotDropping() + segmentReplicaCount.loading();
        int moving = segmentReplicaCount.moving();
        boolean z = i == 0 && moving > 0;
        if (loadedNotDropping == i && !z) {
            return 0;
        }
        SegmentStatusInTier segmentStatusInTier = new SegmentStatusInTier(dataSegment, this.cluster.getHistoricalsByTier(str));
        if (z) {
            cancelOperations(SegmentAction.MOVE_TO, moving, dataSegment, segmentStatusInTier);
            cancelOperations(SegmentAction.MOVE_FROM, moving, dataSegment, segmentStatusInTier);
        }
        if (loadedNotDropping < i && (cancelOperations2 = i3 - (cancelOperations = cancelOperations(SegmentAction.DROP, (i3 = i - loadedNotDropping), dataSegment, segmentStatusInTier))) > 0) {
            incrementStat(Stats.Segments.ASSIGNED, dataSegment, str, loadReplicas(cancelOperations2, segmentReplicaCount.loadedNotDropping() + cancelOperations, dataSegment, str, segmentStatusInTier));
        }
        if (loadedNotDropping <= i) {
            return 0;
        }
        int i4 = loadedNotDropping - i;
        int min = Math.min(i4 - cancelOperations(SegmentAction.LOAD, i4, dataSegment, segmentStatusInTier), i2);
        if (min <= 0) {
            return 0;
        }
        int dropReplicas = dropReplicas(min, dataSegment, str, segmentStatusInTier);
        incrementStat(Stats.Segments.DROPPED, dataSegment, str, dropReplicas);
        return dropReplicas;
    }

    private void reportTierCapacityStats(DataSegment dataSegment, int i, String str) {
        RowKey of = RowKey.of(Dimension.TIER, str);
        this.stats.updateMax(Stats.Tier.REPLICATION_FACTOR, of, i);
        this.stats.add(Stats.Tier.REQUIRED_CAPACITY, of, dataSegment.getSize() * i);
    }

    @Override // org.apache.druid.server.coordinator.rules.SegmentActionHandler
    public void broadcastSegment(DataSegment dataSegment) {
        Object2IntOpenHashMap object2IntOpenHashMap = new Object2IntOpenHashMap();
        for (ServerHolder serverHolder : this.cluster.getAllServers()) {
            if (serverHolder.getServer().getType().isSegmentBroadcastTarget()) {
                String tier = serverHolder.getServer().getTier();
                int i = 0;
                int i2 = 0;
                if (serverHolder.isDecommissioning()) {
                    i = 0 + (dropBroadcastSegment(dataSegment, serverHolder) ? 1 : 0);
                } else {
                    object2IntOpenHashMap.addTo(tier, 1);
                    i2 = 0 + (loadBroadcastSegment(dataSegment, serverHolder) ? 1 : 0);
                }
                if (i2 > 0) {
                    incrementStat(Stats.Segments.ASSIGNED, dataSegment, tier, i2);
                }
                if (i > 0) {
                    incrementStat(Stats.Segments.DROPPED, dataSegment, tier, i);
                }
            }
        }
        object2IntOpenHashMap.object2IntEntrySet().fastForEach(entry -> {
            this.replicaCountMap.computeIfAbsent(dataSegment.getId(), (String) entry.getKey()).setRequired(entry.getIntValue(), entry.getIntValue());
        });
        this.broadcastSegments.add(dataSegment);
    }

    @Override // org.apache.druid.server.coordinator.rules.SegmentActionHandler
    public void deleteSegment(DataSegment dataSegment) {
        this.segmentsToDelete.computeIfAbsent(dataSegment.getDataSource(), str -> {
            return new HashSet();
        }).add(dataSegment.getId());
    }

    private boolean loadBroadcastSegment(DataSegment dataSegment, ServerHolder serverHolder) {
        if (serverHolder.isServingSegment(dataSegment) || serverHolder.isLoadingSegment(dataSegment)) {
            return false;
        }
        if (serverHolder.isDroppingSegment(dataSegment)) {
            return serverHolder.cancelOperation(SegmentAction.DROP, dataSegment);
        }
        if (serverHolder.canLoadSegment(dataSegment)) {
            return loadSegment(dataSegment, serverHolder);
        }
        incrementSkipStat(Stats.Segments.ASSIGN_SKIPPED, serverHolder.getAvailableSize() < dataSegment.getSize() ? "Not enough disk space" : serverHolder.isLoadQueueFull() ? "Load queue is full" : "Unknown error", dataSegment, serverHolder.getServer().getTier());
        return false;
    }

    public Set<DataSegment> getBroadcastSegments() {
        return this.broadcastSegments;
    }

    private boolean dropBroadcastSegment(DataSegment dataSegment, ServerHolder serverHolder) {
        if (serverHolder.isLoadingSegment(dataSegment)) {
            return serverHolder.cancelOperation(SegmentAction.LOAD, dataSegment);
        }
        if (serverHolder.isServingSegment(dataSegment)) {
            return this.loadQueueManager.dropSegment(dataSegment, serverHolder);
        }
        return false;
    }

    private int dropReplicas(int i, DataSegment dataSegment, String str, SegmentStatusInTier segmentStatusInTier) {
        if (i <= 0) {
            return 0;
        }
        List<ServerHolder> serversEligibleToDrop = segmentStatusInTier.getServersEligibleToDrop();
        if (serversEligibleToDrop.isEmpty()) {
            incrementSkipStat(Stats.Segments.DROP_SKIPPED, "No eligible server", dataSegment, str);
            return 0;
        }
        TreeSet treeSet = new TreeSet(Comparator.reverseOrder());
        TreeSet treeSet2 = new TreeSet(Comparator.reverseOrder());
        for (ServerHolder serverHolder : serversEligibleToDrop) {
            if (serverHolder.isDecommissioning()) {
                treeSet2.add(serverHolder);
            } else {
                treeSet.add(serverHolder);
            }
        }
        int dropReplicasFromServers = dropReplicasFromServers(i, dataSegment, treeSet2.iterator(), str);
        if (i > dropReplicasFromServers) {
            int i2 = i - dropReplicasFromServers;
            dropReplicasFromServers += dropReplicasFromServers(i2, dataSegment, (this.useRoundRobinAssignment || treeSet.size() <= i2) ? treeSet.iterator() : this.strategy.findServersToDropSegment(dataSegment, new ArrayList(treeSet)), str);
        }
        return dropReplicasFromServers;
    }

    private int dropReplicasFromServers(int i, DataSegment dataSegment, Iterator<ServerHolder> it, String str) {
        int i2 = 0;
        while (i > i2 && it.hasNext()) {
            if (this.loadQueueManager.dropSegment(dataSegment, it.next())) {
                i2++;
            } else {
                incrementSkipStat(Stats.Segments.DROP_SKIPPED, "Encountered error", dataSegment, str);
            }
        }
        return i2;
    }

    private int loadReplicas(int i, int i2, DataSegment dataSegment, String str, SegmentStatusInTier segmentStatusInTier) {
        int i3;
        boolean z = i2 >= 1;
        if (z && this.replicationThrottler.isReplicationThrottledForTier(str)) {
            return 0;
        }
        List<ServerHolder> serversEligibleToLoad = segmentStatusInTier.getServersEligibleToLoad();
        if (serversEligibleToLoad.isEmpty()) {
            incrementSkipStat(Stats.Segments.ASSIGN_SKIPPED, "No eligible server", dataSegment, str);
            return 0;
        }
        Iterator<ServerHolder> serversInTierToLoadSegment = this.useRoundRobinAssignment ? this.serverSelector.getServersInTierToLoadSegment(str, dataSegment) : this.strategy.findServersToLoadSegment(dataSegment, serversEligibleToLoad);
        if (!serversInTierToLoadSegment.hasNext()) {
            incrementSkipStat(Stats.Segments.ASSIGN_SKIPPED, "No strategic server", dataSegment, str);
            return 0;
        }
        int i4 = 0;
        while (true) {
            i3 = i4;
            if (i3 >= i || !serversInTierToLoadSegment.hasNext()) {
                break;
            }
            ServerHolder next = serversInTierToLoadSegment.next();
            i4 = i3 + (z ? replicateSegment(dataSegment, next) : loadSegment(dataSegment, next) ? 1 : 0);
        }
        return i3;
    }

    private boolean loadSegment(DataSegment dataSegment, ServerHolder serverHolder) {
        String tier = serverHolder.getServer().getTier();
        boolean loadSegment = this.loadQueueManager.loadSegment(dataSegment, serverHolder, SegmentAction.LOAD);
        if (!loadSegment) {
            incrementSkipStat(Stats.Segments.ASSIGN_SKIPPED, "Encountered error", dataSegment, tier);
        }
        return loadSegment;
    }

    private boolean replicateSegment(DataSegment dataSegment, ServerHolder serverHolder) {
        String tier = serverHolder.getServer().getTier();
        if (this.replicationThrottler.isReplicationThrottledForTier(tier)) {
            incrementSkipStat(Stats.Segments.ASSIGN_SKIPPED, "Throttled replication", dataSegment, tier);
            return false;
        }
        boolean loadSegment = this.loadQueueManager.loadSegment(dataSegment, serverHolder, SegmentAction.REPLICATE);
        if (loadSegment) {
            this.replicationThrottler.incrementAssignedReplicas(tier);
        } else {
            incrementSkipStat(Stats.Segments.ASSIGN_SKIPPED, "Encountered error", dataSegment, tier);
        }
        return loadSegment;
    }

    private static ReplicationThrottler createReplicationThrottler(DruidCluster druidCluster, SegmentLoadingConfig segmentLoadingConfig) {
        HashMap hashMap = new HashMap();
        druidCluster.getHistoricals().forEach((str, navigableSet) -> {
            hashMap.put(str, Integer.valueOf(navigableSet.stream().mapToInt((v0) -> {
                return v0.getNumLoadingReplicas();
            }).sum()));
        });
        return new ReplicationThrottler(hashMap, segmentLoadingConfig.getReplicationThrottleLimit());
    }

    private int cancelOperations(SegmentAction segmentAction, int i, DataSegment dataSegment, SegmentStatusInTier segmentStatusInTier) {
        List<ServerHolder> serversPerforming = segmentStatusInTier.getServersPerforming(segmentAction);
        if (serversPerforming.isEmpty() || i <= 0) {
            return 0;
        }
        int i2 = 0;
        for (int i3 = 0; i3 < serversPerforming.size() && i2 < i; i3++) {
            i2 += serversPerforming.get(i3).cancelOperation(segmentAction, dataSegment) ? 1 : 0;
        }
        return i2;
    }

    private void incrementSkipStat(CoordinatorStat coordinatorStat, String str, DataSegment dataSegment, String str2) {
        this.stats.add(coordinatorStat, RowKey.with(Dimension.TIER, str2).with(Dimension.DATASOURCE, dataSegment.getDataSource()).and(Dimension.DESCRIPTION, str), 1L);
    }

    private void incrementStat(CoordinatorStat coordinatorStat, DataSegment dataSegment, String str, long j) {
        this.stats.addToSegmentStat(coordinatorStat, str, dataSegment.getDataSource(), j);
    }
}
