#!/usr/bin/python

# Copyright (C) 2007 SPARTA, Inc.
# This software is licensed under the GPLv3 license, included in
# ./GPLv3-LICENSE.txt in the source distribution

import time
import os
import signal
import re
import sys
import logging
import pprint
import socket
import seer
from distributions import *

class Agent(object):
	
	def __init__(self):
		self.statedump = {}  # For storage of strings used to set variables for dumping state quickly
		self.storage = {}  # For storage of state variables
		self.vartypes = {} # For indicating what vars we except or not
		self.pids = []  # Array for storage of subprocess PID's
		self.addVarType('NODES', 'array', [])

	def agentInit(self, group, whiteboard, mystate, pidkill):
		""" Used by main program to initialize common variables.
			We do this in a subfunction so that actual implementations
			don't have to deal with it in their init function
		"""
		self.type = self.agenttype # pull from static, maybe just use static?
		self.group = group
		self.name = mystate.node
		self.whiteboard = whiteboard # Shared storage location
		self.mystate = mystate # link to MyState object
		self.log = logging.getLogger(self.type+"/"+self.group)
		self.pidkill = pidkill
		
	def setVar(self, key, val):
		""" Always set local (ME.key) """
		self.storage[self.name+"."+key] = val
	
	def getVar(self, key):
		""" Attempt ME.key first, then default to key """
		return self.storage.get(self.name+"."+key, self.storage.get(key))
	
	def getGroupVar(self, key):
		""" Just look up key in storage """
		return self.storage.get(key)

	def getNodeVar(self, key):
		""" Just look up ME.key in storage """
		return self.storage.get(self.name+"."+key)
		
	def getOtherNodeVar(self, node, key):
		""" Look up node.key then default to key """
		return self.storage.get(node+"."+key, self.storage.get(key))
		

	def myNodeMemberOf(self, arrayname):
		""" Check if my node name is in the particular list """
		array = self.getVar(arrayname)
		compare = self.name.lower()

		for node in array:
			if (node == '*') or (node.lower() == compare):
				return True
		return False


	def myIPMemberOf(self, arrayname):
		""" Check if one of my ip addrs is in the particular list and return them """
		array = self.getVar(arrayname)
		iplist = self.mystate.GetIPList()
		collect = {}

		for aip in array:
			for myip in iplist:
				if (aip == myip):
					collect[myip] = 1
		return collect.keys()


	def addVarType(self, key, type, default):
		self.vartypes[key] = type
		self.storage[key] = default
	
	def processArgs(self, args):
		""" Insert all of the variables, verifying/converting each """
		for k in args:
			singlekey = k[k.rfind('.')+1:]
			if (singlekey not in self.vartypes):
				self.log.error("Ignoring unknown key (%s)" % (k))
				continue
		
			# Find out what type of variable we expect this to be			
			type = self.vartypes.get(singlekey)
			val = args[k]

			try:
				if (type == 'array'):
					self.storage[k] = re.split('[\s,]+', val)
				elif (type == 'float'):
					self.storage[k] = float(val)
				elif (type == 'int'):
					self.storage[k] = int(val)
				elif (type == 'bool'):
					self.storage[k] = bool(int(val) == 1)
				elif (type == 'proto'):
					if re.match('[a-zA-Z]+', val): 
						self.storage[k] = socket.getprotobyname(val)
					else:
						self.storage[k] = int(val)
				elif (type == 'cidr'):
					self.storage[k] = seer.CIDR(inputstr=val)
				else:
					self.storage[k] = val

			except Exception:
				self.log.error("Can't convert value (%s) to %s " % (val, type), exc_info=1)

			else:
				# If we accepted the value, put its string in the statedump hash
				self.statedump[k] = val
				
		self.configDone()


	def configDone(self):
		""" Called when all variables in an event have been processed """
		pass


	def launchProgram(self):
		""" Base function that is called with basic agent receives a START """
		pass


	def handleSTART(self):
		if len(self.pids) > 0:
			self.log.info("Already running, not restarting")
			return

		if (self.myNodeMemberOf('NODES')):
			self.launchProgram()


	def handleSTOP(self):
		for pid in self.pids:
			self.pidkill.kill(pid, signal.SIGTERM)
		self.pids = []


	def __repr__(self):
		return "Agent(%s, %s, %s)\n%s" % (self.type, self.group, self.name, pprint.pformat(self.storage, 6))



