#!/usr/bin/env python
#
# -*- Mode: Python; tab-width: 4 -*-
# vim:sw=4:expandtab:
#
# $Id: rsipclient.py,v 0.20 2003/09/08 13:17:52 jose Exp $
#
# RSIP client daemon
#
# Author: Jose Vasconcellos <jvasco@bellatlantic.net>
# Copyright (C) 2002 Jose Vasconcellos
#
# Fixes :
#         Cedric de Launois  : mtu option added
#         Michael Scherer    : shutdown when SIGQUIT received
#
# This file may be distributed under the terms of the
# GNU Public Licence version 2
#

# This daemon manages the RSIP connection and sets up the
# tunnel interface. It supports both RSA-IP and RSAP-IP
# methods of RFC3103. If RSAP-IP is used, the client assumes
# that a macro-flow policy is used, and it will request a
# block of addresses. This range is then used for port
# allocation for all ports.

# this program depends on the following external commands:
#
# - slptool : can be obtained from http://www.openslp.org
# - iptunnel
# - ifconfig
# - route

# 
import sys, signal, socket, select, os, string, getopt, struct
from syslog import *
import sched, time
import binascii

##############################
# RSIP constants

# Message types
ERROR_RESPONSE = 1
REGISTER_REQUEST = 2
REGISTER_RESPONSE = 3
DEREGISTER_REQUEST = 4
DEREGISTER_RESPONSE = 5
ASSIGN_REQUEST_RSA_IP = 6
ASSIGN_RESPONSE_RSA_IP = 7
ASSIGN_REQUEST_RSAP_IP = 8
ASSIGN_RESPONSE_RSAP_IP = 9
EXTEND_REQUEST = 10
EXTEND_RESPONSE = 11
FREE_REQUEST = 12
FREE_RESPONSE = 13
QUERY_REQUEST = 14
QUERY_RESPONSE = 15
LISTEN_REQUEST = 16
LISTEN_RESPONSE = 17

# Parameter types
Address = 1
Ports = 2
Lease_Time = 3
Client_ID = 4
Bind_ID = 5
Tunnel_Type = 6
RSIP_Method = 7
Error = 8
Flow_Policy = 9
Indicator = 10
Message_Counter = 11
Vendor_Specific_Parameter = 12
SPI = 22

# Tunnel types
IPIP = 1
GRE = 2

# RSIP Methods
RSA_IP = 1
RSAP_IP = 2
RSIPSEC = 3
RSIP_Methods = ("Invalid", "RSA", "RSAP", "RSIPSEC")

# Flow policies
MacroFlows = 1
MicroFlows = 2
NoPolicy = 3

##############################
# prebuilt messages
REGISTER_REQUEST_STR   = '\x01\x02\x00\x04'
DEREGISTER_REQUEST_STR  = '\x01\x04\x00\x0b' '\x04\x00\x04%s'
EXTEND_REQUEST_STR      = '\x01\x0a\x00\x12' '\x04\x00\x04%s' '\x05\x00\x04%s'

##############################
# configuration

class rsipdata:
    verbose = 1
    notunnel = 0
    facility = LOG_DAEMON
    tunneltype = IPIP
    RSIPMethod = RSAP_IP
    tunnelmode = "ipip"
    TunnelEndPoint = ""
    Client_Id = 0
    IPaddress = ""
    port_base = 0
    port_num = 255
    tunnel_created = 0
    port_range_set = 0
    reg_lease = 0
    reg_evt = 0
    mtu = 1400

DEST = ''
PORT = 0

listen_ports = ""
Bind_Ids = {}
state = 0

rsip_if = 'rsip'

IP_LOCAL_PORT_RANGE = "/proc/sys/net/ipv4/ip_local_port_range"
pidfile = "/var/run/rsipclient.pid"

logopts = LOG_PID | LOG_PERROR

##############################
def makeint(str, off, cnt):
    'Extract binary integer from string'
    i = 0
    while cnt > 0:
        i = (i << 8) + ord(str[off])
        off = off + 1
        cnt = cnt - 1
    return i

def find_param(p, data, i):
    "search for parameter type p in parameter data starting at offset i"
    while i < len(data):
        if p == ord(data[i]):
            return i;
        i = i + 3 + makeint(data,i+1, 2)
    return -1 # can't find it

