import logging
import syslog
from os import times
import socket
from select import select
import threading
import serial
import time

class SerialGw:
    NOGW   = 0
    CLIENT = 1
    SERVER = 2

    activeIP = None
    toSend=""

    def __init__(self,serial:str,mode:int):
        self.logger = logging.getLogger("serial-gw")
        self.lock   = threading.Lock()

        if mode!=SerialGw.NOGW and mode!=SerialGw.CLIENT and mode!=SerialGw.SERVER:
            self.logger.error("invalid mode: {}".format(mode))
            raise Exception("mode is invalid")

        self.serialPortName = serial
        self.mode       = mode
        self.peerHost   = None
        self.peerPort   = None
        self.baud       = None
        self.running    = False
        self.stopGw     = False


    def setSerialParams(self,baud:int,dataBits:int = 8,parity:str = "N",stopBits:int = 1):
        """
        Set parameters of the serial port
        baud - 9600, 19200, 38400, 57600 or 115200
        dataBits - 7 or 8
        parity - "N"/"NONE", "E"/"EVEN" or "O"/"ODD"
        stopBits - 1 or 2
        """
        try:
            self.lock.acquire()
            if self.running:
                raise Exception("gateway is already running")

            if baud not in [9600,19200,38400,7600,115200]:
                raise Exception("Invalid baud rate")

            if dataBits not in [7,8]:
                raise Exception("invalid data bits")

            if parity not in ["N","E","O","NONE","EVEN","ODD"]:
                raise Exception("Invalid parity")

            if stopBits not in [1,2]:
                raise Exception("Invalid stop bits")

            self.baud       = baud
            self.dataBits   = dataBits
            self.parity     = parity
            self.stopBits   = stopBits

        finally:
            self.lock.release()

    def setPeer(self,port:int,host:str=""):
        """
        Set the host and peer for gateway operation.
        Ignored is mode is set to MODE_DISABLED
        host - address or hostname of the peer (MODE_CLIENT only)
               ignored in other modes
        port - Listening (SERVER) or destination port (CLIENT)
        """
        try:
            self.lock.acquire()
            if self.running:
                raise Exception("gateway is already running")

            if port<1 or port>65535:
                raise Exception("Port is invalid")

            if self.mode==SerialGw.CLIENT and len(host)==0:
                raise Exception("cannot have empty host in CLIENT mode")

            self.peerHost   = host
            self.peerPort   = port
        finally:
            self.lock.release()


    def handleConnection(self, sktCon: socket.socket, stop):
        sktCon.settimeout(0.01)
        sktCon.setsockopt(socket.SOL_SOCKET,socket.SO_KEEPALIVE,1)
        sktCon.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 15)
        sktCon.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 2)
        sktCon.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5)
        while not stop():
            if len(self.toSend)>0:
                sktCon.sendall(self.toSend.encode("utf-8"))
                self.toSend=""

            avail       = self.tty.in_waiting
            if avail>0:
                ttyData     = self.tty.read(avail)
                if len(ttyData)>0:
#                    print("From serial: "+ttyData.decode())
                    r = sktCon.sendall(ttyData)

                    if r is not None:
                        self.logger.error("Serial GW: failed to send bytes through socket. Disconnecting.")
                        sktCon.close()
                        return

            # read from sktCon
            # # @catch socket.timeout --> pass
            # # @catch any other exception return (don't close TTY!!)
#            rdlist, _, _ = select((sktCon,),(),(),0.1)
            rdlist=select([sktCon], [], [], 0.1)[0]
            if rdlist:
                try:
                    sktData = sktCon.recv(100)
                    if len(sktData)==0:
                        logMsg  = "Serial GW: lost connection to client unexpextedly --> cleanup"
                        self.logger.info(logMsg)
                        syslog.syslog(logMsg)
                        sktCon.close()
                        SerialGw.activeIP = None
                        return
                    else:
                        try:
                            self.lock.acquire()
                            self.logger.info("From tcp: "+sktData.decode())
                            self.tty.write(sktData)
                        finally:
                            self.lock.release()
                except socket.timeout:
                    pass
                except Exception as ex:
                    logMsg  = "Serial GW: lost connection to client --> cleanup ({})".format(str(ex))
                    self.logger.info(logMsg)
                    syslog.syslog(logMsg)
                    sktCon.close()
                    SerialGw.activeIP = None
                    return

        self.logger.info("Serial GW: Closing active connection.")
        sktCon.close()
        SerialGw.activeIP = None


    def runServer(self):
        logMsg  = "Serial GW SERVER: listening on port {}".format(self.peerPort)
        self.logger.info(logMsg)
        syslog.syslog(logMsg)
        skt = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
        skt.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)
        skt.bind(("0.0.0.0",self.peerPort))
        skt.listen(1)
        skt.settimeout(30)

        thrd = None
        while not self.stopGw:
            try:
                (sktCon , clientAddr)  = skt.accept()

                # If already active connection reject attempts from other IPs
                if not SerialGw.activeIP == None and clientAddr[0] != SerialGw.activeIP:
                    sktCon.close()
                    self.logger.info("Serial GW SERVER: connection from {} rejected".format(clientAddr[0]))
                    continue

                if not thrd == None:
                    stopHandling = True
                    thrd.join()
                    thrd = None

                SerialGw.activeIP=clientAddr[0]
                stopHandling = False

                logMsg = "SERIAL GW: client connected from {}".format(clientAddr[0])
                self.logger.info(logMsg)
                syslog.syslog(logMsg)

