/*
 * Decompiled with CFR 0.152.
 */
package com.talpie.linker;

import com.talpie.linker.AES;
import com.talpie.linker.DataSocketServer;
import com.talpie.linker.Message;
import com.talpie.linker.RSA;
import com.talpie.linker.ServerService;
import com.talpie.linker.ServerStatsListener;
import com.talpie.linker.StatiCom;
import com.talpie.linker.StreamSocketServer;
import com.talpie.linker.StreamStatsListener;
import java.io.EOFException;
import java.io.InputStream;
import java.io.OutputStream;
import java.math.BigInteger;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

public class ClientHandler {
    private final ServerService serverService;
    private final Socket socket;
    private final String clientId;
    private final ExecutorService socketExecutor = Executors.newSingleThreadExecutor();
    private volatile boolean running = false;
    private RSA rsa;
    private AES aes;
    private String clientMachineId;
    private OutputStream out;
    private InputStream in;
    private final ConcurrentHashMap<String, CompletableFuture<Message>> pending = new ConcurrentHashMap();
    private static final ScheduledExecutorService SCHED = Executors.newSingleThreadScheduledExecutor();
    private final ExecutorService sendExecutor = Executors.newSingleThreadExecutor();
    private final ConcurrentHashMap<String, DataSocketServer> dataSockets = new ConcurrentHashMap();
    private final List<String> pendingDataSockets = new ArrayList<String>();
    private final ConcurrentHashMap<String, CompletableFuture<DSRef>> dsAwaiters = new ConcurrentHashMap();
    private final ServerStatsListener stats = new ServerStatsListener();
    private final ConcurrentHashMap<String, Object> meta = new ConcurrentHashMap();
    private final StreamStatsListener streamStats = new StreamStatsListener();

    public ServerService getServerService() {
        return this.serverService;
    }

    public RSA getRsa() {
        return this.rsa;
    }

    public AES getAes() {
        return this.aes;
    }

    public List<String> getPendingDataSockets() {
        return this.pendingDataSockets;
    }

    public String getClientId() {
        return this.clientId;
    }

    public String getClientMachineId() {
        return this.clientMachineId;
    }

    public ServerStatsListener getStats() {
        return this.stats;
    }

    public Socket getSocket() {
        return this.socket;
    }

    public ConcurrentHashMap<String, DataSocketServer> getDataSockets() {
        return this.dataSockets;
    }

    public ConcurrentHashMap<String, Object> getMeta() {
        return this.meta;
    }

    public StreamStatsListener getStreamStats() {
        return this.streamStats;
    }

    public ClientHandler(ServerService serverService, Socket socket, String clientId) {
        this.serverService = serverService;
        this.socket = socket;
        this.clientId = clientId;
        serverService.getListenersHandlers().register(this.stats);
        serverService.getListenersHandlers().register(this.stats);
        serverService.getListenersHandlers().register(this.stats);
        serverService.getListenersHandlers().register(this.streamStats);
        serverService.getListenersHandlers().register(this.streamStats);
        serverService.getListenersHandlers().register(this.streamStats);
        try {
            socket.setTcpNoDelay(true);
            socket.setKeepAlive(true);
            socket.setReceiveBufferSize(65536);
            socket.setSendBufferSize(65536);
        }
        catch (Exception e) {
            serverService.getListenersHandlers().error(this, e);
        }
    }

    public void start() {
        this.running = true;
        try {
            this.out = this.socket.getOutputStream();
            this.in = this.socket.getInputStream();
            this.rsa = new RSA();
            this.aes = new AES();
        }
        catch (Exception e) {
            this.serverService.getListenersHandlers().error(this, e);
            return;
        }
        if (this.handshake()) {
            this.socketExecutor.submit(this::socketReceiveLoop);
        }
    }

    public void stop(String reason) {
        this.running = false;
        try {
            IllegalStateException ex = new IllegalStateException("Connection closed");
            this.pending.forEach((id, fut) -> fut.completeExceptionally(ex));
            this.pending.clear();
            this.serverService.removeClient(this.clientId, reason);
            this.socketExecutor.shutdownNow();
            this.sendExecutor.shutdownNow();
            if (this.out != null) {
                this.out.close();
            }
            if (this.in != null) {
                this.in.close();
            }
            if (this.socket != null) {
                this.socket.close();
            }
        }
        catch (Exception e) {
            this.serverService.getListenersHandlers().error(this, e);
        }
    }

    public CompletableFuture<Message> sendRequest(String route, byte[] payload) {
        return this.sendRequest(route, payload, 0L);
    }

    public CompletableFuture<Message> sendRequest(String route, byte[] data, long timeoutMillis) {
        return this.openDataSocketFromServerAsync(timeoutMillis).thenCompose(dsRef -> dsRef.ds.sendRequest(route, data, timeoutMillis).whenComplete((m, err) -> this.removeDataSocket(dsRef.socketId)));
    }