def getIPaddress(data, i):
    "extract IP address"
    l = makeint(data, i+1, 2) - 1
    if ord(data[i+3]) == 1:
        return socket.inet_ntoa(data[i+4:i+8])
    if ord(data[i+3]) == 4:
        return data[i+4:i+4+l]
    return "0.0.0.0"

##############################

def find_bind(data):
    global Bind_Ids
    i = find_param(Bind_ID, data, 4)
    Bind_Id = data[i+3:i+7]
    if Bind_Ids.has_key(Bind_Id):
        return Bind_Id
    return None

##############################

def finish(cause):
    if rsipdata.facility == LOG_DAEMON:
        os.unlink(pidfile)
    sys.exit(cause)

##############################

def find_gateway():
    'use SLP to find RSIP gateway'
    global DEST, PORT

    # use slptool

    try:
        slp = os.popen("slptool findsrvs service:rsip", "r")
        for line in slp.readlines():
            if 'service:rsip://' == line[:15]:
                i = line.index(':',15)
                DEST = line[15:i]
                j = line.index(',',i+1)
                PORT = int(line[i+1:j])
        slp.close()
    except IOError:
        syslog(LOG_ERR, 'Unable to find RSIP gateway via SLP')
        finish(2)

    if DEST == '':
        syslog(LOG_ERR, 'No service:rsip defined!')
        finish(2)

    if rsipdata.verbose:
        syslog('RSIP gateway: '+DEST+':'+`PORT`)
    
##############################

def dump(s, data):
    "Display nicely formated raw data"
    if rsipdata.verbose > 1:
        msg = "%s(%d): %s\n\t%s" % \
               (s, len(data), time.strftime("%H:%M:%S"), binascii.b2a_hex(data))
    elif rsipdata.verbose:
        msg = "%s(%d): %d" % (s, len(data), ord(data[1]))
    else:
        return
    syslog(LOG_DEBUG, msg)
        
##############################

def runcmd(s, cause, cmd, msg):
    global state
    if rsipdata.verbose:
        syslog(LOG_INFO, ">>> "+string.join(cmd))
    if rsipdata.notunnel == 0:
        ret = os.spawnvp(os.P_WAIT, cmd[0], cmd)
        if cause != 0:
            if ret == 0:
                state = s
            else:
                shutdown(cause, msg)

##############################

def set_port_range(low, high):
    if rsipdata.notunnel == 0:
        f = open(IP_LOCAL_PORT_RANGE, "w")
        f.write("%d %d" % (low, high))
        f.close()

def get_port_range():
    if rsipdata.notunnel == 0:
        f = open(IP_LOCAL_PORT_RANGE, "r")
        s = string.split(f.readline())
        f.close()
        return (int(s[0]), int(s[1]))
    else:
        return (32000, 48000)

##############################

def send_assign_request(base, n):
    global sock
    # issue assign request
    if rsipdata.RSIPMethod == RSAP_IP:
        type = ASSIGN_REQUEST_RSAP_IP
        if rsipdata.IPaddress:
            addr = struct.pack("!BHB4s",
                     Address, 5, 1, socket.inet_aton(rsipdata.IPaddress))
        else:
            addr = struct.pack("!BHB", Address, 1, 1)
        if base:
            port = struct.pack("!BHBH", Ports, 3, n, base)
        else:
            port = struct.pack("!BHB", Ports, 1, n)
        laddrp = addr+port
    else:
        type = ASSIGN_REQUEST_RSA_IP
        laddrp = struct.pack("!BHB", Address, 1, 1)

    fmt = "!BBH BH4s %ds BHB BHB BHB" % len(laddrp)
    msg = struct.pack(fmt,
                      1, type, struct.calcsize(fmt),
                      Client_ID, 4, rsipdata.Client_Id,
                      laddrp,
                      Address, 1, 1,
                      Ports, 1, 0,
                      Tunnel_Type, 1, rsipdata.tunneltype)

    sock.send(msg)
    dump('send ', msg)
    return ord(msg[1])

