import logging
import syslog
from os import times
import socket
from select import select
import threading
import serial
import time
import sys

PORT_RP = 0
PORT_XD = 1
PORT_TCP = 2
SERIAL_PORTS = ["/dev/ttyS1", "/dev/ttyACM0"]  # RP, XD respectively

MAP_PORT_RP = 1
MAP_PORT_XD = 5
MAP_PORT_TCP = 6


class SerialGw:
    NOGW = 0
    SERVER = 2

    FW_PORT_RP = [MAP_PORT_TCP]  # RP
    FW_PORT_XD = [MAP_PORT_TCP]  # XD
    FW_PORT_TCP = [MAP_PORT_RP, MAP_PORT_XD]  # TCP socket

    activeIP = None

    tty = [None, None]
    baud = [9600, 9600]
    dataBits = [8, 8]
    parity = ["N", "N"]
    stopBits = [1, 1]

    def __init__(self, mode: int):
        self.logger = logging.getLogger("serial-gw")
        self.lock = threading.Lock()

        if mode != SerialGw.NOGW and mode != SerialGw.SERVER:
            self.logger.error("invalid mode: {}".format(mode))
            raise Exception("mode is invalid")

        self.mode = mode
        self.peerHost = None
        self.peerPort = 0
        self.running = False
        self.stopGw = False
        self.device = None

        self.data4skt = ""
        self.data4XD = ""
        self.srcChannel = -1
        self.notification = ""

    def setSerialParams(self, port: int, baud: int, dataBits: int = 8, parity: str = "N", stopBits: int = 1):
        """
        Set parameters of the serial port
        baud - 9600, 19200, 38400, 57600 or 115200
        dataBits - 7 or 8
        parity - "N"/"NONE", "E"/"EVEN" or "O"/"ODD"
        stopBits - 1 or 2
        """
        if port is None:
            return

        self.tty[port] = serial.Serial()
        self.tty[port].port = SERIAL_PORTS[port]

        try:
            self.lock.acquire()
            if self.running:
                raise Exception("gateway is already running")

            if baud not in [9600, 19200, 38400, 115200]:
                raise Exception("Invalid baud rate")

            if dataBits not in [7, 8]:
                raise Exception("invalid data bits")

            if parity not in ["N", "E", "O", "NONE", "EVEN", "ODD"]:
                raise Exception("Invalid parity")

            if stopBits not in [1, 2]:
                raise Exception("Invalid stop bits")

            self.baud[port] = baud
            self.dataBits[port] = dataBits
            self.parity[port] = parity
            self.stopBits[port] = stopBits

        finally:
            self.lock.release()

    def setPeer(self, port: int, host: str = ""):
        """
        Set the host and peer for gateway operation.
        Ignored is mode is set to MODE_DISABLED
        host - address or hostname of the peer (MODE_CLIENT only)
               ignored in other modes
        port - Listening (SERVER) or destination port (CLIENT)
        """
        try:
            self.lock.acquire()
            if self.running:
                raise Exception("gateway is already running")

            if port < 0 or port > 65535:
                raise Exception("Port is invalid")

            self.peerHost = host
            self.peerPort = port
        finally:
            self.lock.release()

    def handleConnection(self, ttt, stop):
        skt = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        skt.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        skt.bind(("0.0.0.0", self.peerPort))
        skt.listen(1)
        skt.settimeout(0.01)



        self.srcChannel = -1
        self.data4skt = ""

        # sktCon.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE,1)
        # sktCon.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 15)
        # sktCon.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 2)
        # sktCon.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5)

        sktCon = None
        while not stop():
            # Unrecognized AE command. Forward it to XD.
            if len(self.data4XD)>0:
                self.tty[PORT_XD].write(self.data4XD.encode())
                self.data4XD=""
            
            # Notify other channels
            if len(self.notification) > 0:
                if not self.srcChannel == PORT_RP and self.tty[PORT_RP] is not None:
                    self.tty[PORT_RP].write(self.notification.encode())
                if not self.srcChannel == PORT_XD and self.tty[PORT_XD] is not None:
                    self.tty[PORT_XD].write(self.notification.encode())
                if not self.srcChannel == PORT_TCP and sktCon is not None:
                    sktCon.sendall(self.notification.encode())
                self.notification = ""

            if self.srcChannel > -1:
                if self.data4skt == "" or self.data4skt == "$$$":
                    self.data4skt = ""
                    self.srcChannel = -1
                else:
                    if self.tty[PORT_RP] is not None and self.srcChannel == PORT_RP:
                        self.tty[PORT_RP].write(self.data4skt.encode())
                    if self.tty[PORT_XD] is not None and self.srcChannel == PORT_XD:
                        self.tty[PORT_XD].write(self.data4skt.encode())
                    if sktCon is not None and self.srcChannel == PORT_TCP:
                        try:
                            sktCon.sendall(self.data4skt.encode())
                        except:
                            self.logger.info("Tcp-socket connection to port {} terminated".format(self.peerPort))
                            sktCon.close()
                            sktCon = None
                            
                self.data4skt = ""
                self.srcChannel = -1

            if self.tty[PORT_RP] is not None:
                avail = self.tty[PORT_RP].in_waiting
                if avail > 0:
                    ttyData = self.tty[PORT_RP].readline()
                    decData = ttyData.decode()
                    if len(ttyData) > 0:
                        ttyData += b'\x0d\x0a'
                        if decData.startswith('$'):
                            self.srcChannel = PORT_RP
                            self.device.handleAeCommand(decData)
                        else:
                            self.logger.info("XD command: '{}'".format(decData))
                            if self.tty[PORT_XD] is not None and MAP_PORT_XD in self.FW_PORT_RP:
                                self.tty[PORT_XD].write(ttyData)
                            if sktCon is not None and MAP_PORT_TCP in self.FW_PORT_RP:
                                r = sktCon.sendall(ttyData)
                                if r is not None:
                                    self.logger.error("failed to send bytes through socket. Disconnecting.")
                                    sktCon.close()
                                    sktCon = None

            if self.tty[PORT_XD] is not None:
                avail = self.tty[PORT_XD].in_waiting
                if avail > 0:
                    ttyData = self.tty[PORT_XD].readline()
                    decData = ttyData.decode()
                    if len(ttyData) > 0:
                        ttyData += b'\x0d\x0a'
                        if decData.startswith('$'):
                            self.srcChannel = PORT_XD
                            self.device.handleAeCommand(decData)
                        else:
                            if self.tty[PORT_RP] is not None and MAP_PORT_RP in self.FW_PORT_XD:
                                self.tty[PORT_RP].write(ttyData)
                            if sktCon is not None and MAP_PORT_TCP in self.FW_PORT_XD:
                                r = sktCon.sendall(ttyData)
                                if r is not None:
                                    self.logger.error("failed to send bytes through socket. Disconnecting.")
                                    sktCon.close()
                                    sktCon = None

            try:
                (sktConTmp , clientAddr)  = skt.accept()
                if not sktConTmp == None:
                    sktConTmp.settimeout(0.01)
