#!/usr/bin/env python
#
# Copyright (C) 2012-2013 Fanout, Inc.
#
# This file is part of Pushpin.
#
# Pushpin is free software: you can redistribute it and/or modify it under
# the terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# Pushpin is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for
# more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import sys
import os
import ConfigParser

config_file = "/etc/pushpin/pushpin.conf"
log_file = None
verbose = False
for arg in sys.argv:
	if arg.startswith("--config="):
		config_file = arg[9:]
	elif arg.startswith("--logfile="):
		log_file = arg[10:]
	elif arg.startswith("--verbose"):
		verbose = True

config = ConfigParser.ConfigParser()
config.read([config_file])

libdir = None
if config.has_option("global", "libdir"):
	libdir = config.get("global", "libdir")

if libdir:
	sys.path.insert(0, os.path.join(libdir, "handler"))

import time
import threading
import json
import logging
from logging.handlers import WatchedFileHandler
from base64 import b64decode
from setproctitle import setproctitle
import zmq
import tnetstring
import httpinterface
from validation import validate_publish, validate_http_publish, ValidationError
from conversion import ensure_utf8, convert_json_transport
from statusreasons import get_reason

setproctitle("pushpin-handler")

# reopen stdout file descriptor with write mode
# and 0 as the buffer size (unbuffered)
sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 0)

logger = logging.getLogger('handler')
if log_file:
	logger_handler = WatchedFileHandler(log_file)
else:
	logger_handler = logging.StreamHandler(stream=sys.stdout)
if verbose:
	logger.setLevel(logging.DEBUG)
	logger_handler.setLevel(logging.DEBUG)
else:
	logger.setLevel(logging.INFO)
	logger_handler.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(levelname)s %(asctime)s.%(msecs)03d %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
logger_handler.setFormatter(formatter)
logger.addHandler(logger_handler)

instance_id = "pushpin-handler"

m2a_out_specs = config.get("handler", "m2a_out_specs").split(",")

if config.has_option("handler", "share_all") and config.get("handler", "share_all") == "true":
	share_all = True
else:
	share_all = False

simple_headers = set()
simple_headers.add("Cache-Control")
simple_headers.add("Content-Language")
simple_headers.add("Content-Type")
simple_headers.add("Expires")
simple_headers.add("Last-Modified")
simple_headers.add("Pragma")

ctx = zmq.Context()

class Hold(object):
	def __init__(self, rid, mode, response, auto_cross_origin, origin, jsonp_callback):
		self.rid = rid
		self.mode = mode
		self.response = response
		self.auto_cross_origin = auto_cross_origin
		self.origin = origin
		self.jsonp_callback = jsonp_callback
		self.expire_time = None

lock = threading.Lock()
response_channels = dict()
response_lastids = dict()
stream_channels = dict()

def header_get(headers, name):
	lname = name.lower()
	for k, v in headers.iteritems():
		if k.lower() == lname:
			return v
	return None

def header_remove(headers, name):
	for k in headers.keys():
		if k.lower() == name:
			del headers[k]
			return

HTTP_FORMAT = "HTTP/1.1 %(code)s %(status)s\r\n%(headers)s\r\n\r\n%(body)s"
HTTP_FORMAT_NOHEADERS = "HTTP/1.1 %(code)s %(status)s\r\n\r\n%(body)s"

def http_response(body, code, status, headers):
	payload = {"code": code, "status": status, "body": body}
	header_remove(headers, "content-length")
	headers["Content-Length"] = len(body)
	payload["headers"] = "\r\n".join("%s: %s" % (k, v) for k, v in
		headers.items())

	return HTTP_FORMAT % payload

def http_response_nolen(body, code, status, headers):
	payload = {"code": code, "status": status, "body": body}
	header_remove(headers, "content-length")

	if len(headers) > 0:
		payload["headers"] = "\r\n".join("%s: %s" % (k, v) for k, v in
			headers.items())

		return HTTP_FORMAT % payload
	else:
		return HTTP_FORMAT_NOHEADERS % payload

def reply_http_old(sock, rid, code, status, headers, body, nolen=False):
	header = "%s %d:%s," % (rid[0], len(rid[1]), rid[1])

	if isinstance(status, unicode):
		status = status.encode("utf-8")

	# ensure headers are utf-8
	tmp = dict()
	for k, v in headers.iteritems():
		if isinstance(k, unicode):
				k = k.encode("utf-8")
		if isinstance(v, unicode):
				v = v.encode("utf-8")
		tmp[k] = v
	headers = tmp

	if isinstance(body, unicode):
		body = body.encode("utf-8")

	if nolen:
		msg = http_response_nolen(body, code, status, headers)
	else:
		msg = http_response(body, code, status, headers)
	m_raw = header + " " + msg
	logger.debug("OUT publish: %s" % m_raw)
	sock.send(m_raw)