"""
Trafgen extends the regular agent and provides some callbacks for starting a server and clients
"""
class TrafgenAgent(Agent):

	def __init__(self):
		Agent.__init__(self)
		self.runningserver = 0
		self.addVarType('servers', 'array', [])
		self.addVarType('think', 'string', '1')
		self.addVarType('sizes', 'string', '1')
		self.addVarType('autoquit', 'int', None)


	def handleSTART(self):
		if len(self.pids) > 0:
			self.log.info("Already running, not restarting")
			return

		if (self.myNodeMemberOf('servers')):
			self.runningserver = 1
			self.serverExec()

		if (self.myNodeMemberOf('NODES')):
			self.launchTrafficController()


	def handleSTOP(self):
		if (self.runningserver):
			self.serverStop()
			self.runningserver = 0

		for pid in self.pids:
			self.pidkill.killpg(pid, signal.SIGTERM)
		self.pids = []


	def serverExec(self):
		""" To be overriden by subclass, called when the local server should be started """
		pass
	
	def serverStop(self):
		""" To be overriden by subclass, called when the local server should be stopped """
		pass

	def clientInit(self):
		""" To be overriden by subclass, called after fork but only once before loop starts """
		pass
	
	def clientExec(self, src, dst, size):
		""" To be overriden by subclass, this will exec or perform the necessary trafgen process multiple times """
		pass


	def launchTrafficController(self):
		pid = os.fork()
		if (pid > 0):
			self.pids.append(pid)
			return

		os.setsid()
		signal.signal(signal.SIGCHLD, signal.SIG_IGN)

		spool = seer.AddressPool(self.whiteboard.get('FAKE.'+self.name))
		dpool = seer.AddressPool()
		serv = self.getGroupVar("servers")
		for s in serv:
			dpool.Add(s, self.whiteboard.get('FAKE.'+s))

		think = self.getVar("think")
		sizes = self.getVar("sizes")
		autoquit = self.getVar("autoquit")

		logfile = self.type+"."+self.group
		starttime = time.time()

		# Redirect stdout for exec'd applications
		try:
			self.log.info("starting launcher process at %d - logging output to %s\n" % (starttime, logfile))
			fp = open("/local/logs/%s" % (logfile), 'a', 1)
			sys.stdout = fp   # For python prints from here on (not logging)
			sys.stderr = fp
			os.dup2(fp.fileno(), 1)  # For anything we exec from here on
			os.dup2(fp.fileno(), 2)
		except Exception:
			self.log.error("Failed to redirect output", exc_info=1);

		# Call overridden init method
		self.clientInit() 
	
		# Loop based on wait times
		try:
			while (True):
				elapsed = time.time() - starttime;
				if ((autoquit > 0) and (elapsed > autoquit)):
					os.killpg(0, signal.SIGTERM) # This should kill me and my children as I forked/setsid()
					return
	
				self.clientOneLoop(spool, dpool, think, sizes)

		except Exception,e:
			self.log.error("error in client process", exc_info=1);
		

	def clientOneLoop(self, spool, dpool, think, sizes):
		""" This provides an internal point to override just the loop behaviour """
		src = spool.Random()
		dst = dpool.Random()
		size = int(eval(sizes))

		# Check memory here	
		self.clientExec(src, dst, size)  # Call overridden method

		waitfor = eval(think)
		time.sleep(waitfor)