    public CompletableFuture<Message> sendControlRequest(String route, byte[] payload) {
        return this.sendControlRequest(route, payload, 0L);
    }

    public CompletableFuture<Message> sendControlRequest(String route, byte[] payload, long timeoutMillis) {
        Message req = new Message(route, payload);
        req.setPayload(payload);
        CompletableFuture<Message> fut = new CompletableFuture<Message>();
        this.pending.put(req.getMessageId(), fut);
        this.sendExecutor.submit(() -> {
            block2: {
                try {
                    this.writeMessage(req);
                }
                catch (Exception e) {
                    CompletableFuture<Message> p = this.pending.remove(req.getMessageId());
                    if (p == null || p.isDone()) break block2;
                    p.completeExceptionally(e);
                }
            }
        });
        if (timeoutMillis > 0L) {
            SCHED.schedule(() -> {
                CompletableFuture<Message> p = this.pending.remove(req.getMessageId());
                if (p != null && !p.isDone()) {
                    p.completeExceptionally(new Exception("Timeout waiting for response: " + req.getMessageId()));
                }
            }, timeoutMillis, TimeUnit.MILLISECONDS);
        }
        return fut;
    }

    private boolean handshake() {
        try {
            this.rsa.setPublicKeyBase64(StatiCom.readLineString(this.in));
            StatiCom.writeLine(this.out, this.rsa.getPublicKeyBase64());
            this.aes.setCounterKeyBase64(this.rsa.decrypt(StatiCom.readLineString(this.in)));
            StatiCom.writeLine(this.out, this.rsa.encrypt(this.aes.getKeyBase64()));
            this.clientMachineId = this.aes.decrypt(StatiCom.readLineString(this.in));
            StatiCom.writeLine(this.out, this.aes.encrypt(this.serverService.getSystemInfo().getMachineId() + this.clientId));
        }
        catch (Exception e) {
            this.serverService.getListenersHandlers().handshakeFailed(this, this, e);
            return false;
        }
        this.serverService.getListenersHandlers().handshakeCompleted(this, this);
        return true;
    }