def reply_http_chunk_old(sock, rid, content):
	header = "%s %d:%s," % (rid[0], len(rid[1]), rid[1])
	m_raw = header + " " + content
	sock.send(m_raw)

def reply_http(sock, rid, code, reason, headers, body, nolen=False):
	if isinstance(reason, unicode):
		reason = reason.encode("utf-8")

	# ensure headers are utf-8
	tmp = dict()
	for k, v in headers.iteritems():
		if isinstance(k, unicode):
				k = k.encode("utf-8")
		if isinstance(v, unicode):
				v = v.encode("utf-8")
		tmp[k] = v
	headers = tmp

	if isinstance(body, unicode):
		body = body.encode("utf-8")

	header_remove(headers, "content-length")
	if not nolen:
		headers["Content-Length"] = str(len(body))

	out = dict()
	out["from"] = instance_id
	out["id"] = rid[1]
	out["code"] = code
	out["reason"] = reason
	headers_list = list()
	for k, v in headers.iteritems():
		headers_list.append([k, v])
	out["headers"] = headers_list
	if body:
		out["body"] = body
	if nolen:
		out["more"] = True

	m_raw = rid[0] + " T" + tnetstring.dumps(out)
	logger.debug("OUT publish: %s" % m_raw)
	sock.send(m_raw)

def reply_http_chunk(sock, rid, content):
	out = dict()
	out["from"] = instance_id
	out["id"] = rid[1]
	out["body"] = content
	out["more"] = True

	m_raw = rid[0] + " T" + tnetstring.dumps(out)
	logger.debug("OUT publish: %s" % m_raw)
	sock.send(m_raw)

def inspect_worker():
	sock = ctx.socket(zmq.REP)
	sock.connect(config.get("handler", "proxy_inspect_spec"))

	while True:
		m_raw = sock.recv()
		m = tnetstring.loads(m_raw)
		logger.debug("IN inspect: %s" % m)

		# reply saying to always proxy
		id = m["id"]
		method = m["method"]
		uri = m["uri"]
		m = dict()
		m["id"] = id
		m["no-proxy"] = False

		if share_all:
			m["sharing-key"] = method + '|' + uri

		logger.debug("OUT inspect: %s" % m)
		m_raw = tnetstring.dumps(m)
		sock.send(m_raw)

	sock.close()

def accept_worker():
	sock = ctx.socket(zmq.PULL)
	sock.connect(config.get("handler", "proxy_accept_in_spec"))

	out_sock = ctx.socket(zmq.PUB)
	for spec in m2a_out_specs:
		out_sock.connect(spec)

	retry_sock = ctx.socket(zmq.PUSH)
	retry_sock.connect(config.get("handler", "proxy_retry_out_spec"))

	while True:
		m_raw = sock.recv()
		m = tnetstring.loads(m_raw)
		logger.debug("IN accept: %s" % m)

		lheaders = m["request-data"]["headers"]
		headers = dict()
		for l in lheaders:
			headers[l[0]] = l[1]
		origin = header_get(headers, "Origin")
		if not origin:
			origin = "*"

		try:
			instruct = json.loads(m["response"]["body"])
			hold = instruct["hold"]
			mode = hold.get("mode")
			if mode is None:
				mode = "response"
			if mode != "response" and mode != "stream":
				raise ValueError("bad mode")
			channel = hold["channels"][0]["name"]
			prev_id = hold["channels"][0].get("prev-id")
			if "channel-prefix" in m:
				channel = m["channel-prefix"] + channel
			if "timeout" in hold:
				timeout = int(hold["timeout"])
			else:
				timeout = 55
			response = instruct.get("response")
			if response is None:
				response = dict()
				response["body"] = ""
			if "body-bin" in response:
				response["body"] = b64decode(response["body-bin"])
				del response["body-bin"]
			else:
				response["body"] = response["body"].encode("utf-8")
		except:
			logger.debug("failed to parse accept instructions")
			continue

		reqs = m["requests"]
		logger.debug("accepting %d requests" % len(reqs))
		
		for req in reqs:
			rid = (req["rid"]["sender"], req["rid"]["id"])

			h = Hold(rid, mode, response, req.get("auto-cross-origin"), origin, req.get("jsonp-callback"))

			if mode == "response":
				logger.debug("adding response hold on %s" % channel)
				h.expire_time = int(time.time()) + timeout
				lock.acquire()
				if prev_id is not None:
					last_id = response_lastids.get(channel)
					if last_id is not None and last_id != prev_id:
						del response_lastids[channel]
						lock.release()
						logger.debug("lastid inconsistency (got=%s, expected=%s), retrying" % (prev_id, last_id))
						r = dict()
						r["requests"] = m["requests"]
						r["request-data"] = m["request-data"]
						if "inspect" in m:
							r["inspect"] = m["inspect"]
						logger.debug("OUT retry: %s" % r)
						r_raw = tnetstring.dumps(r)
						retry_sock.send(r_raw)
						continue
				hchannel = response_channels.get(channel)
				if not hchannel:
					hchannel = dict()
					response_channels[channel] = hchannel
				hchannel[rid] = h
				lock.release()
			else: # stream
				# initial reply
				if "code" in response:
					rcode = response["code"]
				else:
					rcode = 200

				if "reason" in response:
					rreason = response["reason"]
				else:
					rreason = get_reason(rcode)

				if "headers" in response:
					rheaders = response["headers"]
				else:
					rheaders = dict()

				reply_http(out_sock, rid, rcode, rreason, rheaders, response.get("body"), True)

				logger.debug("adding stream hold on %s" % channel)

				# bind channel
				lock.acquire()
				hchannel = stream_channels.get(channel)
				if not hchannel:
					hchannel = dict()
					stream_channels[channel] = hchannel
				hchannel[rid] = h
				lock.release()

	sock.close()