#                print("'{}', '{}'".format(clientAddr[0], SerialGw.activeIP))
                
                # If already active connection reject attempts from other IPs
                if not SerialGw.activeIP == None and clientAddr[0] != SerialGw.activeIP:
#                    print("Rejected")
                    sktConTmp.close()
                    sktConTmp = None
                    self.logger.info("Serial GW SERVER: connection from {} rejected".format(clientAddr[0]))
                else:
                    if not sktCon == None:
#                        print("Switching")
                        sktCon.close()
                    sktCon = sktConTmp
                    sktCon.settimeout(0.01)
                    SerialGw.activeIP=clientAddr[0]

                    logMsg = "SERIAL GW: client connected from {}".format(clientAddr[0])
                    self.logger.info(logMsg)
                    syslog.syslog(logMsg)
            except socket.timeout:
                pass
            except Exception as ex:
                self.logger.error("Unexpected exception in listening socket: {}".format(ex))
            
            if sktCon == None:
                continue
                
            # read from sktCon
            # # @catch socket.timeout --> pass
            # # @catch any other exception return (don't close TTY!!)
            rdlist, _, _ = select((sktCon,), (), (), 0.1)
            #            if rdlist:
            try:
                sktData = sktCon.recv(100)
                if len(sktData) == 0:
                    logMsg = "Serial GW: lost connection to client --> cleanup"
                    self.logger.info(logMsg)
                    syslog.syslog(logMsg)
                    SerialGw.activeIP = None
                    sktCon.close()
                    sktCon = None
                else:
                    data = sktData.decode('utf-8')
                    if len(data) > 0:
                        if data.startswith("$"):
                            lines = data.strip().split('\r\n')
                            for line in lines:
                                self.srcChannel = PORT_TCP
                                self.device.handleAeCommand(line)
                                time.sleep(0.01)
                        else:
                            try:
                                self.lock.acquire()
                                if self.tty[PORT_RP] is not None and MAP_PORT_RP in self.FW_PORT_TCP:
                                    self.tty[PORT_RP].write(sktData)
                                if self.tty[PORT_XD] is not None and MAP_PORT_XD in self.FW_PORT_TCP:
                                    self.tty[PORT_XD].write(sktData)
                            finally:
                                self.lock.release()
            except socket.timeout:
                pass
            except Exception as ex:
                logMsg  = "Serial GW: lost connection to client --> cleanup ({})".format(str(ex))
                self.logger.info(logMsg)
                syslog.syslog(logMsg)
                sktCon.close()
                SerialGw.activeIP = None
                sktCon = None

        self.logger.info("Serial GW: Closing active connection.")
        if not sktCon == None:
            sktCon.close()
        SerialGw.activeIP = None


    def runServer(self):
        logMsg  = "Serial GW SERVER: listening on port {}".format(self.peerPort)
        self.logger.info(logMsg)
        syslog.syslog(logMsg)

        thrd = None
        stopHandling = False
        while not self.stopGw:
            try:
                if thrd == None:
                    thrd = threading.Thread(target=self.handleConnection, args=(0, lambda: stopHandling))
                    thrd.start()
            except socket.timeout:
                pass
            except Exception as ex:
                self.logger.error("Unexpected exception in listening socket: {}".format(ex))
            finally:
                time.sleep(0.01)
                
        stopHandling = True
        thrd.join()
        thrd = None

        self.logger.info("gateway worker (SERVER) exiting.")

    def start(self):
        self.FW_PORT_RP
        self.FW_PORT_XD
        self.FW_PORT_TCP
        """
        Start operation of the serial GW
        Must be called after setSerialParams and setPeer if
        mode is different from MODE_DISABLED
        """
        try:
            self.lock.acquire()
            if self.running:
                raise Exception("gateway is already running")

            if self.tty[PORT_RP] is not None:
                self.tty[PORT_RP].baudrate = self.baud[0]
                if self.dataBits[PORT_RP] == 7:
                    self.tty[PORT_RP].bytesize = serial.SEVENBITS
                elif self.dataBits[PORT_RP] == 8:
                    self.tty[PORT_RP].bytesize = serial.EIGHTBITS
                else:
                    raise Exception("The impossible happened - corrupt data bits config :-(")

                if self.parity[PORT_RP] == "N" or self.parity[PORT_RP] == "NONE":
                    self.tty[PORT_RP].parity = serial.PARITY_NONE
                elif self.parity[PORT_RP] == "E" or self.parity[PORT_RP] == "EVEN":
                    self.tty[PORT_RP].parity = serial.PARITY_EVEN
                elif self.parity[PORT_RP] == "O" or self.parity[PORT_RP] == "ODD":
                    self.tty[PORT_RP].parity = serial.PARITY_ODD
                else:
                    raise Exception("The impossible happened - corrupt parity config :-(")

                if self.stopBits[PORT_RP] == 1:
                    self.tty[PORT_RP].stopbits = serial.STOPBITS_ONE
                elif self.stopBits[PORT_RP] == 2:
                    self.tty[PORT_RP].stopbits = serial.STOPBITS_TWO
                else:
                    raise Exception("The impossible happened - corrupt stopbits config :-(")

                self.tty[PORT_RP].xonxoff = False
                self.tty[PORT_RP].rtscts = False
                self.tty[PORT_RP].dsrdtr = False
                self.tty[PORT_RP].rs485_mode = None
                self.tty[PORT_RP].timeout = 0.2
                self.tty[PORT_RP].open()
                print("Serial port 0 opened")

            if self.tty[PORT_XD] is not None:
                self.tty[PORT_XD].baudrate = self.baud[0]
                if self.dataBits[PORT_XD] == 7:
                    self.tty[PORT_XD].bytesize = serial.SEVENBITS
                elif self.dataBits[PORT_XD] == 8:
                    self.tty[PORT_XD].bytesize = serial.EIGHTBITS
                else:
                    raise Exception("The impossible happened - corrupt data bits config :-(")

                if self.parity[PORT_XD] == "N" or self.parity[PORT_XD] == "NONE":
                    self.tty[PORT_XD].parity = serial.PARITY_NONE
                elif self.parity[PORT_XD] == "E" or self.parity[PORT_XD] == "EVEN":
                    self.tty[PORT_XD].parity = serial.PARITY_EVEN
                elif self.parity[PORT_XD] == "O" or self.parity[PORT_XD] == "ODD":
                    self.tty[PORT_XD].parity = serial.PARITY_ODD
                else:
                    raise Exception("The impossible happened - corrupt parity config :-(")

                if self.stopBits[PORT_XD] == 1:
                    self.tty[PORT_XD].stopbits = serial.STOPBITS_ONE
                elif self.stopBits[PORT_XD] == 2:
                    self.tty[PORT_XD].stopbits = serial.STOPBITS_TWO
                else:
                    raise Exception("The impossible happened - corrupt stopbits config :-(")

                self.tty[PORT_XD].xonxoff = False
                self.tty[PORT_XD].rtscts = False
                self.tty[PORT_XD].dsrdtr = False
                self.tty[PORT_XD].rs485_mode = None
                self.tty[PORT_XD].timeout = 0.2
                self.tty[PORT_XD].open()
                print("Serial port 1 opened")

            if self.mode == SerialGw.NOGW:
                pass
            elif self.mode == SerialGw.SERVER:
                self.workerThread = threading.Thread(target=self.runServer)
                self.workerThread.daemon = True
                self.workerThread.start()
            else:
                raise Exception("The impossible happened - corrupt GW mode: {}".format(self.mode))

            self.running = True

        finally:
            self.lock.release()

    def shutdown(self):
        """
        Shutdown operation of the serial GW and disconnect
        from the serial port
        """
        joinWorker = False
        try:
            self.lock.acquire()
            if not self.running:
                self.logger.warn("gateway not running")
                return

            if self.workerThread is not None:
                joinWorker = True
                self.stopGw = True

            self.running = False
        finally:
            self.lock.release()

        if joinWorker:
            self.logger.info("waiting for worker thread to stop")
            self.workerThread.join()

        if self.tty[PORT_RP] is not None:
            self.tty[PORT_RP].close()
        if self.tty[PORT_XD] is not None:
            self.tty[PORT_XD].close()

    def writeSideChannel(self, data: str) -> int:
        """
        Write data to the sidechannel of the serial port
        data - string to write into the serial port, regardless of
               the operation mode.
        """
        try:
            self.lock.acquire()
            if self.tty[PORT_RP] is not None:
                self.tty[PORT_RP].write(data.encode("utf-8"))
            if self.tty[PORT_XD] is not None:
                self.tty[PORT_XD].write(data.encode("utf-8"))
        finally:
            self.lock.release()
