#!/usr/bin/env python

# This file is part of Window-Switch.
# Copyright (c) 2009-2013 Antoine Martin <antoine@nagafix.co.uk>
# Window-Switch is released under the terms of the GNU GPL v3

import os
import time

from winswitch.virt.ssh_command_monitor import SSHCommandMonitor
from winswitch.objects.session import Session
from winswitch.objects.client_session import ClientSession
from winswitch.objects.server_command import ServerCommand
from winswitch.consts import SSH_TYPE, NOTIFY_ERROR
from winswitch.globals import WIN32
from winswitch.util.common import is_valid_file, generate_UUID
from winswitch.util.file_io import get_ssh_server_control_path
from winswitch.util.process_util import subprocess_terminate, LineProcessProtocolWrapper
from winswitch.util.ssh_util import get_tunnel_command
from winswitch.virt.client_util_base import ClientUtilBase
from winswitch.virt.options_common import COMPRESSION


SHARE_SSH_TUNNEL = True

class	SSH_X_Command(SSHCommandMonitor):

	#CONNECTION_TIMEOUT_RE	= re.compile(r"ssh: connect to host (.*) port [0-9]*: Connection timed out")

	def __init__(self, ssh_command, server, session, command, name, server_timeout, notify, ask):
		self.ssh_command = ssh_command
		self.session = session
		self.remote_command = command
		self.compression = None
		SSHCommandMonitor.__init__(self, server, name, server_timeout, notify, ask)

	def connectionMade(self):
		SSHCommandMonitor.connectionMade(self)
		#not really connected yet... just means the process is talking to us, but better still than just setting CONNECTED on start
		self.session.set_status(Session.STATUS_CONNECTED)

	def exit(self):
		self.slog("return code=%s" % self.returncode)
		#only show warning if the process/session was not terminated when exit() was called:
		warn = not self.terminated and self.session.status!=Session.STATUS_CLOSED
		SSHCommandMonitor.exit(self)		#this will set terminated=True in any case
		self.session.set_status(Session.STATUS_CLOSED)
		if warn and self.returncode!=0:
			if self.notify:
				secs = int(time.time()-self.start_time)
				msg = "Is a firewall blocking it?\nIs the SSH server running?\nIs it allowing X11 Forwarding?"
				if secs<60:
					msg = ("It lasted just %s seconds.\n" % secs)+msg
				self.notify("SSH Session %s failed" % self.session.name, msg,
						notification_type=NOTIFY_ERROR)
			return

	def make_env(self):
		env = LineProcessProtocolWrapper.make_env(self)
		env["DISPLAY"] = self.session.local_display
		return env

	def get_command(self):
		control_path = None
		if SHARE_SSH_TUNNEL:
			control_path=get_ssh_server_control_path(self.server, True)

		cmd = get_tunnel_command(self.ssh_command, self.server, self.remote_command,
								X_forwarding=True, compression=self.compression, interactive=True,
								master="no",
								control_path=control_path)
		self.slog("ssh_command(%s)=%s" % (self.command, cmd))
		return cmd

	def handle(self, line):
		""" Override so we can ignore anything that comes in when the session is closed """
		if self.session.status!=Session.STATUS_CLOSED:
			SSHCommandMonitor.handle(self, line)
		else:
			self.sdebug(None, line)

	def do_handle(self, line):
		self.sdebug(None, line)


class	SSHClientBase(ClientUtilBase):

	def	__init__(self, update_session_status, notify_callback, ask_callback):
		ClientUtilBase.__init__(self, SSH_TYPE, update_session_status, notify_callback)
		self.ask_callback = ask_callback

	def client_kill_session(self, server, client, session):
		self.slog(None, server, client, session)
		session.set_status(Session.STATUS_CLOSED)
		for process in session.processes:
			subprocess_terminate(process)

	def	client_start_session(self, server, client, command, screen_size, opts):
		#check we have enough to authenticate:
		if not is_valid_file(server.ssh_keyfile) and not server.password:
			msg = "SSH sessions require valid authorization keys to be setup"
			msg += " or a password to be entered"
			msg += ", please update the connection settings for this server."
			self.notify_error("Cannot start SSH session", msg)
			return

		session = ClientSession()
		session.display = os.getenv("DISPLAY", ":0")
		session.ID = generate_UUID()
		session.host = server.host
		session.port = server.port
		session.name = command.name
		session.command = command.command
		session.status = Session.STATUS_CONNECTING
		session.screen_size = screen_size					#FIXME: ignored by xnest below..
		session.options = opts
		session.session_type = SSH_TYPE
		session.start_time = time.time()
		session.actor = self.settings.uuid
		session.owner = self.settings.uuid
		session.set_default_icon_data(command.get_icon_data())
		#start SSH:
		real_command = command.command.split()[0]
		if not server.platform.startswith("darwin"):		#we can't use the session wrappers on osx
			if command.type==ServerCommand.DESKTOP:
				real_command = "winswitch_ssh_Xnest %s %s %s %s" % (self.settings.uuid, session.ID, server.xnest_command, real_command)
			else:
				#use wrapper script which will tell us when the session is finished
				real_command = "winswitch_ssh_session %s %s %s" % (self.settings.uuid, session.ID, real_command)

		def do_start(*args):
			self.sdebug(None, *args)
			self.do_client_start_session(server, client, session, command, real_command)

		if WIN32:
			#on win32 we may have to wait for Xming to start:
			self.win32_Xming_session_start(session, do_start)
		else:
			session.local_display = os.environ.get("DISPLAY")
			do_start()

	def do_client_start_session(self, server, client, session, command, real_command):
		compression = session.options.get(COMPRESSION)
		self.slog("cmd=%s, display=%s, compression=%s" % (real_command, session.local_display, compression), server, client, command)
		#dev_null = file(os.devnull)
		self.schedule_connect_check(session, kill_it=True)			#kill it if it is still CONNECTING after 30s
		session.add_status_update_callback(Session.STATUS_CONNECTED, Session.STATUS_CLOSED,
										lambda : self.client_close_session(server, client, session), True, None)
		process = SSH_X_Command(self.settings.ssh_command, server, session, real_command, session.name, server.timeout, self.notify, self.ask_callback)
		process.compression = compression
		session.processes.append(process)
		server.add_session(session)
		process.start()


	def client_close_session(self, server, client, session):
		self.slog(None, server, client, session)
		server.remove_session(session)
		self.kill_client_processes(session)
		if session.local_display_process:
			subprocess_terminate(session.local_display_process)


	def attach(self, server, session, host, port):
		msg = "BUG this method should never be called!"
		self.serror(msg, server, session, host, port)
		raise Exception(msg)

	def get_options_defaults(self):
		return	{
				COMPRESSION: 5,
				}