def push_in_zmq_worker():
	in_sock = ctx.socket(zmq.PULL)
	in_sock.bind(config.get("handler", "push_in_spec"))

	out_sock = ctx.socket(zmq.PUSH)
	out_sock.connect("inproc://push_in")

	while True:
		m_raw = in_sock.recv()
		try:
			try:
				m = tnetstring.loads(m_raw)
			except:
				raise ValidationError("bad format (not a tnetstring)")

			m = validate_publish(m)

		except ValidationError as e:
			logger.debug("warning: %s, dropping" % e)

		out_sock.send(tnetstring.dumps(m))

	out_sock.linger = 0
	out_sock.close()

# return None for success or string on error
def push_in_http_handler(context, m):
	out_sock = context["out_sock"]

	try:
		m = validate_http_publish(m)
	except ValidationError as e:
		return e.message

	for n, i in enumerate(m["items"]):
		out = dict()

		channel = i.get("channel")
		if channel is not None:
			out["channel"] = ensure_utf8(channel)

		id = i.get("id")
		if id is not None:
			out["id"] = ensure_utf8(id)

		prev_id = i.get("prev-id")
		if prev_id is not None:
			out["prev-id"] = ensure_utf8(prev_id)

		for transport in ("http-response", "http-stream"):
			if transport in i:
				out[transport] = convert_json_transport(i[transport])

		out_sock.send(tnetstring.dumps(out))

def push_in_http_worker():
	out_sock = ctx.socket(zmq.PUSH)
	out_sock.connect("inproc://push_in")

	context = dict()
	context["out_sock"] = out_sock
	httpinterface.run(config.get("handler", "push_in_http_addr"), int(config.get("handler", "push_in_http_port")), push_in_http_handler, context)

	out_sock.linger = 0
	out_sock.close()

def push_in_worker(c):
	in_sock = ctx.socket(zmq.PULL)
	in_sock.bind("inproc://push_in")
	c.acquire()
	c.notify()
	c.release()

	out_sock = ctx.socket(zmq.PUB)
	for spec in m2a_out_specs:
		out_sock.connect(spec)

	while True:
		m_raw = in_sock.recv()
		m = tnetstring.loads(m_raw)
		logger.debug("IN publish: %s" % m)
		channel = m["channel"]

		if "http-response" in m:
			lock.acquire()
			hchannel = response_channels.get(channel)
			if hchannel:
				holds = hchannel.values()
				del response_channels[channel]
			else:
				holds = list()
			item_id = m.get("id")
			if item_id is not None:
				response_lastids[channel] = item_id
			lock.release()
			logger.debug("relaying to %d subscribers" % len(holds))
			http_response = m["http-response"]

			if "code" in http_response:
				pcode = http_response["code"]
			else:
				pcode = 200

			if "reason" in http_response:
				preason = http_response["reason"]
			else:
				preason = get_reason(pcode)

			if "headers" in http_response:
				pheaders = http_response["headers"]
			else:
				pheaders = dict()

			if "body" in http_response:
				pbody = http_response["body"]
			else:
				pbody = ""

			for h in holds:
				headers = dict()
				if h.jsonp_callback:
					result = dict()
					result["code"] = pcode
					result["reason"] = preason
					result["headers"] = dict()
					if pheaders:
						for k, v in pheaders.iteritems():
							result["headers"][k] = v
					result["headers"]["Content-Length"] = str(len(pbody))
					result["body"] = pbody

					body = h.jsonp_callback + "(" + json.dumps(result) + ");\n"
					headers["Content-Type"] = "application/javascript"
					headers["Content-Length"] = str(len(body))
					reply_http(out_sock, h.rid, 200, "OK", headers, body)
				else:
					if pheaders:
						for k, v in pheaders.iteritems():
							headers[k] = v
					if h.auto_cross_origin:
						acr_headers = headers.get("Access-Control-Request-Headers")
						ace_headers = list();
						if acr_headers:
							for name in acr_headers.split(","):
								name = name.strip()
								if name:
									ace_headers.append(name)

						expose_headers = list()
						for name in headers.keys():
							if name not in simple_headers and not name.startswith("Access-Control-"):
								expose_headers.append(name)

						headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE"
						if len(ace_headers) > 0:
							headers["Access-Control-Allow-Headers"] = ", ".join(ace_headers)
						if len(expose_headers) > 0:
							headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)
						headers["Access-Control-Allow-Credentials"] = "true"
						headers["Access-Control-Allow-Origin"] = h.origin

					reply_http(out_sock, h.rid, pcode, preason, headers, pbody)

		if "http-stream" in m:
			lock.acquire()
			hchannel = stream_channels.get(channel)
			if hchannel:
				holds = hchannel.values()
			else:
				holds = list()
			lock.release()
			logger.debug("relaying to %d subscribers" % len(holds))
			for h in holds:
				content = m["http-stream"]["content"]
				if content:
					reply_http_chunk(out_sock, h.rid, content)

	in_sock.close()