    private void socketReceiveLoop() {
        while (this.running) {
            try {
                String encHeaderStr = StatiCom.readLineString(this.in);
                byte[] encHeader = Base64.getDecoder().decode(encHeaderStr);
                byte[] headerPlain = this.aes.decrypt(encHeader, null);
                String header = new String(headerPlain, StandardCharsets.UTF_8);
                Message msg = Message.fromHeader(header);
                int total = msg.getLength().intValueExact();
                byte[] encPayload = total > 0 ? this.readPayloadWithProgress(this.in, total, msg) : new byte[]{};
                byte[] plain = total > 0 ? this.aes.decrypt(encPayload, encHeader) : new byte[]{};
                msg.setPayload(plain);
                try {
                    this.stats.totals();
                }
                catch (Throwable throwable) {
                    // empty catch block
                }
                this.stats.onMsgRx(this, msg.getRoute());
                if (msg.getType() == '1') {
                    this.serverService.getListenersHandlers().controlResponse(this, this, msg);
                    CompletableFuture<Message> fut = this.pending.remove(msg.getMessageId());
                    if (fut == null) continue;
                    fut.complete(msg);
                    continue;
                }
                this.handleRequest(msg);
            }
            catch (SocketTimeoutException encHeaderStr) {
            }
            catch (Exception e) {
                this.stop(e.getLocalizedMessage());
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void handleRequest(Message req) {
        try {
            switch (req.getRoute()) {
                case "#_SYS++DS/OPEN": {
                    String socketId = new String(req.getPayload(), StandardCharsets.UTF_8);
                    UUID.fromString(socketId);
                    List<String> list = this.pendingDataSockets;
                    synchronized (list) {
                        if (!this.pendingDataSockets.contains(socketId)) {
                            this.pendingDataSockets.add(socketId);
                        }
                    }
                    req.setResponse("OK".getBytes(StandardCharsets.UTF_8));
                    this.writeMessage(req);
                    break;
                }
                case "#_SYS++DS/START": {
                    String socketId = new String(req.getPayload(), StandardCharsets.UTF_8);
                    UUID.fromString(socketId);
                    List<String> list = this.pendingDataSockets;
                    synchronized (list) {
                        if (!this.pendingDataSockets.contains(socketId)) {
                            this.pendingDataSockets.add(socketId);
                        }
                    }
                    req.setResponse(socketId.getBytes(StandardCharsets.UTF_8));
                    this.writeMessage(req);
                    break;
                }
                case "#_SYS++STREAM/OPEN-REQ": {
                    String streamId = new String(req.getPayload(), StandardCharsets.UTF_8);
                    this.serverService.authorizeStream(this.clientId, streamId);
                    req.setResponse("OK".getBytes(StandardCharsets.UTF_8));
                    this.writeMessage(req);
                    break;
                }
                case "#_SYS++STREAM/CLOSE": {
                    String streamId = new String(req.getPayload(), StandardCharsets.UTF_8);
                    StreamSocketServer ss = this.serverService.getActiveStream(streamId);
                    if (ss != null) {
                        ss.stop();
                    }
                    req.setResponse("CLOSE-ACK".getBytes(StandardCharsets.UTF_8));
                    this.writeMessage(req);
                    break;
                }
                case "#_SYS++PING": {
                    this.writeMessage(this.serverService.getListenersHandlers().controlPing(this, this, req));
                    break;
                }
                default: {
                    this.writeMessage(this.serverService.getListenersHandlers().controlRequest(this, this, req));
                    break;
                }
            }
        }
        catch (Exception e) {
            this.serverService.getListenersHandlers().error(this, e);
        }
    }

    private byte[] readPayloadWithProgress(InputStream in, int total, Message msg) throws Exception {
        int got;
        byte[] buf = new byte[total];
        long recvd = 0L;
        int lastPercent = -1;
        for (int off = 0; off < total; off += got) {
            int want = Math.min(1024, total - off);
            got = in.read(buf, off, want);
            if (got != -1) continue;
            throw new EOFException("EOF during payload read");
        }
        return buf;
    }

    private void onReceiveProgress(Message msg, long recvd, long total, int percent) {
        this.serverService.getListenersHandlers().progressRx(this, this, msg, recvd, total, percent);
    }

    private void onSendProgress(Message msg, long sent, long total, int percent) {
        this.serverService.getListenersHandlers().progressTx(this, this, msg, sent, total, percent);
    }

    private void writeMessage(Message msg) throws Exception {
        if (msg == null) {
            return;
        }
        try {
            this.stats.onMsgTx(this, msg.getRoute());
        }
        catch (Throwable throwable) {
            // empty catch block
        }
        byte[] plainPayload = msg.getPayload();
        int encLen = this.aes.encryptedLength(plainPayload.length);
        byte[] saved = plainPayload;
        msg.setPayload(new byte[encLen]);
        String headerStr = msg.getHeader();
        msg.setPayload(saved);
        String encHeaderStr = this.aes.encrypt(headerStr);
        StatiCom.writeLine(this.out, encHeaderStr);
        byte[] encHeader = Base64.getDecoder().decode(encHeaderStr);
        byte[] encPayload = this.aes.encrypt(plainPayload, encHeader);
        assert (encPayload.length == encLen);
        msg.setPayload(encPayload);
        BigInteger lenBI = msg.getLength();
        if (lenBI.signum() > 0) {
            int total = msg.getPayload().length;
            if (lenBI.compareTo(BigInteger.valueOf(total)) != 0) {
                throw new IllegalStateException("Header length mismatch: header=" + String.valueOf(lenBI) + " bytes, actual=" + total);
            }
            long sent = 0L;
            int lastPercent = -1;
            for (int offset = 0; offset < total; offset += 1024) {
                int chunk = Math.min(1024, total - offset);
                this.out.write(msg.getPayload(), offset, chunk);
                this.out.flush();
                int percent = (int)((sent += (long)chunk) * 100L / (long)total);
                if (percent == lastPercent) continue;
                this.onSendProgress(msg, sent, total, percent);
                lastPercent = percent;
            }
        }
    }

    public void removeDataSocket(String socketId) {
        this.dataSockets.remove(socketId);
    }

    public void addDataSocket(String socketId, DataSocketServer dataSocket) {
        this.dataSockets.put(socketId, dataSocket);
        CompletableFuture<DSRef> w = this.dsAwaiters.remove(socketId);
        if (w != null && !w.isDone()) {
            w.complete(new DSRef(socketId, dataSocket));
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public CompletableFuture<DSRef> openDataSocketFromServerAsync(long timeoutMillis) {
        String socketId = UUID.randomUUID().toString();
        List<String> list = this.pendingDataSockets;
        synchronized (list) {
            this.pendingDataSockets.add(socketId);
        }
        CompletableFuture waiter = new CompletableFuture();
        this.dsAwaiters.put(socketId, waiter);
        CompletableFuture<Message> ackFut = this.sendControlRequest("#_SYS++DS/OPEN-REQ", socketId.getBytes(StandardCharsets.UTF_8), timeoutMillis);
        CompletionStage result = ackFut.thenCompose(ok -> waiter);
        if (timeoutMillis > 0L) {
            SCHED.schedule(() -> this.lambda$openDataSocketFromServerAsync$1((CompletableFuture)result, socketId), timeoutMillis, TimeUnit.MILLISECONDS);
        }
        return result;
    }

    private /* synthetic */ void lambda$openDataSocketFromServerAsync$1(CompletableFuture result, String socketId) {
        if (!result.isDone()) {
            this.dsAwaiters.remove(socketId);
            result.completeExceptionally(new TimeoutException("DS open timeout: " + socketId));
        }
    }

    public static final class DSRef {
        public final String socketId;
        public final DataSocketServer ds;

        public DSRef(String socketId, DataSocketServer ds) {
            this.socketId = socketId;
            this.ds = ds;
        }
    }
}

