import time
import paramiko

##########################################################################
###                      texpect class section                         ###
##########################################################################
class texpect:
    def __init__(self, channel, expect: bytes | str, stimulate: bytes | str=b"", response: bytes | str=b"", timeout: float=60.0, complications: list[str | bytes]=[], verbose: bool=False) -> int:
        self.timeout = timeout
        self.start_time = 0.0
        self.end_time = 0.0
        self.expect = expect
        self.stimulate = stimulate
        self.response = response
        self.buffer = b""
        self.channel = channel
        self.verbose = verbose
        self.complications = complications
    def getTimeElapsed(self) -> float:
        return self.end_time - self.start_time
    def run(self) -> int:
        self.start_time = self.end_time = time.time()
        self.buffer = b""
        if self.verbose: print(f"Looking for match ({self.expect})")
        for x in (self.expect, self.stimulate):
            if type(x) is str:
                x = x.encode()
        if len(self.stimulate):
            self.channel.send(self.stimulate)
            if self.verbose: print(f"Sent stimulus ({self.stimulate})")
        while(self.timeout >= self.end_time - self.start_time):
            if self.channel.recv_ready():
                self.buffer += self.channel.recv(65535)
                self.end_time = time.time()
                if self.expect in self.buffer:
                    if self.verbose: print(f"Found match ({self.expect})")
                    if len(self.response):
                        self.channel.send(self.response)
                        if self.verbose: print(f"Sent response ({self.response})")
                    return 0
                for y in self.complications:
                    if y in self.buffer:
                        if self.verbose: print(f"Caught complication ({y})")
                        raise CaughtComplication(f"Caught complication ({y})")
                        return -1
        raise TimeoutError("Timed out waiting for match")
    
    
class TimeoutError(Exception):
    pass


class CaughtComplication(Exception):
    pass

##########################################################################
###                      DUT class                                     ###
##########################################################################
class DUT:
    def __init__(self, console_address: str, console_port: int, console_user: str, console_pass: str, root_pass: str, boot_device: str, verbose: bool=False):
        self.boot_device = boot_device
        self.verbose = verbose
        self.root_pass = root_pass
        self.console_pass = console_pass
        self.console_user = console_user
        self.console_address = console_address
        self.console_port = console_port
        self.ssh_client = None
        self.ssh_channel = None
        self.info = {} # Dict of SEEPROM platform contents
        self.setupresults = {} # Dict of utility results
        self.testresults = {} # Dict of raw test results

    def openConnection(self) -> int:
        if self.ssh_client:
            raise ConnectionError("Connection already exists!")
        self.ssh_client = paramiko.SSHClient()
        self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        if self.verbose: print(f"Attempting connections to {self.console_address}:{self.console_port}")
        try:
            self.ssh_client.connect(hostname=self.console_address, port=self.console_port, username=self.console_user, password=self.console_pass)
        except paramiko.ssh_exception.NoValidConnectionsError:
            if self.verbose: print("Connection failed: No Valid Connections")
            quit(-1)
        else:
            if self.verbose: print("SSH connection successful!")
        try:
            self.ssh_channel = self.ssh_client.invoke_shell()
        except (paramiko.ssh_exception.SSHException, ConnectionResetError) as e:
            print("SSH shell invocation failed!")
            self.ssh_client.close()
            quit(-1)
    
    def closeConnection(self):
        self.ssh_channel.close()
        self.ssh_client.close()