def timeout_worker():
	out_sock = ctx.socket(zmq.PUB)
	for spec in m2a_out_specs:
		out_sock.connect(spec)

	while True:
		now = int(time.time())
		lock.acquire()
		holds = list()
		channels = set()
		for channel, hchannels in response_channels.iteritems():
			channels.add(channel)
			channel_holds = list()
			for h in hchannels.values():
				if h.expire_time and now >= h.expire_time:
					channel_holds.append(h)
			for h in channel_holds:
				del hchannels[h.rid]
			holds.extend(channel_holds)
		for channel in channels:
			if channel in response_channels and len(response_channels[channel]) == 0:
				del response_channels[channel]
		lock.release()

		if len(holds) > 0:
			logger.debug("timing out %d subscribers" % len(holds))

			for h in holds:
				if "code" in h.response:
					pcode = h.response["code"]
				else:
					pcode = 200

				if "reason" in h.response:
					preason = h.response["reason"]
				else:
					preason = get_reason(pcode)

				if "headers" in h.response:
					pheaders = h.response["headers"]
				else:
					pheaders = dict()

				if "body" in h.response:
					pbody = h.response["body"]
				else:
					pbody = ""

				headers = dict()
				if h.jsonp_callback:
					result = dict()
					result["code"] = pcode
					result["reason"] = preason
					result["headers"] = dict()
					if pheaders:
						for k, v in pheaders.iteritems():
							result["headers"][k] = v
					result["headers"]["Content-Length"] = str(len(pbody))
					result["body"] = pbody

					body = h.jsonp_callback + "(" + json.dumps(result) + ");\n"
					headers["Content-Type"] = "application/javascript"
					headers["Content-Length"] = str(len(body))
					reply_http(out_sock, h.rid, 200, "OK", headers, body)
				else:
					if pheaders:
						for k, v in pheaders.iteritems():
							headers[k] = v
					if h.auto_cross_origin:
						acr_headers = headers.get("Access-Control-Request-Headers")
						ace_headers = list();
						if acr_headers:
							for name in acr_headers.split(","):
								name = name.strip()
								if name:
									ace_headers.append(name)

						expose_headers = list()
						for name in headers.keys():
							if name not in simple_headers and not name.startswith("Access-Control-"):
								expose_headers.append(name)

						headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE"
						if len(ace_headers) > 0:
							headers["Access-Control-Allow-Headers"] = ", ".join(ace_headers)
						if len(expose_headers) > 0:
							headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)
						headers["Access-Control-Allow-Credentials"] = "true"
						headers["Access-Control-Allow-Origin"] = h.origin

					reply_http(out_sock, h.rid, pcode, preason, headers, pbody)

		time.sleep(1)

inspect_thread = threading.Thread(target=inspect_worker)
inspect_thread.start()

accept_thread = threading.Thread(target=accept_worker)
accept_thread.start()

# we use a condition here to ensure the inproc bind succeeds before progressing
c = threading.Condition()
c.acquire()
push_in_thread = threading.Thread(target=push_in_worker, args=(c,))
push_in_thread.start()
c.wait()
c.release()

push_in_zmq_thread = threading.Thread(target=push_in_zmq_worker)
push_in_zmq_thread.start()

push_in_http_thread = threading.Thread(target=push_in_http_worker)
push_in_http_thread.start()

timeout_thread = threading.Thread(target=timeout_worker)
timeout_thread.daemon = True
timeout_thread.start()

try:
	while True:
		time.sleep(60)
except KeyboardInterrupt:
	pass

httpinterface.stop()
ctx.term()
