#!/usr/bin/python

from sys import argv, stdout, stdin, exit
from process import Process
from select import select
from time import time, sleep

import readline
import re
import os
import sys
import atexit

histfile = os.path.join(os.environ["HOME"], ".mhssh")

try:
    readline.read_history_file(histfile)
except IOError:
    pass
atexit.register(readline.write_history_file, histfile)


def read_input():
	if select([stdin], [], [], 0)[0]:
		return stdin.readline()
	return None

def pad(s):
	return s + " " * (max_addr - len(s))

def mhssh_write(s):
	stdout.write("\033[01m%s\033[00m" % s)
	stdout.flush()

class Host:
	def __init__(self, addr, max_addr=0):
		self.proc = Process("ssh", [addr, "bash -i"])
		self.addr = addr
		self.proc.start()
		self.last_line = ""
		if len(addr) + 1 > max_addr:
			max_addr = len(addr) + 1

	def readline(self):
		line = host.proc.readline(timeout=0.1)
		if line and line != "":
			self.last_line = line
			return line
		line = host.proc.readline(timeout=0.1, stream="stderr")
		if line and line != "":
			self.last_line = line
			return line
		return None

def mhssh_command(c):
	c = c.rstrip()
	if c == "":
		cmd_target = "(all)"
		mhssh_write("Now sending command to all hosts\n")
	elif c == "list" or c == "l":
		mhssh_write("Connected host:\n")
		i = 0
		for host in p:
			print i, ":", host.addr
			i = i + 1
	elif c[:len("add")] == "add" or c[0] == "a":
		p.append(Host(c.split(' ')[1]))
	elif c.isdigit() and int(c) < len(p):
		cmd_target = p[int(c)].addr
		mhssh_write("Now sending command to %s\n" % p[int(c)].addr)
	else: 
		for host in p:
			if host.addr == c:
				cmd_target = c
				mhssh_write("Now sending command to %s\n" % c)
				return
		mhssh_write("Unknown host %s\n" % c)
	

p = []

max_addr = 0

for addr in argv[1:]:
	p.append(Host(addr, max_addr))

last_time = time()
last_host = ""
cmd_target="(all)"

while 1:
	sleep(0.01) # Make your cpu happy
	for host in p:
		if not host.proc.isRunning():
			mhssh_write("Host %s disconnected.\n" % host.addr)
			p.remove(host)
			if len(p) == 0:
				exit(0)
			continue
		if host.last_line == "":
			if host.readline():
				if host.addr != last_host:
					last_time = time()
					last_host = host.addr

	write_all = 0
	write_com = 1

	for host1 in p:
		if host1.last_line == "":
			write_com = 0

		for host2 in p:
			if host1 != host2:
				if host1.last_line != host2.last_line and host1.last_line and host2.last_line:
					write_com = 0
					write_all = 1
					break
		if write_all:
			break

	if write_com:
		mhssh_write("%s" % pad("(all)"))
		stdout.write(" |%s" % p[0].last_line)
		stdout.flush()
		for host in p:
			host.last_line = ""

	if write_all or last_time + 0.5 < time():
		for host in p:
			if host.last_line != "":
				mhssh_write("%s" % pad(host.addr))
				stdout.write(" |%s" % host.last_line)
				stdout.flush()
				host.last_line = ""


	cmd = read_input()
	if cmd:
		if cmd[0] == "%":
			mhssh_command(cmd[1:])
		else:
			for host in p:
				if host.addr == cmd_target or cmd_target == "(all)":
					host.proc.writeline(cmd)