def send_listen_request(ports):
    global sock
    if rsipdata.IPaddress:
        addr = struct.pack("!BHB4s",
                           Address, 5, 1, socket.inet_aton(rsipdata.IPaddress))
    else:
        addr = struct.pack("!BHB", Address, 1, 1)

    fmt = "!BBH BH4s %ds BHB%ds BHB BHB BHB" % (len(addr), len(ports))
    msg = struct.pack(fmt,
                      1, LISTEN_REQUEST, struct.calcsize(fmt),
                      Client_ID, 4, rsipdata.Client_Id,
                      addr,
                      Ports, len(ports)+1, len(ports)/2, ports,
                      Address, 1, 1,
                      Ports, 1, 0,
                      Tunnel_Type, 1, rsipdata.tunneltype)

    sock.send(msg)
    dump('send ', msg)

##############################

def process_bind_response(data):
    "Handles ASSIGN_RESPONSE_RSAIP, ASSIGN_RESPONSE_RSAPIP, and LISTEN_RESPOSE"
    global Bind_Ids, peername

    # extract parameters
    #
    # Bind ID
    i = find_param(Bind_ID, data, 4)
    Bind_Id = data[i+3:i+7]
    
    # Local Address
    i = find_param(Address, data, i)
    rsipdata.IPaddress = getIPaddress(data, i)
    if rsipdata.verbose:
        syslog(LOG_INFO, "IP address: %s" % rsipdata.IPaddress)
    # Local Ports
    ports = ""
    port_base = 0
    tports = 0
    if ord(data[1]) != ASSIGN_RESPONSE_RSA_IP:
        i = find_param(Ports, data, i)
        port_len = makeint(data,i+1, 2)
        tports = ord(data[i+3])
        if port_len > 1:
            port_base = makeint(data,i+4, 2)
        ports = data[i:i+3+port_len]

        if rsipdata.verbose:
            syslog(LOG_INFO, "Ports: %d %d" % (port_base, tports))

    if ord(data[1]) != LISTEN_RESPONSE:
        # Lease comes before tunnel type in LISTEN_RESPONSES
        i = find_param(Lease_Time, data, i)
        lease = makeint(data, i+3, 4)

    # set tunnel type
    rsipdata.tunnelmode = "ipip"
    i = find_param(Tunnel_Type, data, i)
    if ord(data[i+3]) == GRE:
        rsipdata.tunnelmode = "gre"
    else:
        if ord(data[i+3]) != IPIP:
            shutdown(7, "Unsupported tunnel mode")

    if ord(data[1]) == LISTEN_RESPONSE:
        # Lease
        i = find_param(Lease_Time, data, i)
        lease = makeint(data, i+3, 4)

    if rsipdata.verbose:
        syslog(LOG_INFO, "Lease: %d" % lease)

    # look for tunnel endpoint
    j = find_param(Address, data, i)
    if j != -1:
        rsipdata.TunnelEndPoint = getIPaddress(data, j)
    else:
        rsipdata.TunnelEndPoint = peername[0]
    if rsipdata.verbose:
        syslog(LOG_INFO, "Tunnel Endpoint IP address: %s" % rsipdata.TunnelEndPoint)

    # save info for latter
    Bind_Ids[Bind_Id] = [ord(data[1]), ports, lease, None]
    return port_base, tports

##############################

def extend_reg(*args):
    global  evtsched
    syslog(LOG_DEBUG, "extend_reg")
    # start another timer
    if Bind_Ids and rsipdata.reg_lease:
        rsipdata.reg_evt = evtsched.enterabs(rsipdata.reg_lease-1, 1, extend_reg, [0])
        rsipdata.reg_lease =  0

def extend_bind(*args):
    global sock
    msg = EXTEND_REQUEST_STR % (rsipdata.Client_Id, args[1])
    args[0].send(msg)
    dump('send ', msg)

##############################

def rsipwait(timeout):
    global sock, p

    #p = select.poll()
    #p.register(sock.fileno(), select.POLLIN)

    while timeout > 0:
        start = time.time()
        syslog(LOG_DEBUG, "poll timeout: %d" % timeout)
        #r = p.poll(int(timeout*1000))
        i,o,e = select.select([sock],[],[sock],timeout)
        timeout = timeout - (time.time() - start)
        if e:
            return
        if not i and not o and not e:
            return
        #for fd, flags in r:
        #    if (flags & select.POLLIN):
        for fd in i:
                data = sock.recv(128)
                cause = rsipevent(data)
    #syslog(LOG_DEBUG, "rsipwait exit: %d" % timeout)

