package org.springframework.web.socket.messaging;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.Principal;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.configuration2.tree.DefaultExpressionEngineSymbols;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.aspectj.weaver.Constants;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpAttributes;
import org.springframework.messaging.simp.SimpAttributesContextHolder;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.stomp.BufferingStompDecoder;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompDecoder;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.AbstractMessageChannel;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.messaging.support.MessageHeaderInitializer;
import org.springframework.util.Assert;
import org.springframework.util.MimeTypeUtils;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.SessionLimitExceededException;
import org.springframework.web.socket.handler.WebSocketSessionDecorator;
import org.springframework.web.socket.sockjs.transport.SockJsSession;

/* loaded from: input_file:WEB-INF/lib/spring-websocket-4.3.18.RELEASE.jar:org/springframework/web/socket/messaging/StompSubProtocolHandler.class */
public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationEventPublisherAware {
    public static final int MINIMUM_WEBSOCKET_MESSAGE_SIZE = 16640;
    public static final String CONNECTED_USER_HEADER = "user-name";
    private static final Log logger = LogFactory.getLog(StompSubProtocolHandler.class);
    private static final byte[] EMPTY_PAYLOAD = new byte[0];
    private StompSubProtocolErrorHandler errorHandler;
    private UserSessionRegistry userSessionRegistry;
    private MessageHeaderInitializer headerInitializer;
    private Boolean immutableMessageInterceptorPresent;
    private ApplicationEventPublisher eventPublisher;
    private int messageSizeLimit = 65536;
    private StompEncoder stompEncoder = new StompEncoder();
    private StompDecoder stompDecoder = new StompDecoder();
    private final Map<String, BufferingStompDecoder> decoders = new ConcurrentHashMap();
    private final Map<String, Principal> stompAuthentications = new ConcurrentHashMap();
    private final Stats stats = new Stats();

    /* loaded from: input_file:WEB-INF/lib/spring-websocket-4.3.18.RELEASE.jar:org/springframework/web/socket/messaging/StompSubProtocolHandler$Stats.class */
    private static class Stats {
        private final AtomicInteger connect;
        private final AtomicInteger connected;
        private final AtomicInteger disconnect;

        private Stats() {
            this.connect = new AtomicInteger();
            this.connected = new AtomicInteger();
            this.disconnect = new AtomicInteger();
        }

        public void incrementConnectCount() {
            this.connect.incrementAndGet();
        }

        public void incrementConnectedCount() {
            this.connected.incrementAndGet();
        }

        public void incrementDisconnectCount() {
            this.disconnect.incrementAndGet();
        }

        public String toString() {
            return "processed CONNECT(" + this.connect.get() + ")-CONNECTED(" + this.connected.get() + ")-DISCONNECT(" + this.disconnect.get() + DefaultExpressionEngineSymbols.DEFAULT_INDEX_END;
        }
    }

    public void setErrorHandler(StompSubProtocolErrorHandler stompSubProtocolErrorHandler) {
        this.errorHandler = stompSubProtocolErrorHandler;
    }

    public StompSubProtocolErrorHandler getErrorHandler() {
        return this.errorHandler;
    }

    public void setMessageSizeLimit(int i) {
        this.messageSizeLimit = i;
    }

    public int getMessageSizeLimit() {
        return this.messageSizeLimit;
    }

    @Deprecated
    public void setUserSessionRegistry(UserSessionRegistry userSessionRegistry) {
        this.userSessionRegistry = userSessionRegistry;
    }

    @Deprecated
    public UserSessionRegistry getUserSessionRegistry() {
        return this.userSessionRegistry;
    }

    public void setEncoder(StompEncoder stompEncoder) {
        this.stompEncoder = stompEncoder;
    }

    public void setDecoder(StompDecoder stompDecoder) {
        this.stompDecoder = stompDecoder;
    }

    public void setHeaderInitializer(MessageHeaderInitializer messageHeaderInitializer) {
        this.headerInitializer = messageHeaderInitializer;
        this.stompDecoder.setHeaderInitializer(messageHeaderInitializer);
    }

    public MessageHeaderInitializer getHeaderInitializer() {
        return this.headerInitializer;
    }

    @Override // org.springframework.web.socket.messaging.SubProtocolHandler
    public List<String> getSupportedProtocols() {
        return Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp");
    }

    @Override // org.springframework.context.ApplicationEventPublisherAware
    public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
        this.eventPublisher = applicationEventPublisher;
    }

    public String getStatsInfo() {
        return this.stats.toString();
    }

    @Override // org.springframework.web.socket.messaging.SubProtocolHandler
    public void handleMessageFromClient(WebSocketSession webSocketSession, WebSocketMessage<?> webSocketMessage, MessageChannel messageChannel) {
        ByteBuffer payload;
        Principal user;
        try {
            if (webSocketMessage instanceof TextMessage) {
                payload = ByteBuffer.wrap(((TextMessage) webSocketMessage).asBytes());
            } else if (!(webSocketMessage instanceof BinaryMessage)) {
                return;
            } else {
                payload = ((BinaryMessage) webSocketMessage).getPayload();
            }
            BufferingStompDecoder bufferingStompDecoder = this.decoders.get(webSocketSession.getId());
            if (bufferingStompDecoder == null) {
                throw new IllegalStateException("No decoder for session id '" + webSocketSession.getId() + "'");
            }
            List<Message<byte[]>> decode = bufferingStompDecoder.decode(payload);
            if (decode.isEmpty()) {
                if (logger.isTraceEnabled()) {
                    logger.trace("Incomplete STOMP frame content received in session " + webSocketSession + ", bufferSize=" + bufferingStompDecoder.getBufferSize() + ", bufferSizeLimit=" + bufferingStompDecoder.getBufferSizeLimit() + ".");
                    return;
                }
                return;
            }
            for (Message<byte[]> message : decode) {
                try {
                    StompHeaderAccessor stompHeaderAccessor = (StompHeaderAccessor) MessageHeaderAccessor.getAccessor((Message<?>) message, StompHeaderAccessor.class);
                    stompHeaderAccessor.setSessionId(webSocketSession.getId());
                    stompHeaderAccessor.setSessionAttributes(webSocketSession.getAttributes());
                    stompHeaderAccessor.setUser(getUser(webSocketSession));
                    stompHeaderAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, stompHeaderAccessor.getHeartbeat());
                    if (!detectImmutableMessageInterceptor(messageChannel)) {
                        stompHeaderAccessor.setImmutable();
                    }
                    if (logger.isTraceEnabled()) {
                        logger.trace("From client: " + stompHeaderAccessor.getShortLogMessage(message.getPayload()));
                    }
                    StompCommand command = stompHeaderAccessor.getCommand();
                    boolean equals = StompCommand.CONNECT.equals(command);
                    if (equals) {
                        this.stats.incrementConnectCount();
                    } else if (StompCommand.DISCONNECT.equals(command)) {
                        this.stats.incrementDisconnectCount();
                    }
                    try {
                        SimpAttributesContextHolder.setAttributesFromMessage(message);
                        if (messageChannel.send(message)) {
                            if (equals && (user = stompHeaderAccessor.getUser()) != null && user != webSocketSession.getPrincipal()) {
                                this.stompAuthentications.put(webSocketSession.getId(), user);
                            }
                            if (this.eventPublisher != null) {
                                if (equals) {
                                    publishEvent(new SessionConnectEvent(this, message, getUser(webSocketSession)));
                                } else if (StompCommand.SUBSCRIBE.equals(command)) {
                                    publishEvent(new SessionSubscribeEvent(this, message, getUser(webSocketSession)));
                                } else if (StompCommand.UNSUBSCRIBE.equals(command)) {
                                    publishEvent(new SessionUnsubscribeEvent(this, message, getUser(webSocketSession)));
                                }
                            }
                        }
                        SimpAttributesContextHolder.resetAttributes();
                    } catch (Throwable th) {
                        SimpAttributesContextHolder.resetAttributes();
                        throw th;
                        break;
                    }
                } catch (Throwable th2) {
                    if (logger.isErrorEnabled()) {
                        logger.error("Failed to send client message to application via MessageChannel in session " + webSocketSession.getId() + ". Sending STOMP ERROR to client.", th2);
                    }
                    handleError(webSocketSession, th2, message);
                }
            }
        } catch (Throwable th3) {
            if (logger.isErrorEnabled()) {
                logger.error("Failed to parse " + webSocketMessage + " in session " + webSocketSession.getId() + ". Sending STOMP ERROR to client.", th3);
            }
            handleError(webSocketSession, th3, null);
        }
    }

    private Principal getUser(WebSocketSession webSocketSession) {
        Principal principal = this.stompAuthentications.get(webSocketSession.getId());
        return principal != null ? principal : webSocketSession.getPrincipal();
    }

    private void handleError(WebSocketSession webSocketSession, Throwable th, Message<byte[]> message) {
        if (getErrorHandler() == null) {
            sendErrorMessage(webSocketSession, th);
            return;
        }
        Message<byte[]> handleClientMessageProcessingError = getErrorHandler().handleClientMessageProcessingError(message, th);
        if (handleClientMessageProcessingError == null) {
            return;
        }
        StompHeaderAccessor stompHeaderAccessor = (StompHeaderAccessor) MessageHeaderAccessor.getAccessor((Message<?>) handleClientMessageProcessingError, StompHeaderAccessor.class);
        Assert.state(stompHeaderAccessor != null, "Expected STOMP headers");
        sendToClient(webSocketSession, stompHeaderAccessor, handleClientMessageProcessingError.getPayload());
    }

    @Deprecated
    protected void sendErrorMessage(WebSocketSession webSocketSession, Throwable th) {
        StompHeaderAccessor create = StompHeaderAccessor.create(StompCommand.ERROR);
        create.setMessage(th.getMessage());
        try {
            webSocketSession.sendMessage(new TextMessage(this.stompEncoder.encode(create.getMessageHeaders(), EMPTY_PAYLOAD)));
        } catch (Throwable th2) {
            logger.debug("Failed to send STOMP ERROR to client", th2);
        }
    }

    private boolean detectImmutableMessageInterceptor(MessageChannel messageChannel) {
        if (this.immutableMessageInterceptorPresent != null) {
            return this.immutableMessageInterceptorPresent.booleanValue();
        }
        if (messageChannel instanceof AbstractMessageChannel) {
            Iterator<ChannelInterceptor> it = ((AbstractMessageChannel) messageChannel).getInterceptors().iterator();
            while (it.hasNext()) {
                if (it.next() instanceof ImmutableMessageChannelInterceptor) {
                    this.immutableMessageInterceptorPresent = true;
                    return true;
                }
            }
        }
        this.immutableMessageInterceptorPresent = false;
        return false;
    }

    private void publishEvent(ApplicationEvent applicationEvent) {
        try {
            this.eventPublisher.publishEvent(applicationEvent);
        } catch (Throwable th) {
            if (logger.isErrorEnabled()) {
                logger.error("Error publishing " + applicationEvent, th);
            }
        }
    }

    @Override // org.springframework.web.socket.messaging.SubProtocolHandler
    public void handleMessageToClient(WebSocketSession webSocketSession, Message<?> message) {
        if (!(message.getPayload() instanceof byte[])) {
            if (logger.isErrorEnabled()) {
                logger.error("Expected byte[] payload. Ignoring " + message + ".");
                return;
            }
            return;
        }
        StompHeaderAccessor stompHeaderAccessor = getStompHeaderAccessor(message);
        StompCommand command = stompHeaderAccessor.getCommand();
        if (StompCommand.MESSAGE.equals(command)) {
            if (stompHeaderAccessor.getSubscriptionId() == null && logger.isWarnEnabled()) {
                logger.warn("No STOMP \"subscription\" header in " + message);
            }
            String firstNativeHeader = stompHeaderAccessor.getFirstNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION);
            if (firstNativeHeader != null) {
                stompHeaderAccessor = toMutableAccessor(stompHeaderAccessor, message);
                stompHeaderAccessor.removeNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION);
                stompHeaderAccessor.setDestination(firstNativeHeader);
            }
        } else if (StompCommand.CONNECTED.equals(command)) {
            this.stats.incrementConnectedCount();
            stompHeaderAccessor = afterStompSessionConnected(message, stompHeaderAccessor, webSocketSession);
            if (this.eventPublisher != null && StompCommand.CONNECTED.equals(command)) {
                try {
                    SimpAttributesContextHolder.setAttributes(new SimpAttributes(webSocketSession.getId(), webSocketSession.getAttributes()));
                    publishEvent(new SessionConnectedEvent(this, message, getUser(webSocketSession)));
                    SimpAttributesContextHolder.resetAttributes();
                } catch (Throwable th) {
                    SimpAttributesContextHolder.resetAttributes();
                    throw th;
                }
            }
        }
        byte[] bArr = (byte[]) message.getPayload();
        if (StompCommand.ERROR.equals(command) && getErrorHandler() != null) {
            Message<byte[]> handleErrorMessageToClient = getErrorHandler().handleErrorMessageToClient(message);
            stompHeaderAccessor = (StompHeaderAccessor) MessageHeaderAccessor.getAccessor((Message<?>) handleErrorMessageToClient, StompHeaderAccessor.class);
            Assert.state(stompHeaderAccessor != null, "Expected STOMP headers");
            bArr = handleErrorMessageToClient.getPayload();
        }
        sendToClient(webSocketSession, stompHeaderAccessor, bArr);
    }

    private void sendToClient(WebSocketSession webSocketSession, StompHeaderAccessor stompHeaderAccessor, byte[] bArr) {
        StompCommand command = stompHeaderAccessor.getCommand();
        try {
            try {
                byte[] encode = this.stompEncoder.encode(stompHeaderAccessor.getMessageHeaders(), bArr);
                if (bArr.length > 0 && !(webSocketSession instanceof SockJsSession) && MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith(stompHeaderAccessor.getContentType())) {
                    webSocketSession.sendMessage(new BinaryMessage(encode));
                } else {
                    webSocketSession.sendMessage(new TextMessage(encode));
                }
                if (StompCommand.ERROR.equals(command)) {
                    try {
                        webSocketSession.close(CloseStatus.PROTOCOL_ERROR);
                    } catch (IOException e) {
                    }
                }
            } catch (SessionLimitExceededException e2) {
                throw e2;
            } catch (Throwable th) {
                if (logger.isDebugEnabled()) {
                    logger.debug("Failed to send WebSocket message to client in session " + webSocketSession.getId(), th);
                }
                if (StompCommand.ERROR.equals(StompCommand.ERROR)) {
                    try {
                        webSocketSession.close(CloseStatus.PROTOCOL_ERROR);
                    } catch (IOException e3) {
                    }
                }
            }
        } catch (Throwable th2) {
            if (StompCommand.ERROR.equals(command)) {
                try {
                    webSocketSession.close(CloseStatus.PROTOCOL_ERROR);
                } catch (IOException e4) {
                }
            }
            throw th2;
        }
    }

    private StompHeaderAccessor getStompHeaderAccessor(Message<?> message) {
        MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, (Class<MessageHeaderAccessor>) MessageHeaderAccessor.class);
        if (accessor instanceof StompHeaderAccessor) {
            return (StompHeaderAccessor) accessor;
        }
        StompHeaderAccessor wrap = StompHeaderAccessor.wrap(message);
        SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(message.getHeaders());
        if (SimpMessageType.CONNECT_ACK.equals(messageType)) {
            wrap = convertConnectAcktoStompConnected(wrap);
        } else if (SimpMessageType.DISCONNECT_ACK.equals(messageType)) {
            String disconnectReceipt = getDisconnectReceipt(wrap);
            if (disconnectReceipt != null) {
                wrap = StompHeaderAccessor.create(StompCommand.RECEIPT);
                wrap.setReceiptId(disconnectReceipt);
            } else {
                wrap = StompHeaderAccessor.create(StompCommand.ERROR);
                wrap.setMessage("Session closed.");
            }
        } else if (SimpMessageType.HEARTBEAT.equals(messageType)) {
            wrap = StompHeaderAccessor.createForHeartbeat();
        } else if (wrap.getCommand() == null || StompCommand.SEND.equals(wrap.getCommand())) {
            wrap.updateStompCommandAsServerMessage();
        }
        return wrap;
    }

    private StompHeaderAccessor convertConnectAcktoStompConnected(StompHeaderAccessor stompHeaderAccessor) {
        Message message = (Message) stompHeaderAccessor.getHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER);
        if (message == null) {
            throw new IllegalStateException("Original STOMP CONNECT not found in " + stompHeaderAccessor);
        }
        StompHeaderAccessor stompHeaderAccessor2 = (StompHeaderAccessor) MessageHeaderAccessor.getAccessor((Message<?>) message, StompHeaderAccessor.class);
        StompHeaderAccessor create = StompHeaderAccessor.create(StompCommand.CONNECTED);
        Set<String> acceptVersion = stompHeaderAccessor2.getAcceptVersion();
        if (acceptVersion.contains(Constants.RUNTIME_LEVEL_12)) {
            create.setVersion(Constants.RUNTIME_LEVEL_12);
        } else if (acceptVersion.contains("1.1")) {
            create.setVersion("1.1");
        } else if (!acceptVersion.isEmpty()) {
            throw new IllegalArgumentException("Unsupported STOMP version '" + acceptVersion + "'");
        }
        long[] jArr = (long[]) stompHeaderAccessor.getHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER);
        if (jArr != null) {
            create.setHeartbeat(jArr[0], jArr[1]);
        } else {
            create.setHeartbeat(0L, 0L);
        }
        return create;
    }

    private String getDisconnectReceipt(SimpMessageHeaderAccessor simpMessageHeaderAccessor) {
        Message message = (Message) simpMessageHeaderAccessor.getHeader(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER);
        if (message != null) {
            return ((StompHeaderAccessor) MessageHeaderAccessor.getAccessor((Message<?>) message, StompHeaderAccessor.class)).getReceipt();
        }
        return null;
    }

    protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor stompHeaderAccessor, Message<?> message) {
        return stompHeaderAccessor.isMutable() ? stompHeaderAccessor : StompHeaderAccessor.wrap(message);
    }

    private StompHeaderAccessor afterStompSessionConnected(Message<?> message, StompHeaderAccessor stompHeaderAccessor, WebSocketSession webSocketSession) {
        Principal user = getUser(webSocketSession);
        if (user != null) {
            stompHeaderAccessor = toMutableAccessor(stompHeaderAccessor, message);
            stompHeaderAccessor.setNativeHeader(CONNECTED_USER_HEADER, user.getName());
            if (this.userSessionRegistry != null) {
                this.userSessionRegistry.registerSessionId(getSessionRegistryUserName(user), webSocketSession.getId());
            }
        }
        if (stompHeaderAccessor.getHeartbeat()[1] > 0) {
            WebSocketSession unwrap = WebSocketSessionDecorator.unwrap(webSocketSession);
            if (unwrap instanceof SockJsSession) {
                ((SockJsSession) unwrap).disableHeartbeat();
            }
        }
        return stompHeaderAccessor;
    }

    private String getSessionRegistryUserName(Principal principal) {
        String name = principal.getName();
        if (principal instanceof DestinationUserNameProvider) {
            name = ((DestinationUserNameProvider) principal).getDestinationUserName();
        }
        return name;
    }

    @Override // org.springframework.web.socket.messaging.SubProtocolHandler
    public String resolveSessionId(Message<?> message) {
        return SimpMessageHeaderAccessor.getSessionId(message.getHeaders());
    }

    @Override // org.springframework.web.socket.messaging.SubProtocolHandler
    public void afterSessionStarted(WebSocketSession webSocketSession, MessageChannel messageChannel) {
        if (webSocketSession.getTextMessageSizeLimit() < 16640) {
            webSocketSession.setTextMessageSizeLimit(MINIMUM_WEBSOCKET_MESSAGE_SIZE);
        }
        this.decoders.put(webSocketSession.getId(), new BufferingStompDecoder(this.stompDecoder, getMessageSizeLimit()));
    }

    @Override // org.springframework.web.socket.messaging.SubProtocolHandler
    public void afterSessionEnded(WebSocketSession webSocketSession, CloseStatus closeStatus, MessageChannel messageChannel) {
        this.decoders.remove(webSocketSession.getId());
        Principal user = getUser(webSocketSession);
        if (user != null && this.userSessionRegistry != null) {
            this.userSessionRegistry.unregisterSessionId(getSessionRegistryUserName(user), webSocketSession.getId());
        }
        Message<byte[]> createDisconnectMessage = createDisconnectMessage(webSocketSession);
        SimpAttributes fromMessage = SimpAttributes.fromMessage(createDisconnectMessage);
        try {
            SimpAttributesContextHolder.setAttributes(fromMessage);
            if (this.eventPublisher != null) {
                publishEvent(new SessionDisconnectEvent(this, createDisconnectMessage, webSocketSession.getId(), closeStatus, getUser(webSocketSession)));
            }
            messageChannel.send(createDisconnectMessage);
            this.stompAuthentications.remove(webSocketSession.getId());
            SimpAttributesContextHolder.resetAttributes();
            fromMessage.sessionCompleted();
        } catch (Throwable th) {
            this.stompAuthentications.remove(webSocketSession.getId());
            SimpAttributesContextHolder.resetAttributes();
            fromMessage.sessionCompleted();
            throw th;
        }
    }

    private Message<byte[]> createDisconnectMessage(WebSocketSession webSocketSession) {
        StompHeaderAccessor create = StompHeaderAccessor.create(StompCommand.DISCONNECT);
        if (getHeaderInitializer() != null) {
            getHeaderInitializer().initHeaders(create);
        }
        create.setSessionId(webSocketSession.getId());
        create.setSessionAttributes(webSocketSession.getAttributes());
        create.setUser(getUser(webSocketSession));
        return MessageBuilder.createMessage(EMPTY_PAYLOAD, create.getMessageHeaders());
    }

    public String toString() {
        return "StompSubProtocolHandler" + getSupportedProtocols();
    }
}