#                self.handleConnection(sktCon)
                thrd = threading.Thread(target=self.handleConnection, args=(sktCon, lambda: stopHandling))
                thrd.start()
            except socket.timeout:
                pass
            except Exception as ex:
                self.logger.error("Unexpected exception in listening socket: {}".format(ex))
            finally:
                pass

        self.logger.info("gateway worker (SERVER) exiting.")


    def runClient(self):
        self.logger.info("gateway worker started in CLIENT mode")
        while not self.stopGw:
            try:
                sktCon  = socket.create_connection( (self.peerHost,self.peerPort),30)
                self.handleConnection(sktCon)
            finally:
                time.sleep(10)

        self.logger.info("gateway worker (CLIENT) exiting")


    def start(self):
        """
        Start operation of the serial GW
        Must be called after setSerialParams and setPeer if
        mode is different from MODE_DISABLED
        """
        try:
            self.lock.acquire()
            if self.running:
                raise Exception("gateway is already running")

            self.tty            = serial.Serial()
            self.tty.port       = self.serialPortName
            self.tty.baudrate   = self.baud
            if self.dataBits==7:
                self.tty.bytesize   = serial.SEVENBITS
            elif self.dataBits==8:
                self.tty.bytesize   = serial.EIGHTBITS
            else:
                raise Exception("The impossible happened - corrupt data bits config :-(")

            if self.parity=="N" or self.parity=="NONE":
                self.tty.parity     = serial.PARITY_NONE
            elif self.parity=="E" or self.parity=="EVEN":
                self.tty.parity     = serial.PARITY_EVEN
            elif self.parity=="O" or self.parity=="ODD":
                self.tty.parity     = serial.PARITY_ODD
            else:
                raise Exception("The impossible happened - corrupt parity config :-(")

            if self.stopBits==1:
                self.tty.stopbits   = serial.STOPBITS_ONE
            elif self.stopBits==2:
                self.tty.stopbits   = serial.STOPBITS_TWO
            else:
                raise Exception("The impossible happened - corrupt stopbits config :-(")

            self.tty.xonxoff    = False
            self.tty.rtscts     = False
            self.tty.dsrdtr     = False
            self.tty.rs485_mode = None
            self.tty.timeout    = 0.2
            self.tty.open()

            if self.mode==SerialGw.NOGW:
                pass
            elif self.mode==SerialGw.SERVER:
                self.workerThread   = threading.Thread(target=self.runServer)
                self.workerThread.daemon    = True
                self.workerThread.start()
            elif self.mode==SerialGw.CLIENT:
                self.workerThread   = threading.Thread(target=self.runClient)
                self.workerThread.daemon    = True
                self.workerThread.start()
            else:
                raise Exception("The impossible happened - corrupt GW mode: {}".format(self.mode))

            self.running                = True
        finally:
            self.lock.release()


    def shutdown(self):
        """
        Shutdown operation of the serial GW and disconnect
        from the serial port
        """
        joinWorker  = False
        try:
            self.lock.acquire()
            if not self.running:
                self.logger.warn("gateway not running")
                return

            if self.workerThread is not None:
                joinWorker  = True
                self.stopGw = True

            self.running    = False
        finally:
            self.lock.release()

        if joinWorker:
            self.logger.info("waiting for worker thread to stop")
            self.workerThread.join()

        self.tty.close()



    def writeSideChannel(self,data:str) -> int :
        """
        Write data to the sidechannel of the serial port
        data - string to write into the serial port, regardless of
               the operation mode.
        """
        try:
            self.lock.acquire()
            if self.running:
                self.tty.write(data.encode("utf-8"))
        finally:
            self.lock.release()