def rsipevent(data):
    global  evtsched, state

    if len(data) < 4:
        return 99

    dump('recv ', data)

    if ord(data[1]) == EXTEND_RESPONSE:
        b = find_bind(data)
        i = find_param(Lease_Time, data, 4)
        lease = makeint(data, i+3, 4)
        if b:
            syslog(LOG_DEBUG, "Lease extension: %d" % lease)
            Bind_Ids[b][2] = lease
            Bind_Ids[b][3] = evtsched.enter(lease-2, 1, extend_bind, [sock, b])
        # check to see if registration lease needs extension
        abslease = time.time() + lease
        if abslease > rsipdata.reg_lease:
            rsipdata.reg_lease = abslease

    elif ord(data[1]) == FREE_RESPONSE:
        b = find_bind(data)
        if Bind_Ids[b][3]:
            evtsched.cancel(Bind_Ids[b][3])
            del Bind_Ids[b]
            state = 2
    else:
        if ord(data[1]) == DEREGISTER_RESPONSE:
            state = 1
        #elif ord(data[1]) == ERROR_RESPONSE:
        raise socket.error

    return 0

##############################

def create_tunnel():
    global sockname

    #cmd_tunnel_add = ['ip', 'tunnel', 'add', rsip_if, 'mode', rsipdata.tunnelmode,
    #                  'local', sockname[0], 'remote', rsipdata.TunnelEndPoint]
    #cmd_addr_add = ['ip', 'addr', 'add', rsipdata.IPaddress, 'dev', rsip_if]
    #cmd_link_up = ['ip', 'link', 'set', rsip_if, 'up']
    #cmd_route_add = ['ip', 'route', 'add', 'default', 'dev', rsip_if]

    cmd_tunnel_add = ['iptunnel', 'add', rsip_if, 'mode', rsipdata.tunnelmode,
                      'local', sockname[0], 'remote', rsipdata.TunnelEndPoint]
    
    if rsipdata.mtu > 0:
        cmd_link_up = ['ifconfig', rsip_if, rsipdata.IPaddress, 'mtu', '1400']
    else:
        cmd_link_up = ['ifconfig', rsip_if, rsipdata.IPaddress]
        
    cmd_route_add = ['route', 'add', 'default', 'dev', rsip_if]

    # setup tunnel interface
    runcmd(4, 10, cmd_tunnel_add, 'Unable to create tunnel')
    #runcmd(4, 11, cmd_addr_add, 'Unable to add ip address')
    runcmd(5, 12, cmd_link_up, 'Unable to bring tunnel up')
    runcmd(7, 14, cmd_route_add, 'Unable to set default route')
    rsipdata.tunnel_created = 1

##############################

def del_tunnel():
    #cmd_tunnel_del = ['ip', 'tunnel', 'del', rsip_if]
    cmd_tunnel_del = ['iptunnel', 'del', rsip_if]

    runcmd(0, 0, cmd_tunnel_del, '')

##############################

def shutdown(n, str):
    "terminate program: restore original network configuration"
    global state, sock, evtsched

    syslog(LOG_NOTICE, str)
    if rsipdata.verbose:
        syslog(LOG_DEBUG, 'state=%d exit=%d' % (state, n))


    # restore original local port range
    if rsipdata.port_range_set:
        set_port_range(local_port_range[0], local_port_range[1])

    if rsipdata.tunnel_created:
        del_tunnel()

    # Do we need to do FREE_REQUEST?

    if state >= 2:
        msg = DEREGISTER_REQUEST_STR % rsipdata.Client_Id
        sock.send(msg)
        dump('send ', msg)
        data = sock.recv(128)
        dump('recv ', data)
    if state >= 1:
        sock.close()

    finish(n)

##############################

def quitHandler(signum, frame):
    syslog(LOG_NOTICE, "Signal: SIGQUIT received")
    shutdown(42,"Signal SIGQUIT received, exiting")

def hupHandler(signum, frame):
    syslog(LOG_NOTICE, "Signal: SIGHUP received")

def termHandler(signum, frame):
    syslog(LOG_NOTICE, "Signal: SIGTERM received")

##############################

def usage():
    print "Usage: %s [options]" % sys.argv[0]
    print "valid options:"
    print "\t -h | --help: usage statement"
    print "\t -V | --version: print version and exit"
    print "\t -v | --verbose: turn on progress messages"
    print "\t -d | --nodaemon: don't run in daemon (background) mode"
    print "\t -q | --quiet: turn off progress messages"
    print "\t -l P | --listen=P: comma separated list of ports (for servers)"
    print "\t -s A | --server=A: use Address for RSIP server (default: use SLP)"
    print "\t                    format: hostname[:port]"
    print "\t -p N | --ports=N: request N ports"
    print "\t -m M | --mtu=M: set tunnel MTU to M"
    print "\t --gre: use GRE instead of IPIP for tunneling"
    print "\t --rsap: use RSAP-IP method (default)"
    print "\t --rsa: use RSA-IP method"
    print "\t --notunnel: don't setup tunnel"

##############################

def do_options():
    global listen_ports, DEST, PORT
    try:
        opts, args = getopt.getopt(sys.argv[1:], "dhl:qs:p:m:vV", \
            ["rsa", "rsap", "help", "quiet", "verbose", "version", "gre", \
             "listen=", "nodaemon", "notunnel", "server=", "ports=", "mtu"])
    except getopt.GetoptError:
        # print help information and exit:
        usage()
        sys.exit(1)

    for o, a in opts:
        if o in ("-h", "--help"):
            usage()
            sys.exit(1)
        if o in ("-V", "--version"):
            print "RSIP client $Revision: 0.19 $"
            sys.exit(1)
        if o in ("-d", "--nodaemon"):
            rsipdata.facility = LOG_USER
        if o in ("-q", "--quiet"):
            rsipdata.verbose = 0
        if o in ("-v", "--verbose"):
            rsipdata.verbose = rsipdata.verbose + 1
        if o in ("-m", "--mtu"):
            rsipdata.mtu = int(a)
        if o in ("-l", "--listen"):
            listen_ports = string.join(map(lambda s: struct.pack("!H",int(s)), string.split(a,',')),'')
        if o in ("-s", "--server"):
            if string.find(a,':') <= 0:
                DEST = a
                PORT = 4555
            else:
                DEST = a[:i]
                PORT = int(a[i+1:])
            if rsipdata.verbose:
                print "Using RSIP server at %s:%d" % (DEST, PORT)
        if o in ("-p", "--ports"):
            rsipdata.port_num = int(a)
#           if not (0 < rsipdata.port_num < 256):
#               print "Invalid port range (must be 1..255)"
#               sys.exit(1)
        if o == "--rsap":
            rsipdata.RSIPMethod = RSAP_IP
        if o == "--rsa":
            rsipdata.RSIPMethod = RSA_IP
        if o == "--gre":
            rsipdata.tunneltype = GRE
        if o == "--notunnel":
            rsipdata.notunnel = 1

##############################

exename = os.path.basename(sys.argv[0])

# process command line options
do_options()

if rsipdata.RSIPMethod == RSA_IP and listen_ports:
    print "listen option ignored with RSA method"
    listen_ports = ""

local_port_range = get_port_range()

if rsipdata.facility == LOG_DAEMON:
    if os.path.exists(pidfile):
        print "File %s exists, %s already running\n" % (pidfile, exename)
        sys.exit(1)

    pid = os.fork()
    if pid:
        open(pidfile, 'w').write(`pid`+'\n')
        sys.exit(0)

    #install signal handlers
    signal.signal(signal.SIGQUIT, quitHandler)
    signal.signal(signal.SIGHUP, hupHandler)
    signal.signal(signal.SIGTERM, termHandler)
    signal.signal(signal.SIGINT, termHandler)

    sys.stdin.close()
    sys.stdout.close()
    sys.stderr.close()

openlog(exename, logopts, rsipdata.facility)
if rsipdata.facility == LOG_USER:
    logmask = LOG_UPTO(LOG_DEBUG)
else:
    logmask = LOG_UPTO(LOG_INFO)
setlogmask(logmask)

if PORT == 0:
    find_gateway()

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
    sock.connect((DEST, PORT))
except socket.error, (e,s):
    shutdown(3, 'Unable to connect to RSIP server: '+s)
state = 1

peername = sock.getpeername()
if rsipdata.verbose:
    syslog(LOG_DEBUG, "peer addr: %s" % peername[0])
sockname = sock.getsockname()
if rsipdata.verbose:
    syslog(LOG_DEBUG, "sock addr: %s" % sockname[0])

try:
    sock.send(REGISTER_REQUEST_STR)
    dump('send ', REGISTER_REQUEST_STR)
    data = sock.recv(128)
    dump('recv ', data)

    # check for REGISTER_RESPONSE
    if len(data) < 4 or ord(data[1]) != REGISTER_RESPONSE:
        shutdown(4, 'Unable to register with RSIP server')
    state = 2
        
    # extract parameters
    i = find_param(Client_ID, data, 4)
    rsipdata.Client_Id = data[i+3:i+7]
    i = find_param(Lease_Time, data, 4)
    rsipdata.reg_lease = time.time() + makeint(data,i+3,4)
    i = find_param(Flow_Policy, data, i)
    FlowPolicy = ( ord(data[i+3]), ord(data[i+4]) )

    RequestedMethod = rsipdata.RSIPMethod
    i = find_param(RSIP_Method, data, i)
    if i != -1:
        j = ord(data[i+3])
        if j != rsipdata.RSIPMethod:
            i = i + 3 + makeint(data,i+1, 2)
            i = find_param(RSIP_Method, data, i)
            k = ord(data[i+3])
            if i == -1 or k != rsipdata.RSIPMethod:
                rsipdata.RSIPMethod = j
    if rsipdata.RSIPMethod != RequestedMethod:
        syslog(LOG_WARNING, "Using %s instead of %s" % \
              (RSIP_Methods[rsipdata.RSIPMethod], RSIP_Methods[RequestedMethod]))

    # should check flow policy

    # check tunnel type
    i = find_param(Tunnel_Type, data, 4)
    if i != -1:
        t = ord(data[i+3])
        if t != IPIP:
            # see if IPIP is supported
            j = find_param(Tunnel_Type, data, i)
            if j != -1 and ord(data[j+3]) != IPIP:
                if ord(data[j+3]) != GRE and t != GRE:
                    shutdown(5, "Unsupported tunnel type: %d" % t)

    # create timer event for registration lease
    evtsched = sched.scheduler(time.time, rsipwait)
    rsipdata.reg_evt = evtsched.enterabs(rsipdata.reg_lease-1, 1, extend_reg, [0])

    # issue listen request
    if listen_ports:
        send_listen_request(listen_ports)
        data = sock.recv(128)
        dump('recv ', data)
        if len(data) < 4 or LISTEN_RESPONSE != ord(data[1]):
            syslog(LOG_ERR, "LISTEN_REQUEST failed")
        else:
            process_bind_response(data)

    # issue assign request
    ports = rsipdata.port_num
    base = 0
    while ports > 0:
        mtype = send_assign_request(base, min(255,ports))
        data = sock.recv(128)
        dump('recv ', data)
        if len(data) < 4 or (mtype+1) != ord(data[1]):
            shutdown(6, "Unable to get port assignments!")
        state = 3

        base, alloced = process_bind_response(data)
        if rsipdata.RSIPMethod == RSA_IP:
            ports = 0
        elif base != 0:
            if ord(data[1]) == ASSIGN_RESPONSE_RSAP_IP:
                rsipdata.port_base = base
            base = base + alloced
            ports = ports - alloced

except IOError:
    shutdown(90, 'RSIP client shutdown')

for b in Bind_Ids.keys():
    Bind_Ids[b][3] = evtsched.enter(Bind_Ids[b][2]-2, 1, extend_bind, [sock, b])

create_tunnel()

if rsipdata.port_base:
    set_port_range(rsipdata.port_base, rsipdata.port_base+rsipdata.port_num-1)
    rsipdata.port_range_set = 1
    state = 8

cause = 0
try:
    evtsched.run()
except IOError:
    syslog(LOG_ERR, "IOError")
    cause = 0
except socket.error:
    syslog(LOG_ERR, "socket.error")
    cause = 91
except os.error, (errno, msg):
    syslog(LOG_ERR, "Error %d - %s" % (errno, msg))
    cause = 92
except:
    syslog(LOG_ERR, "unknown error")
    cause = 93

shutdown(cause, 'RSIP client shutdown')

